In [1]:
import sys

import os
import json
import time
import argparse

import pandas as pd
import numpy as np

from Bio import SeqIO
import esm

import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset

from torchinfo import summary

from sklearn.metrics import f1_score, roc_auc_score, roc_curve, auc, precision_recall_curve
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import StandardScaler

import wandb

sys.path.append('./../../../src/')

from utils import *
from utils_torch import * 
from MHCCBM import *
from PGPredictor_CNN import *
from tqdm import tqdm

In [2]:
# read peptides list
flank0_df = pd.read_csv('./../../../data/PG/esm1b/flank'+str(0)+'_peptides_esm1b.csv',index_col=0)
flank0_df

Unnamed: 0,peptide,hit,0,1,2,3,4,5,6,7,...,1270,1271,1272,1273,1274,1275,1276,1277,1278,1279
0,ILRPKPDYF,0.0,0.533767,1.457934,1.779880,1.169885,-0.719667,1.368618,-0.419608,0.639073,...,-0.299567,1.581104,-0.044787,0.533028,-0.205300,0.492291,-0.345337,-0.407906,0.564666,-0.563590
1,IYPPGFSYL,0.0,0.680211,0.295262,1.714029,1.323705,-0.237892,1.272062,-0.181969,0.669378,...,0.072727,2.108028,0.541240,0.055342,-0.848335,0.260021,-0.807489,0.198852,0.629822,-0.655833
2,RYMPQNPCII,1.0,1.246705,0.606267,1.848789,1.179841,0.466007,-0.020620,-0.366055,0.944435,...,-0.581674,1.101377,0.572914,0.947558,0.473713,-0.519548,-1.464290,0.026348,0.093463,0.058640
3,NYSVNGNCEW,0.0,1.919065,0.869667,0.985458,0.445299,0.271855,0.105966,0.492187,-0.248743,...,-0.638702,-1.316010,0.213583,1.149094,0.373075,-0.711103,-0.847282,0.016607,1.334911,-0.414205
4,LFCDFGEEM,0.0,1.277201,2.439435,0.759459,0.604564,-1.739369,0.803881,0.101670,-0.449408,...,0.290795,1.325325,1.342853,1.117716,0.400609,0.980317,0.538834,0.267250,0.730641,-1.842685
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
297543,AIIHGQMAYI,0.0,-0.204722,0.361943,0.376750,0.218967,0.053348,-0.406903,-1.334947,1.373394,...,0.680224,2.435557,0.793749,1.056151,-0.783031,-0.228265,-1.012242,0.847001,-0.322014,0.437284
297544,GLPEYLSKT,0.0,0.987672,0.042746,0.537548,1.322701,-0.548914,-0.227333,0.410077,0.104871,...,-0.025391,1.265950,-0.156332,0.253393,-1.491441,-0.875211,0.134408,-0.503012,0.850081,-0.875928
297545,RLYPELPSQL,1.0,0.939923,0.162232,0.693979,0.949450,-0.567565,0.443507,-0.485680,0.758387,...,-0.276031,2.206114,0.142970,-0.075833,-1.424216,-0.885639,0.603988,-0.622470,0.679971,-1.805622
297546,HLAKLKEAV,0.0,1.050682,1.284589,-0.609613,0.696779,-0.943364,-0.633307,-0.708188,-0.545659,...,0.070686,1.785624,0.075044,1.157744,-1.672122,-0.336969,1.009572,-0.394122,0.312486,-0.375061


In [4]:
# Make  and y
X = flank0_df.drop(['peptide','hit'],axis=1).to_numpy()
y = flank0_df['hit'].to_numpy()

In [5]:
# Scale the data
scaler = StandardScaler()
X = scaler.fit_transform(X.squeeze())
X = torch.tensor(X, dtype=torch.float32).reshape(X.shape[0],1,X.shape[1])


In [10]:
seed = 42
config_dict = { "config" : {
                    "hidden_channels" : [32],
                    "kernel_size":512,
                    "pool_kernel_size":512,
                    "epochs" : 1000,
                    "batch_size" : 1024,
                    "lr" : 1e-3,
                    "dropout_p" : 0.,
                    "architecture" : "CNN"
                }
        }
