In [1]:
import sys
from tqdm import tqdm
import os
import json
import time
import argparse

import pandas as pd
import numpy as np

from Bio import SeqIO
import esm

import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset

from torchinfo import summary

from sklearn.metrics import f1_score, roc_auc_score, roc_curve, auc, precision_recall_curve
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, MinMaxScaler


# Get the path of the Python script
sys.path.append('./../../../src/')

from utils import *
from utils_torch import * 
from MHCCBM import *
# from TAPPredictor import *
from TAPPredictor_CNN import *


seed = 42
set_seed(42)
print("Seed: ", seed)



Seed:  42


In [2]:
# Best config [128], 5, 32, 1e-05, 0.1

config_dict = { "config" : {
                    "hidden_channels" : [256],
                    "kernel_size": 5,
                    "pool_kernel_size":5,
                    "epochs" : 2000,
                    "batch_size" : 32,
                    "lr" : 1e-5,
                    "dropout_p" : 0.0,
                    "architecture" : "CNN"
                }
        }
config = config_dict['config']

In [6]:

# load X and y
with open('./../../../data/TAP/X.pkl','rb') as f:
    X = pickle.load(f)
f.close()

with open('./../../../data/TAP/y.pkl','rb') as f:
    y = pickle.load(f)
f.close()

# Scale the data
scaler = MinMaxScaler()
X = scaler.fit_transform(X.squeeze())
X = torch.tensor(X, dtype=torch.float32)

# Split the data
train_sequences, test_sequences, train_labels, test_labels = train_test_split(X, y, test_size=0.2, 
                                                                              random_state=seed, stratify=y)

# Create dataset and dataloaders
train_dataset = ProteinSequenceDataset(train_sequences.reshape(train_sequences.shape[0],1,-1), train_labels)
test_dataset = ProteinSequenceDataset(test_sequences.reshape(test_sequences.shape[0],1,-1), test_labels)

train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False)

# Calculate class weights
labels_tensor = torch.tensor(train_labels)
class_counts = torch.bincount(labels_tensor)
pos_weight = class_counts[0]/class_counts[1].to(dtype=torch.float32)
# class_weights = 1.0 / class_counts.float()
# class_weights = class_weights / class_weights.sum()  # Normalize

# # Convert to a tensor
# class_weights = class_weights.to(dtype=torch.float32)
# pos_weight = 1/class_weights[1]


In [7]:

#### model training
input_size = X.shape[1] #embedding size for esm2_t33_650M_UR50D (allele + peptide)

model = TAPPredictor(input_size, config['hidden_channels'], config['kernel_size'], config['pool_kernel_size'], config['dropout_p'])


start = time.time()
model.train_loop(train_loader=train_loader, valid_loader=test_loader, test_loader=None, config_dict=config_dict, pos_weight=pos_weight)
end = time.time()

time_elapsed = end-start
print("Time taken: ", time_elapsed)

result = {'time_elapsed': [time_elapsed]}
result['random_seed'] = [seed] 

#### performance
f1, auroc, auprc = model.eval_dataset(train_loader)
result['f1_train'], result['auroc_train'], result['auprc_train'] = [f1], [auroc], [auprc]

f1, auroc, auprc = model.eval_dataset(test_loader)
result['f1_test'], result['auroc_test'], result['auprc_test'] = [f1], [auroc], [auprc]

print(result)

epoch:  0
[Epoch 1, Batch 10] loss: 0.431
[Epoch 1, Batch 20] loss: 0.415
epoch:  1 val_f1:  0.570465614650302 val_auroc:  0.5930141899267113 val_auprc:  0.7777137223580354
epoch:  1
[Epoch 2, Batch 10] loss: 0.427
[Epoch 2, Batch 20] loss: 0.417
epoch:  2 val_f1:  0.570465614650302 val_auroc:  0.6061125838141276 val_auprc:  0.787313485400223
Validation loss improved to 2.529491
epoch:  2
[Epoch 3, Batch 10] loss: 0.417
[Epoch 3, Batch 20] loss: 0.423
epoch:  3 val_f1:  0.570465614650302 val_auroc:  0.6153126461874318 val_auprc:  0.7940365848138103
Validation loss improved to 2.523044
epoch:  3
[Epoch 4, Batch 10] loss: 0.427
[Epoch 4, Batch 20] loss: 0.428
epoch:  4 val_f1:  0.570465614650302 val_auroc:  0.6243567753001714 val_auprc:  0.799523669144061
Validation loss improved to 2.517004
epoch:  4
[Epoch 5, Batch 10] loss: 0.428
[Epoch 5, Batch 20] loss: 0.415
epoch:  5 val_f1:  0.5676753460004692 val_auroc:  0.6326212381100889 val_auprc:  0.8074259842459306
Validation loss improved 

