In [1]:
# load packages
import numpy as np
import pandas as pd
import seaborn as sns
import time
import matplotlib.pyplot as plt

from sklearn.ensemble import RandomForestRegressor,RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

import torch
import esm
from tqdm import tqdm
import pickle

from torchinfo import summary

from sklearn.metrics import f1_score, roc_auc_score, roc_curve, auc, precision_recall_curve
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import StandardScaler, MinMaxScaler

import sys
sys.path.append('./../../src/')


from utils import *
from TDPredictor_CNN import *
# from TDPredictor_MLP import *
from MHCCBM import *


In [2]:
# load full TD dataframe
TD_full_df = pd.read_csv('./../../data/TD/processed_data//TD_full.csv',index_col=0)
TD_full_df = TD_full_df.rename(columns={'HLA_full':'allele'})
TD_full_df

Unnamed: 0,allele,HLA,MFI_ratio,SD,Source,ID,Sequence,length
0,HLA-A*01:01,A*01:01,32.11,13.77,Bashirova,HLA:HLA00001,MAVMAPRTLLLLLSGALALTQTWAGSHSMRYFFTSVSRPGRGEPRF...,365
1,HLA-A*01:02,A*01:02,109.86,35.04,Bashirova,HLA:HLA00002,MAVMAPRTLLLLLSGALALTQTWAGSHSMRYFSTSVSRPGSGEPRF...,365
2,HLA-A*02:01,A*02:01,2.02,0.22,Bashirova,HLA:HLA00005,MAVMAPRTLVLLLSGALALTQTWAGSHSMRYFFTSVSRPGRGEPRF...,365
3,HLA-A*02:02,A*02:02,1.45,0.16,Bashirova,HLA:HLA00007,MAVMAPRTLVLLLSGALALTQTWAGSHSMRYFFTSVSRPGRGEPRF...,365
4,HLA-A*02:05,A*02:05,1.49,0.12,Bashirova,HLA:HLA00010,MAVMAPRTLVLLLSGALALTQTWAGSHSMRYFYTSVSRPGRGEPRF...,365
...,...,...,...,...,...,...,...,...
92,HLA-C*16:01,C*16:01,5.98,1.33,Bashirova,HLA:HLA00475,MRVMAPRTLILLLSGALALTETWACSHSMRYFYTAVSRPGRGEPRF...,366
93,HLA-C*17:01,C*17:01,2.92,0.98,Bashirova,HLA:HLA04311,MRVMAPQALLLLLSGALALIETWAGSHSMRYFYTAVSRPGRGEPRF...,372
94,HLA-C*17:03,C*17:03,3.63,1.29,Bashirova,HLA:HLA00993,MRVMAPQALLLLLSGALALIETWTGSHSMRYFYTAVSRPGRGEPRF...,372
95,HLA-C*18:01,C*18:01,2.76,0.80,Bashirova,HLA:HLA00483,MRVMAPRALLLLLSGGLALTETWACSHSMRYFDTAVSRPGRGEPRF...,366


In [50]:
TD_full_df[TD_full_df['HLA']=='C*07:01']

Unnamed: 0,allele,HLA,MFI_ratio,SD,Source,ID,Sequence,length
82,HLA-C*07:01,C*07:01,1.38,0.45,Bashirova,HLA:HLA00433,MRVMAPRALLLLLSGGLALTETWACSHSMRYFDTAVSRPGRGEPRF...,366


In [3]:
# load embeddings
with open('./../../data/TD/processed_data//allele_esm1b.pkl','rb') as f:
    embedding_dict = pickle.load(f)
    
embedding_dict

