In [35]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from scipy.stats import pearsonr, spearmanr
from transformers import AutoTokenizer, EsmModel
import time
import lightning as L
import os
import scipy
import scipy.stats
import sklearn.metrics as skmetrics
import matplotlib.pyplot as plt
import seaborn as sns
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



In [36]:
class ProtEmbeddingDataset(Dataset):
    def __init__(self, tensor_folder, csv_file):
        self.tensor_folder = tensor_folder
        self.df = pd.read_csv(csv_file)
        #self.df = self.df[self.df.mut_type != 'wt'].reset_index(drop=True)
        self.labels = torch.tensor(self.df['ddG_ML'].values)
        self.ids = self.df['name'].values

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        # load embedding from .pt file
        tensor_path = os.path.join(self.tensor_folder, self.ids[idx] + '.pt')
        embedding = torch.load(tensor_path)['mean_representations'][6] 
        label = self.labels[idx]
        return embedding, label.float()

In [37]:
dataset_train = ProtEmbeddingDataset(
    'data/mega_train_embeddings',
    'data/mega_train.csv'
)
dataset_val = ProtEmbeddingDataset(
    'data/mega_val_embeddings',
    'data/mega_val.csv'
)

loader_train = DataLoader(dataset_train, batch_size=1024, shuffle=True,  num_workers=4)
loader_val   = DataLoader(dataset_val,   batch_size=512,  shuffle=False, num_workers=4)

In [44]:

from torchmetrics.regression import PearsonCorrCoef, SpearmanCorrCoef

class StabModel(L.LightningModule):
    def __init__(self, input_dim=768, lr=1e-3):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
         
            nn.Linear(256, 1)
        )
        self.loss_fn = nn.MSELoss()
        self.lr = lr

        # Metrics
        self.val_pearson = PearsonCorrCoef()
        self.val_spearman = SpearmanCorrCoef()

    def forward(self, x):
        return self.model(x).squeeze(1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        preds = self(x)
        loss = self.loss_fn(preds, y)
        
        # Print Train Loss to console/progress bar
        self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        preds = self(x)
        loss = self.loss_fn(preds, y)
        
        # Update correlation metrics
        self.val_pearson(preds, y)
        self.val_spearman(preds, y)
        
        # Print Val stats to console/progress bar
        self.log("val_loss", loss, on_epoch=True, prog_bar=True)
        self.log("val_pear", self.val_pearson, on_epoch=True, prog_bar=True)
        self.log("val_spear", self.val_spearman, on_epoch=True, prog_bar=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

In [49]:
model = StabModel(lr = 1e-3)

trainer = L.Trainer(devices =1, max_epochs= 2)
trainer.fit(model, loader_train, loader_val)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
ðŸ’¡ Tip: For seamless cloud logging and experiment tracking, try installing [litlogger](https://pypi.org/project/litlogger/) to enable LitLogger, which logs metrics and artifacts automatically to the Lightning Experiments platform.
ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type             | Params | Mode  | FLOPs
------------------------------------------------------------------
0 | model        | Sequential       | 197 K  | train | 0    
1 | loss_fn      | MSELoss          | 0      | train | 0    
2 | val_pearson  | PearsonCorrCoef  | 0      | train | 0    
3 | val_spearman | SpearmanCorrCoef | 0      | train | 0    
----------------------------------------------------------------

Sanity Checking: |                                        | 0/? [00:00<?, ?it/s]

/home/course/bc_NN/.venv/lib/python3.11/site-packages/lightning/pytorch/utilities/_pytree.py:21: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.


Training: |                                               | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=2` reached.


In [50]:
%reload_ext tensorboard
%tensorboard --logdir=lightning_logs/val_spear=0.221, train 


Reusing TensorBoard on port 6009 (pid 199461), started 0:00:36 ago. (Use '!kill 199461' to kill it.)