[Epoch 40, Batch 10] loss: 0.409
[Epoch 40, Batch 20] loss: 0.419
epoch:  40 val_f1:  0.627680981027837 val_auroc:  0.7746764384843288 val_auprc:  0.8755411476665728
Validation loss improved to 2.438113
epoch:  40
[Epoch 41, Batch 10] loss: 0.407
[Epoch 41, Batch 20] loss: 0.409
epoch:  41 val_f1:  0.6394343252618153 val_auroc:  0.7757679713082801 val_auprc:  0.8752875397699793
Validation loss improved to 2.435473
epoch:  41
[Epoch 42, Batch 10] loss: 0.412
[Epoch 42, Batch 20] loss: 0.408
epoch:  42 val_f1:  0.627680981027837 val_auroc:  0.7770154373927959 val_auprc:  0.8759638569602267
Validation loss improved to 2.433008
epoch:  42
[Epoch 43, Batch 10] loss: 0.398
[Epoch 43, Batch 20] loss: 0.417
epoch:  43 val_f1:  0.6568235476856167 val_auroc:  0.7784188367378763 val_auprc:  0.8768075364673459
Validation loss improved to 2.430515
epoch:  43
[Epoch 44, Batch 10] loss: 0.413
[Epoch 44, Batch 20] loss: 0.403
epoch:  44 val_f1:  0.615789564686151 val_auroc:  0.7791985030406986 val_aup

[Epoch 79, Batch 10] loss: 0.407
[Epoch 79, Batch 20] loss: 0.394
epoch:  79 val_f1:  0.615789564686151 val_auroc:  0.7927646967098081 val_auprc:  0.8855420264778884
Validation loss improved to 2.344365
epoch:  79
[Epoch 80, Batch 10] loss: 0.394
[Epoch 80, Batch 20] loss: 0.395
epoch:  80 val_f1:  0.6283399968083326 val_auroc:  0.7938562295337596 val_auprc:  0.8859831679769014
Validation loss improved to 2.341691
epoch:  80
[Epoch 81, Batch 10] loss: 0.399
[Epoch 81, Batch 20] loss: 0.397
epoch:  81 val_f1:  0.6217530543872527 val_auroc:  0.7947918290971464 val_auprc:  0.8867226164776095
Validation loss improved to 2.339379
epoch:  81
[Epoch 82, Batch 10] loss: 0.402
[Epoch 82, Batch 20] loss: 0.385
epoch:  82 val_f1:  0.6224705853303208 val_auroc:  0.7951036956182754 val_auprc:  0.8868312321458757
Validation loss improved to 2.337178
epoch:  82
[Epoch 83, Batch 10] loss: 0.396
[Epoch 83, Batch 20] loss: 0.394
epoch:  83 val_f1:  0.6224705853303208 val_auroc:  0.7954155621394043 val_a

[Epoch 117, Batch 10] loss: 0.386
[Epoch 117, Batch 20] loss: 0.386
epoch:  117 val_f1:  0.657506010066862 val_auroc:  0.7988460938718229 val_auprc:  0.8906851628949225
Validation loss improved to 2.264057
epoch:  117
[Epoch 118, Batch 10] loss: 0.379
[Epoch 118, Batch 20] loss: 0.387
epoch:  118 val_f1:  0.657506010066862 val_auroc:  0.7991579603929518 val_auprc:  0.8907696987376259
Validation loss improved to 2.262321
epoch:  118
[Epoch 119, Batch 10] loss: 0.384
[Epoch 119, Batch 20] loss: 0.383
epoch:  119 val_f1:  0.6628956457731725 val_auroc:  0.7994698269140808 val_auprc:  0.8912236419566704
Validation loss improved to 2.260239
epoch:  119
[Epoch 120, Batch 10] loss: 0.376
[Epoch 120, Batch 20] loss: 0.386
epoch:  120 val_f1:  0.657506010066862 val_auroc:  0.8000935599563388 val_auprc:  0.8917651832807107
Validation loss improved to 2.258265
epoch:  120
[Epoch 121, Batch 10] loss: 0.386
[Epoch 121, Batch 20] loss: 0.377
epoch:  121 val_f1:  0.6514962135137123 val_auroc:  0.80056

