## Basic ESM exploration
__Feb 2024__
__Keith Cheveralls__

This starts from the example code in the ESM readme. 

In [2]:
import torch
import esm
import numpy as np

In [12]:
esm.__file__

'/home/keith/projects/esm/esm/__init__.py'

In [8]:
import orfipy

In [9]:
orfipy

<module 'orfipy' from '/home/keith/miniforge3/envs/esm-py311-env/lib/python3.11/site-packages/orfipy/__init__.py'>

In [3]:
# for some reason these import statements must be called explicitly, and in this order
import esm.data
import esm.pretrained

In [4]:
import matplotlib.pyplot as plt

In [5]:
torch.cuda.is_available()

True

In [6]:
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.eval()

data = [
    ("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"),
    ("protein2", "KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
]

In [7]:
batch_labels, batch_strs, batch_tokens = batch_converter(data)

# this is the sequence length 
# (in tokens, not amino acids; they appear to be only approx equal)
batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)

In [8]:
# Extract per-residue representations
with torch.no_grad():
    results = model(batch_tokens, repr_layers=[33], return_contacts=True)

token_representations = results["representations"][33]

In [19]:
token_representations.shape

torch.Size([4, 73, 1280])

In [None]:
# Generate per-sequence representations via averaging
# NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1.
sequence_representations = []
for i, tokens_len in enumerate(batch_lens):
    sequence_representations.append(token_representations[i, 1 : tokens_len - 1].mean(0))

In [25]:
np.array(sequence_representations).shape

(4, 1280)

In [None]:
# Look at the unsupervised self-attention map contact predictions
for (_, seq), tokens_len, attention_contacts in zip(data, batch_lens, results["contacts"]):
    plt.matshow(attention_contacts[: tokens_len, : tokens_len])
    plt.title(seq)
    plt.show()

In [10]:
x = np.load('../tmp-embed.npy')
x.shape

(913, 320)

In [11]:
x

array([[-0.03010962, -0.3098611 ,  0.21479861, ...,  0.18719758,
         0.13682021,  0.2104919 ],
       [ 0.08378489, -0.23092562,  0.26359186, ...,  0.05840705,
         0.17555174,  0.12873918],
       [-0.05495199, -0.20750748,  0.08342266, ...,  0.15368064,
         0.19967808,  0.07173502],
       ...,
       [-0.01662284, -0.10068058,  0.03329022, ...,  0.2388579 ,
        -0.0273346 , -0.0422569 ],
       [-0.00831649, -0.06697933,  0.08197283, ...,  0.04316371,
         0.06036634,  0.03868119],
       [ 0.04358827, -0.02912711,  0.08511164, ...,  0.00952161,
         0.0476016 ,  0.03955315]], dtype=float32)