In [1]:
import sys
from tqdm import tqdm
import os
import json
import time
import argparse

import pandas as pd
import numpy as np

from Bio import SeqIO
import esm

import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset

from torchinfo import summary

from sklearn.metrics import f1_score, roc_auc_score, roc_curve, auc, precision_recall_curve
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, MinMaxScaler


# Get the path of the Python script
sys.path.append('./../../../src/')

from utils import *
from utils_torch import * 
from MHCCBM import *
# from TAPPredictor import *
from TAPPredictor_CNN import *


seed = 42
set_seed(42)
print("Seed: ", seed)



Seed:  42


In [2]:
# Best config [128], 5, 32, 1e-05, 0.1

config_dict = { "config" : {
                    "hidden_channels" : [256],
                    "kernel_size": 5,
                    "pool_kernel_size":5,
                    "epochs" : 2000,
                    "batch_size" : 32,
                    "lr" : 1e-5,
                    "dropout_p" : 0.0,
                    "architecture" : "CNN"
                }
        }
config = config_dict['config']

In [3]:

# load X and y
with open('./../../../data/TAP/X.pkl','rb') as f:
    X = pickle.load(f)
f.close()

with open('./../../../data/TAP/y.pkl','rb') as f:
    y = pickle.load(f)
f.close()

# Scale the data
scaler = MinMaxScaler()
X = scaler.fit_transform(X.squeeze())
X = torch.tensor(X, dtype=torch.float32)

# Split the data
train_sequences, test_sequences, train_labels, test_labels = train_test_split(X, y, test_size=0.2, 
                                                                              random_state=seed, stratify=y)

# Create dataset and dataloaders
train_dataset = ProteinSequenceDataset(train_sequences.reshape(train_sequences.shape[0],1,-1), train_labels)
test_dataset = ProteinSequenceDataset(test_sequences.reshape(test_sequences.shape[0],1,-1), test_labels)

train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False)

# Calculate class weights
labels_tensor = torch.tensor(train_labels)
class_counts = torch.bincount(labels_tensor)
pos_weight = class_counts[0]/class_counts[1].to(dtype=torch.float32)



In [4]:

#### model training
input_size = X.shape[1] #embedding size for esm2_t33_650M_UR50D (allele + peptide)

model = TAPPredictor(input_size, config['hidden_channels'], config['kernel_size'], config['pool_kernel_size'], config['dropout_p'])


start = time.time()
model.train_loop(train_loader=train_loader, valid_loader=test_loader, test_loader=None, config_dict=config_dict, pos_weight=pos_weight)
end = time.time()

time_elapsed = end-start
print("Time taken: ", time_elapsed)

result = {'time_elapsed': [time_elapsed]}
result['random_seed'] = [seed] 

#### performance
f1, auroc, auprc = model.eval_dataset(train_loader)
result['f1_train'], result['auroc_train'], result['auprc_train'] = [f1], [auroc], [auprc]

f1, auroc, auprc = model.eval_dataset(test_loader)
result['f1_test'], result['auroc_test'], result['auprc_test'] = [f1], [auroc], [auprc]

print(result)

epoch:  0
[Epoch 1, Batch 10] loss: 0.438
[Epoch 1, Batch 20] loss: 0.433
epoch:  1 val_f1:  0.5901463330834583 val_auroc:  0.5582410728208326 val_auprc:  0.7594384686523318
epoch:  1
[Epoch 2, Batch 10] loss: 0.422
[Epoch 2, Batch 20] loss: 0.424
epoch:  2 val_f1:  0.5861024429095563 val_auroc:  0.6156245127085607 val_auprc:  0.7988013665388077
Validation loss improved to 2.508080
epoch:  2
[Epoch 3, Batch 10] loss: 0.409
[Epoch 3, Batch 20] loss: 0.420
epoch:  3 val_f1:  0.6093087460358488 val_auroc:  0.6295025728987993 val_auprc:  0.8043088069230493
Validation loss improved to 2.503404
epoch:  3
[Epoch 4, Batch 10] loss: 0.411
[Epoch 4, Batch 20] loss: 0.425
epoch:  4 val_f1:  0.5787484035759899 val_auroc:  0.6401060346171838 val_auprc:  0.8136486333621866
Validation loss improved to 2.498994
epoch:  4
[Epoch 5, Batch 10] loss: 0.425
[Epoch 5, Batch 20] loss: 0.411
epoch:  5 val_f1:  0.5572386425834701 val_auroc:  0.6527366287229066 val_auprc:  0.8197128822183968
Validation loss imp

