In [32]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from scipy.stats import pearsonr, spearmanr
from transformers import AutoTokenizer, EsmModel
import time
import lightning as L
import os
import scipy
import scipy.stats
import sklearn.metrics as skmetrics
import matplotlib.pyplot as plt
import seaborn as sns
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

aa_alphabet = 'ACDEFGHIKLMNPQRSTVWY' # amino acid alphabet
aa_to_int = {aa: i for i, aa in enumerate(aa_alphabet)} # mapping from amino acid to number

# function to one hot encode sequence
def one_hot_encode(sequence):
    # initialize a zero matrix of shape (len(sequence), len(amino_acids))
    one_hot = torch.zeros(len(sequence), len(aa_alphabet))
    for i, aa in enumerate(sequence):
        # set the column corresponding to the amino acid to 1
        one_hot[i].scatter_(0, torch.tensor([aa_to_int[aa]]), 1)
    return one_hot


    
class SequenceTransformer(L.LightningModule):
    def __init__(self, lr=1e-3, d_model=64, nhead=4, num_layers=2):
        super().__init__()
        self.lr = lr
        
        # project one-hot (20) to d_model dimensions
        self.input_proj = nn.Linear(20, d_model)
        
        # transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=256,
            dropout=0.1,
            batch_first=True,
            activation='gelu'
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        # output projection: d_model â†’ 20 (one ddG per possible mutation)
        self.output_proj = nn.Linear(d_model, 20)
        
        self.loss_fn = nn.MSELoss()

    def forward(self, x):
        # x: (1, L, 20)
        x = x.squeeze(0)
        out = self.input_proj(x)      # (1, L, d_model)
        out = self.transformer(out)   # (1, L, d_model)
        out = self.output_proj(out)   # (1, L, 20)
        return out

    def training_step(self, batch, batch_idx):
        x      = batch['sequence'].squeeze(0)  # (1, L, 20)
        mask   = batch['mask'].squeeze(0)       # (1, L, 20)
        target = batch['labels'].squeeze(0)     # (1, L, 20)
        pred   = self(x)             # (1, L, 20)
        loss   = self.loss_fn(pred[mask==1], target[mask==1])
        self.log('train_loss', loss, prog_bar=True, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x      = batch['sequence'].squeeze(0)
        mask   = batch['mask'].squeeze(0)
        target = batch['labels'].squeeze(0)
        pred   = self(x)
        loss   = self.loss_fn(pred[mask==1], target[mask==1])
        self.log('val_loss', loss, prog_bar=True, on_epoch=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=1e-4)
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)
        return {'optimizer': optimizer, 'lr_scheduler': scheduler}

In [33]:
class SequenceData(Dataset):
    def __init__(self, csv_file, label_col="ddG_ML"):
        self.df = pd.read_csv(csv_file, sep=",")
        self.label_col = label_col
        self.df = self.df[self.df.mut_type!="wt"]
        self.df["mutation_pos"] = self.df["mut_type"].apply(lambda x: int(x[1:-1])-1)
        self.df["mutation_to"] = self.df["mut_type"].apply(lambda x: aa_to_int[x[-1]])
        self.df = self.df.groupby("WT_name").agg(list)
        self.wt_names = self.df.index.values
        self.encoded_seqs = {}
        for wt_name in self.wt_names:
            mut_row = self.df.loc[wt_name]
            seq = mut_row["wt_seq"][0]
            self.encoded_seqs[wt_name] = one_hot_encode(seq)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        wt_name = self.wt_names[idx]
        mut_row = self.df.loc[wt_name]
        sequence_encoding = self.encoded_seqs[wt_name]
        mask = torch.zeros((1, len(sequence_encoding), 20))
        target = torch.zeros((1, len(sequence_encoding), 20))
        positions = torch.tensor(mut_row["mutation_pos"])
        amino_acids = torch.tensor(mut_row["mutation_to"])
        labels = torch.tensor(mut_row[self.label_col])
        for i in range(len(sequence_encoding)):
            mask[0,i,amino_acids[positions==i]] = 1
            target[0,i,amino_acids[positions==i]] = labels[positions==i]
        return {"sequence": sequence_encoding[None,:,:].float(), "mask": mask, "labels": target}

In [34]:
# use notebook 3's dataloader
dataset_train = SequenceData('data/mega_train.csv')
dataset_val   = SequenceData('data/mega_val.csv')
loader_train  = DataLoader(dataset_train, batch_size=1, shuffle=True)
loader_val    = DataLoader(dataset_val,   batch_size=1, shuffle=False)

model   = SequenceTransformer(lr=1e-3, d_model=64, nhead=4, num_layers=2)
trainer = L.Trainer(devices=1, max_epochs=20)
trainer.fit(model, loader_train, loader_val)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
ðŸ’¡ Tip: For seamless cloud logging and experiment tracking, try installing [litlogger](https://pypi.org/project/litlogger/) to enable LitLogger, which logs metrics and artifacts automatically to the Lightning Experiments platform.
ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type               | Params  | Mode  | FLOPs
--------------------------------------------------------------------
0 | input_proj  | Linear             | 1.3 K   | train | 0    
1 | transformer | TransformerEncoder | 100.0 K | train | 0    
2 | output_proj | Linear             | 1.3 K   | train | 0    
3 | loss_fn     | MSELoss            | 0       | train | 0    
----------------------------------------------------

Sanity Checking: |                                        | 0/? [00:00<?, ?it/s]

/home/course/bc_NN/.venv/lib/python3.11/site-packages/lightning/pytorch/utilities/_pytree.py:21: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
/home/course/bc_NN/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


IndexError: too many indices for tensor of dimension 2