In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import math
from torch.utils.data import Dataset, DataLoader
from scipy.stats import pearsonr, spearmanr
import time

# -------------------------------------------------------------------
# 1. Tokenizer
# -------------------------------------------------------------------
AMINO_ACIDS = 'ACDEFGHIKLMNPQRSTVWY'
AA_TO_IDX = {aa: i+1 for i, aa in enumerate(AMINO_ACIDS)}
PAD_IDX = 0
UNK_IDX = 21

def encode_sequence(seq, max_len=120):
    encoded = [AA_TO_IDX.get(aa, UNK_IDX) for aa in seq]
    if len(encoded) < max_len:
        encoded = encoded + [PAD_IDX] * (max_len - len(encoded))
    else:
        encoded = encoded[:max_len]
    return encoded

# -------------------------------------------------------------------
# 2. PyTorch Dataset
# -------------------------------------------------------------------
class ProteinStabilityDataset(Dataset):
    def __init__(self, csv_file, max_len=120):
        print(f"Loading dataset from {csv_file}...")
        self.data = pd.read_csv(csv_file)
        self.max_len = max_len
        print(f"Loaded {len(self.data)} samples.")
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        
        wt_seq = row['wt_seq']
        mut_seq = row['aa_seq']
        ddg = row['ddG_ML']
        
        wt_encoded = torch.tensor(encode_sequence(wt_seq, self.max_len), dtype=torch.long)
        mut_encoded = torch.tensor(encode_sequence(mut_seq, self.max_len), dtype=torch.long)
        target = torch.tensor(ddg, dtype=torch.float32)
        
        return wt_encoded, mut_encoded, target

# -------------------------------------------------------------------
# 3. Positional Encoding
# -------------------------------------------------------------------
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        # Reshape to (max_len, d_model) to easily add to batch_first tensors
        self.register_buffer('pe', pe.squeeze(1))

    def forward(self, x):
        # x shape: (batch_size, seq_len, d_model)
        x = x + self.pe[:x.size(1), :].unsqueeze(0)
        return self.dropout(x)

# -------------------------------------------------------------------
# 4. The Siamese Transformer Model
# -------------------------------------------------------------------
class SiameseTransformer(nn.Module):
    def __init__(self, vocab_size=22, d_model=128, nhead=4, num_layers=3, dim_feedforward=256, max_len=120):
        super(SiameseTransformer, self).__init__()
        self.d_model = d_model
        
        # 1. Embedding & Positional Encoding
        self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=PAD_IDX)
        self.pos_encoder = PositionalEncoding(d_model, max_len=max_len)
        
        # 2. Transformer Encoder Blocks
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, 
            nhead=nhead, 
            dim_feedforward=dim_feedforward, 
            batch_first=True,
            dropout=0.1
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        # 3. Prediction Head
        self.fc = nn.Sequential(
            nn.Linear(d_model, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, 1)
        )
        
    def forward_one_branch(self, x):
        # x shape: (batch_size, seq_len)
        
        # Create a padding mask: True where the token is PAD_IDX
        # This tells the attention mechanism to ignore these positions
        padding_mask = (x == PAD_IDX)
        
        # Embed and scale, then add positional encoding
        embedded = self.embedding(x) * math.sqrt(self.d_model)
        embedded = self.pos_encoder(embedded)
        
        # Pass through Transformer
        transformer_out = self.transformer_encoder(embedded, src_key_padding_mask=padding_mask)
        
        # Global Average Pooling: Average the sequence dimension, ignoring pads
        # We create a mask of 1s (valid) and 0s (pads)
        mask = (~padding_mask).float().unsqueeze(-1)
        # Sum all valid token embeddings
        sum_embeddings = (transformer_out * mask).sum(dim=1)
        # Divide by the number of valid tokens to get the mean
        mean_pooled = sum_embeddings / mask.sum(dim=1).clamp(min=1e-9)
        
        return mean_pooled

    def forward(self, wt_seq, mut_seq):
        wt_features = self.forward_one_branch(wt_seq)
        mut_features = self.forward_one_branch(mut_seq)
        
        # Siamese difference
        diff = mut_features - wt_features
        
        ddg_prediction = self.fc(diff)
        return ddg_prediction.squeeze(1)

# -------------------------------------------------------------------
# 5. Training and Validation Loop
# -------------------------------------------------------------------
def train_and_evaluate():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"--- Initialization ---")
    print(f"Using compute device: {device}")
    
    # Load Data (Update path to 'data/...' if in a subfolder)
    train_dataset = ProteinStabilityDataset('data/mega_train.csv', max_len=120)
    val_dataset = ProteinStabilityDataset('data/mega_val.csv', max_len=120)
    
    # Transformers use more VRAM than LSTMs, so batch size is set to 128
    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)
    
    # Initialize Model (From-scratch Custom Transformer)
    model = SiameseTransformer(d_model=128, nhead=4, num_layers=3).to(device)
    criterion = nn.MSELoss() 
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0005) # Slightly lower LR for transformers
    
    epochs = 10
    print("\n--- Starting Training ---")
    
    for epoch in range(epochs):
        start_time = time.time()
        
        # --- TRAINING PHASE ---
        model.train()
        total_train_loss = 0.0
        
        for batch_idx, (wt, mut, target) in enumerate(train_loader):
            wt, mut, target = wt.to(device), mut.to(device), target.to(device)
            
            optimizer.zero_grad()
            predictions = model(wt, mut)
            loss = criterion(predictions, target)
            loss.backward()
            optimizer.step()
            
            total_train_loss += loss.item()
            
            if (batch_idx + 1) % 200 == 0:
                print(f"   Epoch {epoch+1} | Batch {batch_idx+1}/{len(train_loader)} | Current Batch MSE: {loss.item():.4f}")

        avg_train_loss = total_train_loss / len(train_loader)
        
        # --- VALIDATION PHASE ---
        model.eval()
        total_val_loss = 0.0
        all_preds = []
        all_targets = []
        
        with torch.no_grad():
            for wt, mut, target in val_loader:
                wt, mut, target = wt.to(device), mut.to(device), target.to(device)
                
                predictions = model(wt, mut)
                loss = criterion(predictions, target)
                total_val_loss += loss.item()
                
                all_preds.extend(predictions.cpu().numpy())
                all_targets.extend(target.cpu().numpy())
                
        # Calculate Metrics
        avg_val_loss = total_val_loss / len(val_loader)
        rmse = np.sqrt(avg_val_loss)
        
        pcc, _ = pearsonr(all_targets, all_preds)
        scc, _ = spearmanr(all_targets, all_preds)
        
        epoch_time = time.time() - start_time
        
        # --- EXTENSIVE EPOCH REPORT ---
        print(f"\n==================================================")
        print(f" Epoch {epoch+1}/{epochs} Completed in {epoch_time:.1f} seconds")
        print(f"--------------------------------------------------")
        print(f" LOSS:")
        print(f"   -> Train MSE: {avg_train_loss:.4f}")
        print(f"   -> Val MSE:   {avg_val_loss:.4f}")
        print(f"   -> Val RMSE:  {rmse:.4f} kcal/mol")
        print(f" METRICS:")
        print(f"   -> Pearson Corr (PCC):  {pcc:.4f}")
        print(f"   -> Spearman Corr (SCC): {scc:.4f}")
        print(f"==================================================\n")

if __name__ == "__main__":
    train_and_evaluate()