[Epoch 40, Batch 10] loss: 0.380
[Epoch 40, Batch 20] loss: 0.394
epoch:  40 val_f1:  0.6520976080269605 val_auroc:  0.7740527054420708 val_auprc:  0.8827524839247265
Validation loss improved to 2.290646
epoch:  40
[Epoch 41, Batch 10] loss: 0.387
[Epoch 41, Batch 20] loss: 0.383
epoch:  41 val_f1:  0.6520976080269605 val_auroc:  0.7753001715265866 val_auprc:  0.8833808537439765
Validation loss improved to 2.286018
epoch:  41
[Epoch 42, Batch 10] loss: 0.394
[Epoch 42, Batch 20] loss: 0.378
epoch:  42 val_f1:  0.6632983744534352 val_auroc:  0.77576797130828 val_auprc:  0.8836851540943131
Validation loss improved to 2.281099
epoch:  42
[Epoch 43, Batch 10] loss: 0.372
[Epoch 43, Batch 20] loss: 0.394
epoch:  43 val_f1:  0.6632983744534352 val_auroc:  0.7760798378294089 val_auprc:  0.8839723416639156
Validation loss improved to 2.275784
epoch:  43
[Epoch 44, Batch 10] loss: 0.390
[Epoch 44, Batch 20] loss: 0.381
epoch:  44 val_f1:  0.6283399968083326 val_auroc:  0.7770154373927959 val_au

epoch:  78 val_f1:  0.7128216474828565 val_auroc:  0.8061749571183534 val_auprc:  0.8972606204613077
Validation loss improved to 2.128097
epoch:  78
[Epoch 79, Batch 10] loss: 0.356
[Epoch 79, Batch 20] loss: 0.348
epoch:  79 val_f1:  0.7128216474828565 val_auroc:  0.8080461562451271 val_auprc:  0.8982147774919189
Validation loss improved to 2.125435
epoch:  79
[Epoch 80, Batch 10] loss: 0.349
[Epoch 80, Batch 20] loss: 0.358
epoch:  80 val_f1:  0.6964502492933143 val_auroc:  0.808669889287385 val_auprc:  0.8984254233899613
Validation loss improved to 2.121533
epoch:  80
[Epoch 81, Batch 10] loss: 0.357
[Epoch 81, Batch 20] loss: 0.353
epoch:  81 val_f1:  0.734051724137931 val_auroc:  0.8096054888507719 val_auprc:  0.8987480960852318
Validation loss improved to 2.118173
epoch:  81
[Epoch 82, Batch 10] loss: 0.351
[Epoch 82, Batch 20] loss: 0.357
epoch:  82 val_f1:  0.7179805988236051 val_auroc:  0.8096054888507719 val_auprc:  0.8987195627573292
Validation loss improved to 2.113707
epoc

epoch:  116 val_f1:  0.7562140804597702 val_auroc:  0.8264462809917356 val_auprc:  0.9035540426099098
Validation loss improved to 2.006737
epoch:  116
[Epoch 117, Batch 10] loss: 0.329
[Epoch 117, Batch 20] loss: 0.317
epoch:  117 val_f1:  0.7711879124332689 val_auroc:  0.826914080773429 val_auprc:  0.9036581347832403
EarlyStopping counter: 1 out of 50
epoch:  117
[Epoch 118, Batch 10] loss: 0.324
[Epoch 118, Batch 20] loss: 0.332
epoch:  118 val_f1:  0.7504739275354663 val_auroc:  0.8267581475128646 val_auprc:  0.9036116145516921
Validation loss improved to 2.001630
epoch:  118
[Epoch 119, Batch 10] loss: 0.321
[Epoch 119, Batch 20] loss: 0.333
epoch:  119 val_f1:  0.7659044407870053 val_auroc:  0.8270700140339935 val_auprc:  0.903674763259597
Validation loss improved to 2.000200
epoch:  119
[Epoch 120, Batch 10] loss: 0.318
[Epoch 120, Batch 20] loss: 0.327
epoch:  120 val_f1:  0.7606155187988662 val_auroc:  0.827225947294558 val_auprc:  0.903732058944333
Validation loss improved to 

