In [1]:
def get_mean_embeds(alleles, weights=None):
    selected_embeds = retriever.hla_embeds[
        [retriever.dataset.allele_idxes_dict[allele][0] for allele in alleles]
    ].copy()
    if weights is None:
        return np.mean(selected_embeds, axis=0)
    else:
        weights = (np.array(weights, np.float32)/np.sum(weights))[None,:]
        return weights@selected_embeds

In [2]:
import pandas as pd

cluster_df = pd.read_table("data/fig3a_cluster_info.tsv")
cluster_df

Unnamed: 0,cluster_id,alleles_num,alleles
0,1,13,A30_01;A03_01;A11_01;A03_02;A11_02;A74_01;A31_...
1,2,7,B58_02;B57_01;B57_03;B58_01;B15_13;A32_01;B15_17
2,3,1,B15_03
3,4,15,A01_01;A36_01;A30_02;A80_01;A29_02;B15_01;B15_...
4,5,16,A68_02;A69_01;A02_07;A02_17;A02_19;A02_06;A02_...
5,6,4,B52_01;C15_02;B13_01;B13_02
6,7,11,B46_01;C12_04;C12_02;C12_03;C02_02;C16_02;C15_...
7,8,11,C08_01;C17_01;C03_04;C03_03;C04_03;C05_01;C08_...
8,9,2,B73_01;B39_06
9,10,12,B27_05;B27_03;B27_09;B27_04;B39_05;B38_01;B38_...


In [3]:
import pickle

with open(f"embeds/hla_v0819_embeds.pkl", "rb") as f:
    data_dict = pickle.load(f)
    hla_df = data_dict["protein_df"]
    hla_embeds = data_dict["embeds"]

hla_df

Unnamed: 0,sequence,allele,allele_detail
0,AHSMRYFYTAVSRPGRGEPHFIAVGYVDDTQFVRFDSDAASPRGEP...,C03_159,C*03:159
1,ALALTETWAGSHSMRYFYTAMSRPGRGEPRFIAVGYVDDTQFVRFD...,B15_128,B*15:128
2,ALALTETWAGSHSMRYFYTSVSRPGRGEPRFISVGYVDDTQFVRFD...,B14_12,B*14:12
3,APRTLLLLLSGALALTQTWAGSHSMRYFYTSVSRPGRGEPRFIAVG...,A03_12,A*03:12
4,APRTLVLLLSGALALTQTWAGSHSMRYFSTSVSRPGRGEPRFIAVG...,A24_79,A*24:79
...,...,...,...
16449,TLLLLLSGALALTQTWAGSHSMRYFSTSVSRPGRGEPRFIAVGYVD...,A24_26,A*24:26
16450,TLVLLLSGALALTQTWAGSHSMRYFFTSVSRPGRGEPRFIAVGYVD...,A02_01,A*02:01:03
16451,VTAPRTLLLLLWGAVALTETWAGSHSMRYFYTAMSRPGRGEPRFIT...,B44_43,B*44:43:01
16452,VTAPRTVLLLLSGALALTETWAGSHSMRYFYTAMSRPGRGEPRFIA...,B15_57,B*15:57


In [4]:
test_align_fasta_dict = {}
allele_list = []
candidate_cluster = [0]
for i in candidate_cluster:
    cluster_alleles = cluster_df.loc[i, "alleles"].split(";")
    for j in range(len(cluster_alleles)):
        if "A" not in cluster_alleles[j]:
            continue
        test_align_fasta_dict[cluster_alleles[j]] = ""
        tmp_df = hla_df[hla_df["allele"] == cluster_alleles[j]].copy()
        longest_idx = -1
        longest_len = 0
        for k in range(tmp_df.shape[0]):
            if len(tmp_df.iloc[k, 0]) > longest_len:
                longest_idx = k
                longest_len = len(tmp_df.iloc[k, 0])
        test_align_fasta_dict[cluster_alleles[j]] = tmp_df.iloc[longest_idx, 0]
        allele_list.append(cluster_alleles[j])
allele_list

['A30_01',
 'A03_01',
 'A11_01',
 'A03_02',
 'A11_02',
 'A74_01',
 'A31_01',
 'A34_02',
 'A33_01',
 'A33_03',
 'A68_01',
 'A34_01',
 'A66_01']

In [5]:
from fennet.mhc.mhc_binding_model import *
from fennet.mhc.mhc_binding_retriever import MHCBindingRetriever

hla_model_pt = "HLA_model_v0819.pt"
pept_model_pt = "pept_model_v0819.pt"

pept_encoder = ModelSeqEncoder().to("cpu")
hla_encoder = ModelHlaEncoder().to("cpu")
hla_encoder.load_state_dict(torch.load(f"../../models/{hla_model_pt}", map_location="cpu"))
pept_encoder.load_state_dict(torch.load(f"../../models/{pept_model_pt}", map_location="cpu"))

fasta_list = ["uniprotkb_UP000005640_AND_reviewed_true_2024_03_01.fasta"]

retriever = MHCBindingRetriever(
    hla_encoder, pept_encoder, hla_df, hla_embeds,
    fasta_list
)


In [None]:
allele_list = cluster_df.loc[0, "alleles"].split(";")
mean_embeds = get_mean_embeds(allele_list)
dist = retriever.get_embedding_distances(hla_embeds, [mean_embeds])
dist = dist.squeeze()

In [7]:
top_ten_indices = np.argpartition(dist, 10)[:10]
top_ten_indices = top_ten_indices.tolist()
return_df = hla_df.iloc[top_ten_indices,:].copy()
return_df

