In [1]:
# ── Imports ───────────────────────────────────────────────────────────────────
import os, math, random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from scipy.stats import pearsonr, spearmanr
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import esm
from datetime import datetime

# ── Setup ─────────────────────────────────────────────────────────────────────
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
random.seed(42); np.random.seed(42); torch.manual_seed(42)
print(f"Device: {DEVICE}")

CFG = dict(
    train_csv  = "project_data/mega_train.csv",
    val_csv    = "project_data/mega_val.csv",
    label_col  = "ddG_ML",
    epochs     = 6,
    batch_size = 10,
    lr_head    = 1e-3,
    lr_esm     = 5e-6,
    patience   = 5, # number of epochs to wait for improvement in validation RMSE before early stopping
    grad_clip  = 1.0, # maximum allowed norm of the gradients during training; this helps prevent exploding gradients by scaling down the gradients if their norm exceeds this threshold.
    lr_factor  = 0.5, # factor by which to reduce the learning rates when validation RMSE plateaus; for example, if lr_factor=0.5, then the learning rates will be halved when the scheduler triggers a learning rate reduction due to lack of improvement in validation RMSE.
    lr_patience= 2, # number of epochs with no improvement in validation RMSE after which to reduce the learning rates by the specified factor; for example, if lr_patience=2, then if the validation RMSE does not improve for 2 consecutive epochs, the learning rates will be reduced by multiplying them with lr_factor (e.g., halved if lr_factor=0.5).
)

# ── Dataset ───────────────────────────────────────────────────────────────────
class ProteinPairDataset(Dataset):
    def __init__(self, csv_path, label_col="ddG_ML"):
        df = pd.read_csv(csv_path)
        self.wt  = df["wt_seq"].tolist()
        self.mut = df["aa_seq"].tolist()
        self.y   = df[label_col].astype(float).values
        self.ids = df["name"].tolist()

    def __len__(self): return len(self.y)
    def __getitem__(self, i): return self.wt[i], self.mut[i], float(self.y[i]), self.ids[i]


def make_collate(batch_converter): # tokenize sequences into ESM input
    def collate(batch):
        wt, mut, y, ids = zip(*batch) # unzips the batch
        _, _, wt_tok  = batch_converter(list(zip(ids, wt))) # ESM models take tokenized tensors
        _, _, mut_tok = batch_converter(list(zip(ids, mut))) # mut_tok: tokenized mutant batch [B, L] (batch, sequence length)
        return wt_tok, mut_tok, torch.tensor(y, dtype=torch.float32), list(ids)
    return collate

# ── Model ─────────────────────────────────────────────────────────────────────
    
class DDGPredictor(nn.Module):
    def __init__(self, esm_model, alphabet, hidden=256, dropout=0.2):
        super().__init__()
        self.esm, self.alphabet = esm_model, alphabet
        self.repr_layer = esm_model.num_layers # extract representations from the last layer
        d = esm_model.embed_dim # d: per-token embedding dimension; for esm2_t6_8M_UR50D, d=320
        self.head = nn.Sequential( # simple MLP head that takes "engineered paired features" (concatenated pooled representations) to map to ΔΔG
            nn.LayerNorm(4*d), # Input diemnsions: (B,4*d). Normalizes the input features across the last dimension, because the four ESM representations can have different scales (like m*w is different than m-w) and LayerNorm can help stabilize training.
            nn.Linear(4*d, hidden), nn.GELU(), nn.Dropout(dropout), # fully connected layer, learning a weighted combination of the concatenated features, with GELU non-linear activation function and dropout (randomly zero some hidden neurons during training with the probability 0.2) for regularization
            nn.Linear(hidden, 1), # final linear layer that maps the hidden representation to a single scalar output (the predicted ΔΔG)
        )

    def encode(self, tokens): # encode a batch of tokenized sequences into pooled representations by averaging the per-token embeddings (ignoring padding and EOS tokens)
        h = self.esm(tokens, repr_layers=[self.repr_layer], return_contacts=False) # h: dictionary with keys "representations" and "contacts"; we only care about the representations, which is a list of length num_layers, where each element is a tensor of shape [B, L, d] (batch size, sequence length, embedding dimension)
        h = h["representations"][self.repr_layer] # h: tensor of shape [B, L, d] containing the per-token embeddings from the last layer of the ESM model
        mask = (tokens != self.alphabet.padding_idx) & (tokens != self.alphabet.eos_idx) # Sequences vary strongly in length. Short sequences were padded to the length of the longest on ein the batch, because matrix multiplication requires fixed shapes. mask: boolean tensor of shape [B, L] where True indicates valid tokens (not padding or EOS) and False indicates invalid tokens; this is used to mask out the padding and EOS tokens when averaging the embeddings.
        mask[:, 0] = False # also ignore the CLS token at the beginning of the sequence, which is not a real amino acid and can have a different embedding distribution than the other tokens. (: -> all rows, 0 -> first column)
        return (h * mask.unsqueeze(-1).float()).sum(1) / mask.sum(1, keepdim=True).float() # mask.unsqueeze(-1) changes shape from [B, L] to [B, L, 1],expanding the mask to apply across all "d" embedding dimensions;  ".float" converts boolean mask to 0/1 for multiplication;  "h * mask" : All "masked = false" tokens get automatically embedding = 0; ".sum(1)" summing all token embeddings within each protein, collapsing [B, L, d] to [B, d], which is the average embedding for each sequence in the batch, where the average is taken over the valid tokens (ignoring padding, EOS, and CLS)

    def forward(self, wt, mut): # 
        w, m = self.encode(wt), self.encode(mut) # each sequence is passed independently through the shared ESM encoder to get their pooled representations w and m, each of shape [B, d] --> Siamese architecture: the same ESM encoder processes both the wild-type and mutant sequences, allowing it to learn a shared representation space for both types of inputs, which can help the model learn how mutations affect the protein's properties by comparing their embeddings.
        return self.head(torch.cat([w, m, m-w, m*w], dim=-1)).squeeze(-1) # passes the concatenated features [w, m, m-w, m*w] of shape [B, 4*d] through the regression head, and outputs a single scalar prediction for each sequence pair; the features include the individual embeddings of the wild-type and mutant sequences (w and m), as well as their element-wise difference (m-w) and product (m*w), which can capture different aspects of how the mutation affects the protein's properties; .squeeze(-1) removes the last dimension of size 1 from the output, resulting in a tensor of shape [B] containing the predicted ΔΔG values for each sequence pair in the batch.