config = config_dict['config']

In [11]:
# Split the data
train_sequences, test_sequences, train_labels, test_labels = train_test_split(X, y, 
                                                                              test_size=0.2, random_state=seed)

# Create dataset and dataloaders
train_dataset = ProteinSequenceDataset(train_sequences, train_labels)
test_dataset = ProteinSequenceDataset(test_sequences, test_labels)

train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False)

#### model training
input_size = X.shape[-1] #embedding size for esm2_t33_650M_UR50D (allele + peptide)
model = PGPredictor(input_size, config['hidden_channels'], 
                                 config['kernel_size'], config['pool_kernel_size'], config['dropout_p'])

# Calculate class weights
labels_tensor = torch.tensor(train_labels, dtype=torch.int16)
class_counts = torch.bincount(labels_tensor)
pos_weight = class_counts[0]/class_counts[1]
pos_weight = pos_weight.to(dtype=torch.float32)

summary(model, input_size=(config['batch_size'],1,input_size))

Layer (type:depth-idx)                   Output Shape              Param #
PGPredictor                              [1024, 1]                 --
├─Sequential: 1-1                        [1024, 1]                 --
│    └─Conv1d: 2-1                       [1024, 32, 769]           16,416
│    └─BatchNorm1d: 2-2                  [1024, 32, 769]           64
│    └─Tanh: 2-3                         [1024, 32, 769]           --
│    └─MaxPool1d: 2-4                    [1024, 32, 1]             --
│    └─Dropout: 2-5                      [1024, 32, 1]             --
│    └─Flatten: 2-6                      [1024, 32]                --
│    └─Linear: 2-7                       [1024, 1]                 33
Total params: 16,513
Trainable params: 16,513
Non-trainable params: 0
Total mult-adds (G): 12.93
Input size (MB): 5.24
Forward/backward pass size (MB): 403.19
Params size (MB): 0.07
Estimated Total Size (MB): 408.49

In [12]:
start = time.time()
model.train_loop(train_loader=train_loader, valid_loader=test_loader, test_loader=None,
                 config_dict=config_dict, pos_weight=pos_weight)
end = time.time()
time_elapsed = end - start
print("Time for training: ", time_elapsed)

epoch:  0
[Epoch 1, Batch 10] loss: 1.561
[Epoch 1, Batch 20] loss: 1.357
[Epoch 1, Batch 30] loss: 1.250
[Epoch 1, Batch 40] loss: 1.118
[Epoch 1, Batch 50] loss: 1.039
[Epoch 1, Batch 60] loss: 0.960
[Epoch 1, Batch 70] loss: 0.938
[Epoch 1, Batch 80] loss: 0.912
[Epoch 1, Batch 90] loss: 0.916
[Epoch 1, Batch 100] loss: 0.920
[Epoch 1, Batch 110] loss: 0.912
[Epoch 1, Batch 120] loss: 0.913
[Epoch 1, Batch 130] loss: 0.908
[Epoch 1, Batch 140] loss: 0.904
[Epoch 1, Batch 150] loss: 0.909
[Epoch 1, Batch 160] loss: 0.904
[Epoch 1, Batch 170] loss: 0.901
[Epoch 1, Batch 180] loss: 0.904
[Epoch 1, Batch 190] loss: 0.907
[Epoch 1, Batch 200] loss: 0.899
[Epoch 1, Batch 210] loss: 0.904
[Epoch 1, Batch 220] loss: 0.905
[Epoch 1, Batch 230] loss: 0.905
epoch:  1 val_f1:  0.6064339884131245 val_auroc:  0.6011333125048586 val_auprc:  0.41222767186962994
epoch:  1
[Epoch 2, Batch 10] loss: 0.907
[Epoch 2, Batch 20] loss: 0.902
[Epoch 2, Batch 30] loss: 0.894
[Epoch 2, Batch 40] loss: 0.894
[


KeyboardInterrupt

