In [1]:
import sys
from tqdm import tqdm
import os
import json
import time
import argparse

import pandas as pd
import numpy as np

from Bio import SeqIO
import esm

import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset

from torchinfo import summary

from sklearn.metrics import f1_score, roc_auc_score, roc_curve, auc, precision_recall_curve
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, MinMaxScaler


# Get the path of the Python script
sys.path.append('./../../../src/')

from utils import *
from utils_torch import * 
from MHCCBM import *
# from TAPPredictor import *
from TAPPredictor_CNN import *


seed = 42
set_seed(42)
print("Seed: ", seed)

Seed:  42


In [2]:
# DeepTap train test
deeptap_train_df = pd.read_csv('./../../../data/TAP/DeepTAP_train_test_split/train.csv')
deeptap_test_df = pd.read_csv('./../../../data/TAP/DeepTAP_train_test_split/test.csv')


# Embeddings
with open('./../../../data/TAP/classification_peptides_esm1b.pkl','rb') as f:
    peptide_embeddings_dict = pickle.load(f)
f.close()

deeptap_train_df['embeddings'] = [peptide_embeddings_dict[i] for i in deeptap_train_df['peptide']]
deeptap_test_df['embeddings'] = [peptide_embeddings_dict[i] for i in deeptap_test_df['peptide']]

# Train test
train_sequences = torch.vstack(deeptap_train_df['embeddings'].to_list())
train_labels = deeptap_train_df['label'].to_numpy()

test_sequences = torch.vstack(deeptap_test_df['embeddings'].to_list())
test_labels = deeptap_test_df['label'].to_numpy()

# Calculate class weights
labels_tensor = torch.tensor(train_labels)
class_counts = torch.bincount(labels_tensor)
pos_weight = class_counts[0]/class_counts[1].to(dtype=torch.float32)


In [3]:
# Scale the data
scaler = MinMaxScaler()
train_sequences = scaler.fit_transform(train_sequences.squeeze())
train_sequences = torch.tensor(train_sequences, dtype=torch.float32)

test_sequences = scaler.transform(test_sequences)
test_sequences = torch.tensor(test_sequences, dtype=torch.float32)


In [4]:
# Best config [256], 5, 32, 1e-05, 0.0 AUPRC
# Best config [128], 3, 32, 0.001, 0 AUROC

config_dict = { "config" : {
                    "hidden_channels" : [128],
                    "kernel_size": 3,
                    "pool_kernel_size":3,
                    "epochs" : 5000,
                    "batch_size" : 32,
                    "lr" : 1e-5,
                    "dropout_p" : 0.0,
                    "architecture" : "CNN"
                }
        }
config = config_dict['config']

In [5]:
# Create dataset and dataloaders
train_dataset = ProteinSequenceDataset(train_sequences.reshape(train_sequences.shape[0],1,-1), train_labels)
test_dataset = ProteinSequenceDataset(test_sequences.reshape(test_sequences.shape[0],1,-1), test_labels)

train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False)


In [6]:
#### model training
input_size = train_sequences.shape[1] #embedding size for esm2_t33_650M_UR50D (allele + peptide)

model = TAPPredictor(input_size, config['hidden_channels'], 
                     config['kernel_size'], config['pool_kernel_size'], config['dropout_p'])


start = time.time()
model.train_loop(train_loader=train_loader, valid_loader=test_loader, test_loader=None, config_dict=config_dict, pos_weight=pos_weight)
end = time.time()

time_elapsed = end-start
print("Time taken: ", time_elapsed)

result = {'time_elapsed': [time_elapsed]}
result['random_seed'] = [seed] 

#### performance
f1, auroc, auprc = model.eval_dataset(train_loader)
result['f1_train'], result['auroc_train'], result['auprc_train'] = [f1], [auroc], [auprc]

f1, auroc, auprc = model.eval_dataset(test_loader)
result['f1_test'], result['auroc_test'], result['auprc_test'] = [f1], [auroc], [auprc]

print(result)


