<a href="https://colab.research.google.com/github/MahdieRah/protein-stability-ml/blob/main/Protein_Stability_Prediction_with_ESM_2_%2B_MLP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# 🧬 Protein Stability Prediction with ESM-2 + MLP

# ✅ Step 1: Install required libraries
!pip install fair-esm torch pandas scikit-learn -q

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m93.1/93.1 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m59.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m31.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m33.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m834.0 kB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m9.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
# ✅ Step 2: Import libraries
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import esm

In [None]:
# ✅ Step 3: Load ESM-2 model
esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
esm_model.eval()

Downloading: "https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t33_650M_UR50D.pt" to /root/.cache/torch/hub/checkpoints/esm2_t33_650M_UR50D.pt
Downloading: "https://dl.fbaipublicfiles.com/fair-esm/regression/esm2_t33_650M_UR50D-contact-regression.pt" to /root/.cache/torch/hub/checkpoints/esm2_t33_650M_UR50D-contact-regression.pt


ESM2(
  (embed_tokens): Embedding(33, 1280, padding_idx=1)
  (layers): ModuleList(
    (0-32): 33 x TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (rot_emb): RotaryEmbedding()
      )
      (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=1280, out_features=5120, bias=True)
      (fc2): Linear(in_features=5120, out_features=1280, bias=True)
      (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    )
  )
  (contact_head): ContactPredictionHead(
    (regression): Linear(in_features=660, out_features=1, bias=True)
    (activation): Sigmoid()
  )
  (emb_layer_norm_after): LayerNorm((1280,), eps=1

In [None]:
# ✅ Step 4: Define embedding extractor
def extract_esm2_embedding(sequence):
    data = [("protein", sequence)]
    _, _, batch_tokens = batch_converter(data)
    with torch.no_grad():
        results = esm_model(batch_tokens, repr_layers=[33])
    token_representations = results["representations"][33]
    embedding = token_representations[0, 1:-1].mean(0).numpy()
    return embedding

In [None]:
# ✅ Step 5: Define the MLP model
class ProteinMLP(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, x):
        return self.model(x)


In [None]:
# ✅ Step 6: Create sample dataset (you can replace with real data)
data = {
    "sequence": [
        "MVKVYAPASSANMSVGFDVLGAAVTPVDGALLGDVVTVEAAETFSLNNLGQKL",
        "GLSDGEWQLVLNVWGKVEADIPGHGQEVLIRLFKGH",
        "MKWVTFISLLFLFSSAYSRGVFRRDAHKSEVAHRFKDLGE",
        "GAVLIGTAAQIVATAGTNLVHSYDDGKSWTYLWEVQKAF",
        "MPTFISLLFLFSSAYSAVETALFNAQEQDGRQAK"
    ],
    "stability_score": [0.8, 0.75, 0.9, 0.65, 0.85]
}
df = pd.DataFrame(data)


In [None]:
# ✅ Step 7: Extract ESM embeddings
df["features"] = df["sequence"].apply(extract_esm2_embedding)
X = np.vstack(df["features"])
y = df["stability_score"].values

In [None]:
# ✅ Step 8: Train/test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# ✅ Step 9: Convert to tensors
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.float32).view(-1, 1)
X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test, dtype=torch.float32).view(-1, 1)


In [None]:
# ✅ Step 10: Train the model
model = ProteinMLP(input_dim=1280)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()

for epoch in range(30):
    model.train()
    optimizer.zero_grad()
    predictions = model(X_train_tensor)
    loss = loss_fn(predictions, y_train_tensor)
    loss.backward()
    optimizer.step()
    if (epoch + 1) % 5 == 0:
        print(f"Epoch {epoch+1} - Loss: {loss.item():.4f}")

Epoch 5 - Loss: 0.3100
Epoch 10 - Loss: 0.0246
Epoch 15 - Loss: 0.1064
Epoch 20 - Loss: 0.0024
Epoch 25 - Loss: 0.0131
Epoch 30 - Loss: 0.0229


In [None]:
# ✅ Step 11: Evaluate
model.eval()
with torch.no_grad():
    test_preds = model(X_test_tensor)
    test_loss = loss_fn(test_preds, y_test_tensor)
    print(f"\nTest Loss: {test_loss.item():.4f}")


Test Loss: 0.0069


