In [64]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from scipy.stats import pearsonr, spearmanr
from transformers import AutoTokenizer, EsmModel
import time
import lightning as L
import os
import scipy
import scipy.stats
import sklearn.metrics as skmetrics
import matplotlib.pyplot as plt
import seaborn as sns
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

aa_alphabet = 'ACDEFGHIKLMNPQRSTVWY' # amino acid alphabet
aa_to_int = {aa: i for i, aa in enumerate(aa_alphabet)} # mapping from amino acid to number

# function to one hot encode sequence
def one_hot_encode(sequence):
    # initialize a zero matrix of shape (len(sequence), len(amino_acids))
    one_hot = torch.zeros(len(sequence), len(aa_alphabet))
    for i, aa in enumerate(sequence):
        # set the column corresponding to the amino acid to 1
        one_hot[i].scatter_(0, torch.tensor([aa_to_int[aa]]), 1)
    return one_hot


from torchmetrics.regression import PearsonCorrCoef, SpearmanCorrCoef
    
class SequenceTransformer(L.LightningModule):
    def __init__(self, lr=1e-3, d_model=64, nhead=4, num_layers=2):
        super().__init__()
        self.lr = lr
        
        # project one-hot (20) to d_model dimensions
        self.input_proj = nn.Linear(20, d_model)
        
        # transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=256,
            dropout=0.1,
            batch_first=True,
            activation='gelu'
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        # output projection: d_model â†’ 20 (one ddG per possible mutation)
        self.output_proj = nn.Linear(d_model, 20)
        
        self.loss_fn = nn.MSELoss()
        # Metrics
        self.val_pearson = PearsonCorrCoef()
        self.val_spearman = SpearmanCorrCoef()

    def forward(self, x):
        # x: (1, L, 20) - NO squeeze here
        out = self.input_proj(x)      # (1, L, d_model)
        out = self.transformer(out)   # (1, L, d_model)
        out = self.output_proj(out)   # (1, L, 20)
        return out
    
    def training_step(self, batch, batch_idx):
        x      = batch['sequence'].squeeze(1)  # squeeze dim 1 not 0
        mask   = batch['mask'].squeeze(1)
        target = batch['labels'].squeeze(1)
        pred   = self(x)
        # flatten before masking
        pred_flat   = pred.reshape(-1)  
        target_flat = target.reshape(-1)
        mask_flat   = mask.reshape(-1)
        loss = self.loss_fn(pred_flat[mask_flat==1], target_flat[mask_flat==1])
        self.log('train_loss', loss, prog_bar=True, on_epoch=True)
        # Update correlation metrics
        self.val_pearson(pred, target)
        self.val_spearman(pred, target)
        # Print Val stats to console/progress bar
        self.log("val_loss", loss,  on_epoch=True, prog_bar=True)
        self.log("val_pear", self.val_pearson, on_epoch=True, prog_bar=True)
        self.log("val_spear", self.val_spearman, on_epoch=True, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x      = batch['sequence'].squeeze(1)
        mask   = batch['mask'].squeeze(1)
        target = batch['labels'].squeeze(1)
        pred   = self(x)
        pred_flat   = pred.reshape(-1)
        target_flat = target.reshape(-1)
        mask_flat   = mask.reshape(-1)
        
        pred_masked   = pred_flat[mask_flat==1]
        target_masked = target_flat[mask_flat==1]
        
        loss = self.loss_fn(pred_masked, target_masked)
        self.log('val_loss', loss, prog_bar=True, on_epoch=True, on_step=False)
        return {'loss': loss, 'preds': pred_masked.detach(), 'targets': target_masked.detach()}
    
    def validation_epoch_end(self, outputs):
        # concatenate all predictions and targets across batches
        preds   = torch.cat([o['preds']   for o in outputs]).cpu().numpy()
        targets = torch.cat([o['targets'] for o in outputs]).cpu().numpy()
        
        pearson  = scipy.stats.pearsonr(preds, targets)[0]
        spearman = scipy.stats.spearmanr(preds, targets)[0]
        
        self.log('val_pearson',  pearson,  prog_bar=True, on_epoch=True)
        self.log('val_spearman', spearman, prog_bar=True, on_epoch=True)

         
        
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr = self.lr)

In [65]:
class SequenceData(Dataset):
    def __init__(self, csv_file, label_col="ddG_ML"):
        self.label_col = label_col
        self.df = pd.read_csv(csv_file, sep=",")
        self.df = self.df[self.df.mut_type != "wt"]
        self.df["mutation_pos"] = self.df["mut_type"].apply(lambda x: int(x[1:-1])-1)
        self.df["mutation_to"]  = self.df["mut_type"].apply(lambda x: aa_to_int[x[-1]])
        self.df = self.df.groupby("WT_name").agg(list)
        self.wt_names = self.df.index.values

        # preload everything into RAM
        print("Loading sequences into RAM...")
        self.data = []
        for wt_name in self.wt_names:
            mut_row  = self.df.loc[wt_name]
            seq      = mut_row["wt_seq"][0]
            seq_enc  = one_hot_encode(seq)
            L        = len(seq_enc)

            mask   = torch.zeros((1, L, 20))
            target = torch.zeros((1, L, 20))

            positions   = torch.tensor(mut_row["mutation_pos"])
            amino_acids = torch.tensor(mut_row["mutation_to"])
            labels      = torch.tensor(mut_row[label_col])

            for i in range(L):
                mask[0,   i, amino_acids[positions==i]] = 1
                target[0, i, amino_acids[positions==i]] = labels[positions==i]

            self.data.append({
                "sequence": seq_enc[None,:,:].float(),
                "mask":     mask,
                "labels":   target
            })
        print("Done!")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [66]:
# use notebook 3's dataloader
dataset_train = SequenceData('data/mega_train.csv')
dataset_val   = SequenceData('data/mega_val.csv')
loader_train  = DataLoader(dataset_train, batch_size=1, shuffle=True)
loader_val    = DataLoader(dataset_val,   batch_size=1, shuffle=False)

model   = SequenceTransformer(lr=1e-5, d_model=64, nhead=4, num_layers=2)
trainer = L.Trainer(devices=1, max_epochs=20)
trainer.fit(model, loader_train, loader_val)

Loading sequences into RAM...
Done!


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
ðŸ’¡ Tip: For seamless cloud logging and experiment tracking, try installing [litlogger](https://pypi.org/project/litlogger/) to enable LitLogger, which logs metrics and artifacts automatically to the Lightning Experiments platform.
ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type               | Params  | Mode  | FLOPs
---------------------------------------------------------------------
0 | input_proj   | Linear             | 1.3 K   | train | 0    
1 | transformer  | TransformerEncoder | 100.0 K | train | 0    
2 | output_proj  | Linear             | 1.3 K   | train | 0    
3 | loss_fn      | MSELoss            | 0       | train | 0    
4 | val_pearson  | PearsonCorrCoef    | 0     

Loading sequences into RAM...
Done!


Sanity Checking: |                                        | 0/? [00:00<?, ?it/s]

/home/course/bc_NN/.venv/lib/python3.11/site-packages/lightning/pytorch/utilities/_pytree.py:21: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
/home/course/bc_NN/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
/home/course/bc_NN/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Training: |                                               | 0/? [00:00<?, ?it/s]

ValueError: Expected both predictions and target to be either 1- or 2-dimensional tensors, but got 3 and 3.

In [None]:
%reload_ext tensorboard
%tensorboard --logdir=lightning_logs