# LLM features
Here I'll be experimenting with obtaining LLM features. First, I want to create a few simplified functions that will only take the sequence as an input, and output an array with the corresponding embeddings at the AA level. I'll try a few LLM like ESM, protBERT and protT5. I will start with ESM

## ESM
First we gotta install ESM through `pip`: `pip install esm` Note that you can run this on the worker which will install a cached version.

In [3]:
!pip install fair-esm

Looking in indexes: https://soft-proxy.scicore.unibas.ch/repository/python-all/simple
Collecting fair-esm
  Using cached https://soft-proxy.scicore.unibas.ch/repository/python-all/packages/fair-esm/2.0.0/fair_esm-2.0.0-py3-none-any.whl (93 kB)
Installing collected packages: fair-esm
Successfully installed fair-esm-2.0.0


Next, you need to load the model. You need to have the model downloaded though, so run this on `login-transfer` first so you can download it.

In [32]:
import torch
import esm

# Load ESM-2 model
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.eval()  # disables dropout for deterministic results

ESM2(
  (embed_tokens): Embedding(33, 1280, padding_idx=1)
  (layers): ModuleList(
    (0-32): 33 x TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (rot_emb): RotaryEmbedding()
      )
      (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=1280, out_features=5120, bias=True)
      (fc2): Linear(in_features=5120, out_features=1280, bias=True)
      (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    )
  )
  (contact_head): ContactPredictionHead(
    (regression): Linear(in_features=660, out_features=1, bias=True)
    (activation): Sigmoid()
  )
  (emb_layer_norm_after): LayerNorm((1280,), eps=1

The data is stored as a list of tuples with the structure `[(seq_name, sequence)]`. The data then needs to be divided into batches using `batch_converter()`, and then we calculate the length.

In [33]:
# Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4)
data = [
    ("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"),
    ("protein2", "KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
    ("protein2 with mask","KALTARQQEVFDLIRD<mask>ISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
    ("protein3",  "K A <mask> I S Q"),
]
batch_labels, batch_strs, batch_tokens = batch_converter(data)
batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)

Finally, the following obtains the sequences

In [34]:
# Extract per-residue representations (on CPU)
with torch.no_grad():
    results = model(batch_tokens, repr_layers=[33], return_contacts=True)
token_representations = results["representations"][33]

Which gives us a tensor of `proteins x max_len x embedding size`.

In [8]:
token_representations.shape

torch.Size([4, 73, 1280])

Pretty simple, right? lets turn this into a function.

In [29]:
from typing import Union
import numpy as np

def get_esm_embedding(sequence: Union[str, list, tuple]) -> torch.Tensor:
    
    # Load the model
    esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
    batch_converter = alphabet.get_batch_converter()
    esm_model.eval()  # disables dropout for deterministic results
    
    
    # Convert input to be [(name, sequence), ...]
    # Can be a single string, list of strings, tuple, or list of tuples.
    if type(sequence) == str:
        data = [('protein0',sequence)]
    elif type(sequence) == list:
        data = [(f'protein{x}',sequence[x]) for x in range(len(sequence))]
    elif type(sequence) == tuple and len(sequence) == 2:
        data = [sequence]
    elif type(sequence[0]) == tuple and len(sequence[0]) == 2:
        data = sequence
        
    # Obtain the batches
    batch_labels, batch_strs, batch_tokens = batch_converter(data)
    batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)
    
    # Extract per-residue representations (on CPU)
    with torch.no_grad():
        results = model(batch_tokens, repr_layers=[33], return_contacts=True)
    token_representations = results["representations"][33]
    
    return token_representations
    
get_esm_embedding(('protein0', 'MMMMMMMMM'))

tensor([[[ 8.4402e-02,  3.1896e-02, -2.2837e-04,  ..., -3.0202e-01,
           1.6993e-01, -9.5368e-02],
         [ 5.3387e-02, -1.4939e-01, -6.0084e-03,  ..., -1.4319e-01,
           5.8913e-02,  4.5721e-02],
         [ 5.7838e-02, -1.7693e-01,  2.1494e-01,  ..., -1.3017e-01,
           1.6030e-01, -2.4029e-02],
         ...,
         [ 1.1289e-01, -1.0321e-01,  2.4431e-01,  ..., -7.9042e-02,
           7.4088e-02,  4.6810e-02],
         [ 1.6301e-01, -1.2352e-01,  2.2617e-01,  ..., -1.2443e-01,
           5.8596e-02,  1.2715e-02],
         [ 1.6387e-01, -9.8342e-02,  1.2440e-01,  ..., -2.0597e-01,
           1.6841e-01, -1.1368e-01]]])

Perfect, that works! Now lets move on to protBERT.

## protBERT