epoch:  0
[Epoch 1, Batch 10] loss: 0.416
[Epoch 1, Batch 20] loss: 0.416
epoch:  1 val_f1:  0.570465614650302 val_auroc:  0.6048651177296117 val_auprc:  0.7589731329092564
epoch:  1
[Epoch 2, Batch 10] loss: 0.423
[Epoch 2, Batch 20] loss: 0.402
epoch:  2 val_f1:  0.6055905850655791 val_auroc:  0.6394823015749259 val_auprc:  0.7846536556789132
Validation loss improved to 2.701782
epoch:  2
[Epoch 3, Batch 10] loss: 0.421
[Epoch 3, Batch 20] loss: 0.406
epoch:  3 val_f1:  0.6369518944231588 val_auroc:  0.6519569624200843 val_auprc:  0.8018989550123801
Validation loss improved to 2.652542
epoch:  3
[Epoch 4, Batch 10] loss: 0.411
[Epoch 4, Batch 20] loss: 0.413
epoch:  4 val_f1:  0.624036754171402 val_auroc:  0.6614688913145174 val_auprc:  0.813936752881644
Validation loss improved to 2.628526
epoch:  4
[Epoch 5, Batch 10] loss: 0.402
[Epoch 5, Batch 20] loss: 0.415
epoch:  5 val_f1:  0.6183053767054943 val_auroc:  0.6645875565258069 val_auprc:  0.8233252472376495
Validation loss improv

[Epoch 40, Batch 10] loss: 0.374
[Epoch 40, Batch 20] loss: 0.354
epoch:  40 val_f1:  0.645753339546443 val_auroc:  0.7706221737096524 val_auprc:  0.8955853499174712
Validation loss improved to 2.347325
epoch:  40
[Epoch 41, Batch 10] loss: 0.371
[Epoch 41, Batch 20] loss: 0.351
epoch:  41 val_f1:  0.657506010066862 val_auroc:  0.7721815063152971 val_auprc:  0.8964842593405846
Validation loss improved to 2.344561
epoch:  41
[Epoch 42, Batch 10] loss: 0.356
[Epoch 42, Batch 20] loss: 0.355
epoch:  42 val_f1:  0.6798464668886421 val_auroc:  0.7746764384843288 val_auprc:  0.8978699331625848
EarlyStopping counter: 1 out of 200
epoch:  42
[Epoch 43, Batch 10] loss: 0.363
[Epoch 43, Batch 20] loss: 0.357
epoch:  43 val_f1:  0.6452618135376755 val_auroc:  0.7759239045688445 val_auprc:  0.8984515645743397
Validation loss improved to 2.327671
epoch:  43
[Epoch 44, Batch 10] loss: 0.357
[Epoch 44, Batch 20] loss: 0.354
epoch:  44 val_f1:  0.6743377720123389 val_auroc:  0.7762357710899734 val_aup

[Epoch 79, Batch 10] loss: 0.320
[Epoch 79, Batch 20] loss: 0.318
epoch:  79 val_f1:  0.7295347598068578 val_auroc:  0.8149072197099642 val_auprc:  0.9180891414272784
Validation loss improved to 2.190667
epoch:  79
[Epoch 80, Batch 10] loss: 0.325
[Epoch 80, Batch 20] loss: 0.311
epoch:  80 val_f1:  0.729467139148214 val_auroc:  0.8147512864493998 val_auprc:  0.9179121355565357
EarlyStopping counter: 1 out of 200
epoch:  80
[Epoch 81, Batch 10] loss: 0.301
[Epoch 81, Batch 20] loss: 0.335
epoch:  81 val_f1:  0.7122851418327534 val_auroc:  0.8163106190550444 val_auprc:  0.9185972138398149
EarlyStopping counter: 2 out of 200
epoch:  81
[Epoch 82, Batch 10] loss: 0.324
[Epoch 82, Batch 20] loss: 0.309
epoch:  82 val_f1:  0.7295347598068578 val_auroc:  0.8161546857944799 val_auprc:  0.9184716602081419
Validation loss improved to 2.181303
epoch:  82
[Epoch 83, Batch 10] loss: 0.312
[Epoch 83, Batch 20] loss: 0.326
epoch:  83 val_f1:  0.7406029403051959 val_auroc:  0.8175580851395603 val_aup

