# LLM features
Here I'll be experimenting with obtaining LLM features. First, I want to create a few simplified functions that will only take the sequence as an input, and output an array with the corresponding embeddings at the AA level. I'll try a few LLM like ESM, protBERT and protT5. I will start with ESM

## ESM
First we gotta install ESM through `pip`: `pip install esm` Note that you can run this on the worker which will install a cached version.

In [1]:
!pip install fair-esm

Looking in indexes: https://soft-proxy.scicore.unibas.ch/repository/python-all/simple


Next, you need to load the model. You need to have the model downloaded though, so run this on `login-transfer` first so you can download it.

In [2]:
import torch
import esm

# Load ESM-2 model
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.eval()  # disables dropout for deterministic results

ESM2(
  (embed_tokens): Embedding(33, 1280, padding_idx=1)
  (layers): ModuleList(
    (0-32): 33 x TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (rot_emb): RotaryEmbedding()
      )
      (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=1280, out_features=5120, bias=True)
      (fc2): Linear(in_features=5120, out_features=1280, bias=True)
      (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    )
  )
  (contact_head): ContactPredictionHead(
    (regression): Linear(in_features=660, out_features=1, bias=True)
    (activation): Sigmoid()
  )
  (emb_layer_norm_after): LayerNorm((1280,), eps=1

The data is stored as a list of tuples with the structure `[(seq_name, sequence)]`. The data then needs to be divided into batches using `batch_converter()`, and then we calculate the length.

In [3]:
# Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4)
data = [
    ("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"),
    ("protein2", "KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
    ("protein2 with mask","KALTARQQEVFDLIRD<mask>ISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
    ("protein3",  "K A <mask> I S Q"),
]
batch_labels, batch_strs, batch_tokens = batch_converter(data)
batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)

Finally, the following obtains the sequences

In [4]:
# Extract per-residue representations (on CPU)
with torch.no_grad():
    results = model(batch_tokens, repr_layers=[33], return_contacts=True)
token_representations = results["representations"][33]

Which gives us a tensor of `proteins x max_len x embedding size`.

In [5]:
token_representations.shape

torch.Size([4, 73, 1280])

Pretty simple, right? lets turn this into a function.

In [6]:
from typing import Union
import numpy as np

def get_esm_embedding(sequence: Union[str, list, tuple]) -> torch.Tensor:
    global esm_model
    global alphabet
    global batch_converter
    
    try: 
        if esm_model:
            print('ESM is loaded')
    except:
        # Load the model if its not loaded
        esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
        batch_converter = alphabet.get_batch_converter()
        esm_model.eval()  # disables dropout for deterministic results

    
    # Convert input to be [(name, sequence), ...]
    # Can be a single string, list of strings, tuple, or list of tuples.
    if type(sequence) == str:
        data = [('protein0',sequence)]
    elif type(sequence) == list:
        data = [(f'protein{x}',sequence[x]) for x in range(len(sequence))]
    elif type(sequence) == tuple and len(sequence) == 2:
        data = [sequence]
    elif type(sequence[0]) == tuple and len(sequence[0]) == 2:
        data = sequence
        
    # Obtain the batches
    batch_labels, batch_strs, batch_tokens = batch_converter(data)
    batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)
    
    # Extract per-residue representations (on CPU)
    with torch.no_grad():
        results = esm_model(batch_tokens, repr_layers=[33], return_contacts=True)
    token_representations = results["representations"][33]
    
    return token_representations
    
get_esm_embedding(('protein0', 'MMMMMMMMM'))

tensor([[[ 8.4402e-02,  3.1896e-02, -2.2837e-04,  ..., -3.0202e-01,
           1.6993e-01, -9.5368e-02],
         [ 5.3387e-02, -1.4939e-01, -6.0084e-03,  ..., -1.4319e-01,
           5.8913e-02,  4.5721e-02],
         [ 5.7838e-02, -1.7693e-01,  2.1494e-01,  ..., -1.3017e-01,
           1.6030e-01, -2.4029e-02],
         ...,
         [ 1.1289e-01, -1.0321e-01,  2.4431e-01,  ..., -7.9042e-02,
           7.4088e-02,  4.6810e-02],
         [ 1.6301e-01, -1.2352e-01,  2.2617e-01,  ..., -1.2443e-01,
           5.8596e-02,  1.2715e-02],
         [ 1.6387e-01, -9.8342e-02,  1.2440e-01,  ..., -2.0597e-01,
           1.6841e-01, -1.1368e-01]]])

Perfect, that works! Now lets move on to protBERT.

## protBERT

Just like with esm, we first need to install the library. This time we can do it with `mamba`.

In [7]:
!mamba install -c conda-forge transformers -y


                  __    __    __    __
                 /  \  /  \  /  \  /  \
                /    \/    \/    \/    \
███████████████/  /██/  /██/  /██/  /████████████████████████
              /  / \   / \   / \   / \  \____
             /  /   \_/   \_/   \_/   \    o \__,
            / _/                       \_____/  `
            |/
        ███╗   ███╗ █████╗ ███╗   ███╗██████╗  █████╗
        ████╗ ████║██╔══██╗████╗ ████║██╔══██╗██╔══██╗
        ██╔████╔██║███████║██╔████╔██║██████╔╝███████║
        ██║╚██╔╝██║██╔══██║██║╚██╔╝██║██╔══██╗██╔══██║
        ██║ ╚═╝ ██║██║  ██║██║ ╚═╝ ██║██████╔╝██║  ██║
        ╚═╝     ╚═╝╚═╝  ╚═╝╚═╝     ╚═╝╚═════╝ ╚═╝  ╚═╝

        mamba (1.4.1) supported by @QuantStack

        GitHub:  https://github.com/mamba-org/mamba
        Twitter: https://twitter.com/QuantStack

█████████████████████████████████████████████████████████████


Looking for: ['transformers']

[?25l[2K[0G[+] 0.0s
conda-forge/linux-64 [90m━━━━━╸[0m[33m━━━━━━━━━━━━━━━╸[

The Huggingface entry gives the following code to obtain the internal representations: (note you need to download the model first on `login-transfer`.

In [8]:
from transformers import BertModel, BertTokenizer
import re
tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False )
model = BertModel.from_pretrained("Rostlab/prot_bert")
sequence_Example = "A E T C Z A O"
sequence_Example = re.sub(r"[UZOB]", "X", sequence_Example)
encoded_input = tokenizer(sequence_Example, return_tensors='pt')
output = model(**encoded_input)


  from .autonotebook import tqdm as notebook_tqdm
Some weights of the model checkpoint at Rostlab/prot_bert were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


This is pretty simple. It first loads the model, replaces U, Z, O and B with X and then tokenizes the sequences. Then it runs the sequence and returns the encoded input.

The ouput should be `protein x max_len x embedding_size`

In [9]:
print(output.last_hidden_state)
print(output.last_hidden_state.shape)

tensor([[[ 0.0454,  0.1140, -0.0117,  ..., -0.0875, -0.1143,  0.0204],
         [ 0.0923,  0.1391, -0.0524,  ..., -0.1395, -0.0428,  0.0743],
         [ 0.1151,  0.0200, -0.0863,  ..., -0.0095, -0.1873,  0.1317],
         ...,
         [ 0.1079,  0.0977, -0.0583,  ..., -0.1277, -0.0649,  0.1289],
         [ 0.0546,  0.0364, -0.0782,  ..., -0.0302, -0.0602,  0.0890],
         [ 0.0515,  0.0571, -0.0693,  ..., -0.0394, -0.0663,  0.0977]]],
       grad_fn=<NativeLayerNormBackward0>)
torch.Size([1, 9, 1024])


In [10]:
def get_protbert_embeddings(sequence) -> torch.Tensor: 
    
    try:
        if tokenizer is not None or protbert_model is not None:
            print('protBERT loaded')
    except:
        tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False )
        protbert_model = BertModel.from_pretrained("Rostlab/prot_bert")
    
    sequence = re.sub(r"[UZOB]", "X", sequence_Example)
    encoded_input = tokenizer(sequence, return_tensors='pt')
    output = protbert_model(**encoded_input)
    
    return output.last_hidden_state

get_protbert_embeddings('"A E T C Z A A E T C Z A A E T C Z A A E T C Z A')

Some weights of the model checkpoint at Rostlab/prot_bert were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


tensor([[[ 0.0454,  0.1140, -0.0117,  ..., -0.0875, -0.1143,  0.0204],
         [ 0.0923,  0.1391, -0.0524,  ..., -0.1395, -0.0428,  0.0743],
         [ 0.1151,  0.0200, -0.0863,  ..., -0.0095, -0.1873,  0.1317],
         ...,
         [ 0.1079,  0.0977, -0.0583,  ..., -0.1277, -0.0649,  0.1289],
         [ 0.0546,  0.0364, -0.0782,  ..., -0.0302, -0.0602,  0.0890],
         [ 0.0515,  0.0571, -0.0693,  ..., -0.0394, -0.0663,  0.0977]]],
       grad_fn=<NativeLayerNormBackward0>)

This is dumb, I didnt realize protBERT is only for protein level embeddings. Nevertheless, I'll leave it here.

Now lets move on to protT5.

## ProtT5
do this later

## Built-in functions from graphein
So it turns out that graphein has built-in functions for getting ESM embeddings at the residue level, and will immediately embed this into the graph as well. Lets see how it works.
First we have to construct a graph from some protein.

In [11]:
from graphein.protein.config import ProteinGraphConfig
from graphein.protein.graphs import construct_graph

config = ProteinGraphConfig()
g = construct_graph(config=config, path="../structures/1a9m.pdb")
g

<networkx.classes.graph.Graph at 0x2aaecf8415a0>

Now to obtain the ESM embeddings I use the `graphein.protein.features.sequence.embeddings.compute_esm_embedding` function. I probably need to download it again....

In [13]:
from graphein.protein.features.sequence.embeddings import esm_residue_embedding

g_m = esm_residue_embedding(g, model_name = 'esm1b_t33_650M_UR50S')

g_m.shape

Using cache found in /scicore/home/schwede/goetze0000/.cache/torch/hub/facebookresearch_esm_main


ImportError: cannot import name 'esmfold_structure_module_only_8M' from 'esm.pretrained' (/scicore/home/schwede/goetze0000/mambaforge/envs/hackathon/lib/python3.10/site-packages/esm/pretrained.py)