In [12]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from scipy.stats import pearsonr, spearmanr
from transformers import AutoTokenizer, EsmModel
import time
import lightning as L
import os
import scipy
import scipy.stats
import sklearn.metrics as skmetrics
import matplotlib.pyplot as plt
import seaborn as sns


class ProtEmbeddingDataset(Dataset):
    def __init__(self, tensor_folder, csv_file):
        self.tensor_folder = tensor_folder
        self.df = pd.read_csv(csv_file)
        self.df = self.df[self.df.mut_type != 'wt'].reset_index(drop=True)
        self.labels = torch.tensor(self.df['ddG_ML'].values)
        self.ids = self.df['name'].values

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        # load embedding from .pt file
        tensor_path = os.path.join(self.tensor_folder, self.ids[idx] + '.pt')
        embedding = torch.load(tensor_path)['mean_representations'][6]  # (768,)
        label = self.labels[idx]
        return embedding, label.float()

In [13]:
dataset_train = ProtEmbeddingDataset(
    'data/mega_train_embeddings',
    'data/mega_train.csv'
)
dataset_val = ProtEmbeddingDataset(
    'data/mega_val_embeddings',
    'data/mega_val.csv'
)

loader_train = DataLoader(dataset_train, batch_size=1024, shuffle=True,  num_workers=4)
loader_val   = DataLoader(dataset_val,   batch_size=512,  shuffle=False, num_workers=4)

In [14]:
class LinearModel(L.LightningModule):
    def __init__(self, lr=1e-3):
        super().__init__()
        self.linear  = nn.Linear(768, 1)
        self.loss_fn = nn.MSELoss()
        self.lr = lr

    def forward(self, x):
        return self.linear(x).squeeze(1)