[Epoch 117, Batch 10] loss: 0.296
[Epoch 117, Batch 20] loss: 0.295
epoch:  117 val_f1:  0.7403420339139749 val_auroc:  0.8315920785903633 val_auprc:  0.9232725569139848
EarlyStopping counter: 4 out of 200
epoch:  117
[Epoch 118, Batch 10] loss: 0.284
[Epoch 118, Batch 20] loss: 0.298
epoch:  118 val_f1:  0.7459576112348344 val_auroc:  0.8315920785903633 val_auprc:  0.9232482565920238
EarlyStopping counter: 5 out of 200
epoch:  118
[Epoch 119, Batch 10] loss: 0.297
[Epoch 119, Batch 20] loss: 0.279
epoch:  119 val_f1:  0.7403420339139749 val_auroc:  0.8319039451114922 val_auprc:  0.9232973204928933
EarlyStopping counter: 6 out of 200
epoch:  119
[Epoch 120, Batch 10] loss: 0.304
[Epoch 120, Batch 20] loss: 0.275
epoch:  120 val_f1:  0.7405092967340015 val_auroc:  0.8322158116326213 val_auprc:  0.9232980196110578
Validation loss improved to 2.105159
epoch:  120
[Epoch 121, Batch 10] loss: 0.286
[Epoch 121, Batch 20] loss: 0.299
epoch:  121 val_f1:  0.7400997675433223 val_auroc:  0.83268

[Epoch 155, Batch 10] loss: 0.268
[Epoch 155, Batch 20] loss: 0.280
epoch:  155 val_f1:  0.7459576112348344 val_auroc:  0.8345548105410884 val_auprc:  0.9219802099685287
Validation loss improved to 2.071971
epoch:  155
[Epoch 156, Batch 10] loss: 0.279
[Epoch 156, Batch 20] loss: 0.275
epoch:  156 val_f1:  0.7457573822516352 val_auroc:  0.834087010759395 val_auprc:  0.9218318528097299
EarlyStopping counter: 1 out of 200
epoch:  156
[Epoch 157, Batch 10] loss: 0.275
[Epoch 157, Batch 20] loss: 0.267
epoch:  157 val_f1:  0.7457573822516352 val_auroc:  0.8345548105410884 val_auprc:  0.9220426520251603
EarlyStopping counter: 2 out of 200
epoch:  157
[Epoch 158, Batch 10] loss: 0.283
[Epoch 158, Batch 20] loss: 0.261
epoch:  158 val_f1:  0.7400997675433223 val_auroc:  0.833775144238266 val_auprc:  0.9218062538248003
EarlyStopping counter: 3 out of 200
epoch:  158
[Epoch 159, Batch 10] loss: 0.261
[Epoch 159, Batch 20] loss: 0.272
epoch:  159 val_f1:  0.7457573822516352 val_auroc:  0.8340870

[Epoch 193, Batch 10] loss: 0.260
[Epoch 193, Batch 20] loss: 0.260
epoch:  193 val_f1:  0.7562140804597702 val_auroc:  0.8368938094495556 val_auprc:  0.9230239160115489
EarlyStopping counter: 11 out of 200
epoch:  193
[Epoch 194, Batch 10] loss: 0.267
[Epoch 194, Batch 20] loss: 0.245
epoch:  194 val_f1:  0.7562140804597702 val_auroc:  0.837673475752378 val_auprc:  0.9232188939818868
EarlyStopping counter: 12 out of 200
epoch:  194
[Epoch 195, Batch 10] loss: 0.255
[Epoch 195, Batch 20] loss: 0.261
epoch:  195 val_f1:  0.7676149222495315 val_auroc:  0.8375175424918135 val_auprc:  0.9229876808752535
EarlyStopping counter: 13 out of 200
epoch:  195
[Epoch 196, Batch 10] loss: 0.237
[Epoch 196, Batch 20] loss: 0.272
epoch:  196 val_f1:  0.7622189362630395 val_auroc:  0.8370497427101201 val_auprc:  0.9230062586787255
EarlyStopping counter: 14 out of 200
epoch:  196
[Epoch 197, Batch 10] loss: 0.252
[Epoch 197, Batch 20] loss: 0.262
epoch:  197 val_f1:  0.7622189362630395 val_auroc:  0.837

