In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from scipy.stats import pearsonr, spearmanr
from transformers import AutoTokenizer, EsmModel
import time

# -------------------------------------------------------------------
# 1. PyTorch Dataset with ESM Tokenizer
# -------------------------------------------------------------------
class ProteinStabilityESMDataset(Dataset):
    def __init__(self, csv_file, tokenizer, max_len=120):
        print(f"Loading dataset from {csv_file}...")
        self.data = pd.read_csv(csv_file)
        self.tokenizer = tokenizer
        self.max_len = max_len
        print(f"Loaded {len(self.data)} samples.")
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        
        wt_seq = row['wt_seq']
        mut_seq = row['aa_seq']
        ddg = row['ddG_ML']
        
        # Tokenize WT
        wt_encoded = self.tokenizer(
            wt_seq, 
            add_special_tokens=True, 
            max_length=self.max_len, 
            padding='max_length', 
            truncation=True, 
            return_tensors='pt'
        )
        
        # Tokenize Mutant
        mut_encoded = self.tokenizer(
            mut_seq, 
            add_special_tokens=True, 
            max_length=self.max_len, 
            padding='max_length', 
            truncation=True, 
            return_tensors='pt'
        )
        
        target = torch.tensor(ddg, dtype=torch.float32)
        
        # Squeeze to remove the batch dimension added by the tokenizer
        return {
            'wt_input_ids': wt_encoded['input_ids'].squeeze(0),
            'wt_attention_mask': wt_encoded['attention_mask'].squeeze(0),
            'mut_input_ids': mut_encoded['input_ids'].squeeze(0),
            'mut_attention_mask': mut_encoded['attention_mask'].squeeze(0),
            'target': target
        }

# -------------------------------------------------------------------
# 2. The Siamese ESM Model
# -------------------------------------------------------------------
class SiameseESM(nn.Module):
    def __init__(self, model_name='facebook/esm2_t6_8M_UR50D', freeze_esm=True):
        super(SiameseESM, self).__init__()
        
        # Load pre-trained ESM-2 Model
        print(f"Loading pre-trained ESM model: {model_name}")
        self.esm = EsmModel.from_pretrained(model_name)
        
        # Freeze the ESM layers so we don't destroy pre-trained biology knowledge
        # and to save massive amounts of RAM/compute during training
        if freeze_esm:
            for param in self.esm.parameters():
                param.requires_grad = False
                
        # ESM hidden size (320 for the 8M model)
        hidden_size = self.esm.config.hidden_size
        
        # Prediction Head
        self.fc = nn.Sequential(
            nn.Linear(hidden_size, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 1)
        )
        
    def forward(self, wt_ids, wt_mask, mut_ids, mut_mask):
        # Pass WT through ESM
        wt_outputs = self.esm(input_ids=wt_ids, attention_mask=wt_mask)
        # Pass Mutant through ESM
        mut_outputs = self.esm(input_ids=mut_ids, attention_mask=mut_mask)
        
        # Extract the [CLS] token representation (the 0th token) 
        # which acts as a summary of the entire sequence
        wt_cls = wt_outputs.last_hidden_state[:, 0, :]
        mut_cls = mut_outputs.last_hidden_state[:, 0, :]
        
        # Calculate the biological difference
        diff = mut_cls - wt_cls
        
        # Predict ddG
        ddg_prediction = self.fc(diff)
        return ddg_prediction.squeeze(1)

# -------------------------------------------------------------------
# 3. Training and Validation Loop
# -------------------------------------------------------------------
def train_and_evaluate():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"--- Initialization ---")
    print(f"Using compute device: {device}")
    
    # 1. Setup ESM Tokenizer
    model_name = 'facebook/esm2_t6_8M_UR50D'
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    # 2. Load Data 
    train_dataset = ProteinStabilityESMDataset('data/mega_train.csv', tokenizer, max_len=120)
    val_dataset = ProteinStabilityESMDataset('data/mega_val.csv', tokenizer, max_len=120)
    
    # Smaller batch size compared to LSTM because Transformers use more RAM
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
    
    # 3. Initialize Model
    model = SiameseESM(model_name=model_name, freeze_esm=True).to(device)
    criterion = nn.MSELoss() 
    
    # Only pass the unfrozen parameters (the FC head) to the optimizer
    optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()), 
        lr=0.001
    )
    
    epochs = 10
    print("\n--- Starting Training ---")
    
    for epoch in range(epochs):
        start_time = time.time()
        
        # --- TRAINING PHASE ---
        model.train()
        total_train_loss = 0.0
        
        for batch_idx, batch in enumerate(train_loader):
            # Move data to GPU/CPU
            wt_ids = batch['wt_input_ids'].to(device)
            wt_mask = batch['wt_attention_mask'].to(device)
            mut_ids = batch['mut_input_ids'].to(device)
            mut_mask = batch['mut_attention_mask'].to(device)
            target = batch['target'].to(device)
            
            optimizer.zero_grad()
            predictions = model(wt_ids, wt_mask, mut_ids, mut_mask)
            loss = criterion(predictions, target)
            loss.backward()
            optimizer.step()
            
            total_train_loss += loss.item()
            
            # Print intermediate progress
            if (batch_idx + 1) % 500 == 0:
                print(f"   Epoch {epoch+1} | Batch {batch_idx+1}/{len(train_loader)} | Current Batch MSE: {loss.item():.4f}")

        avg_train_loss = total_train_loss / len(train_loader)
        
        # --- VALIDATION PHASE ---
        model.eval()
        total_val_loss = 0.0
        all_preds = []
        all_targets = []
        
        with torch.no_grad():
            for batch in val_loader:
                wt_ids = batch['wt_input_ids'].to(device)
                wt_mask = batch['wt_attention_mask'].to(device)
                mut_ids = batch['mut_input_ids'].to(device)
                mut_mask = batch['mut_attention_mask'].to(device)
                target = batch['target'].to(device)
                
                predictions = model(wt_ids, wt_mask, mut_ids, mut_mask)
                loss = criterion(predictions, target)
                total_val_loss += loss.item()
                
                all_preds.extend(predictions.cpu().numpy())
                all_targets.extend(target.cpu().numpy())
                
        # Calculate Metrics
        avg_val_loss = total_val_loss / len(val_loader)
        rmse = np.sqrt(avg_val_loss)
        
        pcc, _ = pearsonr(all_targets, all_preds)
        scc, _ = spearmanr(all_targets, all_preds)
        
        epoch_time = time.time() - start_time
        
        # --- EXTENSIVE EPOCH REPORT ---
        print(f"\n==================================================")
        print(f" Epoch {epoch+1}/{epochs} Completed in {epoch_time:.1f} seconds")
        print(f"--------------------------------------------------")
        print(f" LOSS:")
        print(f"   -> Train MSE: {avg_train_loss:.4f}")
        print(f"   -> Val MSE:   {avg_val_loss:.4f}")
        print(f"   -> Val RMSE:  {rmse:.4f} kcal/mol")
        print(f" METRICS:")
        print(f"   -> Pearson Corr (PCC):  {pcc:.4f}")
        print(f"   -> Spearman Corr (SCC): {scc:.4f}")
        print(f"==================================================\n")

if __name__ == "__main__":
    train_and_evaluate()

--- Initialization ---
Using compute device: cuda
Loading dataset from data/mega_train.csv...
Loaded 216919 samples.
Loading dataset from data/mega_val.csv...
Loaded 27481 samples.
Loading pre-trained ESM model: facebook/esm2_t6_8M_UR50D


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



--- Starting Training ---
   Epoch 1 | Batch 500/3390 | Current Batch MSE: 0.7536