[Epoch 155, Batch 10] loss: 0.364
[Epoch 155, Batch 20] loss: 0.382
epoch:  155 val_f1:  0.6963406663908527 val_auroc:  0.8047715577732731 val_auprc:  0.8957210768074502
Validation loss improved to 2.196147
epoch:  155
[Epoch 156, Batch 10] loss: 0.361
[Epoch 156, Batch 20] loss: 0.381
epoch:  156 val_f1:  0.6632983744534352 val_auroc:  0.8052393575549666 val_auprc:  0.8960864668059584
Validation loss improved to 2.193864
epoch:  156
[Epoch 157, Batch 10] loss: 0.376
[Epoch 157, Batch 20] loss: 0.362
epoch:  157 val_f1:  0.6688613111026905 val_auroc:  0.8047715577732731 val_auprc:  0.8958780139806782
Validation loss improved to 2.192289
epoch:  157
[Epoch 158, Batch 10] loss: 0.367
[Epoch 158, Batch 20] loss: 0.367
epoch:  158 val_f1:  0.6799171842650105 val_auroc:  0.805083424294402 val_auprc:  0.8960337796789319
Validation loss improved to 2.190558
epoch:  158
[Epoch 159, Batch 10] loss: 0.370
[Epoch 159, Batch 20] loss: 0.365
epoch:  159 val_f1:  0.6632983744534352 val_auroc:  0.805

[Epoch 193, Batch 10] loss: 0.366
[Epoch 193, Batch 20] loss: 0.345
epoch:  193 val_f1:  0.7179805988236051 val_auroc:  0.812880087322626 val_auprc:  0.9009513521521064
Validation loss improved to 2.136427
epoch:  193
[Epoch 194, Batch 10] loss: 0.364
[Epoch 194, Batch 20] loss: 0.353
epoch:  194 val_f1:  0.7125953016757615 val_auroc:  0.8130360205831904 val_auprc:  0.9009779607186321
Validation loss improved to 2.134470
epoch:  194
[Epoch 195, Batch 10] loss: 0.366
[Epoch 195, Batch 20] loss: 0.350
epoch:  195 val_f1:  0.7125953016757615 val_auroc:  0.8131919538437549 val_auprc:  0.9010875601025659
Validation loss improved to 2.133104
epoch:  195
[Epoch 196, Batch 10] loss: 0.359
[Epoch 196, Batch 20] loss: 0.360
epoch:  196 val_f1:  0.723351097916109 val_auroc:  0.8133478871043194 val_auprc:  0.9011431918798914
Validation loss improved to 2.132100
epoch:  196
[Epoch 197, Batch 10] loss: 0.360
[Epoch 197, Batch 20] loss: 0.352
epoch:  197 val_f1:  0.7125953016757615 val_auroc:  0.8138

[Epoch 231, Batch 10] loss: 0.357
[Epoch 231, Batch 20] loss: 0.338
epoch:  231 val_f1:  0.7292927587613782 val_auroc:  0.822392016217059 val_auprc:  0.9055288260519823
Validation loss improved to 2.082256
epoch:  231
[Epoch 232, Batch 10] loss: 0.351
[Epoch 232, Batch 20] loss: 0.337
epoch:  232 val_f1:  0.7397809537410333 val_auroc:  0.8225479494776238 val_auprc:  0.9056705050278372
Validation loss improved to 2.081244
epoch:  232
[Epoch 233, Batch 10] loss: 0.352
[Epoch 233, Batch 20] loss: 0.345
epoch:  233 val_f1:  0.7400997675433223 val_auroc:  0.8230157492593171 val_auprc:  0.9059088545202156
Validation loss improved to 2.079671
epoch:  233
[Epoch 234, Batch 10] loss: 0.346
[Epoch 234, Batch 20] loss: 0.339
epoch:  234 val_f1:  0.7393838798703758 val_auroc:  0.8231716825198816 val_auprc:  0.9059839575144679
Validation loss improved to 2.078862
epoch:  234
[Epoch 235, Batch 10] loss: 0.342
[Epoch 235, Batch 20] loss: 0.348
epoch:  235 val_f1:  0.7397809537410333 val_auroc:  0.823

[Epoch 269, Batch 10] loss: 0.338
[Epoch 269, Batch 20] loss: 0.338
epoch:  269 val_f1:  0.7615639752005565 val_auroc:  0.8276937470762514 val_auprc:  0.9063255127373678
Validation loss improved to 2.035446
epoch:  269
[Epoch 270, Batch 10] loss: 0.329
[Epoch 270, Batch 20] loss: 0.343
epoch:  270 val_f1:  0.7615639752005565 val_auroc:  0.8281615468579449 val_auprc:  0.9065103679342683
Validation loss improved to 2.034788
epoch:  270
[Epoch 271, Batch 10] loss: 0.319
[Epoch 271, Batch 20] loss: 0.348
epoch:  271 val_f1:  0.7615639752005565 val_auroc:  0.8278496803368158 val_auprc:  0.9064279378109223
Validation loss improved to 2.034425
epoch:  271
[Epoch 272, Batch 10] loss: 0.330
[Epoch 272, Batch 20] loss: 0.341
epoch:  272 val_f1:  0.7615639752005565 val_auroc:  0.8278496803368158 val_auprc:  0.9064279378109223
Validation loss improved to 2.031699
epoch:  272
[Epoch 273, Batch 10] loss: 0.338
[Epoch 273, Batch 20] loss: 0.326
epoch:  273 val_f1:  0.7615639752005565 val_auroc:  0.82

