Generate embeddings from protein sequence data using lightweight ESM2 model

In [3]:
from transformers import EsmModel, EsmTokenizer
import torch

#Suppress warnings
import warnings
from transformers import logging

# Suppress Hugging Face warnings
logging.set_verbosity_error()

# Suppress all warnings (optional)
warnings.filterwarnings("ignore")

In [4]:


# Load large ESM-2 model
model_name = "facebook/esm2_t6_8M_UR50D"
tokenizer = EsmTokenizer.from_pretrained(model_name)
model = EsmModel.from_pretrained(model_name)

#Example protein sequence
sequence = "MKTLLVLLLGAAGG"
tokens = tokenizer(sequence, return_tensors="pt")

# Generate embeddings quickly
embeddings = model(**tokens).last_hidden_state
print(f"Embedding shape: {embeddings.shape}")

Embedding shape: torch.Size([1, 16, 320])


Protein mutation effect analysis

In [None]:
# Load mid-weight ESM-2 model
model_name = "facebook/esm2_t12_35M_UR50D" 
tokenizer = EsmTokenizer.from_pretrained(model_name)
model = EsmModel.from_pretrained(model_name)

# Example wild-type sequence
sequence = "MKTLLVLLLGAAGG"

# Introduce mutation (G -> A at position 13)
mutated_sequence = "MKTLLVLLLGAAGG".replace("G", "A", 1)

# Tokenize sequences
tokens_wt = tokenizer(sequence, return_tensors="pt")
tokens_mut = tokenizer(mutated_sequence, return_tensors="pt")

# Generate embeddings
embeddings_wt = model(**tokens_wt).last_hidden_state
embeddings_mut = model(**tokens_mut).last_hidden_state

# Compute similarity (cosine similarity, lower values means proteins phenotype is affected, closer to 1 means minimal or no effect)
similarity = torch.nn.functional.cosine_similarity(embeddings_wt.mean(dim=1), embeddings_mut.mean(dim=1))
print(f"Mutation Impact Score: {similarity.item()}")


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Mutation Impact Score: 0.9933467507362366


Fine-Tuning ESM-2 for Mutation Effect Prediction
High-quality datasets for ESM-based mutation effect analysis is available from several trusted sources, especially those based on Deep Mutational Scanning (DMS).
Curated DMS datasets sources:
- ProteinGym — 2.5M mutants across 217 assays
- GSK DMS fine-tuning benchmark — includes normalized log-odds scores
- Hugging Face mutation scoring tutorial — includes example data and scoring methods
For this tutorial, a synthetic mutational effect dataset  is used to demonstrate the steps involved. The file contains the following columns:
- sequence: wild-type protein sequence
- mutation: e.g., "G13A"
- score: experimental fitness or activity score

Interpretation
- The R² score reflects how well the model predicts mutation effects. Since this example uses synthetic dummy data with randomly generated scores, the R² value may be misleading or even negative. This is expected—there’s no real biological signal for the model to learn from. For meaningful results, consider using a curated dataset (e.g., from ProteinGym or a deep mutational scanning study) and experimenting with different model architectures. This example is intended purely to demonstrate the workflow and pipeline setup.

- Predicted score: Close to 1 → likely no significant impact; close to 0 → likely disruptive.




In [None]:
import pandas as pd

def apply_mutation(seq, mut):
    """
    Function to generate mutated sequence given wild type sequece and mutation
    """
    pos = int(mut[1:-1]) - 1
    return seq[:pos] + mut[-1] + seq[pos+1:]

df = pd.read_csv("../data/mutations.csv")
df["mutated_sequence"] = df.apply(lambda row: apply_mutation(row["sequence"], row["mutation"]), axis=1)
print(df)

          sequence mutation  score mutated_sequence