In [None]:
# ✅ Step 12: Predict on new sequences
new_sequences = ["GLSDGEWQLVLNVWGKVEADIPGHGQEVLIRLFKGH"]
new_features = np.vstack([extract_esm2_embedding(seq) for seq in new_sequences])
new_tensor = torch.tensor(new_features, dtype=torch.float32)

with torch.no_grad():
    prediction = model(new_tensor)
    print("\nPredicted Stability Score:", prediction.numpy().flatten()[0])


Predicted Stability Score: 0.6667696


In [1]:
import pandas as pd

# Upload the file using Colab's UI
from google.colab import files
uploaded = files.upload()

# Load dataset (adjust filename if needed)
df = pd.read_csv("protherm.csv", sep=",")  # or sep="\t" if it's TSV
df.head()



Saving protherm.csv to protherm.csv


UnicodeDecodeError: 'utf-8' codec can't decode byte 0x81 in position 23145: invalid start byte

In [2]:
import pandas as pd

# Try with ISO-8859-1 encoding
df = pd.read_csv("protherm.csv", encoding='ISO-8859-1')
df.head()


  df = pd.read_csv("protherm.csv", encoding='ISO-8859-1')


Unnamed: 0,ID,NO.,***** Sequence and structural information*****,PROTEIN,SOURCE,LENGTH,MOL-WEIGHT,PIR_ID,SWISSPROT_ID,SWISSPROT_ID_LINK,...,ACTIVITY_Kd,***** Literature *****,KEY_WORDS,REFERENCE,REFERENCE_PMID,AUTHOR,REMARKS,RELATED_ENTRIES,RELATED_ENTRIES.1,e_c_number_link
0,142452,1.0,,Phospholipase A2,Bovine,130.0,14536.12,PSBOA,PA21_BOVIN (P00593),P00593,...,,,catalytic triad; PLA2; conformational stabilit...,"J AM CHEM SOC 115, 8523-8526 (1993) PMID:",,Li Y. & Tsai M.-D.,"additive : EDTA(0.1 mM),","2, 3, 4","<a href=""/protherm/show/ 2""> 2</a>, <a href=""/...",
1,142453,2.0,,Phospholipase A2,Bovine,130.0,14513.08,PSBOA,PA21_BOVIN (P00593),P00593,...,,,catalytic triad; PLA2; conformational stabilit...,"J AM CHEM SOC 115, 8523-8526 (1993) PMID:",,Li Y. & Tsai M.-D.,"additive : EDTA(0.1 mM),","1, 3, 4","<a href=""/protherm/show/ 1""> 1</a>, <a href=""/...",
2,142454,3.0,,Phospholipase A2,Bovine,130.0,14527.11,PSBOA,PA21_BOVIN (P00593),P00593,...,,,catalytic triad; PLA2; conformational stabilit...,"J AM CHEM SOC 115, 8523-8526 (1993) PMID:",,Li Y. & Tsai M.-D.,"additive : EDTA(0.1 mM),","1, 2, 4","<a href=""/protherm/show/ 1""> 1</a>, <a href=""/...",
3,142455,4.0,,Phospholipase A2,Bovine,130.0,14470.06,PSBOA,PA21_BOVIN (P00593),P00593,...,,,catalytic triad; PLA2; conformational stabilit...,"J AM CHEM SOC 115, 8523-8526 (1993) PMID:",,Li Y. & Tsai M.-D.,"additive : EDTA(0.1 mM),","1, 2, 3","<a href=""/protherm/show/ 1""> 1</a>, <a href=""/...",
4,142456,5.0,,Ribonuclease HI,Escherichia coli,155.0,17597.0,NRECH,RNH_ECOLI (P00647),P00647,...,,,structural stability; mutagenesis; free energy...,"J BIOL CHEM 267, 22014-22017 (1992) PMID: 1...",1331044.0,"Kimura S., Kanaya S. & Nakamura H.",,"6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 2143,...","<a href=""/protherm/show/ 6""> 6</a>, <a href=""/...",3.1.27.3


In [3]:
print(df.columns)

