In [1]:
import torch
import esm
import pandas as pd

# Load ESM-2 model
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.eval()  # disables dropout for deterministic results

data_test_file = "./dataset/protein/test.csv"
data_test_ori = pd.read_csv(data_test_file, sep="\t")
data_test = data_test_ori[["dna", "Protein_Sequence"]]
data_test

  from .autonotebook import tqdm as notebook_tqdm


Unnamed: 0,dna,Protein_Sequence
0,1_1534010_CAGA_C,MSWLFGINKGPKGEGAGPPPPLPPAQPGAEGGGDRGLGDRPAPKDK...
1,1_8361034_GGGGATGTGGCGA_G,MFKPVKEEDDGLSGKHSMRTRRSRGSMSTLRSGRKKQPASPDGRTS...
2,1_12001402_T_TGCG,MSLLFSRCNSIVTVKKNKRHMAEVNASPLKHFVTAKKKINGIFEQL...
3,1_17024026_G_GGCAGCTGGTGCT,MAAVVALSLRRRLPATTLGGACLQASRGAQTAAATAPRIKKFAIYR...
4,1_21577482_TGCTGCACGGCGTCCACGA_T,MPWSFRSSTPTWLRMSSCSWEMTYNTNAQVPDSAGTATAYLCGVKA...
...,...,...
1255,X_153935318_C_CGGGGGA,MSKGLPARQDMEKERETLQAWKERVGQELDRVVAFWMEHSHDQEHG...
1256,X_153950407_G_GCCAGCTGGGCCGGGGTGGAGC,MASAVSPANLPAVLLQPRWKRVVGWSGPVPRPRHGHRAVAIKELIV...
1257,X_154030493_CGTGGCGGCG_C,MAAAAAAAPSGGGGGGEEERLEEKSEDQDLQGLKDKPLKFKKVKKD...
1258,X_154030630_TGGGGTCCTCGGAGCTCTC_T,MVAGMLGLREEKSEDQDLQGLKDKPLKFKKVKKDKKEEKEGKHEPV...


In [2]:
# Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4)
data = [
    ("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"),
    ("protein2", "KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
    ("protein2 with mask","KALTARQQEVFDLIRD<mask>ISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
    ("protein3",  "K A <mask> I S Q"),
]
batch_labels, batch_strs, batch_tokens = batch_converter(data)
batch_tokens

tensor([[ 0, 20, 15, 11,  7, 10, 16,  9, 10,  4, 15,  8, 12,  7, 10, 12,  4,  9,
         10,  8, 15,  9, 14,  7,  8,  6,  5, 16,  4,  5,  9,  9,  4,  8,  7,  8,
         10, 16,  7, 12,  7, 16, 13, 12,  5, 19,  4, 10,  8,  4,  6, 19, 17, 12,
          7,  5, 11, 14, 10,  6, 19,  7,  4,  5,  6,  6,  2,  1,  1,  1,  1,  1,
          1],
        [ 0, 15,  5,  4, 11,  5, 10, 16, 16,  9,  7, 18, 13,  4, 12, 10, 13, 21,
         12,  8, 16, 11,  6, 20, 14, 14, 11, 10,  5,  9, 12,  5, 16, 10,  4,  6,
         18, 10,  8, 14, 17,  5,  5,  9,  9, 21,  4, 15,  5,  4,  5, 10, 15,  6,
          7, 12,  9, 12,  7,  8,  6,  5,  8, 10,  6, 12, 10,  4,  4, 16,  9,  9,
          2],
        [ 0, 15,  5,  4, 11,  5, 10, 16, 16,  9,  7, 18, 13,  4, 12, 10, 13, 32,
         12,  8, 16, 11,  6, 20, 14, 14, 11, 10,  5,  9, 12,  5, 16, 10,  4,  6,
         18, 10,  8, 14, 17,  5,  5,  9,  9, 21,  4, 15,  5,  4,  5, 10, 15,  6,
          7, 12,  9, 12,  7,  8,  6,  5,  8, 10,  6, 12, 10,  4,  4, 16,  9,  9,


In [3]:
batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)

# Extract per-residue representations (on CPU)
with torch.no_grad():
    results = model(batch_tokens, repr_layers=[33], return_contacts=True)
token_representations = results["representations"][33]

# Generate per-sequence representations via averaging
# NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1.
sequence_representations = []
for i, tokens_len in enumerate(batch_lens):
    sequence_representations.append(token_representations[i, 1 : tokens_len - 1].mean(0))

# # Look at the unsupervised self-attention map contact predictions
# import matplotlib.pyplot as plt
# for (_, seq), tokens_len, attention_contacts in zip(data, batch_lens, results["contacts"]):
#     plt.matshow(attention_contacts[: tokens_len, : tokens_len])
#     plt.title(seq)
#     plt.show()

In [4]:
sequence_representations

[tensor([ 0.0614, -0.0687,  0.0430,  ..., -0.1642, -0.0678,  0.0446]),
 tensor([ 0.0553, -0.0757,  0.0414,  ..., -0.3117, -0.0026,  0.1683]),
 tensor([ 0.0618, -0.0769,  0.0405,  ..., -0.3037, -0.0013,  0.1741]),
 tensor([ 0.0084,  0.1425,  0.0506,  ...,  0.0403, -0.1063,  0.0079])]