In [1]:
import sys
from tqdm import tqdm
import os
import json
import time
import argparse

import pandas as pd
import numpy as np

from Bio import SeqIO
import esm

import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset

from torchinfo import summary

from sklearn.metrics import f1_score, roc_auc_score, roc_curve, auc, precision_recall_curve
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

import wandb


# Get the path of the Python script
sys.path.append('./../../../src/')

from utils import *
from utils_torch import * 
from MHCCBM import *
from TAPPredictor_CNN import *


In [2]:
seed = 42
set_seed(42)
print("Seed: ", seed)

# load json file
with open('./../config/final_configs_5runs/run1.json') as jsonfile:
    config_dict = json.load(jsonfile)
config = config_dict['config']
print("config: ", config_dict)

Seed:  42
config:  {'project': 'MHCCBM', 'name': 'TAPCNNRun1', 'config': {'hidden_channels': [1024, 512, 256, 128, 16, 4], 'epochs': 200, 'classes': 2, 'batch_size': 8, 'lr': 1e-05, 'dataset': 'X.pkl, y.pkl (embedded peptide seq from classification', 'dropout_p': 0.0, 'architecture': 'CNN', 'seed': 42}}


In [15]:


seed = config['seed']
set_seed(seed)
print("Seed: ", seed)

# load X and y
with open('./../../../data/TAP/X.pkl','rb') as f:
    X = pickle.load(f)
f.close()

with open('./../../../data/TAP/y.pkl','rb') as f:
    y = pickle.load(f)
f.close()

# Scale the data
scaler = StandardScaler()
X = scaler.fit_transform(X.squeeze())
X = torch.tensor(X, dtype=torch.float32).reshape(X.shape[0],1,X.shape[1])

# Split the data
train_sequences, temp_sequences, train_labels, temp_labels = train_test_split(X, y, test_size=0.2, random_state=seed, stratify=y)
valid_sequences, test_sequences, valid_labels, test_labels = train_test_split(temp_sequences, temp_labels, test_size=0.5, random_state=seed, stratify=temp_labels)

# Create dataset and dataloaders
train_dataset = ProteinSequenceDataset(train_sequences, train_labels)
valid_dataset = ProteinSequenceDataset(valid_sequences, valid_labels)
test_dataset = ProteinSequenceDataset(test_sequences, test_labels)

train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=config['batch_size'], shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False)


# Calculate class weights
labels_tensor = torch.tensor(train_labels)
class_counts = torch.bincount(labels_tensor)
pos_weight = class_counts[0]/class_counts[1]
print(pos_weight)


Seed:  42
tensor(0.4369)


In [16]:
train_sequences.shape

torch.Size([694, 1, 1280])

In [17]:
#### model init
input_size = train_sequences.shape[2] #embedding size for esm2_t33_650M_UR50D (allele + peptide)
model = TAPPredictor(input_size, 
                     hidden_channels=config['hidden_channels'], 
                     dropout_p=config['dropout_p'])


In [18]:
summary(model, input_size=(config['batch_size'],1,input_size), 
        col_names=[ "input_size","output_size", "num_params"])

Layer (type:depth-idx)                   Input Shape               Output Shape              Param #
TAPPredictor                             [8, 1, 1280]              [8, 1]                    --
├─Sequential: 1-1                        [8, 1, 1280]              [8, 1]                    --
│    └─Conv1d: 2-1                       [8, 1, 1280]              [8, 1024, 1280]           4,096
│    └─Tanh: 2-2                         [8, 1024, 1280]           [8, 1024, 1280]           --
│    └─MaxPool1d: 2-3                    [8, 1024, 1280]           [8, 1024, 640]            --
│    └─Conv1d: 2-4                       [8, 1024, 640]            [8, 512, 640]             1,573,376
│    └─Tanh: 2-5                         [8, 512, 640]             [8, 512, 640]             --
│    └─MaxPool1d: 2-6                    [8, 512, 640]             [8, 512, 320]             --
│    └─Conv1d: 2-7                       [8, 512, 320]             [8, 256, 320]             393,472
│    └─Tanh: 2-8    

In [19]:
# model training
start = time.time()
model.train_loop(train_loader=train_loader, valid_loader=valid_loader, pos_weight=pos_weight,
                 test_loader=None, config_dict=config_dict)
end = time.time()

time_elapsed = end-start
print("Time taken: ", time_elapsed)

result = {'time_elapsed': [time_elapsed]}
result['random_seed'] = [seed] 

#### performance
f1, auroc, auprc = model.eval_dataset(train_loader)
result['f1_train'], result['auroc_train'], result['auprc_train'] = [f1], [auroc], [auprc]


f1, auroc, auprc = model.eval_dataset(test_loader)
result['f1_test'], result['auroc_test'], result['auprc_test'] = [f1], [auroc], [auprc]

print(result)

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668877129753432, max=1.0…

epoch:  0
[Epoch 1, Batch 10] loss: 0.458
[Epoch 1, Batch 20] loss: 0.443
[Epoch 1, Batch 30] loss: 0.393
[Epoch 1, Batch 40] loss: 0.482
[Epoch 1, Batch 50] loss: 0.448
[Epoch 1, Batch 60] loss: 0.391
[Epoch 1, Batch 70] loss: 0.404
[Epoch 1, Batch 80] loss: 0.415
epoch:  1
[Epoch 2, Batch 10] loss: 0.413
[Epoch 2, Batch 20] loss: 0.433



KeyboardInterrupt