def count_params(model):
    trainable     = sum(p.numel() for p in model.parameters() if p.requires_grad) 
    non_trainable = sum(p.numel() for p in model.parameters() if not p.requires_grad)
    print(f"{'Module':<20} {'Params':>12} {'Trainable':>12}")
    print("-" * 46)
    for name, module in model.named_children():
        p  = sum(x.numel() for x in module.parameters()) # total number of parameters in the module, calculated by summing the number of elements in each parameter tensor; for example, if the ESM model has 8 million parameters and the head has 0.5 million parameters, then p for the ESM module would be 8 million and p for the head module would be 0.5 million.
        pt = sum(x.numel() for x in module.parameters() if x.requires_grad) # number of trainable parameters in the module, which is a subset of the total parameters; this is calculated by summing the number of elements in each parameter tensor that has requires_grad=True, indicating that it will be updated during training; for example, if the ESM model's parameters are frozen (requires_grad=False), then pt for the ESM module would be 0, while for the head module, pt would equal p since all its parameters are trainable.
        print(f"{name:<20} {p:>12,} {pt:>12,}")
    print("-" * 46)
    print(f"{'Trainable':<20} {trainable:>12,}")
    print(f"{'Non-trainable':<20} {non_trainable:>12,}")
    print(f"{'Total':<20} {trainable+non_trainable:>12,}")
    print(f"{'Size (MB)':<20} {(trainable+non_trainable)*4/1024**2:>11.3f}")

