## Imports

In [48]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from icecream import ic
import pandas as pd
import pytorch_lightning as pl
import ast
import wandb

## Load dataset

In [49]:
residue_seqs_csv_path = '/workspace/protein_lm/evaluation/binding_site_prediction/data/residue_seqs_processed.csv'
residue_seqs_df = pd.read_csv(residue_seqs_csv_path)
residue_seqs_df['aligned_labels_with_gaps'] = residue_seqs_df['aligned_labels_with_gaps'].apply(ast.literal_eval)
residue_seqs_df['aligned_labels'] = residue_seqs_df['aligned_labels'].apply(ast.literal_eval)
residue_seqs_df['labels'] = residue_seqs_df['labels'].apply(ast.literal_eval)
residue_seqs_df.head()

Unnamed: 0,AC_ID,wt_seq,ptm_seq,pdb_id_with_chain_name,aligned_labels_with_gaps,aligned_labels,labels,ppbs_seq_alignment,ptm_seq_alignment,is_test,is_val,is_train,esm_650m_embedding_path,esm_650m_embedding_padding_mask_path,mamba_with_ptms_embedding_path,mamba_with_ptms_embedding_padding_mask_path,mamba_without_ptms_embedding_path,mamba_without_ptms_embedding_padding_mask_path,ppbs_embedding_path
0,C4YMW2,MSTNKITFLLNWEAAPYHIPVYLANIKGYFKDENLDIAILEPSNPS...,MSTNKITFLLNWEAAPYHIPVYLANIKGYFKDENLDIAILEPSNPS...,4esw_A,"[1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, ...","[1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, ...","[0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, ...",GSHMSTNKITFLLNWEAAPYHIPVYLANIKGYFKDENLDIAILEPS...,---MSTNKITFLLNWEAAPYHIPVYLANIKGYFKDENLDIAILEPS...,0,1,0,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...
1,P97291,MPERLAETLMDLWTPLIILWITLPSCVYTAPMNQAHVLTTGSPLEL...,MPERLAETLMDLWTPLIILWITLPSCVYTAPMNQAHVLTTGSPLEL...,1zxk_A,"[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, ...",------------------------S---------------------...,MPERLAETLMDLWTPLIILWITLPSCVYTAPMNQAHVLTTGSPLEL...,0,0,1,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...
2,P93114,MAVPMDTISGPWGNNGGNFWSFRPVNKINQIVISYGGGGNNPIALT...,M<N-acetylalanine>VPMDTISGPWGNNGGNFWSFRPVNKINQ...,1ouw_B,"[-1, -1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1...","[0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, ...","[1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, ...",--VPMDTISGPWGNNGGNFWSFRPVNKINQIVISYGGGGNNPIALT...,MAVPMDTISGPWGNNGGNFWSFRPVNKINQIVISYGGGGNNPIALT...,1,0,0,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...
3,P42262,MQKIMHISVLLSPVLWGLIFGVSSNSIQIGGLFPRGADQEYSAFRV...,MQKIMHISVLLSPVLWGLIFGVSSNSIQIGGLFPRGADQEYSAFRV...,2xhd_A,"[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",------------------------N---------------------...,MQKIMHISVLLSPVLWGLIFGVSSNSIQIGGLFPRGADQEYSAFRV...,0,0,1,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...
4,O75208,MAAAAVSGALGRAGWRLLQLRCLPVARCRQALVPRAFHASAVGLRS...,MAAAAVSGALGRAGWRLLQLRCLPVARCRQALVPRAFHASAVGLRS...,4rhp_B,"[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",----------------------------------------------...,MAAAAVSGALGRAGWRLLQLRCLPVARCRQALVPRAFHASAVGLRS...,0,0,1,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...


In [50]:
max_seq_len = max(residue_seqs_df['wt_seq'].apply(len))
ic(max_seq_len)

ic| max_seq_len: 6486


6486

## Create dataloaders

In [72]:
class ResiduePredictionDataset(Dataset):
	def __init__(self, df, embedding_column, padding_mask_column, max_seq_len, cutoff_seq_len=1000, label_column='aligned_labels'):
		self.df = df
		self.embedding_column = embedding_column
		self.padding_mask_column = padding_mask_column
		self.max_seq_len = max_seq_len
		self.cutoff_seq_len = cutoff_seq_len
		self.label_column = label_column

	def __len__(self):
		return len(self.df)

	def __getitem__(self, idx):
		row = self.df.iloc[idx]
		ppbs_seq = row['ppbs_seq_alignment'].replace('-', '')
		embedding_path = row[self.embedding_column]
		embedding = torch.load(embedding_path)
		labels = torch.tensor(row[self.label_column], dtype=torch.float32)
		# ic(embedding.shape, labels.shape, len(ppbs_seq))
		# ic(embedding, labels, ppbs_seq)
		return {'embedding': embedding, 'labels': labels}