[Epoch 154, Batch 10] loss: 0.324
[Epoch 154, Batch 20] loss: 0.285
epoch:  154 val_f1:  0.7823494776804266 val_auroc:  0.8353344768439108 val_auprc:  0.9049338621184001
Validation loss improved to 1.922055
epoch:  154
[Epoch 155, Batch 10] loss: 0.296
[Epoch 155, Batch 20] loss: 0.306
epoch:  155 val_f1:  0.787016661160945 val_auroc:  0.8359582098861686 val_auprc:  0.905145985424394
EarlyStopping counter: 1 out of 50
epoch:  155
[Epoch 156, Batch 10] loss: 0.308
[Epoch 156, Batch 20] loss: 0.303
epoch:  156 val_f1:  0.7615639752005565 val_auroc:  0.8362700764072977 val_auprc:  0.9053655247372976
Validation loss improved to 1.917398
epoch:  156
[Epoch 157, Batch 10] loss: 0.309
[Epoch 157, Batch 20] loss: 0.299
epoch:  157 val_f1:  0.7823494776804266 val_auroc:  0.8361141431467332 val_auprc:  0.9051762590809185
EarlyStopping counter: 1 out of 50
epoch:  157
[Epoch 158, Batch 10] loss: 0.314
[Epoch 158, Batch 20] loss: 0.295
epoch:  158 val_f1:  0.7823494776804266 val_auroc:  0.83580227

[Epoch 192, Batch 10] loss: 0.286
[Epoch 192, Batch 20] loss: 0.293
epoch:  192 val_f1:  0.787016661160945 val_auroc:  0.8358022766256042 val_auprc:  0.9032138730827298
EarlyStopping counter: 1 out of 50
epoch:  192
[Epoch 193, Batch 10] loss: 0.275
[Epoch 193, Batch 20] loss: 0.296
epoch:  193 val_f1:  0.7876436781609196 val_auroc:  0.8356463433650397 val_auprc:  0.9031792554410683
Validation loss improved to 1.865552
epoch:  193
[Epoch 194, Batch 10] loss: 0.278
[Epoch 194, Batch 20] loss: 0.293
epoch:  194 val_f1:  0.7876436781609196 val_auroc:  0.8359582098861688 val_auprc:  0.9032515051293016
Validation loss improved to 1.863360
epoch:  194
[Epoch 195, Batch 10] loss: 0.290
[Epoch 195, Batch 20] loss: 0.281
epoch:  195 val_f1:  0.787016661160945 val_auroc:  0.8364260096678622 val_auprc:  0.9033929850905027
EarlyStopping counter: 1 out of 50
epoch:  195
[Epoch 196, Batch 10] loss: 0.278
[Epoch 196, Batch 20] loss: 0.285
epoch:  196 val_f1:  0.7876436781609196 val_auroc:  0.83611414

[Epoch 230, Batch 10] loss: 0.280
[Epoch 230, Batch 20] loss: 0.263
epoch:  230 val_f1:  0.8028367790353751 val_auroc:  0.8406362077031031 val_auprc:  0.9048238682398909
Validation loss improved to 1.832485
epoch:  230
[Epoch 231, Batch 10] loss: 0.270
[Epoch 231, Batch 20] loss: 0.275
epoch:  231 val_f1:  0.7817429034556151 val_auroc:  0.8406362077031031 val_auprc:  0.9048238682398909
Validation loss improved to 1.826854
epoch:  231
[Epoch 232, Batch 10] loss: 0.276
[Epoch 232, Batch 20] loss: 0.274
epoch:  232 val_f1:  0.8028367790353751 val_auroc:  0.8406362077031031 val_auprc:  0.9047973467375668
EarlyStopping counter: 1 out of 50
epoch:  232
[Epoch 233, Batch 10] loss: 0.275
[Epoch 233, Batch 20] loss: 0.270
epoch:  233 val_f1:  0.8081137119249642 val_auroc:  0.8406362077031031 val_auprc:  0.904604412772711
EarlyStopping counter: 2 out of 50
epoch:  233
[Epoch 234, Batch 10] loss: 0.261
[Epoch 234, Batch 20] loss: 0.292
epoch:  234 val_f1:  0.8028367790353751 val_auroc:  0.8409480