[Epoch 307, Batch 10] loss: 0.312
[Epoch 307, Batch 20] loss: 0.332
epoch:  307 val_f1:  0.7615639752005565 val_auroc:  0.8315920785903634 val_auprc:  0.9079941369830891
Validation loss improved to 1.994087
epoch:  307
[Epoch 308, Batch 10] loss: 0.315
[Epoch 308, Batch 20] loss: 0.329
epoch:  308 val_f1:  0.7615639752005565 val_auroc:  0.8319039451114922 val_auprc:  0.9081425479048031
Validation loss improved to 1.993285
epoch:  308
[Epoch 309, Batch 10] loss: 0.318
[Epoch 309, Batch 20] loss: 0.333
epoch:  309 val_f1:  0.7615639752005565 val_auroc:  0.8320598783720567 val_auprc:  0.9082118896551054
Validation loss improved to 1.991378
epoch:  309
[Epoch 310, Batch 10] loss: 0.343
[Epoch 310, Batch 20] loss: 0.310
epoch:  310 val_f1:  0.7615639752005565 val_auroc:  0.8314361453297988 val_auprc:  0.9077050122903999
EarlyStopping counter: 1 out of 50
epoch:  310
[Epoch 311, Batch 10] loss: 0.320
[Epoch 311, Batch 20] loss: 0.324
epoch:  311 val_f1:  0.7615639752005565 val_auroc:  0.8322

[Epoch 345, Batch 10] loss: 0.311
[Epoch 345, Batch 20] loss: 0.327
epoch:  345 val_f1:  0.7669048358703531 val_auroc:  0.8348666770622174 val_auprc:  0.9087613792191167
Validation loss improved to 1.960516
epoch:  345
[Epoch 346, Batch 10] loss: 0.312
[Epoch 346, Batch 20] loss: 0.316
epoch:  346 val_f1:  0.7669048358703531 val_auroc:  0.8350226103227819 val_auprc:  0.9088039690004498
Validation loss improved to 1.959100
epoch:  346
[Epoch 347, Batch 10] loss: 0.315
[Epoch 347, Batch 20] loss: 0.319
epoch:  347 val_f1:  0.77223771970517 val_auroc:  0.8351785435833464 val_auprc:  0.9088559076661734
EarlyStopping counter: 1 out of 50
epoch:  347
[Epoch 348, Batch 10] loss: 0.308
[Epoch 348, Batch 20] loss: 0.319
epoch:  348 val_f1:  0.7615639752005565 val_auroc:  0.8351785435833464 val_auprc:  0.9088559076661734
Validation loss improved to 1.955518
epoch:  348
[Epoch 349, Batch 10] loss: 0.327
[Epoch 349, Batch 20] loss: 0.302
epoch:  349 val_f1:  0.7669048358703531 val_auroc:  0.835646

[Epoch 383, Batch 10] loss: 0.296
[Epoch 383, Batch 20] loss: 0.317
epoch:  383 val_f1:  0.77223771970517 val_auroc:  0.8381412755340714 val_auprc:  0.9092545061345244
Validation loss improved to 1.929290
epoch:  383
[Epoch 384, Batch 10] loss: 0.309
[Epoch 384, Batch 20] loss: 0.306
epoch:  384 val_f1:  0.7669048358703531 val_auroc:  0.8382972087946359 val_auprc:  0.9092963427645195
Validation loss improved to 1.926254
epoch:  384
[Epoch 385, Batch 10] loss: 0.295
[Epoch 385, Batch 20] loss: 0.308
epoch:  385 val_f1:  0.7659044407870053 val_auroc:  0.8382972087946359 val_auprc:  0.9092963427645195
EarlyStopping counter: 1 out of 50
epoch:  385
[Epoch 386, Batch 10] loss: 0.301
[Epoch 386, Batch 20] loss: 0.306
epoch:  386 val_f1:  0.77223771970517 val_auroc:  0.8382972087946359 val_auprc:  0.9093084272962597
EarlyStopping counter: 2 out of 50
epoch:  386
[Epoch 387, Batch 10] loss: 0.308
[Epoch 387, Batch 20] loss: 0.308
epoch:  387 val_f1:  0.77223771970517 val_auroc:  0.838453142055

