In [1]:
from fennet.mhc.mhc_binding_model import ModelHlaEncoder, ModelSeqEncoder

model_version = "v0613"

In [2]:
hla_encoder = ModelHlaEncoder().to("cuda")
pept_encoder = ModelSeqEncoder().to("cuda")

In [3]:
import torch

hla_encoder.load_state_dict(
    torch.load(f"model/HLA_model_{model_version}.pt", map_location="cuda")
)

pept_encoder.load_state_dict(
    torch.load(f"model/pept_model_{model_version}.pt", map_location="cuda")
)

<All keys matched successfully>

In [4]:
from fennet.mhc.mhc_utils import load_esm_pkl

hla_df, hla_esm_list = load_esm_pkl()
hla_df

Unnamed: 0,sequence,allele,allele_detail
0,AHSMRYFYTAVSRPGRGEPHFIAVGYVDDTQFVRFDSDAASPRGEP...,C03_159,C*03:159
1,ALALTETWAGSHSMRYFYTAMSRPGRGEPRFIAVGYVDDTQFVRFD...,B15_128,B*15:128
2,ALALTETWAGSHSMRYFYTSVSRPGRGEPRFISVGYVDDTQFVRFD...,B14_12,B*14:12
3,APRTLLLLLSGALALTQTWAGSHSMRYFYTSVSRPGRGEPRFIAVG...,A03_12,A*03:12
4,APRTLVLLLSGALALTQTWAGSHSMRYFSTSVSRPGRGEPRFIAVG...,A24_79,A*24:79
...,...,...,...
16449,TLLLLLSGALALTQTWAGSHSMRYFSTSVSRPGRGEPRFIAVGYVD...,A24_26,A*24:26
16450,TLVLLLSGALALTQTWAGSHSMRYFFTSVSRPGRGEPRFIAVGYVD...,A02_01,A*02:01:03
16451,VTAPRTLLLLLWGAVALTETWAGSHSMRYFYTAMSRPGRGEPRFIT...,B44_43,B*44:43:01
16452,VTAPRTVLLLLSGALALTETWAGSHSMRYFYTAMSRPGRGEPRFIA...,B15_57,B*15:57


In [5]:
from fennet.mhc.mhc_binding_model import embed_hla_esm_list

hla_embeds = embed_hla_esm_list(hla_encoder, hla_esm_list)

In [6]:
from fennet.mhc.mhc_binding_retriever import MHCBindingRetriever

fasta_list = ["uniprotkb_UP000005640_AND_reviewed_true_2024_03_01.fasta"]
retriever = MHCBindingRetriever(
    hla_encoder, pept_encoder, hla_df, hla_embeds, fasta_list
)

In [10]:
import numpy as np
from fennet.mhc.mhc_binding_model import embed_peptides
from fennet.mhc.plotting_utils import fit_hla_umap_reducer

np.random.seed(1337)
rnd_df = retriever.dataset.digest.get_random_pept_df(10000)
rnd_embeds = embed_peptides(pept_encoder, rnd_df.sequence.values.astype("U"))
umap_reducer = fit_hla_umap_reducer(np.concatenate((hla_embeds, rnd_embeds), axis=0))

In [11]:
import pickle

with open(f"embeds/HLA_model_{model_version}1.embed", "wb") as f:
    pickle.dump(
        {
            "protein_df": hla_df,
            "embeds": hla_embeds,
            "umap_reducer": umap_reducer,
        },
        f,
        protocol=pickle.HIGHEST_PROTOCOL,
    )