[Epoch 231, Batch 10] loss: 0.238
[Epoch 231, Batch 20] loss: 0.251
epoch:  231 val_f1:  0.7622189362630395 val_auroc:  0.8387650085763294 val_auprc:  0.9228929166911543
EarlyStopping counter: 5 out of 200
epoch:  231
[Epoch 232, Batch 10] loss: 0.234
[Epoch 232, Batch 20] loss: 0.259
epoch:  232 val_f1:  0.7622189362630395 val_auroc:  0.8389209418368938 val_auprc:  0.9230057739866425
EarlyStopping counter: 6 out of 200
epoch:  232
[Epoch 233, Batch 10] loss: 0.261
[Epoch 233, Batch 20] loss: 0.234
epoch:  233 val_f1:  0.772654022865647 val_auroc:  0.8389209418368938 val_auprc:  0.92291125183563
EarlyStopping counter: 7 out of 200
epoch:  233
[Epoch 234, Batch 10] loss: 0.256
[Epoch 234, Batch 20] loss: 0.238
epoch:  234 val_f1:  0.7568114091102596 val_auroc:  0.8390768750974582 val_auprc:  0.9229149725734429
EarlyStopping counter: 8 out of 200
epoch:  234
[Epoch 235, Batch 10] loss: 0.236
[Epoch 235, Batch 20] loss: 0.256
epoch:  235 val_f1:  0.7676149222495315 val_auroc:  0.838920941

[Epoch 269, Batch 10] loss: 0.255
[Epoch 269, Batch 20] loss: 0.217
epoch:  269 val_f1:  0.7622189362630395 val_auroc:  0.8421955403087478 val_auprc:  0.9245538804566205
EarlyStopping counter: 5 out of 200
epoch:  269
[Epoch 270, Batch 10] loss: 0.235
[Epoch 270, Batch 20] loss: 0.240
epoch:  270 val_f1:  0.7622189362630395 val_auroc:  0.8417277405270543 val_auprc:  0.9242383910138489
EarlyStopping counter: 6 out of 200
epoch:  270
[Epoch 271, Batch 10] loss: 0.220
[Epoch 271, Batch 20] loss: 0.256
epoch:  271 val_f1:  0.7676149222495315 val_auroc:  0.8418836737876189 val_auprc:  0.9243084824432884
EarlyStopping counter: 7 out of 200
epoch:  271
[Epoch 272, Batch 10] loss: 0.237
[Epoch 272, Batch 20] loss: 0.228
epoch:  272 val_f1:  0.7676149222495315 val_auroc:  0.8418836737876189 val_auprc:  0.9243084824432884
EarlyStopping counter: 8 out of 200
epoch:  272
[Epoch 273, Batch 10] loss: 0.227
[Epoch 273, Batch 20] loss: 0.232
epoch:  273 val_f1:  0.7676149222495315 val_auroc:  0.841883

[Epoch 307, Batch 10] loss: 0.240
[Epoch 307, Batch 20] loss: 0.209
epoch:  307 val_f1:  0.7783764367816092 val_auroc:  0.8451582722594729 val_auprc:  0.9256325350238261
EarlyStopping counter: 43 out of 200
epoch:  307
[Epoch 308, Batch 10] loss: 0.223
[Epoch 308, Batch 20] loss: 0.241
epoch:  308 val_f1:  0.7837440705307375 val_auroc:  0.8451582722594729 val_auprc:  0.925624675047971
EarlyStopping counter: 44 out of 200
epoch:  308
[Epoch 309, Batch 10] loss: 0.227
[Epoch 309, Batch 20] loss: 0.231
epoch:  309 val_f1:  0.7568114091102596 val_auroc:  0.8453142055200376 val_auprc:  0.9256767570352401
EarlyStopping counter: 45 out of 200
epoch:  309
[Epoch 310, Batch 10] loss: 0.216
[Epoch 310, Batch 10] loss: 0.216


In [10]:
pd.DataFrame(result)

Unnamed: 0,time_elapsed,random_seed,f1_train,auroc_train,auprc_train,f1_test,auroc_test,auprc_test
0,417.845965,42,0.879441,0.954304,0.97988,0.783348,0.84547,0.922774


In [37]:
pd.DataFrame(result)

Unnamed: 0,time_elapsed,random_seed,f1_train,auroc_train,auprc_train,f1_test,auroc_test,auprc_test
0,220.537996,42,0.83155,0.9033,0.954688,0.773,0.841884,0.925619


In [8]:
# torch.save(model, "./../results/final_model/TAP.pt")

Deep Tap performance

1. AUROC:  0.91330110712615
2. f1:  0.870425521932911
3. AUPRC:  0.9571925284874283