[Epoch 268, Batch 10] loss: 0.244
[Epoch 268, Batch 20] loss: 0.268
epoch:  268 val_f1:  0.8194011406352603 val_auroc:  0.8435989396538283 val_auprc:  0.9064610233468257
EarlyStopping counter: 5 out of 50
epoch:  268
[Epoch 269, Batch 10] loss: 0.265
[Epoch 269, Batch 20] loss: 0.263
epoch:  269 val_f1:  0.8147446423308492 val_auroc:  0.8442226726960861 val_auprc:  0.9072260483479022
Validation loss improved to 1.806771
epoch:  269
[Epoch 270, Batch 10] loss: 0.263
[Epoch 270, Batch 20] loss: 0.262
epoch:  270 val_f1:  0.8186802482535972 val_auroc:  0.8440667394355216 val_auprc:  0.9071005075672119
EarlyStopping counter: 1 out of 50
epoch:  270
[Epoch 271, Batch 10] loss: 0.264
[Epoch 271, Batch 20] loss: 0.245
epoch:  271 val_f1:  0.8186802482535972 val_auroc:  0.8442226726960861 val_auprc:  0.9071792618234199
EarlyStopping counter: 2 out of 50
epoch:  271
[Epoch 272, Batch 10] loss: 0.253
[Epoch 272, Batch 20] loss: 0.270
epoch:  272 val_f1:  0.8147446423308492 val_auroc:  0.84437860

[Epoch 306, Batch 10] loss: 0.246
[Epoch 306, Batch 20] loss: 0.249
epoch:  306 val_f1:  0.8306885693455566 val_auroc:  0.8462498050834243 val_auprc:  0.909077248767101
EarlyStopping counter: 6 out of 50
epoch:  306
[Epoch 307, Batch 10] loss: 0.252
[Epoch 307, Batch 20] loss: 0.249
epoch:  307 val_f1:  0.8306885693455566 val_auroc:  0.8460938718228599 val_auprc:  0.9090879501602469
EarlyStopping counter: 7 out of 50
epoch:  307
[Epoch 308, Batch 10] loss: 0.257
[Epoch 308, Batch 20] loss: 0.239
epoch:  308 val_f1:  0.8247038785258097 val_auroc:  0.8465616716045533 val_auprc:  0.9092654781804328
EarlyStopping counter: 8 out of 50
epoch:  308
[Epoch 309, Batch 10] loss: 0.253
[Epoch 309, Batch 20] loss: 0.247
epoch:  309 val_f1:  0.8306885693455566 val_auroc:  0.8462498050834244 val_auprc:  0.9093744597119982
EarlyStopping counter: 9 out of 50
epoch:  309
[Epoch 310, Batch 10] loss: 0.233
[Epoch 310, Batch 20] loss: 0.265
epoch:  310 val_f1:  0.8306885693455566 val_auroc:  0.84640573834