Index(['ID', 'NO.', '***** Sequence and structural information*****',
       'PROTEIN', 'SOURCE', 'LENGTH', 'MOL-WEIGHT', 'PIR_ID', 'SWISSPROT_ID',
       'SWISSPROT_ID_LINK', 'E.C.NUMBER', 'PMD.NO', 'PDB_wild', 'PDB_mutant',
       'MUTATION', 'MUTATED_CHAIN', 'NO_MOLECULE', 'SEC.STR.', 'ASA',
       '***** Experimental condition *****', 'T', 'pH', 'BUFFER_NAME',
       'BUFFER_CONC', 'ION_NAME_1', 'ION_CONC_1', 'PROTEIN_CONC', 'MEASURE',
       'METHOD', '***** Thermodynamic data *****', 'dG_H2O', 'ddG_H2O', 'dG',
       'ddG', 'Tm', 'dTm', 'dHvH', 'dHcal', 'm', 'Cm', 'dCp', 'STATE',
       'REVERSIBILITY', 'ACTIVITY', 'ACTIVITY_Km', 'ACTIVITY_Kcat',
       'ACTIVITY_Kd', '***** Literature *****', 'KEY_WORDS', 'REFERENCE',
       'REFERENCE_PMID', 'AUTHOR', 'REMARKS', 'RELATED_ENTRIES',
       'RELATED_ENTRIES.1', 'e_c_number_link'],
      dtype='object')


In [4]:
print(df.columns.tolist())
df.head(3)


['ID', 'NO.', '***** Sequence and structural information*****', 'PROTEIN', 'SOURCE', 'LENGTH', 'MOL-WEIGHT', 'PIR_ID', 'SWISSPROT_ID', 'SWISSPROT_ID_LINK', 'E.C.NUMBER', 'PMD.NO', 'PDB_wild', 'PDB_mutant', 'MUTATION', 'MUTATED_CHAIN', 'NO_MOLECULE', 'SEC.STR.', 'ASA', '***** Experimental condition *****', 'T', 'pH', 'BUFFER_NAME', 'BUFFER_CONC', 'ION_NAME_1', 'ION_CONC_1', 'PROTEIN_CONC', 'MEASURE', 'METHOD', '***** Thermodynamic data *****', 'dG_H2O', 'ddG_H2O', 'dG', 'ddG', 'Tm', 'dTm', 'dHvH', 'dHcal', 'm', 'Cm', 'dCp', 'STATE', 'REVERSIBILITY', 'ACTIVITY', 'ACTIVITY_Km', 'ACTIVITY_Kcat', 'ACTIVITY_Kd', '***** Literature *****', 'KEY_WORDS', 'REFERENCE', 'REFERENCE_PMID', 'AUTHOR', 'REMARKS', 'RELATED_ENTRIES', 'RELATED_ENTRIES.1', 'e_c_number_link']


Unnamed: 0,ID,NO.,***** Sequence and structural information*****,PROTEIN,SOURCE,LENGTH,MOL-WEIGHT,PIR_ID,SWISSPROT_ID,SWISSPROT_ID_LINK,...,ACTIVITY_Kd,***** Literature *****,KEY_WORDS,REFERENCE,REFERENCE_PMID,AUTHOR,REMARKS,RELATED_ENTRIES,RELATED_ENTRIES.1,e_c_number_link
0,142452,1.0,,Phospholipase A2,Bovine,130.0,14536.12,PSBOA,PA21_BOVIN (P00593),P00593,...,,,catalytic triad; PLA2; conformational stabilit...,"J AM CHEM SOC 115, 8523-8526 (1993) PMID:",,Li Y. & Tsai M.-D.,"additive : EDTA(0.1 mM),","2, 3, 4","<a href=""/protherm/show/ 2""> 2</a>, <a href=""/...",
1,142453,2.0,,Phospholipase A2,Bovine,130.0,14513.08,PSBOA,PA21_BOVIN (P00593),P00593,...,,,catalytic triad; PLA2; conformational stabilit...,"J AM CHEM SOC 115, 8523-8526 (1993) PMID:",,Li Y. & Tsai M.-D.,"additive : EDTA(0.1 mM),","1, 3, 4","<a href=""/protherm/show/ 1""> 1</a>, <a href=""/...",
2,142454,3.0,,Phospholipase A2,Bovine,130.0,14527.11,PSBOA,PA21_BOVIN (P00593),P00593,...,,,catalytic triad; PLA2; conformational stabilit...,"J AM CHEM SOC 115, 8523-8526 (1993) PMID:",,Li Y. & Tsai M.-D.,"additive : EDTA(0.1 mM),","1, 2, 4","<a href=""/protherm/show/ 1""> 1</a>, <a href=""/...",


In [5]:
# Load with encoding
df = pd.read_csv("protherm.csv", encoding='ISO-8859-1')