# ── Train ─────────────────────────────────────────────────────────────────────
def train(cfg):
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    save_dir = f"checkpoints_esm_finetune_{timestamp}"
    os.makedirs(save_dir, exist_ok=True)
    best_path = os.path.join(save_dir, f"ddg_best_{timestamp}.pt")

    esm_model, alphabet = esm.pretrained.esm2_t6_8M_UR50D() # loads the pretrained ESM-2 model with 6 layers and 8 million parameters, trained on the UR50D dataset; esm_model is the actual PyTorch model that can be used for encoding sequences, while alphabet is an object that contains the vocabulary and tokenization logic for converting amino acid sequences into token IDs that the model can process.

    for p in esm_model.parameters(): p.requires_grad = False # freezes the ESM model's parameters by setting requires_grad to False, which means that during training, the gradients for these parameters will not be computed and they will not be updated by the optimizer;
    n = int(cfg.get("unfreeze_last_n", 0)); [p.requires_grad_(True) for b in (esm_model.layers[-n:] if n > 0 else []) for p in b.parameters()] # unfreezes the parameters of the last layer of the ESM model, allowing them to be fine-tuned during training;
    for m in esm_model.modules(): # Layernorms standardize hidden features within each layer using scaling and shifting, and unfreezing them allows the model to adapt its normalization to the new task, prevents overfitting.
        if isinstance(m, nn.LayerNorm): # unfreezes all LayerNorm parameters in the ESM model, which can help stabilize training and allow the model to adapt its normalization to the new task, even if the main weights of the ESM model are mostly frozen;
            for p in m.parameters(): p.requires_grad = True

    model   = DDGPredictor(esm_model, alphabet).to(DEVICE) # initializes the DDGPredictor model, which consists of the ESM encoder and a regression head, and moves it to the specified device (GPU or CPU) for training; the model will use the pretrained ESM encoder to extract features from the input sequences, and the regression head will learn to map those features to the ΔΔG predictions based on the training data.
    collate = make_collate(alphabet.get_batch_converter()) # creates a collate function for the DataLoader using the batch converter from the alphabet, which will be used to tokenize the input sequences on-the-fly during training and validation; this allows the DataLoader to take raw amino acid sequences from the dataset and convert them into the tokenized format required by the ESM model when forming batches.

    train_dl = DataLoader(ProteinPairDataset(cfg["train_csv"]), cfg["batch_size"], shuffle=True,  collate_fn=collate) # initializes the DataLoader for the training dataset, which will load data from the specified CSV file, create batches of the specified size, shuffle the data at the beginning of each epoch to improve training, and use the custom collate function to tokenize the sequences; this DataLoader will yield batches of tokenized wild-type and mutant sequences along with their corresponding ΔΔG labels during training.
    val_dl   = DataLoader(ProteinPairDataset(cfg["val_csv"]),   cfg["batch_size"], shuffle=False, collate_fn=collate)

    opt = torch.optim.AdamW([ # Adam optimizer adjusts the effective gradient update step size applied to each parameter during backpropagation, but does not reduce the global learning rate over time.
        {"params": [p for n,p in model.named_parameters() if p.requires_grad and not n.startswith("esm.")], "lr": cfg["lr_head"]},
        {"params": [p for n,p in model.named_parameters() if p.requires_grad and     n.startswith("esm.")], "lr": cfg["lr_esm"]},
    ], weight_decay=1e-2) # L2 regularization weight decay penalizes large weights and reduces overfitting
    
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( # scheduler that reduces the global learning rates alpha when the validation RMSE plateaus, which can help the model converge to a better solution by allowing it to take smaller steps when it is no longer improving; it monitors the validation RMSE and if it does not improve for a certain number of epochs (lr_patience), it multiplies the learning rates by a specified factor (lr_factor) to reduce them.
        opt,
        mode="min",
        factor=cfg.get("lr_factor", 0.5),
        patience=cfg.get("lr_patience", 2),
    )

    count_params(model)

    history = {"train_rmse": [], "val_rmse": [], "pearson": [], "spearman": []}

    # Initialize best state safely
    best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
    best_rmse = float("inf")
    patience_left = cfg["patience"]

    for epoch in range(1, cfg["epochs"] + 1): # start counting epochs from 1 for better readability in logs and plots
        # train
        model.train()
        model.esm.train(cfg.get("unfreeze_last_n", 0) > 0) # set the esm backbone to training mode only if we unfreeze any layer to finetune.
        train_sse = 0.0 # train_sse (sum of squared errors) accumulates the total squared error across all training samples in the current epoch, which is used to calculate the training RMSE at the end of the epoch by dividing it by the total number of samples (train_n) and taking the square root; this metric gives an indication of how well the model is fitting the training data, with lower values indicating better fit.
        train_n = 0 # train_n counts the total number of training samples processed in the current epoch, which is used to calculate the training RMSE by dividing the accumulated sum of squared errors (train_sse) by this count and taking the square root; this ensures that the RMSE is correctly normalized by the number of samples, giving an average error per sample.
        pbar = tqdm(train_dl, desc=f"Epoch {epoch}/{cfg['epochs']}", leave=False)
        for wt, mut, y, _ in pbar:
            wt, mut, y = wt.to(DEVICE), mut.to(DEVICE), y.to(DEVICE) # y is the true ΔΔG values for the batch.
            loss = nn.functional.mse_loss(model(wt, mut), y) # calculates the MSE over batch (by default uses mean)
            opt.zero_grad(); loss.backward(); torch.nn.utils.clip_grad_norm_([p for p in model.parameters() if p.requires_grad], cfg.get("grad_clip", 1.0)); opt.step() # "zero_grad()" clears old gradients, "backward" computes gradients via autograd; gradient clipping is a technique used to prevent exploding gradients by scaling down the gradients if their norm exceeds a specified threshold (grad_clip); this can help stabilize training, especially when fine-tuning large models like ESM, by ensuring that the updates to the model's parameters do not become excessively large, which can lead to divergence or unstable training dynamics.
            bs = y.numel() # number of samples in the batch; multiplying the loss (which is an average over the batch) by the batch size gives the total squared error for that batch, which we accumulate in train_sse to compute the overall training RMSE at the end of the epoch.
            train_sse += loss.item() * bs # accumulates the total squared error for the batch into train_sse, which will be used to calculate the training RMSE at the end of the epoch; multiplying the loss (which is the mean squared error for the batch) by the number of samples in the batch (bs) gives the total squared error for that batch, which we sum across all batches to get the total squared error for the entire training set in that epoch.
            train_n += bs # accumulates the total number of training samples processed in the epoch by adding the batch size (bs) for each batch; this count is used to calculate the training RMSE at the end of the epoch by dividing the accumulated sum of squared errors (train_sse) by this count and taking the square root, giving an average error per sample.
            pbar.set_postfix({"loss": f"{loss.item():.4f}"}) # update the progress bar 

        # validate
        model.eval() 
        preds, trues = [], []
        with torch.no_grad(): # disables gradient calculation during validation
            for wt, mut, y, _ in val_dl: # iterates over the validation DataLoader
                preds.append(model(wt.to(DEVICE), mut.to(DEVICE)).cpu().numpy()) # for each batch, the model's predictions are computed and appended to the preds list, while the true ΔΔG values (y) are appended to the trues list; after processing all validation batches, these lists will contain the predictions and true values for the entire validation set, which can then be concatenated into single arrays for calculating the validation RMSE and correlation metrics.
                trues.append(y.cpu().numpy())

        p, t       = np.concatenate(preds), np.concatenate(trues) # concatenates the predictions and true values from all validation batches into single arrays p and t, which are then used to calculate the validation RMSE and correlation metrics; this allows us to evaluate the model's performance on the entire validation set after processing it in batches.
        train_rmse = math.sqrt(train_sse / train_n) # by dividing the accumulated sum of squared errors (train_sse) by the total number of training samples (train_n);
        val_rmse   = float(np.sqrt(np.mean((p - t) ** 2))) # calculates the validation RMSE by taking the square root of the mean squared error between the predicted values (p) and the true values (t) for the entire validation set;
        r          = float(pearsonr(p, t)[0]) # calculates the Pearson correlation coefficient (r) between the predicted values (p) and the true values (t) for the validation set, which measures the linear correlation between the predictions and the true values.
        rho        = float(spearmanr(p, t)[0]) # calculates the Spearman rank correlation coefficient (ρ) between the predicted values (p) and the true values (t) for the validation set, which measures the monotonic relationship between the predictions and the true values, regardless of whether that relationship is linear or not.
        scheduler.step(val_rmse) # updates the learning rates according to the ReduceLROnPlateau scheduler based on the validation RMSE; if the validation RMSE does not improve for a certain number of epochs (lr_patience), the scheduler will reduce the learning rates by multiplying them with a specified factor (lr_factor), which can help the model converge to a better solution by allowing it to take smaller steps when it is no longer improving.
        lr_head = opt.param_groups[0]["lr"] # retrieves the current learning rate for the head parameters from the optimizer's parameter groups for logging.
        lr_esm  = opt.param_groups[1]["lr"]

        history["train_rmse"].append(train_rmse)
        history["val_rmse"].append(val_rmse)
        history["pearson"].append(r)
        history["spearman"].append(rho)

        if epoch % 1 == 0: # Save history after every epoch
            history_path = os.path.join(save_dir, f"history_epoch_{epoch}.pth")
            torch.save(history, history_path)
            print(f"  → Saved history to: {history_path}")

        print(f"Epoch {epoch:2d} | Train RMSE: {train_rmse:.4f} | Val RMSE: {val_rmse:.4f} | r: {r:.3f} | ρ: {rho:.3f}| lr_head: {lr_head:.2e} | lr_esm: {lr_esm:.2e}")

        if val_rmse < best_rmse - 1e-4:
            best_rmse, patience_left = val_rmse, cfg["patience"] # update the best RMSE and reset the patience
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}

            checkpoint = {
                "epoch": epoch,
                "best_rmse": best_rmse,
                "cfg": cfg,
                "model_state": best_state,
                "optimizer_state": opt.state_dict(),
                "scheduler_state": scheduler.state_dict(),
            }
            torch.save(checkpoint, best_path)
            print(f"  → Saved best checkpoint to: {best_path}")
        else:
            patience_left -= 1
            if patience_left == 0:
                print(f"Early stopping. Best RMSE: {best_rmse:.4f}"); break

    model.load_state_dict(best_state)
    return model, alphabet, history, best_path, best_rmse

