## ESM2 encoding

ESM2 (Evolutionary Scale Model) is a **pre-trained model** by META, which captures evolutionary and structural context. The ESM2 transformer model used was `esm2_t33_650M_UR50D` 

Reference: Lin Z, Akin H, Rao R, Hie B, Zhu Z, Lu W, et al. Evolutionary-scale prediction of atomic level protein structure with a language model. bioRxiv (Cold Spring Harbor Laboratory). 2022 Jul 21; Available from:https://doi.org/10.1101/2022.07.20.500902

Documentation can be found at: https://huggingface.co/docs/transformers/en/model_doc/esm and https://github.com/facebookresearch/esm/tree/main


In [None]:
import pandas as pd
import numpy as np
import torch

# Import model using PyTorch
model, alphabet = torch.hub.load("facebookresearch/esm:main", "esm2_t33_650M_UR50D") #choose model

In [None]:
'''
# might not work as "esm" module isn't available 
import esm 

# Load the ESM2 model ()
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D() #choose model
batch_converter = alphabet.get_batch_converter()
model.eval()
'''

In [None]:
# Load dataset (with 'sequence' column)
protein_sequences_file = 'Example_Data.csv'  # CSV file path
df = pd.read_csv(protein_sequences_file)

# Prepare a list of tuples with sequence IDs and sequences
sequences = [(f"seq_{i}", seq) for i, seq in enumerate(df['sequence'])]
attention_maps = {}

In [None]:
# Generate embeddings for each sequence

embeddings = []
for i in range(0, len(sequences), 10):  # Batch processing with a batch size of 32
    batch_data = sequences[i:i+10]
    batch_labels, batch_strs, batch_tokens = batch_converter(batch_data)
    
    with torch.no_grad():
        results = model(batch_tokens, repr_layers=[33], return_contacts=True)
    
    # Extract the mean of the embeddings across the sequence length
    for label, embedding in zip(batch_labels, results["representations"][33]):
        mean_embedding = embedding.mean(0).numpy()
        embeddings.append(mean_embedding)
    
    for (label, seq), tokens_len, attention_contacts in zip(batch_data, batch_tokens.ne(1).sum(1), results["contacts"]):
        attention_maps[label] = (seq, attention_contacts[:tokens_len, :tokens_len].cpu().numpy())

In [None]:
# Create a new dataframe to store embeddings
embedding_df = pd.DataFrame(embeddings)
#embedding_df['Sequence'] = df['sequence'] # to save sequences in the esm2 encoded dataset
#embedding_df['Fitness'] = df['fitness'] # to save fitness in the esm2 encoded dataset

# Save the esm2 embeddings as a csv file
embedding_df.to_csv('esm2_encoded.csv', index=False)