# Keep only rows with mutation info + ddG
df_filtered = df[["MUTATION", "ddG"]].dropna()

# Remove outliers/extreme ddG values if desired
df_filtered = df_filtered[df_filtered["ddG"].between(-10, 10)]

# Reset index
df_filtered = df_filtered.reset_index(drop=True)

# Display sample
print(df_filtered.head())


  df = pd.read_csv("protherm.csv", encoding='ISO-8859-1')


TypeError: '>=' not supported between instances of 'str' and 'int'

In [6]:
import pandas as pd

# Load with correct options
df = pd.read_csv("protherm.csv", encoding='ISO-8859-1', low_memory=False)

# Keep only columns we care about
df_filtered = df[["MUTATION", "ddG"]].dropna()

# Convert ddG to numeric, force errors to NaN, then drop them
df_filtered["ddG"] = pd.to_numeric(df_filtered["ddG"], errors="coerce")
df_filtered = df_filtered.dropna(subset=["ddG"])

# Remove outliers/extreme ddG values if desired
df_filtered = df_filtered[df_filtered["ddG"].between(-10, 10)]

# Reset index
df_filtered = df_filtered.reset_index(drop=True)

# Show sample
print(df_filtered.head())


Empty DataFrame
Columns: [MUTATION, ddG]
Index: []


In [7]:
# Reload the full file (in case it's modified)
df = pd.read_csv("protherm.csv", encoding='ISO-8859-1', low_memory=False)

# See if ddG exists and what's inside
print("All column names:\n", df.columns.tolist())
print("\nSample ddG values:\n", df["ddG"].head(20))


All column names:
 ['ID', 'NO.', '***** Sequence and structural information*****', 'PROTEIN', 'SOURCE', 'LENGTH', 'MOL-WEIGHT', 'PIR_ID', 'SWISSPROT_ID', 'SWISSPROT_ID_LINK', 'E.C.NUMBER', 'PMD.NO', 'PDB_wild', 'PDB_mutant', 'MUTATION', 'MUTATED_CHAIN', 'NO_MOLECULE', 'SEC.STR.', 'ASA', '***** Experimental condition *****', 'T', 'pH', 'BUFFER_NAME', 'BUFFER_CONC', 'ION_NAME_1', 'ION_CONC_1', 'PROTEIN_CONC', 'MEASURE', 'METHOD', '***** Thermodynamic data *****', 'dG_H2O', 'ddG_H2O', 'dG', 'ddG', 'Tm', 'dTm', 'dHvH', 'dHcal', 'm', 'Cm', 'dCp', 'STATE', 'REVERSIBILITY', 'ACTIVITY', 'ACTIVITY_Km', 'ACTIVITY_Kcat', 'ACTIVITY_Kd', '***** Literature *****', 'KEY_WORDS', 'REFERENCE', 'REFERENCE_PMID', 'AUTHOR', 'REMARKS', 'RELATED_ENTRIES', 'RELATED_ENTRIES.1', 'e_c_number_link']

Sample ddG values:
 0               NaN
1               NaN
2               NaN
3               NaN
4               NaN
5               NaN
6               NaN
7               NaN
8               NaN
9               

In [10]:
df_filtered = df[["MUTATION", "ddG_H2O"]].dropna()
df_filtered = df_filtered.rename(columns={"ddG_H2O": "ddG"})
df_filtered["ddG"] = pd.to_numeric(df_filtered["ddG"], errors="coerce")
df_filtered = df_filtered.dropna(subset=["ddG"])
df_filtered = df_filtered[df_filtered["ddG"].between(-10, 10)]
print("\nSample ddG values:\n", df["ddG_H2O"].head(20))


Sample ddG values:
 0                NaN
1     -3.00 kcal/mol
2     -0.60 kcal/mol
3     -3.10 kcal/mol
4                NaN
5                NaN
6                NaN
7                NaN
8                NaN
9                NaN
10               NaN
11               NaN
12               NaN
13               NaN
14               NaN
15               NaN
16               NaN
17    -2.10 kcal/mol
18    -3.10 kcal/mol
19    -2.80 kcal/mol
Name: ddG_H2O, dtype: object


In [11]:
# Clean the "ddG_H2O" column by removing 'kcal/mol' and trimming whitespace
df["ddG_H2O_clean"] = df["ddG_H2O"].str.replace("kcal/mol", "", regex=False).str.strip()