Unnamed: 0,sequence,allele,allele_detail,hla_id
12817,SHSMRYFTTSVSRPGRGEPRFIAVGYVDDTQFVRFYSDAASQRMEP...,A31_217,A*31:217,12817
2400,MAVMAPRTLVLLLSGALALTQTWAGSHSMRYFYTSVSRPGRGEPRF...,A68_175,A*68:175:01,2400
1234,MAVMAPRTLLLLLSGALALTQTWAGSHSMRYFYTSVSRPGRGEPRF...,A11_310N,A*11:310N,1234
3,APRTLLLLLSGALALTQTWAGSHSMRYFYTSVSRPGRGEPRFIAVG...,A03_12,A*03:12,3
15822,SHSMRYFYTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRMEP...,A66_06,A*66:06,15822
15339,SHSMRYFYTSASRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRMEP...,A11_260,A*11:260,15339
874,MAVMAPRTLLLLLSGALALTQTWAGSHSMRYFFTSVSRPGRGEPRF...,A03_357N,A*03:357N,874
10799,SHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRMEP...,A03_32,A*03:32,10799
62,MAPRTLLLLLSGALALTQTWAGSHSMRYFFTSVSRPGRGEPRFIAV...,A03_350,A*03:350,62
9357,SHSMGYFYTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRMEP...,A11_200,A*11:200,9357


In [8]:
for i in range(return_df.shape[0]):
    tmp_allele = return_df.iloc[i, 1]
    if tmp_allele in test_align_fasta_dict:
        continue
    else:
        test_align_fasta_dict[tmp_allele] = ""
        tmp_df = return_df[return_df["allele"] == tmp_allele].copy()
        longest_idx = -1
        longest_len = 0
        for k in range(tmp_df.shape[0]):
            if len(tmp_df.iloc[k, 0]) > longest_len:
                longest_idx = k
                longest_len = len(tmp_df.iloc[k, 0])
        test_align_fasta_dict[tmp_allele] = tmp_df.iloc[longest_idx, 0]

In [9]:
import random
random.seed(1314)

for i in range(cluster_df.shape[0]):
    if i in candidate_cluster:
        continue
    cluster_alleles = cluster_df.loc[i, "alleles"].split(";")
    random_choose_flag = False
    for j in range(len(cluster_alleles)):
        if "A" in cluster_alleles[j]:
            random_choose_flag = True
            break
    if random_choose_flag:
        while True:
            selected_alleles = random.choice(cluster_alleles)
            if "A" in selected_alleles:
                break
    else:
        continue

    test_align_fasta_dict[selected_alleles] = ""
    tmp_df = hla_df[hla_df["allele"] == selected_alleles].copy()
    longest_idx = -1
    longest_len = 0
    for k in range(tmp_df.shape[0]):
        if len(tmp_df.iloc[k, 0]) > longest_len:
            longest_idx = k
            longest_len = len(tmp_df.iloc[k, 0])
    test_align_fasta_dict[selected_alleles] = tmp_df.iloc[longest_idx, 0]

test_align_fasta_dict

{'A30_01': 'MAVMAPRTLLLLLSGALALTQTWAGSHSMRYFSTSVSRPGSGEPRFIAVGYVDDTQFVRFDSDAASQRMEPRAPWIEQERPEYWDQETRNVKAQSQTDRVDLGTLRGYYNQSEAGSHTIQIMYGCDVGSDGRFLRGYEQHAYDGKDYIALNEDLRSWTAADMAAQITQRKWEAARWAEQLRAYLEGTCVEWLRRYLENGKETLQRTDPPKTHMTHHPISDHEATLRCWALGFYPAEITLTWQRDGEDQTQDTELVETRPAGDGTFQKWAAVVVPSGEEQRYTCHVQHEGLPKPLTLRWELSSQPTIPIVGIIAGLVLLGAVITGAVVAAVMWRRKSSDRKGGSYTQAASSDSAQGSDVSLTACKV',
 'A03_01': 'MAVMAPRTLLLLLSGALALTQTWAGSHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRMEPRAPWIEQEGPEYWDQETRNVKAQSQTDRVDLGTLRGYYNQSEAGSHTIQIMYGCDVGSDGRFLRGYRQDAYDGKDYIALNEDLRSWTAADMAAQITKRKWEAAHEAEQLRAYLDGTCVEWLRRYLENGKETLQRTDPPKTHMTHHPISDHEATLRCWALGFYPAEITLTWQRDGEDQTQDTELVETRPAGDGTFQKWAAVVVPSGEEQRYTCHVQHEGLPKPLTLRWELSSQPTIPIVGIIAGLVLLGAVITGAVVAAVMWRRKSSDRKGGSYTQAASSDSAQGSDVSLTACKV',
 'A11_01': 'MAVMAPRTLLLLLSGALALTQTWAGSHSMRYFYTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRMEPRAPWIEQEGPEYWDQETRNVKAQSQTDRVDLGTLRGYYNQSEDGSHTIQIMYGCDVGPDGRFLRGYRQDAYDGKDYIALNEDLRSWTAADMAAQITKRKWEAAHAAEQQRAYLEGRCVEWLRRYLENGKETLQRTDPPKTHMTHHPISDHEATLRCW

In [10]:
f = open("data/cluster1_HLA_A_combined_multi_seq.fasta", "w")
for key in test_align_fasta_dict.keys():
    f.write(f">{key}\n")
    f.write(f"{test_align_fasta_dict[key]}\n")
f.close()