# ── Plot ──────────────────────────────────────────────────────────────────────
def plot_history(history):
    epochs = range(1, len(history["train_rmse"]) + 1)
    fig, axes = plt.subplots(1, 2, figsize=(10, 3.5))

    axes[0].plot(epochs, history["train_rmse"], label="Train")
    axes[0].plot(epochs, history["val_rmse"],   label="Val")
    axes[0].set(xlabel="Epoch", ylabel="RMSE (kcal/mol)", title="Learning Curves")
    axes[0].legend(frameon=False)

    axes[1].plot(epochs, history["pearson"],  label="Pearson r")
    axes[1].plot(epochs, history["spearman"], label="Spearman ρ")
    axes[1].set(xlabel="Epoch", ylabel="Correlation", title="Validation Correlations")
    axes[1].legend(frameon=False)

    for ax in axes:
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)

    plt.tight_layout()
    plt.savefig("training_history.pdf", bbox_inches="tight")
    plt.show()


def plot_predictions(model, alphabet, cfg):
    collate = make_collate(alphabet.get_batch_converter())
    val_dl  = DataLoader(ProteinPairDataset(cfg["val_csv"]), cfg["batch_size"], collate_fn=collate)
    preds, trues = [], []
    model.eval()
    with torch.no_grad():
        for wt, mut, y, _ in val_dl:
            preds.append(model(wt.to(DEVICE), mut.to(DEVICE)).cpu().numpy())
            trues.append(y.cpu().numpy())
    p, t = np.concatenate(preds), np.concatenate(trues)

    fig, ax = plt.subplots(figsize=(4.5, 4.5))
    ax.hexbin(p, t, gridsize=50, cmap="Blues", mincnt=1)
    lim = [min(p.min(), t.min()), max(p.max(), t.max())]
    ax.plot(lim, lim, "k--", lw=1, alpha=0.5)
    ax.set(xlabel="Predicted ΔΔG", ylabel="Measured ΔΔG", aspect="equal")
    ax.text(0.05, 0.95,
            f"RMSE = {np.sqrt(np.mean((p-t)**2)):.3f}\nr = {pearsonr(p,t)[0]:.3f}",
            transform=ax.transAxes, va="top",
            bbox=dict(boxstyle="round", fc="white", ec="gray", alpha=0.8))
    ax.spines["top"].set_visible(False); ax.spines["right"].set_visible(False)
    plt.tight_layout()
    plt.savefig("predictions.pdf", bbox_inches="tight")
    plt.show()