# Convert cleaned column to numeric
df["ddG"] = pd.to_numeric(df["ddG_H2O_clean"], errors="coerce")

# Filter rows where both 'MUTATION' and 'ddG' are valid
df_filtered = df[["MUTATION", "ddG"]].dropna()
df_filtered = df_filtered[df_filtered["ddG"].between(-10, 10)]

# Show sample
print(df_filtered.head())


   MUTATION  ddG
1    H 48 N -3.0
2    H 48 Q -0.6
3    H 48 A -3.1
17    T 9 A -2.1
18    T 9 L -3.1


In [12]:
df_filtered.to_csv("clean_protherm_ddg.csv", index=False)


In [13]:
# Get rows from original df that match mutations we kept
matched_mutations = df[df["MUTATION"].isin(df_filtered["MUTATION"])][["MUTATION", "PROTEIN", "SWISSPROT_ID"]]
print(matched_mutations.drop_duplicates())


      MUTATION                                 PROTEIN         SWISSPROT_ID
0         wild                        Phospholipase A2  PA21_BOVIN (P00593)
1       H 48 N                        Phospholipase A2  PA21_BOVIN (P00593)
2       H 48 Q                        Phospholipase A2  PA21_BOVIN (P00593)
3       H 48 A                        Phospholipase A2  PA21_BOVIN (P00593)
4         wild                         Ribonuclease HI   RNH_ECOLI (P00647)
...        ...                                     ...                  ...
24164     wild                      VL domain antibody                  NaN
24231     wild                   Aacetylcholinesterase                  NaN
24239     wild  Tryptophan zipper beta-hairpin peptide                  NaN
24249     wild                                HIV gp41                  NaN
24256     wild                                 Barnase                  NaN

[3850 rows x 3 columns]


In [14]:
import requests

# UniProt ID for Phospholipase A2 (bovine)
uniprot_id = "P00593"
url = f"https://www.uniprot.org/uniprot/{uniprot_id}.fasta"

response = requests.get(url)
fasta = response.text

# Parse sequence from FASTA
sequence = "".join(line.strip() for line in fasta.split("\n") if not line.startswith(">"))
print("Wild-type sequence (first 60 aa):", sequence[:60])
print("Sequence length:", len(sequence))


Wild-type sequence (first 60 aa): MRLLVLAALLTVGAGQAGLNSRALWQFNGMIKCKIPSSEPLLDFNNYGCYCGLGGSGTPV
Sequence length: 145


In [15]:
import pandas as pd

# Your wild-type sequence
wt_sequence = sequence  # From your earlier code

# Ensure df_filtered is already defined and contains 'MUTATION' and 'ddG' columns
# If not, reload and preprocess it as before

# Mutation function
def apply_mutation(wt_seq, mutation_str):
    try:
        original_aa, pos, mutant_aa = mutation_str.strip().split()
        pos = int(pos) - 1  # Convert to 0-based index

        if pos < 0 or pos >= len(wt_seq):
            return None  # Position out of range

        if wt_seq[pos] != original_aa:
            print(f"⚠️ Mismatch at position {pos+1}: expected {original_aa}, found {wt_seq[pos]}")
            return None

        return wt_seq[:pos] + mutant_aa + wt_seq[pos+1:]
    except Exception as e:
        print(f"❌ Error parsing mutation '{mutation_str}': {e}")
        return None

# Apply the mutation column to generate sequences
df_filtered["mutated_sequence"] = df_filtered["MUTATION"].apply(lambda x: apply_mutation(wt_sequence, x))

# Drop any rows where mutation failed
df_mutants = df_filtered.dropna(subset=["mutated_sequence"])

# Show first few rows
df_mutants = df_mutants[["MUTATION", "ddG", "mutated_sequence"]]
print("\n✅ Mutant sequences:")
print(df_mutants.head())


⚠️ Mismatch at position 48: expected H, found G
⚠️ Mismatch at position 48: expected H, found G
⚠️ Mismatch at position 48: expected H, found G
⚠️ Mismatch at position 9: expected T, found L
⚠️ Mismatch at position 9: expected T, found L
⚠️ Mismatch at position 9: expected T, found L
⚠️ Mismatch at position 9: expected T, found L
⚠️ Mismatch at position 97: expected K, found Y
⚠️ Mismatch at position 97: expected K, found Y
⚠️ Mismatch at position 97: expected K, found Y
⚠️ Mismatch at position 4: expected I, found L
⚠️ Mismatch at position 4: expected I, found L
⚠️ Mismatch at position 5: expected N, found V
⚠️ Mismatch at position 6: expected T, found L
⚠️ Mismatch at position 6: expected T, found L
⚠️ Mismatch at position 8: expected D, found A
⚠️ Mismatch at position 10: expected V, found L
⚠️ Mismatch at position 10: expected V, found L
⚠️ Mismatch at position 12: expected D, found V
⚠️ Mismatch at position 13: expected Y, found G
⚠️ Mismatch at position 14: expected L, found A
⚠️