[Epoch 344, Batch 10] loss: 0.255
[Epoch 344, Batch 20] loss: 0.233
epoch:  344 val_f1:  0.8306885693455566 val_auroc:  0.8470294713862467 val_auprc:  0.9098279483668203
EarlyStopping counter: 2 out of 50
epoch:  344
[Epoch 345, Batch 10] loss: 0.249
[Epoch 345, Batch 20] loss: 0.227
epoch:  345 val_f1:  0.8306885693455566 val_auroc:  0.8473413379073756 val_auprc:  0.9099704076568158
EarlyStopping counter: 3 out of 50
epoch:  345
[Epoch 346, Batch 10] loss: 0.243
[Epoch 346, Batch 20] loss: 0.238
epoch:  346 val_f1:  0.8306885693455566 val_auroc:  0.8470294713862466 val_auprc:  0.9096004003268017
EarlyStopping counter: 4 out of 50
epoch:  346
[Epoch 347, Batch 10] loss: 0.235
[Epoch 347, Batch 20] loss: 0.245
epoch:  347 val_f1:  0.8306885693455566 val_auroc:  0.8471854046468111 val_auprc:  0.9099094151275697
EarlyStopping counter: 5 out of 50
epoch:  347
[Epoch 348, Batch 10] loss: 0.246
[Epoch 348, Batch 20] loss: 0.246
epoch:  348 val_f1:  0.8306885693455566 val_auroc:  0.8473413379

[Epoch 382, Batch 10] loss: 0.239
[Epoch 382, Batch 20] loss: 0.227
epoch:  382 val_f1:  0.8306885693455566 val_auroc:  0.848432870731327 val_auprc:  0.9100134370716564
EarlyStopping counter: 12 out of 50
epoch:  382
[Epoch 383, Batch 10] loss: 0.249
[Epoch 383, Batch 20] loss: 0.226
epoch:  383 val_f1:  0.8306885693455566 val_auroc:  0.8485888039918914 val_auprc:  0.910070827058665
EarlyStopping counter: 13 out of 50
epoch:  383
[Epoch 384, Batch 10] loss: 0.244
[Epoch 384, Batch 20] loss: 0.233
epoch:  384 val_f1:  0.8413452172218976 val_auroc:  0.8485888039918915 val_auprc:  0.9100691772204348
EarlyStopping counter: 14 out of 50
epoch:  384
[Epoch 385, Batch 10] loss: 0.226
[Epoch 385, Batch 20] loss: 0.232
epoch:  385 val_f1:  0.8306885693455566 val_auroc:  0.8482769374707625 val_auprc:  0.9097468630334794
EarlyStopping counter: 15 out of 50
epoch:  385
[Epoch 386, Batch 10] loss: 0.221
[Epoch 386, Batch 20] loss: 0.242
epoch:  386 val_f1:  0.8306885693455566 val_auroc:  0.84858880

[Epoch 420, Batch 10] loss: 0.239
[Epoch 420, Batch 20] loss: 0.221
epoch:  420 val_f1:  0.8360133057176928 val_auroc:  0.8506159363792297 val_auprc:  0.9109974338275393
EarlyStopping counter: 1 out of 50
epoch:  420
[Epoch 421, Batch 10] loss: 0.218
[Epoch 421, Batch 20] loss: 0.232
epoch:  421 val_f1:  0.8466856205834817 val_auroc:  0.8507718696397942 val_auprc:  0.9110611294374333
EarlyStopping counter: 2 out of 50
epoch:  421
[Epoch 422, Batch 10] loss: 0.226
[Epoch 422, Batch 20] loss: 0.225
epoch:  422 val_f1:  0.8413452172218976 val_auroc:  0.8507718696397941 val_auprc:  0.9111254281681695
EarlyStopping counter: 3 out of 50
epoch:  422
[Epoch 423, Batch 10] loss: 0.223
[Epoch 423, Batch 20] loss: 0.227
epoch:  423 val_f1:  0.8360133057176928 val_auroc:  0.8503040698581008 val_auprc:  0.910826745752119
EarlyStopping counter: 4 out of 50
epoch:  423
[Epoch 424, Batch 10] loss: 0.226
[Epoch 424, Batch 20] loss: 0.220
epoch:  424 val_f1:  0.8413452172218976 val_auroc:  0.85061593637