[Epoch 421, Batch 10] loss: 0.306
[Epoch 421, Batch 20] loss: 0.296
epoch:  421 val_f1:  0.7717496807151979 val_auroc:  0.8403243411819741 val_auprc:  0.9095970753266348
EarlyStopping counter: 1 out of 50
epoch:  421
[Epoch 422, Batch 10] loss: 0.291
[Epoch 422, Batch 20] loss: 0.315
epoch:  422 val_f1:  0.7711879124332689 val_auroc:  0.8403243411819741 val_auprc:  0.9095970753266348
EarlyStopping counter: 2 out of 50
epoch:  422
[Epoch 423, Batch 10] loss: 0.294
[Epoch 423, Batch 20] loss: 0.299
epoch:  423 val_f1:  0.7717496807151979 val_auroc:  0.8406362077031032 val_auprc:  0.9097425769012681
Validation loss improved to 1.902042
epoch:  423
[Epoch 424, Batch 10] loss: 0.294
[Epoch 424, Batch 20] loss: 0.308
epoch:  424 val_f1:  0.7717496807151979 val_auroc:  0.8407921409636676 val_auprc:  0.9097869108491147
Validation loss improved to 1.900926
epoch:  424
[Epoch 425, Batch 10] loss: 0.304
[Epoch 425, Batch 20] loss: 0.292
epoch:  425 val_f1:  0.783348074841503 val_auroc:  0.8407921

[Epoch 459, Batch 10] loss: 0.292
[Epoch 459, Batch 20] loss: 0.289
epoch:  459 val_f1:  0.7770518483685764 val_auroc:  0.8425074068298768 val_auprc:  0.9107206400684194
EarlyStopping counter: 1 out of 50
epoch:  459
[Epoch 460, Batch 10] loss: 0.286
[Epoch 460, Batch 20] loss: 0.288
epoch:  460 val_f1:  0.7770518483685764 val_auroc:  0.842195540308748 val_auprc:  0.9106965941797015
EarlyStopping counter: 2 out of 50
epoch:  460
[Epoch 461, Batch 10] loss: 0.291
[Epoch 461, Batch 20] loss: 0.289
epoch:  461 val_f1:  0.7770518483685764 val_auroc:  0.8428192733510058 val_auprc:  0.9107998756063198
EarlyStopping counter: 3 out of 50
epoch:  461
[Epoch 462, Batch 10] loss: 0.293
[Epoch 462, Batch 20] loss: 0.285
epoch:  462 val_f1:  0.7770518483685764 val_auroc:  0.8425074068298769 val_auprc:  0.9105515750255596
Validation loss improved to 1.879899
epoch:  462
[Epoch 463, Batch 10] loss: 0.291
[Epoch 463, Batch 20] loss: 0.297
epoch:  463 val_f1:  0.7770518483685764 val_auroc:  0.842507406

[Epoch 497, Batch 10] loss: 0.291
[Epoch 497, Batch 20] loss: 0.282
epoch:  497 val_f1:  0.7711879124332689 val_auroc:  0.8435989396538282 val_auprc:  0.911093766948037
Validation loss improved to 1.862329
epoch:  497
[Epoch 498, Batch 10] loss: 0.294
[Epoch 498, Batch 20] loss: 0.283
epoch:  498 val_f1:  0.7711879124332689 val_auroc:  0.8437548729143927 val_auprc:  0.9111387061497099
Validation loss improved to 1.861977
epoch:  498
[Epoch 499, Batch 10] loss: 0.259
[Epoch 499, Batch 20] loss: 0.298
epoch:  499 val_f1:  0.7711879124332689 val_auroc:  0.8439108061749572 val_auprc:  0.9112024603021502
EarlyStopping counter: 1 out of 50
epoch:  499
[Epoch 500, Batch 10] loss: 0.284
[Epoch 500, Batch 20] loss: 0.292
epoch:  500 val_f1:  0.7711879124332689 val_auroc:  0.8437548729143927 val_auprc:  0.9111575211004773
EarlyStopping counter: 2 out of 50
epoch:  500
[Epoch 501, Batch 10] loss: 0.291
[Epoch 501, Batch 20] loss: 0.284
epoch:  501 val_f1:  0.7711879124332689 val_auroc:  0.8437548

