In [1]:
import pandas as pd
mhc_type = "hla1"
protein_df = pd.read_table(f"{mhc_type}_prot.tsv")

In [2]:
import torch
torch.set_num_threads(20)
import esm

# Load ESM-2 model
esm2_model, alphabet = esm.pretrained.esm2_t12_35M_UR50D()
esm2_model.to("cuda")
esm2_model.eval()
batch_converter = alphabet.get_batch_converter()

In [3]:
max_nAA = protein_df.sequence.str.len().max()
batch_size=100
data = [("_", "A"*max_nAA)]*batch_size
with torch.no_grad():
    batch_labels, batch_strs, batch_tokens = batch_converter(data)
    results = esm2_model(batch_tokens.to("cuda"), repr_layers=[12], return_contacts=False)
max_nAA

384

In [4]:
import tqdm
embedding_list = []
with torch.no_grad():
    for i in tqdm.tqdm(range(0, len(protein_df), batch_size)):
        sequences = protein_df.sequence.values[i:i+batch_size]
        data = list(zip(["_"]*len(sequences), sequences))
        batch_labels, batch_strs, batch_tokens = batch_converter(data)
        results = esm2_model(batch_tokens.to("cuda"), repr_layers=[12], return_contacts=False)
        embedding_list.extend(list(results["representations"][12].cpu().detach().numpy()[:,1:-1].copy()))

100%|██████████| 59/59 [00:46<00:00,  1.28it/s]


In [5]:
embedding_list[0].shape, len(embedding_list)

((372, 480), 5874)

In [6]:
import pickle
with open(f"{mhc_type}_esm_embeds.pkl", 'wb') as f:
    pickle.dump(
        {
            "protein_df":protein_df,"embedding_list":embedding_list
        }, 
        f, protocol=pickle.HIGHEST_PROTOCOL
    )