<a href="https://colab.research.google.com/github/MahdieRah/protein-stability-ml/blob/main/Protein_Stability_Prediction_with_ESM_2_%2B_MLP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
# 🧬 Protein Stability Prediction with ESM-2 + MLP

# ✅ Step 1: Install required libraries
!pip install fair-esm torch pandas scikit-learn -q

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m93.1/93.1 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m59.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m31.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m33.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m834.0 kB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m9.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [4]:
# ✅ Step 2: Import libraries
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import esm

In [5]:
# ✅ Step 3: Load ESM-2 model
esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
esm_model.eval()

Downloading: "https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t33_650M_UR50D.pt" to /root/.cache/torch/hub/checkpoints/esm2_t33_650M_UR50D.pt
Downloading: "https://dl.fbaipublicfiles.com/fair-esm/regression/esm2_t33_650M_UR50D-contact-regression.pt" to /root/.cache/torch/hub/checkpoints/esm2_t33_650M_UR50D-contact-regression.pt


ESM2(
  (embed_tokens): Embedding(33, 1280, padding_idx=1)
  (layers): ModuleList(
    (0-32): 33 x TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (rot_emb): RotaryEmbedding()
      )
      (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=1280, out_features=5120, bias=True)
      (fc2): Linear(in_features=5120, out_features=1280, bias=True)
      (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    )
  )
  (contact_head): ContactPredictionHead(
    (regression): Linear(in_features=660, out_features=1, bias=True)
    (activation): Sigmoid()
  )
  (emb_layer_norm_after): LayerNorm((1280,), eps=1

In [7]:
# ✅ Step 4: Define embedding extractor
def extract_esm2_embedding(sequence):
    data = [("protein", sequence)]
    _, _, batch_tokens = batch_converter(data)
    with torch.no_grad():
        results = esm_model(batch_tokens, repr_layers=[33])
    token_representations = results["representations"][33]
    embedding = token_representations[0, 1:-1].mean(0).numpy()
    return embedding

In [8]:
# ✅ Step 5: Define the MLP model
class ProteinMLP(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, x):
        return self.model(x)


In [9]:
# ✅ Step 6: Create sample dataset (you can replace with real data)
data = {
    "sequence": [
        "MVKVYAPASSANMSVGFDVLGAAVTPVDGALLGDVVTVEAAETFSLNNLGQKL",
        "GLSDGEWQLVLNVWGKVEADIPGHGQEVLIRLFKGH",
        "MKWVTFISLLFLFSSAYSRGVFRRDAHKSEVAHRFKDLGE",
        "GAVLIGTAAQIVATAGTNLVHSYDDGKSWTYLWEVQKAF",
        "MPTFISLLFLFSSAYSAVETALFNAQEQDGRQAK"
    ],
    "stability_score": [0.8, 0.75, 0.9, 0.65, 0.85]
}
df = pd.DataFrame(data)


In [10]:
# ✅ Step 7: Extract ESM embeddings
df["features"] = df["sequence"].apply(extract_esm2_embedding)
X = np.vstack(df["features"])
y = df["stability_score"].values

In [11]:
# ✅ Step 8: Train/test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# ✅ Step 9: Convert to tensors
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.float32).view(-1, 1)
X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test, dtype=torch.float32).view(-1, 1)


In [12]:
# ✅ Step 10: Train the model
model = ProteinMLP(input_dim=1280)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()

for epoch in range(30):
    model.train()
    optimizer.zero_grad()
    predictions = model(X_train_tensor)
    loss = loss_fn(predictions, y_train_tensor)
    loss.backward()
    optimizer.step()
    if (epoch + 1) % 5 == 0:
        print(f"Epoch {epoch+1} - Loss: {loss.item():.4f}")

Epoch 5 - Loss: 0.3100
Epoch 10 - Loss: 0.0246
Epoch 15 - Loss: 0.1064
Epoch 20 - Loss: 0.0024
Epoch 25 - Loss: 0.0131
Epoch 30 - Loss: 0.0229


In [13]:
# ✅ Step 11: Evaluate
model.eval()
with torch.no_grad():
    test_preds = model(X_test_tensor)
    test_loss = loss_fn(test_preds, y_test_tensor)
    print(f"\nTest Loss: {test_loss.item():.4f}")


Test Loss: 0.0069


In [14]:
# ✅ Step 12: Predict on new sequences
new_sequences = ["GLSDGEWQLVLNVWGKVEADIPGHGQEVLIRLFKGH"]
new_features = np.vstack([extract_esm2_embedding(seq) for seq in new_sequences])
new_tensor = torch.tensor(new_features, dtype=torch.float32)

with torch.no_grad():
    prediction = model(new_tensor)
    print("\nPredicted Stability Score:", prediction.numpy().flatten()[0])


Predicted Stability Score: 0.6667696