# ── Run ───────────────────────────────────────────────────────────────────────
sweep = [0, 1, 2, 3]
results = []
all_histories = {}

for n in sweep:
    cfg_run = dict(CFG)
    cfg_run["unfreeze_last_n"] = n

    print(f"\n=== unfreeze_last_n = {n} ===")

    model, alphabet, history, best_path, best_rmse, save_dir = train(cfg_run)  
    plot_history(history, save_dir=save_dir)                                    
    plot_predictions(model, alphabet, cfg_run, save_dir=save_dir)               
    final_history_path = os.path.join(save_dir, f"final_history_unfreeze_{n}.pth") 
    torch.save(history, final_history_path)
    results.append((n, best_rmse, best_path))
    all_histories[n] = history

print("\nSummary:")
for n, rmse, path in sorted(results, key=lambda x: x[1]):
    print(f"n={n} | best_rmse={rmse:.4f} | {path}")

# Sweep comparison plot — saved next to the best run's folder, or just in cwd
fig, ax = plt.subplots(figsize=(9, 5))
for n, h in all_histories.items():
    ax.plot(range(1, len(h["val_rmse"]) + 1), h["val_rmse"], label=f"unfreeze={n}")
ax.set_xlabel("Epoch")
ax.set_ylabel("Val RMSE")
ax.legend()
ax.set_title("Validation RMSE Comparison Across Unfreeze Levels")
plt.tight_layout()
plt.savefig("sweep_comparison.pdf", bbox_inches="tight")  # top-level; spans all runs
plt.show()

Device: mps

=== unfreeze_last_n = 0 ===
Module                     Params    Trainable
----------------------------------------------
esm                     7,512,474        8,960
head                      330,753      330,753
----------------------------------------------
Trainable                 339,713
Non-trainable           7,503,514
Total                   7,843,227
Size (MB)                 29.920


Epoch 1/6:   0%|          | 0/21692 [00:00<?, ?it/s]

KeyboardInterrupt: 