In [16]:
!pip install fair-esm


Collecting fair-esm
  Downloading fair_esm-2.0.0-py3-none-any.whl.metadata (37 kB)
Downloading fair_esm-2.0.0-py3-none-any.whl (93 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/93.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m93.1/93.1 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fair-esm
Successfully installed fair-esm-2.0.0


In [17]:
import torch
import esm
from tqdm import tqdm

# Load ESM-2 model (650M parameter model)
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.eval()  # Set to eval mode

# Function to get ESM-2 embedding for one sequence
def get_esm_embedding(sequence):
    batch_labels = [("protein", sequence)]
    batch_tokens = batch_converter(batch_labels)[2]

    with torch.no_grad():
        results = model(batch_tokens, repr_layers=[33], return_contacts=False)
    # Get per-sequence embedding from layer 33
    token_representations = results["representations"][33]
    sequence_embedding = token_representations[0, 1:len(sequence)+1].mean(0)  # Mean over residues
    return sequence_embedding.numpy()

# Generate embeddings for all mutant sequences
embedding_list = []
ddg_list = []

print("⚙️ Generating embeddings...")
for i, row in tqdm(df_mutants.iterrows(), total=len(df_mutants)):
    seq = row["mutated_sequence"]
    ddg = row["ddG"]
    emb = get_esm_embedding(seq)
    embedding_list.append(emb)
    ddg_list.append(ddg)

# Convert to DataFrame
embedding_df = pd.DataFrame(embedding_list)
embedding_df["ddG"] = ddg_list

print("✅ Embedding DataFrame shape:", embedding_df.shape)
embedding_df.head()


Downloading: "https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t33_650M_UR50D.pt" to /root/.cache/torch/hub/checkpoints/esm2_t33_650M_UR50D.pt
Downloading: "https://dl.fbaipublicfiles.com/fair-esm/regression/esm2_t33_650M_UR50D-contact-regression.pt" to /root/.cache/torch/hub/checkpoints/esm2_t33_650M_UR50D-contact-regression.pt


⚙️ Generating embeddings...


100%|██████████| 117/117 [09:24<00:00,  4.83s/it]

✅ Embedding DataFrame shape: (117, 1281)





Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,1271,1272,1273,1274,1275,1276,1277,1278,1279,ddG
0,-0.009243,-0.125319,-0.103301,0.173982,-0.036075,-0.079967,0.076539,-0.119213,-0.006076,0.104057,...,-0.076498,-0.105887,0.052079,-0.030775,-0.059344,0.072604,-0.10978,-0.140911,0.070614,-2.14
1,-0.001386,-0.115838,-0.104058,0.175939,-0.030668,-0.079796,0.068999,-0.123076,-0.013256,0.096661,...,-0.071938,-0.099808,0.046403,-0.032829,-0.056876,0.075581,-0.106996,-0.139728,0.070735,0.4
2,-0.003175,-0.122614,-0.104667,0.174426,-0.032397,-0.082828,0.078625,-0.118985,-0.004894,0.105656,...,-0.076756,-0.098936,0.055323,-0.031343,-0.060806,0.072341,-0.105408,-0.147368,0.071722,-5.3
3,-0.002807,-0.126627,-0.106243,0.177825,-0.034716,-0.082411,0.078574,-0.117067,-0.004128,0.105049,...,-0.074676,-0.099539,0.055198,-0.031039,-0.061126,0.073545,-0.105512,-0.146921,0.075111,-6.7
4,-0.010176,-0.123098,-0.103274,0.178761,-0.037183,-0.080673,0.075554,-0.113988,-0.009087,0.112106,...,-0.075986,-0.117337,0.051234,-0.030255,-0.063539,0.064882,-0.107388,-0.139839,0.066852,-0.1


In [18]:
embedding_df.to_csv("mutant_esm_embeddings.csv", index=False)
