In [2]:
!pip install fair-esm==2.0.0

Collecting fair-esm==2.0.0
  Downloading fair_esm-2.0.0-py3-none-any.whl.metadata (37 kB)
Downloading fair_esm-2.0.0-py3-none-any.whl (93 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m93.1/93.1 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fair-esm
Successfully installed fair-esm-2.0.0


In [6]:
import esm
import torch
import collections
import pandas as pd

def esm_embeddings(peptide_sequence_list, model_name):
    
    model_dict = {
        'esm2_t6_8M_UR50D': (esm.pretrained.esm2_t6_8M_UR50D, 6),
        'esm2_t12_35M_UR50D': (esm.pretrained.esm2_t12_35M_UR50D, 12),
        'esm2_t30_150M_UR50D': (esm.pretrained.esm2_t30_150M_UR50D, 30),
        'esm2_t33_650M_UR50D': (esm.pretrained.esm2_t33_650M_UR50D, 33),
    }
    
    # Check if the provided model name is valid or not
    if model_name not in model_dict:
        raise ValueError(f"Invalid model name '{model_name}'. Please choose from {list(model_dict.keys())}.")
    
    model_func, num_layers = model_dict[model_name]
    model, alphabet = model_func()
    
    batch_converter = alphabet.get_batch_converter()
    model.eval()  

    # Load the peptide sequence list into the batch_converter
    batch_labels, batch_strs, batch_tokens = batch_converter(peptide_sequence_list)
    batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)

    # Extract per-residue representations (on CPU)
    with torch.no_grad():
        results = model(batch_tokens, repr_layers=[num_layers], return_contacts=True)
    
    token_representations = results["representations"][num_layers] 
    sequence_representations = []
    for i, tokens_len in enumerate(batch_lens):
        sequence_representations.append(token_representations[i, 1 : tokens_len - 1].mean(0))
    
    embeddings_results = collections.defaultdict(list)
    for i in range(len(sequence_representations)):
        # tensor can be transformed as numpy sequence_representations[i].numpy() or sequence_representations[i].tolist
        each_seq_rep = sequence_representations[i].tolist()
        for each_element in each_seq_rep:
            embeddings_results[i].append(each_element)
    
    embeddings_results = pd.DataFrame(embeddings_results).T
    return embeddings_results

In [None]:
# Example usage

# Add the dataset
dataset = pd.read_excel('MRSA-25-3-2024.xlsx', na_filter = False) 
sequence_list = dataset['seq']
peptide_sequence_list = []

# Prepare sequence_list to esm process 
for seq in sequence_list:
    format_seq = [seq,seq] # the setting is just following the input format setting in ESM model, [name,sequence]
    tuple_sequence = tuple(format_seq)
    peptide_sequence_list.append(tuple_sequence) # build a summarize list variable including all the sequence information

# Usage esm_embedding and other esm models in these hashtags
#embeddings_results = esm_embeddings(peptide_sequence_list,'esm2_t6_8M_UR50D')
#embeddings_results = esm_embeddings(peptide_sequence_list,'esm2_t12_35M_UR50D')
#embeddings_results = esm_embeddings(peptide_sequence_list,'esm2_t30_150M_UR50D')
embeddings_results = esm_embeddings(peptide_sequence_list,'esm2_t33_650M_UR50D')

# Convert to pandas
embeddings_results = pd.DataFrame(embeddings_results)
embeddings_results

Downloading: "https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t33_650M_UR50D.pt" to /root/.cache/torch/hub/checkpoints/esm2_t33_650M_UR50D.pt
Downloading: "https://dl.fbaipublicfiles.com/fair-esm/regression/esm2_t33_650M_UR50D-contact-regression.pt" to /root/.cache/torch/hub/checkpoints/esm2_t33_650M_UR50D-contact-regression.pt
