# Generate ESM embeddings using ESM-2 form Meta
- Loads your .fasta file
- Feeds each sequence into ESM-2
- Extracts [CLS] token embedding (or mean pooled)
- Saves to a .csv



In [3]:
## Import necessary libraries
import torch
import esm
import os
import pandas as pd
from Bio import SeqIO
from tqdm import tqdm
print("Libraries imported successfully")

Libraries imported successfully


In [None]:
# Function to generate ESM-2 embeddings for protein sequences
def generate_esm2_embeddings(fasta_path, esm2_model_path, output_csv="../data/step4.1_esm_protein_embeddings.csv"):
    # Load pretrained ESM-2 model locally
    print("Loading local ESM-2 model...")
    
    # Fix for PyTorch 2.6+ weights_only default change
    import torch.serialization
    import argparse
    torch.serialization.add_safe_globals([argparse.Namespace])
    
    try:
        model, alphabet = esm.pretrained.load_model_and_alphabet_local(esm2_model_path)
    except Exception as e:
        print(f"Failed to load with weights_only=True, trying with weights_only=False...")
        # Monkey patch torch.load to use weights_only=False for ESM models
        original_load = torch.load
        def patched_load(*args, **kwargs):
            kwargs['weights_only'] = False
            return original_load(*args, **kwargs)
        torch.load = patched_load
        
        model, alphabet = esm.pretrained.load_model_and_alphabet_local(esm2_model_path)
        
        # Restore original torch.load
        torch.load = original_load
    
    batch_converter = alphabet.get_batch_converter()
    model.eval()

    # Get the number of layers from the model
    num_layers = model.num_layers

    # Read sequences
    sequences = list(SeqIO.parse(fasta_path, "fasta"))
    print(f"Found {len(sequences)} protein sequences.")

    records = []

    for record in tqdm(sequences, desc="Generating embeddings"):
        name = record.id
        sequence = str(record.seq)

        if len(sequence) > 4096:
            print(f"Skipping {name}: sequence too long for ESM-2")
            continue

        batch_labels, batch_strs, batch_tokens = batch_converter([(name, sequence)])
        with torch.no_grad():
            results = model(batch_tokens, repr_layers=[num_layers], return_contacts=False)

        token_representations = results["representations"][num_layers]
        cls_embedding = token_representations[0, 0, :].numpy()

        records.append({
            "id": name,
            **{f"feat_{i}": cls_embedding[i] for i in range(len(cls_embedding))}
        })

    df = pd.DataFrame(records)
    os.makedirs("data", exist_ok=True)
    df.to_csv(output_csv, index=False)
    print(f"Embeddings saved to '{output_csv}'")

In [7]:
#Run
if __name__ == "__main__":
    fasta_path = "../data/step3_kinase_target_sequences.fasta"
    esm2_model_path = r"C:\Users\FEL_BA_01\.cache\torch\hub\checkpoints\esm2_t33_650M_UR50D.pt"
    generate_esm2_embeddings(fasta_path, esm2_model_path)


Loading local ESM-2 model...
Found 188 protein sequences.
Found 188 protein sequences.


Generating embeddings:   4%|▎         | 7/188 [01:00<15:58,  5.30s/it]  

Skipping P78527|CHEMBL3142|DNA-dependent: sequence too long for ESM-2


Generating embeddings: 100%|██████████| 188/188 [42:16<00:00, 13.49s/it]   



Embeddings saved to 'data/step4.1_esm_protein_embeddings.csv'