[Epoch 535, Batch 10] loss: 0.296
[Epoch 535, Batch 20] loss: 0.270
epoch:  535 val_f1:  0.7817429034556151 val_auroc:  0.8443786059566506 val_auprc:  0.9109801363793728
EarlyStopping counter: 4 out of 50
epoch:  535
[Epoch 536, Batch 10] loss: 0.291
[Epoch 536, Batch 20] loss: 0.264
epoch:  536 val_f1:  0.7823494776804266 val_auroc:  0.8442226726960861 val_auprc:  0.9106559151460607
Validation loss improved to 1.846354
epoch:  536
[Epoch 537, Batch 10] loss: 0.274
[Epoch 537, Batch 20] loss: 0.281
epoch:  537 val_f1:  0.7817429034556151 val_auroc:  0.8440667394355216 val_auprc:  0.9106098810607706
EarlyStopping counter: 1 out of 50
epoch:  537
[Epoch 538, Batch 10] loss: 0.281
[Epoch 538, Batch 20] loss: 0.268
epoch:  538 val_f1:  0.7817429034556151 val_auroc:  0.8440667394355217 val_auprc:  0.9106861994523285
EarlyStopping counter: 2 out of 50
epoch:  538
[Epoch 539, Batch 10] loss: 0.262
[Epoch 539, Batch 20] loss: 0.304
epoch:  539 val_f1:  0.7764670296430732 val_auroc:  0.84406673

[Epoch 573, Batch 10] loss: 0.260
[Epoch 573, Batch 20] loss: 0.284
epoch:  573 val_f1:  0.8028367790353751 val_auroc:  0.8443786059566506 val_auprc:  0.9107254962480045
EarlyStopping counter: 3 out of 50
epoch:  573
[Epoch 574, Batch 10] loss: 0.253
[Epoch 574, Batch 20] loss: 0.296
epoch:  574 val_f1:  0.8028367790353751 val_auroc:  0.8442226726960861 val_auprc:  0.9106803270425435
EarlyStopping counter: 4 out of 50
epoch:  574
[Epoch 575, Batch 10] loss: 0.283
[Epoch 575, Batch 20] loss: 0.267
epoch:  575 val_f1:  0.7823494776804266 val_auroc:  0.8443786059566506 val_auprc:  0.9107254962480045
Validation loss improved to 1.831691
epoch:  575
[Epoch 576, Batch 10] loss: 0.260
[Epoch 576, Batch 20] loss: 0.292
epoch:  576 val_f1:  0.7764670296430732 val_auroc:  0.8443786059566506 val_auprc:  0.9107254962480045
EarlyStopping counter: 1 out of 50
epoch:  576
[Epoch 577, Batch 10] loss: 0.279
[Epoch 577, Batch 20] loss: 0.268
epoch:  577 val_f1:  0.8028367790353751 val_auroc:  0.84437860

[Epoch 611, Batch 10] loss: 0.276
[Epoch 611, Batch 20] loss: 0.258
epoch:  611 val_f1:  0.7975624256837099 val_auroc:  0.8454701387806021 val_auprc:  0.9106945257519043
EarlyStopping counter: 9 out of 50
epoch:  611
[Epoch 612, Batch 10] loss: 0.256
[Epoch 612, Batch 20] loss: 0.280
epoch:  612 val_f1:  0.8028367790353751 val_auroc:  0.8453142055200376 val_auprc:  0.9106469275786082
EarlyStopping counter: 10 out of 50
epoch:  612
[Epoch 613, Batch 10] loss: 0.270
[Epoch 613, Batch 20] loss: 0.268
epoch:  613 val_f1:  0.7922894474618614 val_auroc:  0.8454701387806018 val_auprc:  0.9107037886061436
Validation loss improved to 1.823185
epoch:  613
[Epoch 614, Batch 10] loss: 0.269
[Epoch 614, Batch 20] loss: 0.280
epoch:  614 val_f1:  0.8028367790353751 val_auroc:  0.8454701387806018 val_auprc:  0.9107037886061436
EarlyStopping counter: 1 out of 50
epoch:  614
[Epoch 615, Batch 10] loss: 0.272
[Epoch 615, Batch 20] loss: 0.274
epoch:  615 val_f1:  0.7922894474618614 val_auroc:  0.8451582

[Epoch 649, Batch 10] loss: 0.262
[Epoch 649, Batch 20] loss: 0.273
epoch:  649 val_f1:  0.7975624256837099 val_auroc:  0.8454701387806018 val_auprc:  0.910594003097408
EarlyStopping counter: 1 out of 50
epoch:  649
[Epoch 650, Batch 10] loss: 0.258
[Epoch 650, Batch 20] loss: 0.268
epoch:  650 val_f1:  0.8081137119249642 val_auroc:  0.845470138780602 val_auprc:  0.9105630325282094
EarlyStopping counter: 2 out of 50
epoch:  650
[Epoch 651, Batch 10] loss: 0.264
[Epoch 651, Batch 20] loss: 0.254
epoch:  651 val_f1:  0.7975624256837099 val_auroc:  0.845470138780602 val_auprc:  0.9105630325282094
EarlyStopping counter: 3 out of 50
epoch:  651
[Epoch 652, Batch 10] loss: 0.251
[Epoch 652, Batch 20] loss: 0.273
epoch:  652 val_f1:  0.8081137119249642 val_auroc:  0.8454701387806018 val_auprc:  0.9104428547897888
EarlyStopping counter: 4 out of 50
epoch:  652
[Epoch 653, Batch 10] loss: 0.273
[Epoch 653, Batch 20] loss: 0.258
epoch:  653 val_f1:  0.7922894474618614 val_auroc:  0.8454701387806