test_dataset = ResiduePredictionDataset(residue_seqs_df, 'esm_650m_embedding_path', 'esm_650m_embedding_padding_mask_path', max_seq_len)

embedding, labels = next(iter(test_dataset))

In [73]:
from torch.nn.utils.rnn import pad_sequence

def crop_seq(input_ids, max_seq_len):
    """
    randomly crop sequences to max_seq_len
    Args:
        input_ids: tensor of shape (seq_len)
        max_seq_len: int
    """
    seq_len = len(input_ids)
    if seq_len <= max_seq_len:
        return input_ids
    else:
        start_idx = torch.randint(0, seq_len - max_seq_len + 1, (1,)).item()
        return input_ids[start_idx : start_idx + max_seq_len]
def collate_fn(batch):
    labels = [item['labels'] for item in batch]
    min_seq_len = min([len(label) for label in labels])
    embeddings = [item['embedding'] for item in batch]
    embeddings = torch.stack([crop_seq(embedding, min_seq_len) for embedding in embeddings])

    labels = torch.stack([crop_seq(label, min_seq_len) for label in labels])
    # labels = pad_sequence(
    #         labels,
    #         batch_first=True,
    #         padding_value=-1,
    # )
    # embeddings = pad_sequence(
    #     embeddings, 
    #     batch_first=True, 
    #     padding_value=0.0  # Assuming 0.0 is an appropriate padding value for your embeddings
    # )
    pad_mask = (labels != -1)
    return embeddings, labels, pad_mask
    # embeddings is a list of tensors, each of shape (seq_len, embedding_dim) please pad them to the same length
	

In [74]:
train_df = residue_seqs_df[residue_seqs_df['is_train'] == True]
val_df = residue_seqs_df[residue_seqs_df['is_val'] == True]
test_df = residue_seqs_df[residue_seqs_df['is_test'] == True]

ic(len(train_df), len(val_df), len(test_df))

ic| len(train_df): 2544, len(val_df): 263, len(test_df): 243


(2544, 263, 243)

In [75]:
esm_train_dataset = ResiduePredictionDataset(train_df, 'esm_650m_embedding_path', 'esm_650m_embedding_padding_mask_path', max_seq_len)
esm_val_dataset = ResiduePredictionDataset(val_df, 'esm_650m_embedding_path', 'esm_650m_embedding_padding_mask_path', max_seq_len)
esm_test_dataset = ResiduePredictionDataset(test_df, 'esm_650m_embedding_path', 'esm_650m_embedding_padding_mask_path', max_seq_len)

mamba_wt_train_dataset = ResiduePredictionDataset(train_df, 'mamba_without_ptms_embedding_path', 'mamba_without_ptms_embedding_padding_mask_path', max_seq_len)
mamba_wt_val_dataset = ResiduePredictionDataset(val_df, 'mamba_without_ptms_embedding_path', 'mamba_without_ptms_embedding_padding_mask_path', max_seq_len)
mamba_wt_test_dataset = ResiduePredictionDataset(test_df, 'mamba_without_ptms_embedding_path', 'mamba_without_ptms_embedding_padding_mask_path', max_seq_len)

mamba_ptm_train_dataset = ResiduePredictionDataset(train_df, 'mamba_with_ptms_embedding_path', 'mamba_with_ptms_embedding_padding_mask_path', max_seq_len)
mamba_ptm_val_dataset = ResiduePredictionDataset(val_df, 'mamba_with_ptms_embedding_path', 'mamba_with_ptms_embedding_padding_mask_path', max_seq_len)
mamba_ptm_test_dataset = ResiduePredictionDataset(test_df, 'mamba_with_ptms_embedding_path', 'mamba_with_ptms_embedding_padding_mask_path', max_seq_len)