0   MKTLLVLLLGAAGG     G13A   0.92   MKTLLVLLLGAAAG
1   MKTLLVLLLGAAGG     G13D   0.45   MKTLLVLLLGAADG
2   MKTLLVLLLGAAGG     G13V   0.12   MKTLLVLLLGAAVG
3   MKTLLVLLLGAAGG      L5P   0.78   MKTLPVLLLGAAGG
4   MKTLLVLLLGAAGG      L5F   0.33   MKTLFVLLLGAAGG
5   MKTLLVLLLGAAGG      L5S   0.15   MKTLSVLLLGAAGG
6   MKTLLVLLLGAAGG     A11T   0.88   MKTLLVLLLGTAGG
7   MKTLLVLLLGAAGG     A11G   0.51   MKTLLVLLLGGAGG
8   MKTLLVLLLGAAGG     A11C   0.09   MKTLLVLLLGCAGG
9   MKTLLVLLLGAAGG      K2R   0.95   MRTLLVLLLGAAGG
10  MKTLLVLLLGAAGG      K2E   0.40   METLLVLLLGAAGG
11  MKTLLVLLLGAAGG      K2Q   0.18   MQTLLVLLLGAAGG


In [7]:
#Generate embeddings
from tqdm import tqdm

def get_embedding(sequence):
    tokens = tokenizer(sequence, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**tokens)
    return outputs.last_hidden_state.mean(dim=1).squeeze().numpy()

df["embedding"] = [get_embedding(seq) for seq in tqdm(df["mutated_sequence"])]
df

100%|██████████| 12/12 [00:00<00:00, 100.89it/s]


Unnamed: 0,sequence,mutation,score,mutated_sequence,embedding
0,MKTLLVLLLGAAGG,G13A,0.92,MKTLLVLLLGAAAG,"[0.016369378, 0.20922367, 0.24752869, 0.189788..."
1,MKTLLVLLLGAAGG,G13D,0.45,MKTLLVLLLGAADG,"[-0.021203855, 0.035817057, 0.25660914, 0.1874..."
2,MKTLLVLLLGAAGG,G13V,0.12,MKTLLVLLLGAAVG,"[0.01610994, 0.23487134, 0.30792505, 0.2103723..."
3,MKTLLVLLLGAAGG,L5P,0.78,MKTLPVLLLGAAGG,"[-0.07309258, 0.019085575, 0.22213294, 0.19425..."
4,MKTLLVLLLGAAGG,L5F,0.33,MKTLFVLLLGAAGG,"[-0.025348786, 0.14206602, 0.24159305, 0.19302..."
5,MKTLLVLLLGAAGG,L5S,0.15,MKTLSVLLLGAAGG,"[-0.061086595, 0.11706587, 0.22230922, 0.15267..."
6,MKTLLVLLLGAAGG,A11T,0.88,MKTLLVLLLGTAGG,"[0.0010373108, 0.12275866, 0.27234566, 0.21511..."
7,MKTLLVLLLGAAGG,A11G,0.51,MKTLLVLLLGGAGG,"[-0.024962382, 0.05335851, 0.2446322, 0.269901..."
8,MKTLLVLLLGAAGG,A11C,0.09,MKTLLVLLLGCAGG,"[-0.013253654, 0.063041046, 0.24657904, 0.2604..."
9,MKTLLVLLLGAAGG,K2R,0.95,MRTLLVLLLGAAGG,"[-0.0016255928, 0.07183409, 0.26255524, 0.2034..."


In [8]:
#Train a regression model
from sklearn.linear_model import Ridge
from sklearn.model_selection import train_test_split

X = list(df["embedding"])
y = df["score"]

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

regressor = Ridge()
regressor.fit(X_train, y_train)

r2 = regressor.score(X_test, y_test)
print(f"R² on test set: {r2:.2f}")

R² on test set: -1.13


In [9]:
test_seq = apply_mutation("MKTLLVLLLGAAGG", "G13A")
embedding = get_embedding(test_seq).reshape(1, -1)
predicted_score = regressor.predict(embedding)[0]
print(f"Predicted impact score: {predicted_score:.2f}")

Predicted impact score: 0.70