[Epoch 687, Batch 10] loss: 0.256
[Epoch 687, Batch 20] loss: 0.265
epoch:  687 val_f1:  0.8081137119249642 val_auroc:  0.8464057383439888 val_auprc:  0.9108796930788604
EarlyStopping counter: 6 out of 50
epoch:  687
[Epoch 688, Batch 10] loss: 0.270
[Epoch 688, Batch 20] loss: 0.246
epoch:  688 val_f1:  0.8133944513339264 val_auroc:  0.8464057383439888 val_auprc:  0.9108123110761714
EarlyStopping counter: 7 out of 50
epoch:  688
[Epoch 689, Batch 10] loss: 0.261
[Epoch 689, Batch 20] loss: 0.263
epoch:  689 val_f1:  0.8081137119249642 val_auroc:  0.8465616716045533 val_auprc:  0.9109417200374703
EarlyStopping counter: 8 out of 50
epoch:  689
[Epoch 690, Batch 10] loss: 0.258
[Epoch 690, Batch 20] loss: 0.263
epoch:  690 val_f1:  0.8186802482535972 val_auroc:  0.8465616716045533 val_auprc:  0.9109417200374703
EarlyStopping counter: 9 out of 50
epoch:  690
[Epoch 691, Batch 10] loss: 0.278
[Epoch 691, Batch 20] loss: 0.249
epoch:  691 val_f1:  0.8028367790353751 val_auroc:  0.8467176048

[Epoch 725, Batch 10] loss: 0.263
[Epoch 725, Batch 20] loss: 0.250
epoch:  725 val_f1:  0.8186802482535972 val_auroc:  0.8470294713862466 val_auprc:  0.911261269967039
EarlyStopping counter: 2 out of 50
epoch:  725
[Epoch 726, Batch 10] loss: 0.242
[Epoch 726, Batch 20] loss: 0.267
epoch:  726 val_f1:  0.8186802482535972 val_auroc:  0.8470294713862467 val_auprc:  0.9112879739683577
EarlyStopping counter: 3 out of 50
epoch:  726
[Epoch 727, Batch 10] loss: 0.266
[Epoch 727, Batch 20] loss: 0.250
epoch:  727 val_f1:  0.8081137119249642 val_auroc:  0.8470294713862467 val_auprc:  0.911245821994924
EarlyStopping counter: 4 out of 50
epoch:  727
[Epoch 728, Batch 10] loss: 0.272
[Epoch 728, Batch 20] loss: 0.254
epoch:  728 val_f1:  0.8186802482535972 val_auroc:  0.8470294713862466 val_auprc:  0.9112480456959844
EarlyStopping counter: 5 out of 50
epoch:  728
[Epoch 729, Batch 10] loss: 0.282
[Epoch 729, Batch 20] loss: 0.230
epoch:  729 val_f1:  0.8081137119249642 val_auroc:  0.847029471386

[Epoch 763, Batch 10] loss: 0.253
[Epoch 763, Batch 20] loss: 0.253
epoch:  763 val_f1:  0.8186802482535972 val_auroc:  0.8471854046468112 val_auprc:  0.9116958268349993
EarlyStopping counter: 11 out of 50
epoch:  763
[Epoch 764, Batch 10] loss: 0.239
[Epoch 764, Batch 20] loss: 0.274
epoch:  764 val_f1:  0.8186802482535972 val_auroc:  0.8471854046468112 val_auprc:  0.9116958268349993
EarlyStopping counter: 12 out of 50
epoch:  764
[Epoch 765, Batch 10] loss: 0.264
[Epoch 765, Batch 20] loss: 0.232
epoch:  765 val_f1:  0.8186802482535972 val_auroc:  0.8474972711679403 val_auprc:  0.9119370572852641
EarlyStopping counter: 13 out of 50
epoch:  765
[Epoch 766, Batch 10] loss: 0.239
[Epoch 766, Batch 20] loss: 0.272
epoch:  766 val_f1:  0.8133944513339264 val_auroc:  0.8474972711679403 val_auprc:  0.9119370572852641
EarlyStopping counter: 14 out of 50
epoch:  766
[Epoch 767, Batch 10] loss: 0.255
[Epoch 767, Batch 20] loss: 0.255
epoch:  767 val_f1:  0.8186802482535972 val_auroc:  0.847497