ppbs_train_dataset = ResiduePredictionDataset(train_df, 'ppbs_embedding_path', '', max_seq_len, label_column='labels')
ppbs_val_dataset = ResiduePredictionDataset(val_df, 'ppbs_embedding_path', '', max_seq_len, label_column='labels')
ppbs_test_dataset = ResiduePredictionDataset(test_df, 'ppbs_embedding_path', '', max_seq_len, label_column='labels')

In [76]:
batch_size = 64
esm_train_loader = DataLoader(esm_train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
esm_val_loader = DataLoader(esm_val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
esm_test_loader = DataLoader(esm_test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

mamba_wt_train_loader = DataLoader(mamba_wt_train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
mamba_wt_val_loader = DataLoader(mamba_wt_val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
mamba_wt_test_loader = DataLoader(mamba_wt_test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

mamba_ptm_train_loader = DataLoader(mamba_ptm_train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
mamba_ptm_val_loader = DataLoader(mamba_ptm_val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
mamba_ptm_test_loader = DataLoader(mamba_ptm_test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

ppbs_train_loader = DataLoader(ppbs_train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
ppbs_val_loader = DataLoader(ppbs_val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
ppbs_test_loader = DataLoader(ppbs_test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

In [77]:
element = next(iter(ppbs_train_dataset))
embedding = element['embedding']
labels = element['labels']

ic(embedding.shape, labels.shape)
ic(embedding, labels)

ic| embedding.shape: torch.Size([96, 1280])
    labels.shape: torch.

Size([96])
ic| embedding: tensor([[ 0.0678,  0.1634, -0.0771,  ...,  0.1612, -0.0652, -0.0625],
                       [-0.0928, -0.0198, -0.1412,  ..., -0.0006, -0.1817, -0.0120],
                       [-0.0697,  0.1247, -0.0773,  ...,  0.0706, -0.0806, -0.0698],
                       ...,
                       [-0.1904,  0.1109,  0.0012,  ..., -0.3720, -0.0700,  0.0319],
                       [ 0.1291,  0.0144,  0.2626,  ...,  0.1445,  0.0790, -0.1674],
                       [ 0.0236,  0.0195,  0.0934,  ..., -0.3019, -0.1133,  0.3432]])
    labels: tensor([1., 1., 1., 1., 1., 1., 0., 1., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0.,
                    1., 1., 0., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
                    0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                    0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                    1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 1.

(tensor([[ 0.0678,  0.1634, -0.0771,  ...,  0.1612, -0.0652, -0.0625],
         [-0.0928, -0.0198, -0.1412,  ..., -0.0006, -0.1817, -0.0120],
         [-0.0697,  0.1247, -0.0773,  ...,  0.0706, -0.0806, -0.0698],
         ...,
         [-0.1904,  0.1109,  0.0012,  ..., -0.3720, -0.0700,  0.0319],
         [ 0.1291,  0.0144,  0.2626,  ...,  0.1445,  0.0790, -0.1674],
         [ 0.0236,  0.0195,  0.0934,  ..., -0.3019, -0.1133,  0.3432]]),
 tensor([1., 1., 1., 1., 1., 1., 0., 1., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0.,
         1., 1., 0., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 1., 0., 1.,
         0., 0., 1., 0., 0., 0.]))

In [78]:
embedding_path = '/workspace/protein_lm/evaluation/binding_site_prediction/data/embeddings/A0A0B4J1L0_mamba_with_ptms.pt'
embedding = torch.load(embedding_path)

## Define model

In [91]:
from pytorch_lightning.utilities.types import EVAL_DATALOADERS
from sklearn.metrics import f1_score, matthews_corrcoef, recall_score
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pytorch_lightning as pl

class Model(nn.Module):
	def __init__(self, input_dim, hidden_dim):
		super().__init__()
		self.linear1 = nn.Linear(input_dim, hidden_dim)
		self.linear2 = nn.Linear(hidden_dim, 2)
	
	def forward(self, x):
		x = F.relu(self.linear1(x))
		x = self.linear2(x)
		return x  # Removed softmax here

class ResiduePredictionModel(pl.LightningModule):
	def __init__(self, input_dim, hidden_dim, test_loader, lr=1e-3):
		super().__init__()
		self.model = Model(input_dim, hidden_dim)
		self.lr = lr
		self.test_loader = test_loader

	def training_step(self, batch, batch_idx):
		embeddings, labels, padding_mask = batch
		padding_mask = padding_mask.to(torch.bool)
		labels = labels.to(torch.long).unsqueeze(-1)
		labels = labels.expand(-1, -1, 2)
		labels[:, :, 1] = 1 - labels[:, :, 0]
		labels = labels.float()
		predictions = self.model(embeddings)
		# ic(labels.shape, labels.dtype) # (64, 38)
		# ic(predictions.shape, predictions.dtype) # (64, 38, 2)
		# ic(labels)
		# ic(predictions)
		# loss_fn = nn.BCEWithLogitsLoss()
		# loss = loss_fn(predictions, labels)

		# Corrected accuracy calculation
		true_labels = labels.argmax(dim=-1)
		pred_labels = predictions.argmax(dim=-1)
		acc = (true_labels == pred_labels).float().mean()
		# f1 = f1_score(labels.cpu().numpy(), predictions.argmax(dim=-1).cpu().numpy())
		# mcc = matthews_corrcoef(labels.cpu().numpy(), predictions.argmax(dim=-1).cpu().numpy())
		# recall = recall_score(labels.cpu().numpy(), predictions.argmax(dim=-1).cpu().numpy())
		self.log('train_loss', loss, prog_bar=True)
		self.log('train_acc', acc, prog_bar=True)
		# self.log('train_f1', f1, prog_bar=True)
		# self.log('train_mcc', mcc, prog_bar=True)
		# self.log('train_recall', recall, prog_bar=True)
		return loss
	
	def configure_optimizers(self):
		return optim.Adam(self.parameters(), lr=self.lr)
	
	# def validation_step(self, batch, batch_idx):
	# 	ic('validation step')
	# 	embeddings, labels, padding_mask = batch
	# 	padding_mask = padding_mask.to(torch.bool)
	# 	labels = labels.to(torch.long)
	# 	predictions = self.model(embeddings)
	# 	masked_predictions = predictions#[padding_mask]
	# 	masked_labels = labels#[padding_mask]
	# 	loss = F.cross_entropy(masked_predictions, masked_labels)
	# 	# Corrected accuracy calculation
	# 	accuracy = (masked_predictions.argmax(dim=-1) == masked_labels).float().mean()
	# 	self.log('val_loss', loss, prog_bar=True)
	# 	self.log('val_accuracy', accuracy, prog_bar=True)
	# 	f1 = f1_score(masked_labels.cpu().numpy(), masked_predictions.argmax(dim=-1).cpu().numpy())
	# 	self.log('val_f1', f1, prog_bar=True)
	# 	return loss
	
	# def test_step(self, batch, batch_idx):
	# 	embeddings, labels, padding_mask = batch
	# 	padding_mask = padding_mask.to(torch.bool)
	# 	labels = labels.to(torch.long)
	# 	predictions = self.model(embeddings)
	# 	masked_predictions = predictions[padding_mask]
	# 	masked_labels = labels[padding_mask]
	# 	loss = F.cross_entropy(masked_predictions, masked_labels)
	# 	# Corrected accuracy calculation
	# 	accuracy = (masked_predictions.argmax(dim=-1) == masked_labels).float().mean()
	# 	f1 = f1_score(masked_labels.cpu().numpy(), masked_predictions.argmax(dim=-1).cpu().numpy())
	# 	self.log('test_loss', loss, prog_bar=True)
	# 	self.log('test_accuracy', accuracy, prog_bar=True)
	# 	return loss
	
	def test_dataloader(self):
		return self.test_loader
		

# Assuming esm_650m_embedding_dim, input_dim, hidden_dim, and learning rate are defined elsewhere
esm_650m_embedding_dim = 1280
input_dim = esm_650m_embedding_dim
hidden_dim = 512
model = ResiduePredictionModel(input_dim, hidden_dim, esm_test_loader, lr=2e-2)
trainer = pl.Trainer(max_epochs=20, devices=[5],overfit_batches=0)  # Ensure overfit_batches is set appropriately
# Assuming esm_train_loader and esm_val_loader are defined elsewhere
trainer.fit(model, ppbs_train_loader, ppbs_val_loader)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/configuration_validator.py:72: You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

  | Name  | Type  | Params
--------------------------------
0 | model | Model | 656 K 
--------------------------------
656 K     Trainable params
0         Non-trainable params
656 K     Total params
2.628     Total estimated model params size (MB)
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=255` in the `DataLoader` to improve performance.
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/

Training:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 13:  22%|██▎       | 9/40 [00:00<00:01, 21.01it/s, v_num=90, train_loss=0.457, train_acc=0.683]   

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