{'HLA-A*01:01': tensor([[-0.0175, -0.0270, -0.0234,  ..., -0.0569,  0.0393,  0.0693]]),
 'HLA-A*01:02': tensor([[-0.0160, -0.0260, -0.0257,  ..., -0.0601,  0.0384,  0.0687]]),
 'HLA-A*02:01': tensor([[-0.0016, -0.0314, -0.0237,  ..., -0.0590,  0.0384,  0.0710]]),
 'HLA-A*02:02': tensor([[-0.0020, -0.0312, -0.0239,  ..., -0.0593,  0.0410,  0.0700]]),
 'HLA-A*02:05': tensor([[-0.0019, -0.0322, -0.0257,  ..., -0.0591,  0.0420,  0.0692]]),
 'HLA-A*03:01': tensor([[-0.0115, -0.0268, -0.0182,  ..., -0.0563,  0.0385,  0.0768]]),
 'HLA-A*11:01': tensor([[-0.0127, -0.0270, -0.0204,  ..., -0.0585,  0.0396,  0.0772]]),
 'HLA-A*11:02': tensor([[-0.0161, -0.0275, -0.0226,  ..., -0.0567,  0.0384,  0.0760]]),
 'HLA-A*23:01': tensor([[-0.0117, -0.0215, -0.0158,  ..., -0.0574,  0.0455,  0.0724]]),
 'HLA-A*24:02': tensor([[-0.0135, -0.0240, -0.0174,  ..., -0.0590,  0.0413,  0.0787]]),
 'HLA-A*25:01': tensor([[-0.0043, -0.0157, -0.0223,  ..., -0.0545,  0.0368,  0.0696]]),
 'HLA-A*26:01': tensor([[-0.0055

In [4]:
## Combined embeddings and y
merged_df = pd.DataFrame({'allele':embedding_dict.keys(),
                          'embedding':embedding_dict.values()}).merge(TD_full_df, 
                                                                      on='allele')[['embedding','MFI_ratio']]

X = torch.cat(merged_df['embedding'].to_list())
y = merged_df['MFI_ratio'].to_numpy()

# Scale the data
scaler = MinMaxScaler()
X = scaler.fit_transform(X.squeeze())
X = torch.tensor(X, dtype=torch.float32)

# Scale y
y = np.where(y>2, 1, 0)


## Neural network

In [5]:
seed = 42

# best config [128], 3, 32, 0.001, 0

config_dict = {"project": "MHCCBM", 
                 "name": "CNN_80", 
                 "config": {"hidden_channels": [128], 
                            "epochs": 200,  # determined from number of epochs needed from hyperopts
                            "kernel_size":13, #13, 
                            "pool_kernel_size": 13, #13, 
                            "classes": 2, 
                            "seed": seed, 
                            "batch_size": 32, 
                            "lr": 1e-03, 
                            "dataset": "TD bashirova MFI data", 
                            "dropout_p": 0.0, 
                            "architecture": "CNN"}}

config = config_dict['config']


In [12]:
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
result = {'idx':[], 'time_elapsed':[],'f1':[],'auroc':[],'auprc':[]}

for idx, (train, test) in enumerate(skf.split(X, y)):
    result['idx'] = result['idx'] + [idx]

    train_sequences, train_labels = X[train], y[train]
    test_sequences, test_labels = X[test], y[test]
    
    # Create dataset and dataloaders
    train_dataset = ProteinSequenceDataset(train_sequences.reshape(train_sequences.shape[0],1,-1), train_labels)
    test_dataset = ProteinSequenceDataset(test_sequences.reshape(test_sequences.shape[0],1,-1), test_labels)

    train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False)

    # Calculate class weights
    labels_tensor = torch.tensor(train_labels)
    class_counts = torch.bincount(labels_tensor)
    pos_weight = class_counts[0]/class_counts[1]
    print(pos_weight)
    
    input_size = train_sequences.shape[1] #embedding size for esm2_t33_650M_UR50D (HLA)
    model = TDPredictor(input_size, 
                         hidden_channels=config['hidden_channels'],
                         dropout_p=config['dropout_p'],
                         kernel_size = config['kernel_size'],
                         pool_kernel_size = config['pool_kernel_size'])

    # model training
    start = time.time()
    model.train_loop(train_loader=train_loader, 
                     valid_loader=test_loader, 
                     pos_weight=pos_weight,
                     config_dict=config_dict)
    end = time.time()

    time_elapsed = end-start
    print("Time taken: ", time_elapsed)

    result['time_elapsed'] =  result['time_elapsed'] + [time_elapsed]
    f1,auroc,auprc = model.eval_dataset(test_loader, return_label=False, )
    
    result['f1'] =  result['f1'] + [f1]
    result['auroc'] =  result['auroc'] + [auroc]
    result['auprc'] =  result['auprc'] + [auprc]
    
pd.DataFrame(result).mean(), pd.DataFrame(result).std()

tensor(0.2222)
epoch:  0
epoch:  1 val_loss:  tensor(0.2601) val_f1:  0.6380952380952382 val_auroc:  0.640625 val_auprc:  0.8896827821194113
epoch:  1
epoch:  2 val_loss:  tensor(0.2849) val_f1:  0.06666666666666668 val_auroc:  0.5 val_auprc:  0.7891616224494031
EarlyStopping counter: 1 out of 50
epoch:  2
epoch:  3 val_loss:  tensor(0.2910) val_f1:  0.06666666666666668 val_auroc:  0.46875 val_auprc:  0.7388081998303555
EarlyStopping counter: 2 out of 50
epoch:  3
epoch:  4 val_loss:  tensor(0.2658) val_f1:  0.06666666666666668 val_auroc:  0.46875 val_auprc:  0.7388081998303555
EarlyStopping counter: 3 out of 50
epoch:  4
epoch:  5 val_loss:  tensor(0.2754) val_f1:  0.7359307359307359 val_auroc:  0.46875 val_auprc:  0.7388081998303555
EarlyStopping counter: 4 out of 50
epoch:  5
epoch:  6 val_loss:  tensor(0.2820) val_f1:  0.6857142857142857 val_auroc:  0.46875 val_auprc:  0.7388081998303555
EarlyStopping counter: 5 out of 50
epoch:  6
epoch:  7 val_loss:  tensor(0.2672) val_f1:  0.679

epoch:  54 val_loss:  tensor(0.2548) val_f1:  0.7711598746081505 val_auroc:  0.65625 val_auprc:  0.7980198334893092
Validation loss improved to 0.254808
epoch:  54
epoch:  55 val_loss:  tensor(0.2545) val_f1:  0.7711598746081505 val_auroc:  0.65625 val_auprc:  0.7980198334893092
Validation loss improved to 0.254493
epoch:  55
epoch:  56 val_loss:  tensor(0.2542) val_f1:  0.7711598746081505 val_auroc:  0.65625 val_auprc:  0.7980198334893092
Validation loss improved to 0.254173
epoch:  56
epoch:  57 val_loss:  tensor(0.2538) val_f1:  0.7711598746081505 val_auroc:  0.65625 val_auprc:  0.7980198334893092
Validation loss improved to 0.253847
epoch:  57
epoch:  58 val_loss:  tensor(0.2535) val_f1:  0.7711598746081505 val_auroc:  0.65625 val_auprc:  0.7980198334893092
Validation loss improved to 0.253520
epoch:  58
epoch:  59 val_loss:  tensor(0.2532) val_f1:  0.7711598746081505 val_auroc:  0.65625 val_auprc:  0.7980198334893092
Validation loss improved to 0.253198
epoch:  59
epoch:  60 val_l

epoch:  105 val_loss:  tensor(0.2368) val_f1:  0.7598566308243728 val_auroc:  0.703125 val_auprc:  0.807455167381638
Validation loss improved to 0.236799
epoch:  105
epoch:  106 val_loss:  tensor(0.2364) val_f1:  0.7598566308243728 val_auroc:  0.703125 val_auprc:  0.807455167381638
Validation loss improved to 0.236398
epoch:  106
epoch:  107 val_loss:  tensor(0.2360) val_f1:  0.7598566308243728 val_auroc:  0.703125 val_auprc:  0.807455167381638
Validation loss improved to 0.235994
epoch:  107
epoch:  108 val_loss:  tensor(0.2356) val_f1:  0.7598566308243728 val_auroc:  0.703125 val_auprc:  0.807455167381638
Validation loss improved to 0.235589
epoch:  108
epoch:  109 val_loss:  tensor(0.2352) val_f1:  0.7598566308243728 val_auroc:  0.703125 val_auprc:  0.807455167381638
Validation loss improved to 0.235179
epoch:  109
epoch:  110 val_loss:  tensor(0.2348) val_f1:  0.7598566308243728 val_auroc:  0.703125 val_auprc:  0.807455167381638
Validation loss improved to 0.234767
epoch:  110
epoc

epoch:  159 val_loss:  tensor(0.2132) val_f1:  0.8 val_auroc:  0.78125 val_auprc:  0.9164209411911617
Validation loss improved to 0.213204
epoch:  159
epoch:  160 val_loss:  tensor(0.2128) val_f1:  0.8 val_auroc:  0.78125 val_auprc:  0.9164209411911617
Validation loss improved to 0.212771
epoch:  160
epoch:  161 val_loss:  tensor(0.2123) val_f1:  0.8 val_auroc:  0.78125 val_auprc:  0.9164209411911617
Validation loss improved to 0.212338
epoch:  161
epoch:  162 val_loss:  tensor(0.2119) val_f1:  0.8 val_auroc:  0.78125 val_auprc:  0.9164209411911617
Validation loss improved to 0.211907
epoch:  162
epoch:  163 val_loss:  tensor(0.2115) val_f1:  0.8 val_auroc:  0.78125 val_auprc:  0.9164209411911617
Validation loss improved to 0.211478
epoch:  163
epoch:  164 val_loss:  tensor(0.2111) val_f1:  0.8 val_auroc:  0.78125 val_auprc:  0.9164209411911617
Validation loss improved to 0.211052
epoch:  164
epoch:  165 val_loss:  tensor(0.2106) val_f1:  0.8 val_auroc:  0.78125 val_auprc:  0.916420941

epoch:  13 val_loss:  tensor(0.2483) val_f1:  0.7285714285714285 val_auroc:  0.6875 val_auprc:  0.923501871150865
EarlyStopping counter: 1 out of 50
epoch:  13
epoch:  14 val_loss:  tensor(0.2476) val_f1:  0.7285714285714285 val_auroc:  0.6875 val_auprc:  0.923501871150865
EarlyStopping counter: 2 out of 50
epoch:  14
epoch:  15 val_loss:  tensor(0.2472) val_f1:  0.6826666666666666 val_auroc:  0.6875 val_auprc:  0.923501871150865
Validation loss improved to 0.247214
epoch:  15
epoch:  16 val_loss:  tensor(0.2477) val_f1:  0.581074168797954 val_auroc:  0.6875 val_auprc:  0.923501871150865
EarlyStopping counter: 1 out of 50
epoch:  16
epoch:  17 val_loss:  tensor(0.2475) val_f1:  0.6333333333333333 val_auroc:  0.6875 val_auprc:  0.923501871150865
EarlyStopping counter: 2 out of 50
epoch:  17
epoch:  18 val_loss:  tensor(0.2472) val_f1:  0.6849002849002848 val_auroc:  0.6875 val_auprc:  0.923501871150865
Validation loss improved to 0.247165
epoch:  18
epoch:  19 val_loss:  tensor(0.2474) 

epoch:  64 val_loss:  tensor(0.2511) val_f1:  0.6266666666666667 val_auroc:  0.65625 val_auprc:  0.9089077923334405
EarlyStopping counter: 46 out of 50
epoch:  64
epoch:  65 val_loss:  tensor(0.2512) val_f1:  0.6266666666666667 val_auroc:  0.65625 val_auprc:  0.9089077923334405
EarlyStopping counter: 47 out of 50
epoch:  65
epoch:  66 val_loss:  tensor(0.2514) val_f1:  0.6266666666666667 val_auroc:  0.65625 val_auprc:  0.9089077923334405
EarlyStopping counter: 48 out of 50
epoch:  66
epoch:  67 val_loss:  tensor(0.2515) val_f1:  0.6266666666666667 val_auroc:  0.65625 val_auprc:  0.9089077923334405
EarlyStopping counter: 49 out of 50
epoch:  67
epoch:  68 val_loss:  tensor(0.2516) val_f1:  0.6266666666666667 val_auroc:  0.65625 val_auprc:  0.9089077923334405
EarlyStopping counter: 50 out of 50
Early stopping triggered
Early stopping triggered!
Time taken:  5.928755760192871
tensor(0.2381)
epoch:  0
epoch:  1 val_loss:  tensor(0.2466) val_f1:  0.7699248120300751 val_auroc:  0.35416666666

epoch:  44 val_loss:  tensor(0.2237) val_f1:  0.724812030075188 val_auroc:  0.7916666666666667 val_auprc:  0.9597298631451109
Validation loss improved to 0.223678
epoch:  44
epoch:  45 val_loss:  tensor(0.2233) val_f1:  0.724812030075188 val_auroc:  0.7916666666666667 val_auprc:  0.9597298631451109
Validation loss improved to 0.223312
epoch:  45
epoch:  46 val_loss:  tensor(0.2230) val_f1:  0.724812030075188 val_auroc:  0.7916666666666667 val_auprc:  0.9597298631451109
Validation loss improved to 0.222956
epoch:  46
epoch:  47 val_loss:  tensor(0.2226) val_f1:  0.724812030075188 val_auroc:  0.7916666666666667 val_auprc:  0.9597298631451109
Validation loss improved to 0.222618
epoch:  47
epoch:  48 val_loss:  tensor(0.2223) val_f1:  0.724812030075188 val_auroc:  0.7916666666666667 val_auprc:  0.9597298631451109
Validation loss improved to 0.222279
epoch:  48
epoch:  49 val_loss:  tensor(0.2219) val_f1:  0.724812030075188 val_auroc:  0.7916666666666667 val_auprc:  0.9597298631451109
Vali

epoch:  92 val_loss:  tensor(0.2024) val_f1:  0.7670901391409558 val_auroc:  0.8125 val_auprc:  0.966425511247009
Validation loss improved to 0.202390
epoch:  92
epoch:  93 val_loss:  tensor(0.2018) val_f1:  0.7670901391409558 val_auroc:  0.8125 val_auprc:  0.966425511247009
Validation loss improved to 0.201807
epoch:  93
epoch:  94 val_loss:  tensor(0.2012) val_f1:  0.7670901391409558 val_auroc:  0.8125 val_auprc:  0.966425511247009
Validation loss improved to 0.201220
epoch:  94
epoch:  95 val_loss:  tensor(0.2006) val_f1:  0.8165413533834586 val_auroc:  0.8125 val_auprc:  0.966425511247009
Validation loss improved to 0.200623
epoch:  95
epoch:  96 val_loss:  tensor(0.2000) val_f1:  0.8165413533834586 val_auroc:  0.8125 val_auprc:  0.966425511247009
Validation loss improved to 0.200017
epoch:  96
epoch:  97 val_loss:  tensor(0.1994) val_f1:  0.8165413533834586 val_auroc:  0.8125 val_auprc:  0.966425511247009
Validation loss improved to 0.199403
epoch:  97
epoch:  98 val_loss:  tensor

epoch:  143 val_loss:  tensor(0.1638) val_f1:  0.8165413533834586 val_auroc:  0.9166666666666667 val_auprc:  0.9858209978070176
Validation loss improved to 0.163811
epoch:  143
epoch:  144 val_loss:  tensor(0.1629) val_f1:  0.8165413533834586 val_auroc:  0.9166666666666667 val_auprc:  0.9858209978070176
Validation loss improved to 0.162938
epoch:  144
epoch:  145 val_loss:  tensor(0.1621) val_f1:  0.8165413533834586 val_auroc:  0.9166666666666667 val_auprc:  0.9858209978070176
Validation loss improved to 0.162063
epoch:  145
epoch:  146 val_loss:  tensor(0.1612) val_f1:  0.8165413533834586 val_auroc:  0.9166666666666667 val_auprc:  0.9858209978070176
Validation loss improved to 0.161187
epoch:  146
epoch:  147 val_loss:  tensor(0.1603) val_f1:  0.8165413533834586 val_auroc:  0.9166666666666667 val_auprc:  0.9858209978070176
Validation loss improved to 0.160310
epoch:  147
epoch:  148 val_loss:  tensor(0.1594) val_f1:  0.8165413533834586 val_auroc:  0.9166666666666667 val_auprc:  0.9858

epoch:  194 val_loss:  tensor(0.1238) val_f1:  0.8602540834845736 val_auroc:  0.9375 val_auprc:  0.9898574561403508
Validation loss improved to 0.123843
epoch:  194
epoch:  195 val_loss:  tensor(0.1232) val_f1:  0.8602540834845736 val_auroc:  0.9375 val_auprc:  0.9898574561403508
Validation loss improved to 0.123234
epoch:  195
epoch:  196 val_loss:  tensor(0.1226) val_f1:  0.8602540834845736 val_auroc:  0.9375 val_auprc:  0.9898574561403508
Validation loss improved to 0.122631
epoch:  196
epoch:  197 val_loss:  tensor(0.1220) val_f1:  0.8602540834845736 val_auroc:  0.9375 val_auprc:  0.9898574561403508
Validation loss improved to 0.122038
epoch:  197
epoch:  198 val_loss:  tensor(0.1215) val_f1:  0.8602540834845736 val_auroc:  0.9375 val_auprc:  0.9898574561403508
Validation loss improved to 0.121458
epoch:  198
epoch:  199 val_loss:  tensor(0.1209) val_f1:  0.8602540834845736 val_auroc:  0.9375 val_auprc:  0.9898574561403508
Validation loss improved to 0.120888
epoch:  199
epoch:  20

epoch:  44 val_loss:  tensor(0.2445) val_f1:  0.751394615571186 val_auroc:  0.4375 val_auprc:  0.8040837124476831
Validation loss improved to 0.244515
epoch:  44
epoch:  45 val_loss:  tensor(0.2442) val_f1:  0.751394615571186 val_auroc:  0.4375 val_auprc:  0.8040837124476831
Validation loss improved to 0.244194
epoch:  45
epoch:  46 val_loss:  tensor(0.2439) val_f1:  0.751394615571186 val_auroc:  0.4375 val_auprc:  0.8040837124476831
Validation loss improved to 0.243895
epoch:  46
epoch:  47 val_loss:  tensor(0.2436) val_f1:  0.751394615571186 val_auroc:  0.4375 val_auprc:  0.8040837124476831
Validation loss improved to 0.243602
epoch:  47
epoch:  48 val_loss:  tensor(0.2433) val_f1:  0.751394615571186 val_auroc:  0.4375 val_auprc:  0.8040837124476831
Validation loss improved to 0.243308
epoch:  48
epoch:  49 val_loss:  tensor(0.2430) val_f1:  0.7894736842105263 val_auroc:  0.4375 val_auprc:  0.8040837124476831
Validation loss improved to 0.243015
epoch:  49
epoch:  50 val_loss:  tenso

epoch:  95 val_loss:  tensor(0.2315) val_f1:  0.828708133971292 val_auroc:  0.6041666666666667 val_auprc:  0.9010331172095878
Validation loss improved to 0.231528
epoch:  95
epoch:  96 val_loss:  tensor(0.2313) val_f1:  0.828708133971292 val_auroc:  0.625 val_auprc:  0.9063677131691839
Validation loss improved to 0.231292
epoch:  96
epoch:  97 val_loss:  tensor(0.2311) val_f1:  0.828708133971292 val_auroc:  0.625 val_auprc:  0.9063677131691839
Validation loss improved to 0.231059
epoch:  97
epoch:  98 val_loss:  tensor(0.2308) val_f1:  0.828708133971292 val_auroc:  0.625 val_auprc:  0.9063677131691839
Validation loss improved to 0.230825
epoch:  98
epoch:  99 val_loss:  tensor(0.2306) val_f1:  0.828708133971292 val_auroc:  0.625 val_auprc:  0.9063677131691839
Validation loss improved to 0.230589
epoch:  99
epoch:  100 val_loss:  tensor(0.2304) val_f1:  0.828708133971292 val_auroc:  0.625 val_auprc:  0.9063677131691839
Validation loss improved to 0.230356
epoch:  100
epoch:  101 val_los

epoch:  146 val_loss:  tensor(0.2282) val_f1:  0.828708133971292 val_auroc:  0.625 val_auprc:  0.9107057362572069
EarlyStopping counter: 17 out of 50
epoch:  146
epoch:  147 val_loss:  tensor(0.2285) val_f1:  0.828708133971292 val_auroc:  0.625 val_auprc:  0.9107057362572069
EarlyStopping counter: 18 out of 50
epoch:  147
epoch:  148 val_loss:  tensor(0.2287) val_f1:  0.828708133971292 val_auroc:  0.625 val_auprc:  0.9107057362572069
EarlyStopping counter: 19 out of 50
epoch:  148
epoch:  149 val_loss:  tensor(0.2290) val_f1:  0.828708133971292 val_auroc:  0.625 val_auprc:  0.9107057362572069
EarlyStopping counter: 20 out of 50
epoch:  149
epoch:  150 val_loss:  tensor(0.2293) val_f1:  0.828708133971292 val_auroc:  0.625 val_auprc:  0.9107057362572069
EarlyStopping counter: 21 out of 50
epoch:  150
epoch:  151 val_loss:  tensor(0.2296) val_f1:  0.828708133971292 val_auroc:  0.6458333333333334 val_auprc:  0.9190762719714926
EarlyStopping counter: 22 out of 50
epoch:  151
epoch:  152 val

epoch:  17 val_loss:  tensor(0.1912) val_f1:  0.7124060150375939 val_auroc:  0.95 val_auprc:  0.9885620915032679
EarlyStopping counter: 1 out of 50
epoch:  17
epoch:  18 val_loss:  tensor(0.1959) val_f1:  0.5501990269792129 val_auroc:  0.95 val_auprc:  0.9885620915032679
EarlyStopping counter: 2 out of 50
epoch:  18
epoch:  19 val_loss:  tensor(0.1941) val_f1:  0.6614797864225782 val_auroc:  0.95 val_auprc:  0.9885620915032679
EarlyStopping counter: 3 out of 50
epoch:  19
epoch:  20 val_loss:  tensor(0.1905) val_f1:  0.8083670715249661 val_auroc:  0.95 val_auprc:  0.9885620915032679
EarlyStopping counter: 4 out of 50
epoch:  20
epoch:  21 val_loss:  tensor(0.1901) val_f1:  0.8083670715249661 val_auroc:  0.95 val_auprc:  0.9885620915032679
EarlyStopping counter: 5 out of 50
epoch:  21
epoch:  22 val_loss:  tensor(0.1921) val_f1:  0.7124060150375939 val_auroc:  0.95 val_auprc:  0.9885620915032679
EarlyStopping counter: 6 out of 50
epoch:  22
epoch:  23 val_loss:  tensor(0.1937) val_f1:  

(idx              2.000000
 time_elapsed    12.402620
 f1               0.811988
 auroc            0.813333
 auprc            0.953367
 dtype: float64,
 idx             1.581139
 time_elapsed    6.003646
 f1              0.109971
 auroc           0.156504
 auprc           0.039609
 dtype: float64)

In [16]:
result['auprc']

[0.9533009015086221,
 0.9089077923334405,
 0.9898574561403508,
 0.9190762719714926,
 0.9956944444444444]

In [None]:
# Final dataset
#### Create dataset and dataloaders
train_dataset = ProteinSequenceDataset(X.reshape(X.shape[0],1,-1), y)
train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)


##### Calculate class weights
labels_tensor = torch.tensor(y)
class_counts = torch.bincount(labels_tensor)
pos_weight = class_counts[0]/class_counts[1]
print(pos_weight)

In [8]:
# #### model init
# input_size = X.shape[1] #embedding size for esm2_t33_650M_UR50D (HLA)
# model = TDPredictor(input_size, 
#                      hidden_channels=config['hidden_channels'],
#                      dropout_p=config['dropout_p'],
#                      kernel_size = config['kernel_size'],
#                      pool_kernel_size = config['pool_kernel_size'])

# model.early_stopping = False


# # model training
# start = time.time()
# model.train_loop(train_loader=train_loader, 
#                  valid_loader=None, 
#                  pos_weight=pos_weight,
#                  config_dict=config_dict)
# end = time.time()

# time_elapsed = end-start
# print("Time taken: ", time_elapsed)

# result = {'time_elapsed': [time_elapsed]}
# result['random_seed'] = [seed] 

epoch:  0
epoch:  1
epoch:  2
epoch:  3
epoch:  4
epoch:  5
epoch:  6
epoch:  7
epoch:  8
epoch:  9
epoch:  10
epoch:  11
epoch:  12
epoch:  13
epoch:  14
epoch:  15
epoch:  16
epoch:  17
epoch:  18
epoch:  19
epoch:  20
epoch:  21
epoch:  22
epoch:  23
epoch:  24
epoch:  25
epoch:  26
epoch:  27
epoch:  28
epoch:  29
epoch:  30
epoch:  31
epoch:  32
epoch:  33
epoch:  34
epoch:  35
epoch:  36
epoch:  37
epoch:  38
epoch:  39
epoch:  40
epoch:  41
epoch:  42
epoch:  43
epoch:  44
epoch:  45
epoch:  46
epoch:  47
epoch:  48
epoch:  49
epoch:  50
epoch:  51
epoch:  52
epoch:  53
epoch:  54
epoch:  55
epoch:  56
epoch:  57
epoch:  58
epoch:  59
epoch:  60
epoch:  61
epoch:  62
epoch:  63
epoch:  64
epoch:  65
epoch:  66
epoch:  67
epoch:  68
epoch:  69
epoch:  70
epoch:  71
epoch:  72
epoch:  73
epoch:  74
epoch:  75
epoch:  76
epoch:  77
epoch:  78
epoch:  79
epoch:  80
epoch:  81
epoch:  82
epoch:  83
epoch:  84
epoch:  85
epoch:  86
epoch:  87
epoch:  88
epoch:  89
epoch:  90
epoch:  9

In [9]:
summary(model)

Layer (type:depth-idx)                   Param #
TDPredictor                              --
├─Sequential: 1-1                        --
│    └─Conv1d: 2-1                       1,792
│    └─BatchNorm1d: 2-2                  256
│    └─Tanh: 2-3                         --
│    └─MaxPool1d: 2-4                    --
│    └─Dropout: 2-5                      --
│    └─Flatten: 2-6                      --
│    └─Linear: 2-7                       12,417
├─BCEWithLogitsLoss: 1-2                 --
Total params: 14,465
Trainable params: 14,465
Non-trainable params: 0

In [10]:
# torch.save(model,'./results/final_model/TD.pt')

In [45]:

# Model correctly classifies B4402 as tapasin dependent and B4405 as tapasin independent
model = torch.load('./results/final_model/TD.pt')

with torch.no_grad():
    print(torch.sigmoid(model(torch.tensor(scaler.transform(embedding_dict['HLA-B*44:02']).reshape(1,1,-1),dtype=torch.float32))))
    print(torch.sigmoid(model(torch.tensor(scaler.transform(embedding_dict['HLA-B*44:05']).reshape(1,1,-1),dtype=torch.float32))))

tensor([[0.7300]])
tensor([[0.3594]])


## RF

In [17]:
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
result = {'idx':[], 'f1':[],'auroc':[],'auprc':[]}

for idx, (train, test) in enumerate(skf.split(X, y)):
    result['idx'] = result['idx'] + [idx]
    
    X_train, y_train = X[train], y[train]
    X_test, y_test = X[test], y[test]
    
    # Calculate class weights
    labels_tensor = torch.tensor(y_train)
    class_counts = torch.bincount(labels_tensor)
    pos_weight = class_counts[0]/class_counts[1]
    print(pos_weight)
    
    # Initialize and train the Random Forest regressor
    RF = RandomForestClassifier(n_estimators=100, max_depth=5, random_state=seed, class_weight={1:pos_weight})
    RF.fit(X_train.squeeze(), y_train)

    # Make predictions on the test set
    y_pred = RF.predict(X_test.squeeze())
    y_proba = RF.predict_proba(X_test.squeeze())

    f1 = f1_score(y_test, y_pred, average='weighted')
        
    fpr, tpr, _ = roc_curve(y_test, y_proba[:,1])
    auroc = auc(fpr, tpr)

    precision, recall, _ = precision_recall_curve(y_test, y_proba[:,1])
    auprc = auc(recall, precision)
    
    result['f1'] =  result['f1'] + [f1]
    result['auroc'] =  result['auroc'] + [auroc]
    result['auprc'] =  result['auprc'] + [auprc]
    
pd.DataFrame(result).mean(), pd.DataFrame(result).std()

tensor(0.2222)
tensor(0.2222)
tensor(0.2381)
tensor(0.2381)
tensor(0.2188)


(idx      2.000000
 f1       0.789244
 auroc    0.718333
 auprc    0.902864
 dtype: float64,
 idx      1.581139
 f1       0.090280
 auroc    0.185700
 auprc    0.079007
 dtype: float64)

In [18]:
result['auroc'], result['auprc']

([0.875, 0.5625, 0.9583333333333334, 0.5625, 0.6333333333333333],
 [0.971515415376677,
  0.8106995742529315,
  0.992172181372549,
  0.8415002889267594,
  0.8984314570305283])

# LR

In [19]:
from sklearn.linear_model import LogisticRegression

skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
result = {'idx':[], 'f1':[],'auroc':[],'auprc':[]}

for idx, (train, test) in enumerate(skf.split(X, y)):
    result['idx'] = result['idx'] + [idx]
    
    X_train, y_train = X[train], y[train]
    X_test, y_test = X[test], y[test]
    
    clf = LogisticRegression(random_state=42).fit(X_train.squeeze(), y_train)

    # Make predictions on the test set
    y_pred = clf.predict(X_test.squeeze())
    y_proba = clf.predict_proba(X_test.squeeze())


    f1 = f1_score(y_test, y_pred, average='weighted')
        
    fpr, tpr, _ = roc_curve(y_test, y_proba[:,1])
    auroc = auc(fpr, tpr)

    precision, recall, _ = precision_recall_curve(y_test, y_proba[:,1])
    auprc = auc(recall, precision)
    
    result['f1'] =  result['f1'] + [f1]
    result['auroc'] =  result['auroc'] + [auroc]
    result['auprc'] =  result['auprc'] + [auprc]
    
pd.DataFrame(result).mean(), pd.DataFrame(result).std()

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

(idx      2.000000
 f1       0.819686
 auroc    0.765417
 auprc    0.918160
 dtype: float64,
 idx      1.581139
 f1       0.073169
 auroc    0.200193
 auprc    0.095834
 dtype: float64)

In [20]:
result['auroc'], result['auprc']

([0.9375,
  0.71875,
  0.9791666666666667,
  0.7083333333333334,
  0.48333333333333334],
 [0.9852685866013071,
  0.9225698654200976,
  0.9962086397058824,
  0.9297994897259603,
  0.7569555553379084])

In [21]:
from sklearn import svm


skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
result = {'idx':[], 'f1':[],'auroc':[],'auprc':[]}

for idx, (train, test) in enumerate(skf.split(X, y)):
    result['idx'] = result['idx'] + [idx]
    
    X_train, y_train = X[train], y[train]
    X_test, y_test = X[test], y[test]
    
    # , 'poly', 'rbf', 'sigmoid'
    clf = svm.SVC(probability=True, degree=3, kernel='poly').fit(X_train.squeeze(), y_train)

    # Make predictions on the test set
    y_pred = clf.predict(X_test.squeeze())
    y_proba = clf.predict_proba(X_test.squeeze())



    f1 = f1_score(y_test, y_pred, average='weighted')
        
    fpr, tpr, _ = roc_curve(y_test, y_proba[:,1])
    auroc = auc(fpr, tpr)

    precision, recall, _ = precision_recall_curve(y_test, y_proba[:,1])
    auprc = auc(recall, precision)
    
    result['f1'] =  result['f1'] + [f1]
    result['auroc'] =  result['auroc'] + [auroc]
    result['auprc'] =  result['auprc'] + [auprc]
    
pd.DataFrame(result).mean(), pd.DataFrame(result).std()



(idx      2.000000
 f1       0.806487
 auroc    0.761250
 auprc    0.933329
 dtype: float64,
 idx      1.581139
 f1       0.063196
 auroc    0.164778
 auprc    0.056385
 dtype: float64)

In [22]:
result['auroc'], result['auprc']

([0.875, 0.78125, 0.9583333333333334, 0.625, 0.5666666666666667],
 [0.9634443859857463,
  0.950719322816846,
  0.9928513071895425,
  0.9132602997492703,
  0.8463694747518277])