[Epoch 801, Batch 10] loss: 0.219
[Epoch 801, Batch 20] loss: 0.266
epoch:  801 val_f1:  0.8186802482535972 val_auroc:  0.8481210042101981 val_auprc:  0.9125543171876036
EarlyStopping counter: 15 out of 50
epoch:  801
[Epoch 802, Batch 10] loss: 0.228
[Epoch 802, Batch 20] loss: 0.257
epoch:  802 val_f1:  0.8133944513339264 val_auroc:  0.8481210042101981 val_auprc:  0.9125543171876036
Validation loss improved to 1.793943
epoch:  802
[Epoch 803, Batch 10] loss: 0.266
[Epoch 803, Batch 20] loss: 0.238
epoch:  803 val_f1:  0.8081137119249642 val_auroc:  0.8481210042101981 val_auprc:  0.9125543171876036
Validation loss improved to 1.791652
epoch:  803
[Epoch 804, Batch 10] loss: 0.244
[Epoch 804, Batch 20] loss: 0.253
epoch:  804 val_f1:  0.8186802482535972 val_auroc:  0.8481210042101981 val_auprc:  0.9125543171876036
EarlyStopping counter: 1 out of 50
epoch:  804
[Epoch 805, Batch 10] loss: 0.249
[Epoch 805, Batch 20] loss: 0.242
epoch:  805 val_f1:  0.8186802482535972 val_auroc:  0.84812

[Epoch 839, Batch 10] loss: 0.252
[Epoch 839, Batch 20] loss: 0.239
epoch:  839 val_f1:  0.8133944513339264 val_auroc:  0.8482769374707625 val_auprc:  0.9126951834394651
EarlyStopping counter: 36 out of 50
epoch:  839
[Epoch 840, Batch 10] loss: 0.254
[Epoch 840, Batch 20] loss: 0.241
epoch:  840 val_f1:  0.8186802482535972 val_auroc:  0.8484328707313271 val_auprc:  0.9127409778372585
EarlyStopping counter: 37 out of 50
epoch:  840
[Epoch 841, Batch 10] loss: 0.256
[Epoch 841, Batch 20] loss: 0.238
epoch:  841 val_f1:  0.8239723791884421 val_auroc:  0.848121004210198 val_auprc:  0.9123433448959269
EarlyStopping counter: 38 out of 50
epoch:  841
[Epoch 842, Batch 10] loss: 0.242
[Epoch 842, Batch 20] loss: 0.250
epoch:  842 val_f1:  0.8239723791884421 val_auroc:  0.8482769374707626 val_auprc:  0.9123891392937203
EarlyStopping counter: 39 out of 50
epoch:  842
[Epoch 843, Batch 10] loss: 0.247
[Epoch 843, Batch 20] loss: 0.235
epoch:  843 val_f1:  0.8239723791884421 val_auroc:  0.8485888

In [8]:
pd.DataFrame(result)

Unnamed: 0,time_elapsed,random_seed,f1_train,auroc_train,auprc_train,f1_test,auroc_test,auprc_test
0,532.493232,42,0.829846,0.91168,0.961068,0.823972,0.848433,0.912476


In [49]:
pd.DataFrame(result)

Unnamed: 0,time_elapsed,random_seed,f1_train,auroc_train,auprc_train,f1_test,auroc_test,auprc_test
0,1090.139801,42,0.795013,0.948927,0.975514,0.7687,0.85046,0.914091


In [50]:
summary(model, input_size=(config['batch_size'],1,input_size))

Layer (type:depth-idx)                   Output Shape              Param #
TAPPredictor                             [32, 1]                   --
├─Sequential: 1-1                        [32, 1]                   --
│    └─Conv1d: 2-1                       [32, 256, 256]            1,536
│    └─BatchNorm1d: 2-2                  [32, 256, 256]            512
│    └─Tanh: 2-3                         [32, 256, 256]            --
│    └─MaxPool1d: 2-4                    [32, 256, 51]             --
│    └─Dropout: 2-5                      [32, 256, 51]             --
│    └─Flatten: 2-6                      [32, 13056]               --
│    └─Linear: 2-7                       [32, 1]                   13,057
Total params: 15,105
Trainable params: 15,105
Non-trainable params: 0
Total mult-adds (M): 13.02
Input size (MB): 0.16
Forward/backward pass size (MB): 33.55
Params size (MB): 0.06
Estimated Total Size (MB): 33.78

In [51]:
torch.save(model.state_dict(),'./../results/final_model/TAP.pt')