## ESM2

Refer: https://doi.org/10.1101/2022.07.20.500902 

Code: https://huggingface.co/docs/transformers/en/model_doc/esm and https://github.com/facebookresearch/esm/tree/main


In [None]:
import torch
model, alphabet = torch.hub.load("facebookresearch/esm:main", "esm2_t33_650M_UR50D")

In [None]:
import torch
import esm

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

In [None]:


# Load the ESM-2 model
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.eval()

# Load your protein sequence dataset
dataset_path = 'ExampleData.csv'  # Update with your CSV file path
sequences_df = pd.read_csv(dataset_path)

# Prepare a list of tuples with sequence IDs and sequences
sequences = [(f"seq_{i}", seq) for i, seq in enumerate(sequences_df['Sequence'])]
attention_maps = {}

# Generate embeddings for each sequence
embeddings = []
for i in range(0, len(sequences), 10):  # Batch processing with a batch size of 32
    batch_data = sequences[i:i+10]
    batch_labels, batch_strs, batch_tokens = batch_converter(batch_data)
    
    with torch.no_grad():
        results = model(batch_tokens, repr_layers=[33], return_contacts=True)
    
    # Extract the mean of the embeddings across the sequence length
    for label, embedding in zip(batch_labels, results["representations"][33]):
        mean_embedding = embedding.mean(0).numpy()
        embeddings.append(mean_embedding)
    
    for (label, seq), tokens_len, attention_contacts in zip(batch_data, batch_tokens.ne(1).sum(1), results["contacts"]):
        attention_maps[label] = (seq, attention_contacts[:tokens_len, :tokens_len].cpu().numpy())
        

# Create a DataFrame to store embeddings with sequence IDs
embedding_df = pd.DataFrame(embeddings)
embedding_df['Sequence'] = sequences_df['Sequence']

# Save the embeddings to a CSV file
output_file = 'esm_embeddings.csv'
embedding_df.to_csv(output_file, index=False)

print(f"ESM embeddings saved successfully to {output_file}.")


In [None]:
for label, (sequence, attention_contacts) in attention_maps.items():
    plt.figure(figsize=(6, 6))
    plt.matshow(attention_contacts, cmap='viridis', fignum=0)
    plt.title(f"Attention Contacts: {label}\n{sequence[:30]}...")
    plt.colorbar(label='Contact Probability')
    plt.xlabel('Residue Position')
    plt.ylabel('Residue Position')
    plt.show()
    

In [None]:
out=pd.read_csv("esm_embeddings.csv")

In [None]:
out