[Epoch 458, Batch 10] loss: 0.219
[Epoch 458, Batch 20] loss: 0.226
epoch:  458 val_f1:  0.8360133057176928 val_auroc:  0.8518634024637455 val_auprc:  0.9111436885073868
EarlyStopping counter: 9 out of 50
epoch:  458
[Epoch 459, Batch 10] loss: 0.228
[Epoch 459, Batch 20] loss: 0.202
epoch:  459 val_f1:  0.8466856205834817 val_auroc:  0.8520193357243101 val_auprc:  0.9112204917900797
EarlyStopping counter: 10 out of 50
epoch:  459
[Epoch 460, Batch 10] loss: 0.228
[Epoch 460, Batch 20] loss: 0.214
epoch:  460 val_f1:  0.8413452172218976 val_auroc:  0.8518634024637456 val_auprc:  0.9109516525587102
EarlyStopping counter: 11 out of 50
epoch:  460
[Epoch 461, Batch 10] loss: 0.226
[Epoch 461, Batch 20] loss: 0.212
epoch:  461 val_f1:  0.8360133057176928 val_auroc:  0.8520193357243101 val_auprc:  0.9111878016204444
EarlyStopping counter: 12 out of 50
epoch:  461
[Epoch 462, Batch 10] loss: 0.237
[Epoch 462, Batch 20] loss: 0.205
epoch:  462 val_f1:  0.8413452172218976 val_auroc:  0.8520193

[Epoch 496, Batch 10] loss: 0.202
[Epoch 496, Batch 20] loss: 0.223
epoch:  496 val_f1:  0.8360133057176928 val_auroc:  0.852331202245439 val_auprc:  0.9106636791937086
EarlyStopping counter: 32 out of 50
epoch:  496
[Epoch 497, Batch 10] loss: 0.210
[Epoch 497, Batch 20] loss: 0.209
epoch:  497 val_f1:  0.8520358613664902 val_auroc:  0.8521752689848746 val_auprc:  0.9106114382752247
EarlyStopping counter: 33 out of 50
epoch:  497
[Epoch 498, Batch 10] loss: 0.217
[Epoch 498, Batch 20] loss: 0.203
epoch:  498 val_f1:  0.8413452172218976 val_auroc:  0.8518634024637455 val_auprc:  0.9106957604979932
EarlyStopping counter: 34 out of 50
epoch:  498
[Epoch 499, Batch 10] loss: 0.212
[Epoch 499, Batch 20] loss: 0.204
epoch:  499 val_f1:  0.8466856205834817 val_auroc:  0.8524871355060035 val_auprc:  0.9110655209892341
EarlyStopping counter: 35 out of 50
epoch:  499
[Epoch 500, Batch 10] loss: 0.216
[Epoch 500, Batch 20] loss: 0.208
epoch:  500 val_f1:  0.8466856205834817 val_auroc:  0.8524871

In [5]:
pd.DataFrame(result)

Unnamed: 0,time_elapsed,random_seed,f1_train,auroc_train,auprc_train,f1_test,auroc_test,auprc_test
0,1944.024949,42,0.86994,0.94094,0.974364,0.841345,0.852175,0.910623


In [8]:
pd.DataFrame(result)

Unnamed: 0,time_elapsed,random_seed,f1_train,auroc_train,auprc_train,f1_test,auroc_test,auprc_test
0,532.493232,42,0.829846,0.91168,0.961068,0.823972,0.848433,0.912476


In [6]:
summary(model, input_size=(config['batch_size'],1,input_size))

Layer (type:depth-idx)                   Output Shape              Param #
TAPPredictor                             [32, 1]                   --
├─Sequential: 1-1                        [32, 1]                   --
│    └─Conv1d: 2-1                       [32, 256, 1276]           1,536
│    └─BatchNorm1d: 2-2                  [32, 256, 1276]           512
│    └─Tanh: 2-3                         [32, 256, 1276]           --
│    └─MaxPool1d: 2-4                    [32, 256, 255]            --
│    └─Dropout: 2-5                      [32, 256, 255]            --
│    └─Flatten: 2-6                      [32, 65280]               --
│    └─Linear: 2-7                       [32, 1]                   65,281
Total params: 67,329
Trainable params: 67,329
Non-trainable params: 0
Total mult-adds (M): 64.82
Input size (MB): 0.16
Forward/backward pass size (MB): 167.25
Params size (MB): 0.27
Estimated Total Size (MB): 167.68

In [51]:
torch.save(model.state_dict(),'./../results/final_model/TAP.pt')