In [None]:
import numpy as np
import pandas as pd
import sys
import csv
import random
import os
from io import StringIO
import keras
from keras.layers import Input, Dense, concatenate, Dropout
from keras.models import Model, load_model
from keras import backend as K
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as Data
import time
from sklearn.model_selection import KFold
import matplotlib.pyplot as plt
import seaborn as sns
from sparsemax import Sparsemax
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, recall_score, roc_auc_score, roc_curve, f1_score, confusion_matrix
from tqdm import tqdm
from scipy.stats import spearmanr

In [None]:
########################### Define how you want to encode ##########################

#The path used to get the HLA dictionary：
aa_dict_dir='../library/Atchley_factors.csv'
hla_db_dir='../library/hla_library'
aa_list = [letter for letter in 'ARNDCEQGHILKMFPSTWYV']
aa_to_idx_dict = {aa: idx+1 for idx, aa in enumerate(aa_list)}
aa_dict = {aa: idx+1 for idx, aa in enumerate(aa_list)}

pca15_aaidx = {
    'A': np.array([
        -0.97090603, -0.32368135, 15.72065182, -0.50884066,  3.74021858,
        -0.77879681,  3.38667656, -0.91306275,  3.00614751, -2.32919906,
         0.7870219 , -1.87437966,  1.53335388,  1.344461  ,  3.35815929]),
    'R': np.array([
         8.53814625, -13.78745836,  -4.7706284 ,  -1.16449378,
        -9.46561383,   6.79244266,   1.98971004,  -5.15274046,
        -2.93772088,  -5.89206951,   5.26993696,   1.61036613,
        -1.95208183,  -0.64409   ,   1.53934309]),
    'N': np.array([
        14.86976088,  1.57062419, -3.62321751,  5.63672398, -1.78396131,
        -3.62736822, -1.44652457,  5.39945262,  0.87616878,  4.1128776 ,
         1.91381611,  2.4206327 , -3.81386422, -5.242596  ,  3.97131233]),
    'D': np.array([
        18.12626589, -2.14738155, -0.25216984,  2.31366107,  6.52406372,
        -4.88386469, -9.13828094, -0.71024376, -2.10295526, -1.27865374,
         1.48217398,  1.43032329, -4.04631611,  5.72388976, -1.78674913]),
    'C': np.array([
        -8.36918251e+00,  8.30319345e+00, -6.61966946e+00,  1.38734140e+01,
         8.59531324e+00,  9.79914818e+00,  1.26911187e+00, -4.63752901e+00,
        -9.83921566e-01,  4.36342436e+00,  2.34365256e+00,  3.43445840e-01,
         1.18290876e+00,  1.32558804e-02, -6.96621503e-01]),
    'E': np.array([
        12.03159874, -13.30121908,   8.27995382,  -2.1963886 ,
         8.48308971,  -2.02004094,  -4.58165905,  -2.68526505,
        -2.60424798,  -0.13228263,  -0.2394944 ,  -2.40393485,
          4.1802745 ,  -4.40432039,   0.23979065]),
    'Q': np.array([
         7.9210965 , -8.69666506, -0.59701694,  0.02046296,  0.93026397,
         2.56565474,  1.14554766,  0.50500247,  0.71467344,  0.04438606,
        -6.00954372,  3.89329376,  2.44305793, -2.76002278, -1.6987172 ]),
    'G': np.array([
        14.83920873, 19.24138906,  5.93438829,  5.66738401, -5.81883134,
        -8.78866739,  5.73914235, -4.72849218, -4.20333219, -1.56492845,
        -1.37450169, -0.61253836,  0.6617134 , -0.69921372, -1.40639534]),
    'H': np.array([
         0.68805439, -6.17956596, -6.80550604,  3.94746531,  1.33297594,
        -0.51211304,  4.73783089,  8.74722377, -3.12401236, -1.57598186,
         1.98977266, -7.78869789,  0.99852322,  1.37575734, -0.43120018]),
    'I': np.array([
        -20.34084677,   4.14384693,   3.86394279,  -3.40504491,
         -2.78992754,   1.46082574,  -4.07088528,   0.58387152,
         -3.41337954,   1.84625632,  -3.12082912,  -1.14873131,
         -4.65093114,  -1.64689363,   1.46078507]),
    'L': np.array([
        -17.63623772,  -0.35239597,  11.83924317,  -5.37540851,
         -1.15954659,  -1.97240237,   0.8551395 ,   0.5305574 ,
          2.1201762 ,   4.05167588,   7.06745699,   2.57547229,
          1.4547904 ,   1.28288187,  -0.26554788]),
    'K': np.array([
        11.70658851, -13.64488577,   2.13621858,  -2.01605285,
        -6.38068619,   0.81226081,   4.37857192,  -1.6496082 ,
         1.42540472,   8.6009414 ,  -2.9141596 ,  -1.63197862,
        -1.76542326,   2.93593542,  -2.78651305]),
    'M': np.array([
        -15.59654594,  -5.74596901,   1.05654898,   1.8258127 ,
          6.65236588,  -1.21996347,   6.87749571,   2.80480871,
         -1.90746691,  -3.59099156,  -4.0585946 ,   5.83997368,
         -2.03857504,   2.47274257,   1.55383649]),
    'F': np.array([
        -18.59572859,   0.92235509,  -3.31743268,  -2.52174562,
         -0.45125923,  -4.09153376,  -0.39932096,   2.4112244 ,
         -1.0836413 ,  -1.47072537,   3.19432868,   2.04374527,
          0.29441083,  -3.18840009,  -6.45603094]),
    'P': np.array([
        16.21704723,  15.09127036,  -8.72320617, -18.77750006,
         6.46478124,   4.56244976,   3.05557799,   0.15690408,
        -0.66929999,   0.7338994 ,   0.3953321 ,  -0.16512658,
        -0.43811542,   0.05203475,   0.50555428]),
    'S': np.array([
        11.85274099,  6.88460214,  2.96725951,  3.6290185 , -2.88619312,
         2.32382186, -1.33830204,  3.31676758,  6.26373175, -0.89789266,
         0.739414  ,  1.15021805,  1.36104635,  1.13305848,  1.16395609]),
    'T': np.array([
         4.53029865,  5.12654936,  1.43605965,  1.7684372 , -3.39334852,
         5.24360932, -3.36448143,  2.20062621,  6.52685307, -4.74609137,
        -2.03988324, -0.65055289,  0.19272884, -0.25439792, -3.25573314]),
    'W': np.array([
        -16.29768602,  -3.90769937, -13.79255488,  -0.84325799,
          2.33264713,  -8.04187276,   0.50186494,  -6.42108915,
          7.2327353 ,  -1.12521691,  -1.01561524,  -3.54186466,
         -1.9445611 ,  -0.81438463,   1.37250825]),
    'Y': np.array([
        -7.49926   ,   1.31004507, -12.07792238,  -1.16106141,
        -6.8946393 ,  -3.23103968,  -4.86248111,   0.68815174,
        -2.45473764,   1.48233974,  -1.42869414,   2.18166427,
         7.4468281 ,   3.0783195 ,   2.84874949]),
    'V': np.array([
        -16.01441319,   5.49304582,   7.34505768,  -0.71258533,
         -4.03171245,   5.60745006,  -4.73473404,  -0.44655998,
         -2.68117513,  -0.63176764,  -2.98159019,  -3.67133046,
         -1.09976811,   0.24198258,   0.76951331])}

def enc_list_bl_max_len(aa_seqs, aa_codes, max_seq_len):
    '''
    aa_codes of a list of amino acid sequences with padding 
    to a max length

    parameters:
        - aa_seqs : list with AA sequences
        - aa_codes : dictionnary: key= AA, value= aa_codes
        - max_seq_len: common length for padding
    returns:
        - enc_aa_seq : list of np.ndarrays containing padded, encoded amino acid sequences
    '''

    # encode sequences:
    sequences=[]
    for seq in aa_seqs:
        e_seq=np.zeros((len(seq),len(aa_codes["A"])), dtype=np.int32)
        count=0
        for aa in seq:
            if aa in aa_codes:
                e_seq[count]=aa_codes[aa]
                count+=1
            else:
                sys.stderr.write("Unknown amino acid in peptides: "+ aa +", encoding aborted!\n")
                sys.exit(2)
                
        sequences.append(e_seq)

    # pad sequences:
    #max_seq_len = max([len(x) for x in aa_seqs])
    
    n_seqs = len(aa_seqs)
    n_features = sequences[0].shape[1]

    enc_aa_seq = np.zeros((n_seqs, max_seq_len, n_features), dtype=np.float32)
    for i in range(0,n_seqs):
        enc_aa_seq[i, :sequences[i].shape[0], :n_features] = sequences[i]

    return enc_aa_seq

In [None]:
########################### HLA pseudo-sequence ##########################

HLA_ABC=[hla_db_dir+'/A_prot.fasta',hla_db_dir+'/B_prot.fasta',hla_db_dir+'/C_prot.fasta',hla_db_dir+'/E_prot.fasta']
# HLA_ABC=[hla_db_dir+'/AA.txt']
HLA_seq_lib = {}

# 遍历 HLA_ABC 中的文件路径
for one_class in HLA_ABC:
    with open(one_class, 'r') as prot:
        name = ''
        sequence = ''
        for line in prot:
            if line.startswith('>HLA'):
                if name and sequence:
                    HLA_seq_lib[name] = sequence
                name = line.split(' ')[1]
                sequence = ''
            else:
                sequence += line.strip()
        
        # 处理最后一个 HLA
        if name and sequence:
            HLA_seq_lib[name] = sequence


def HLA_seq_list(HLA):
    hla_list = []

    for HLA_name in HLA:
        if HLA_name not in HLA_seq_lib.keys():
            if len([hla_allele for hla_allele in HLA_seq_lib.keys() if hla_allele.startswith(str(HLA_name))]) == 0:
                print('cannot find' + HLA_name)
            HLA_name = [hla_allele for hla_allele in HLA_seq_lib.keys() if hla_allele.startswith(str(HLA_name))][0]

        if HLA_name not in HLA_seq_lib.keys():
            print('Not proper HLA allele:' + HLA_name)

        HLA_sequence = HLA_seq_lib[HLA_name]
        hla_list.append(HLA_sequence)

    return hla_list

In [None]:
###########################  LSTM module ##########################

class DotProductScore(nn.Module):
    def __init__(self, hidden_size):
        super(DotProductScore, self).__init__()
        self.q = nn.Parameter(torch.empty(size=(hidden_size, 1), dtype=torch.float32))
        self.init_weights()

    def init_weights(self):
        # Initialize weight range
        initrange = 0.5
        self.q.data.uniform_(-initrange, initrange)

    def forward(self, inputs):
        """
        Input:
            - X: Input matrix, inputs=[batch_size, seq_length, hidden_size]
        Output:
            - scores: Output matrix, shape=[batch_size, seq_length]
        """
        # Calculate attention scores using dot product attention
        scores = torch.matmul(inputs, self.q)
        # Compress the last dimension of the tensor, from (batch_size, seq_length, 1) to (batch_size, seq_length)
        scores = scores.squeeze(-1)
        return scores

class Attention(nn.Module):
    def __init__(self, hidden_size):
        super(Attention, self).__init__()
        # Define the attention module, which includes the DotProductScore module
        self.scores = DotProductScore(hidden_size)

    def forward(self, X, valid_lens):
        # Calculate attention scores ((batch_size, seq_length))
        scores = self.scores(X)
        # Generate masks to mask positions beyond valid lengths
        arrange = torch.arange(X.size(1), dtype=torch.float32, device=X.device).unsqueeze(0)
        mask = (arrange < valid_lens.unsqueeze(-1)).float()
        # Set scores of invalid positions to negative infinity to make them approach zero in softmax operation
        scores = scores * mask - (1 - mask) * 1e9  # Mask invalid positions
        # Use softmax to get attention weights
        attention_weights = nn.functional.softmax(scores, dim=-1)
        # Apply attention weights to input tensor X to get the weighted average output
        out = torch.matmul(attention_weights.unsqueeze(1), X).squeeze(1)
        return out

class ModelLSTMAttention(nn.Module):
    def __init__(self, input_size=15, hidden_size=30, output_size=64, num_layers=2, dropout=0.4):
        super(ModelLSTMAttention, self).__init__()
        
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=dropout, bidirectional=True)
        # Define the attention module, where the input hidden_size needs to be multiplied by 2 because it is bidirectional
        self.attention = Attention(hidden_size * 2)
        
        self.dropout = nn.Dropout(p=dropout)
        
        # Define the final fully connected layer, the input dimension is also hidden_size * 2
        self.fc = nn.Linear(hidden_size * 2, output_size)
        
        self.dropout = nn.Dropout(p=dropout)
        
        
    def forward(self, seq, valid_lens):
        
        output, _ = self.lstm(seq)
        valid_lens = valid_lens.view(-1,).to(device)
        # Apply the attention module to get the weighted average output
        out = self.attention(output, valid_lens)
        
        out = self.dropout(out)

        out = self.fc(out)

        return out


In [None]:
###########################  Transformer module ##########################

sys.path.append("./python/")

class PositionalEmbedding(nn.Module):
    def __init__(self, max_steps, max_dims, dtype=torch.float32):
        super().__init__()
        if max_dims % 2 == 1: max_dims += 1
        p, i = np.meshgrid(np.arange(max_steps), np.arange(max_dims // 2))
        pos_emb = np.empty((1, max_steps, max_dims))
        pos_emb[0, :,  ::2] = np.sin(p / 10000**(2 * i / max_dims)).T
        pos_emb[0, :, 1::2] = np.cos(p / 10000**(2 * i / max_dims)).T
        self.positional_embeddding = torch.tensor(pos_emb, dtype=dtype)
        #positional_embeddding = torch.tensor(pos_emb, dtype=dtype)
        #self.register_buffer('positional_embeddding', positional_embeddding)
    def forward(self, inputs):
        shape = inputs.shape
        return inputs + self.positional_embeddding[:, :shape[-2], :shape[-1]]
      
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.embed_dim = embed_dim
        
#         self.sparsemax = Sparsemax(dim = 1)
        
        assert embed_dim % self.num_heads == 0
        self.depth = embed_dim // self.num_heads
        self.scale = torch.sqrt(torch.FloatTensor([self.depth]))
        #scale = torch.sqrt(torch.FloatTensor([self.depth]))
        #self.register_buffer('scale', scale)
        self.wq = nn.Linear(embed_dim, embed_dim)
        self.wk = nn.Linear(embed_dim, embed_dim)
        self.wv = nn.Linear(embed_dim, embed_dim)
        self.fc = nn.Linear(embed_dim, embed_dim)
        
    def forward(self, q, k, v, mask=None):
        batch_size = q.shape[0] # 
        
        # (batch_size, seq_len, embed_dim)
        Q = self.wq(q)  
        K = self.wk(k)
        V = self.wv(v)
        
        # (batch_size, num_heads, seq_len_q, depth) 
        Q = Q.view(batch_size, -1, self.num_heads, self.depth).permute(0,2,1,3)
        K = K.view(batch_size, -1, self.num_heads, self.depth).permute(0,2,1,3)
        V = V.view(batch_size, -1, self.num_heads, self.depth).permute(0,2,1,3)
        
        attention = torch.matmul(Q, K.permute(0,1,3,2)) / self.scale.to(device)
        
        
 
        
        if mask is not None:
            attention = attention.masked_fill(mask==0, -1e10)
        attention = torch.softmax(attention, dim=-1).to(device)
        
        x = torch.matmul(attention, V)
        x = x.permute(0, 2, 1, 3).contiguous() # (batch_size, seq_len_q, num_heads, depth)
        x = x.view(batch_size, -1, self.embed_dim) # (batch_size, seq_len, embed_dim)
        x = self.fc(x)
        return x
      
class Encoder(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
        super().__init__()
        self.mha = MultiHeadAttention(embed_dim, num_heads).to(device)
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim)
        )
        self.layernorm1 = nn.LayerNorm(embed_dim)
        self.layernorm2 = nn.LayerNorm(embed_dim)
        
        self.dropout1 = nn.Dropout(rate)
        self.dropout2 = nn.Dropout(rate)
        
    def forward(self, x):
        out1 = self.mha(x,x,x)
        out1 = self.dropout1(out1)
        out1 = self.layernorm1(x+out1)
        out2 = self.ffn(out1)
        out2 = self.dropout2(out2)
        out2 = self.layernorm2(out1+out2)
        return out2
class TokAndPosEmbedding(nn.Module):
    def __init__(self, embed_dim, max_len=9, voc_size=21):
        super().__init__()
#         self.pos = torch.arange(max_len)
        self.pos = torch.arange(max_len).to(device)
        #pos = torch.arange(max_len)
        #self.register_buffer('pos', pos)
        self.tok_embedding = nn.Embedding(voc_size, embed_dim)
        self.pos_embedding = nn.Embedding(max_len, embed_dim).to(device)
    def forward(self, x):
        return self.tok_embedding(x) + self.pos_embedding(self.pos)

class AttenCaldX(nn.Module):
    def __init__(self, embed_dim=16, num_heads=4, ff_dim=64, rate=0.1, max_len=12):
        super().__init__()
        self.emb = TokAndPosEmbedding(embed_dim, max_len, 21)
        self.encoder1 = Encoder(embed_dim, num_heads, ff_dim, rate)
        self.dropout1 = nn.Dropout(rate)
        self.encoder2 = Encoder(embed_dim, num_heads, ff_dim, rate)
        self.dropout2 = nn.Dropout(rate)
        self.dropout3 = nn.Dropout(rate)
        self.avg_pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(embed_dim*2, 64)
    def forward(self, x):
        batch_size = x.shape[0]
        x = self.emb(x)
        x = self.encoder1(x)
        x = self.dropout1(x)
        y = self.encoder2(x)
        y = self.dropout2(y)
        #out = x + y
        out = torch.concat([x,y], dim=2)
        out = out.permute(0,2,1)
        out = self.avg_pool(out)
        out = out.view(batch_size, -1)
        out = self.dropout3(out)
        out = self.fc(out)
        return out

In [None]:
###########################  model ##########################

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        # Initializing model components
        self.pep_extractor = AttenCaldX()
        self.hla_extractor = ModelLSTMAttention()

        # Defining convolutional layers
        self.conv_layers = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=8, kernel_size=(2,1)),
            nn.BatchNorm2d(8),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 1)),
            nn.Conv2d(in_channels=8, out_channels=16, kernel_size=(1,4)),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 1)),
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(1,4)),
            nn.ReLU(),
            # nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2))
        )

        # Defining fully connected layers
        self.fc_layers = nn.Sequential(
            nn.Linear(1792, 100),
            nn.ReLU(),
            nn.Linear(100, 100),
            nn.ReLU(),
            # nn.Dropout(0.4),
            nn.Linear(100, 1)
        )

    def forward(self, pep, hla_seq, hla_len):
        # Extracting features
        output1 = self.pep_extractor(pep)
        output2 = self.hla_extractor(hla_seq, hla_len)
        # Concatenating features
        x = torch.cat((output1.unsqueeze(1), output2.unsqueeze(1)), dim=1)
        x = x.unsqueeze(1)  # Adding a dimension as channel
        # Convolutional layers
        x = self.conv_layers(x)
        # Flattening features
        x = x.view(x.size(0), -1)
        # Fully connected layers
        x = self.fc_layers(x)

        return x


In [None]:

# Define the function to compute metrics
def compute_metrics(y_labels, y_preds, cutoff=0.5, digit=4):
    if (cutoff <= 0) or (cutoff >= 1):
        fpr, tpr, threshold = roc_curve(y_labels, y_preds)
        cutoff = sorted(list(zip(np.abs(tpr-fpr), threshold)), key=lambda i: i[0], reverse=True)[0][1]
    
    y_pred_labels = np.array([1. if p >= cutoff else 0. for p in y_preds])
    Accuracy = accuracy_score(y_labels, y_pred_labels)
    Recall = recall_score(y_labels, y_pred_labels)
    
    # Compute specificity
    tn, fp, fn, tp = confusion_matrix(y_labels, y_pred_labels).ravel()
    Specificity = tn / (tn + fp) if (tn + fp) != 0 else 0
    
    F1 = f1_score(y_labels, y_pred_labels)
    AUC = roc_auc_score(y_labels, y_preds)
    spearman_corr, _ = spearmanr(y_labels, y_preds)
    
    # Calculate FPR and TPR
    fpr, tpr, _ = roc_curve(y_labels, y_preds)
    
    return {'Accuracy': np.round(Accuracy, digit),
            'Sensitivity': np.round(Recall, digit),
            'Specificity': np.round(Specificity, digit),
            'Threshold': np.round(cutoff, digit),
            'F1': np.round(F1, digit),
            'AUC': np.round(AUC, digit),
            'SRCC':  np.round(spearman_corr,digit),
            
            'FPR': fpr,
            'TPR': tpr}


In [None]:

# # Load the HLA dataset
def load_hla_dataframe(fname):
    df = pd.read_table(fname) 
    # remove peptides with ambiguous symbols: X, B
    df = df[df.Peptide.str.contains('X|B') == False]
    return df

def get_hla_subtype(df, min_len=9, max_len=12):
    df = df[
             (df['Peptide'].str.len() >= min_len) &
             (df['Peptide'].str.len() <= max_len)]
    return df


def vectorize_peps(peps, max_len=12):
    num_pep = len(peps)
    X = np.zeros((num_pep, max_len), dtype=np.int32)
    for i in range(num_pep):
        aa_code_seq = [aa_dict[aa] for aa in peps[i]]
        pep_len = len(aa_code_seq)
        assert pep_len <= max_len
        X[i, max_len-pep_len:]=np.array(aa_code_seq)
        pass
    return np.array(X, dtype=np.int32)
    
class HLADataset(Dataset):
    def __init__(self, peptide, hla,lengths, label):
        self.peptide = peptide
        self.hla = hla
        self.label = label
        self.lengths=lengths

    def __len__(self):
        return len(self.label)

    def __getitem__(self, i):
        return self.peptide[i], self.hla[i], self.lengths[i],self.label[i]



#Learning rate scheduling
def adjust_learning_rate(epoch):
    if epoch < 100:
      return 0.0001
    elif epoch < 200:
      return 0.00005
    elif epoch < 250:
      return 0.00001
    else:
      return 0.000001


In [None]:
def data_processing(df, BATCH_SIZE=64):
    # Vectorize peptide sequences
    Peptide = vectorize_peps(list(df.Peptide))
    # Create labels
    label = np.array(
        [1 if x==1 else 0 for x in df.BindingCategory], dtype=np.float32).reshape(-1,1)
    print(label.shape)
    print(Peptide.shape)

    # Extract HLA sequences
    hla = df.HLA
    HLA_sequence = []
    HLA_sequence = HLA_seq_list(hla)
    lengths_HLA_sequence = [len(sequence) for sequence in HLA_sequence]
    # Encode HLA sequences
    HLA_seq = enc_list_bl_max_len(HLA_sequence, pca15_aaidx, 366)

    # Split data into train and test sets
    Peptide_train, Peptide_test, HLA_train, HLA_test, lengths_train, lengths_test, label_train, label_test = train_test_split(
         Peptide, HLA_seq, lengths_HLA_sequence, label, test_size=0.2, random_state=2024)
    print("X_train shape:", Peptide_train.shape)
    print("X_test shape:", Peptide_test.shape)
    print("Z_train shape:", HLA_train.shape)
    print("Z_test shape:", HLA_test.shape)
    print("y_train shape:", label_train.shape)
    print("y_test shape:", label_test.shape)

    # Convert to tensors
    Peptide_train = list(Peptide_train)
    Peptide_test = list(Peptide_test)
    HLA_train = list(HLA_train)
    HLA_test = list(HLA_test)

    Peptide_train = torch.tensor(Peptide_train)
    Peptide_test = torch.tensor(Peptide_test)
    lengths_train = torch.tensor(lengths_train)
    lengths_test = torch.tensor(lengths_test)
    hla_train = torch.tensor(HLA_train)
    hla_test = torch.tensor(HLA_test)
    label_train = torch.tensor(label_train)
    label_test = torch.tensor(label_test)

    # Create datasets
    train_data = HLADataset(Peptide_train, hla_train, lengths_train, label_train)
    test_data = HLADataset(Peptide_test, hla_test, lengths_test, label_test)

    # Create data loaders
    train_loader = DataLoader(train_data, shuffle=True, batch_size=BATCH_SIZE)
    test_loader = DataLoader(test_data, shuffle=False, batch_size=BATCH_SIZE)

    return train_data, test_data, train_loader, test_loader


## Training function

In [None]:

def train_network(model,train_data, test_data,train_loader, test_loader, loss_fn, NUM_EPOCHS=300,cur_lr = 0.0001,BATCH_SIZE=64):

    #Count the number of iterations of the training and test datasets in each round
    trainSteps = len(train_data) // BATCH_SIZE
    testSteps = len(test_data) // BATCH_SIZE
    
    optimizer = torch.optim.Adam(model.parameters(), lr=cur_lr) # Optimizer
    
    # Initialize dictionary to record history data
    H = {"train_loss": [], "test_loss": [], "train_acc": [], "test_acc": [],
         "train_auc": [], "test_auc": [], "train_sensitivity": [], "test_sensitivity": [],
         "train_specificity": [], "test_specificity": [], "train_f1": [], "test_f1": [],
         "train_SRCC": [], "test_SRCC": []}

    best_auc = 0.0
    best_epoch = 0

    # Initialize arrays to store ROC data
    roc_fpr_train = np.array([])
    roc_tpr_train = np.array([])
    roc_fpr_test = np.array([])
    roc_tpr_test = np.array([])

    startTime = time.time()
    for e in tqdm(range(NUM_EPOCHS)):
        lr = adjust_learning_rate(e)
        if cur_lr != lr:
            cur_lr = lr
            optimizer = torch.optim.Adam(model.parameters(), lr=cur_lr)

        # Set model to training mode
        
        model.train()
        totalTrainLoss = 0
        totalTestLoss = 0

        # Initialize total predictions and true labels
        all_train_predictions = []
        all_train_targets = []


        # Iterate over training dataset
        for (i, (x, z, l, y)) in enumerate(train_loader):
            (x, z, l, y) = (x.to(device), z.to(device), l.to(device), y.to(device))

            # Forward pass and compute loss
            pred = model(x, z, l)
            loss = loss_fn(pred, y)
            optimizer.zero_grad()
            loss.backward()
            
            # Zero gradients, backward pass, and update parameters
            optimizer.step()

            # Add loss to total training loss
            totalTrainLoss += loss

            # Save predictions and target labels for training ACC and AUC calculation
            all_train_predictions.extend(pred.cpu().detach().numpy())
            all_train_targets.extend(y.cpu().detach().numpy())

        # Set model to evaluation mode
        model.eval()

        # Initialize total predictions and true labels for testing
        all_test_predictions = []
        all_test_targets = []

        # Iterate over testing dataset
        for (x1, z1, l1, y1) in test_loader:
            (x1, z1, l1, y1) = (x1.to(device), z1.to(device), l1.to(device), y1.to(device))
            pred1 = model(x1, z1, l1)
            totalTestLoss += loss_fn(pred1, y1)

            # Save predictions and target labels for testing ACC and AUC calculation
            all_test_predictions.extend(pred1.cpu().detach().numpy())
            all_test_targets.extend(y1.cpu().detach().numpy())

        # Calculate average training loss and testing loss
        avgTrainLoss = totalTrainLoss / trainSteps
        avgTestLoss = totalTestLoss / testSteps

        # Calculate training and testing metrics
        train_metrics = compute_metrics(all_train_targets, all_train_predictions)
        test_metrics = compute_metrics(all_test_targets, all_test_predictions)

        # Update training history data
        H["train_loss"].append(avgTrainLoss.cpu().detach().numpy())
        H["test_loss"].append(avgTestLoss.cpu().detach().numpy())
        H["train_acc"].append(train_metrics['Accuracy'])
        H["test_acc"].append(test_metrics['Accuracy'])
        H["train_auc"].append(train_metrics['AUC'])
        H["test_auc"].append(test_metrics['AUC'])
        H["train_sensitivity"].append(train_metrics['Sensitivity'])
        H["test_sensitivity"].append(test_metrics['Sensitivity'])
        H["train_specificity"].append(train_metrics['Specificity'])
        H["test_specificity"].append(test_metrics['Specificity'])
        H["train_f1"].append(train_metrics['F1'])
        H["test_f1"].append(test_metrics['F1'])
        H["train_SRCC"].append(train_metrics['SRCC'])
        H["test_SRCC"].append(test_metrics['SRCC'])

        
        # Print training information for the current epoch
        print("[INFO] Epoch: {}/{}, LR: {:.6f}".format(e + 1, NUM_EPOCHS, cur_lr))
        print("Train Loss: {:.4f}, Test Loss: {:.4f}".format(avgTrainLoss, avgTestLoss))
        print("Train ACC: {:.3f}, Test ACC: {:.3f}".format(train_metrics['Accuracy'], test_metrics['Accuracy']))
        print("Train AUC: {:.3f}, Test AUC: {:.3f}".format(train_metrics['AUC'], test_metrics['AUC']))
        print("Train Sensitivity: {:.3f}, Test Sensitivity: {:.3f}".format(train_metrics['Sensitivity'], test_metrics['Sensitivity']))
        print("Train Specificity: {:.3f}, Test Specificity: {:.3f}".format(train_metrics['Specificity'], test_metrics['Specificity']))
        print("Train F1: {:.3f}, Test F1: {:.3f}".format(train_metrics['F1'], test_metrics['F1']))
        print("Train SRCC: {:.3f}, Test SRCC: {:.3f}".format(train_metrics['SRCC'], test_metrics['F1']))


        # Update best AUC and epoch
        if test_metrics['AUC'] > best_auc:
            best_auc = test_metrics['AUC']
            best_epoch = e
            best_auc_trein = train_metrics['AUC']

            roc_fpr_train = train_metrics['FPR']
            roc_tpr_train = train_metrics['TPR']
            roc_fpr_test = test_metrics['FPR']
            roc_tpr_test = test_metrics['TPR']

    # Print total time taken for training
    endTime = time.time()
    print("[INFO] Total time taken to train the model: {:.2f}s".format(endTime - startTime))

    #Save the model
    torch.save(model.state_dict(), '../model/TlcMHCpan_model.pth')
    return  best_auc,best_auc_trein, best_epoch, roc_fpr_train, roc_tpr_train, roc_fpr_test, roc_tpr_test


In [None]:
def train_and_evaluate_model(train_data_path, model, loss_fn, device="cuda"):
    # Load and process the data
    df_data = load_hla_dataframe(train_data_path)
    df_data = get_hla_subtype(df_data)
    # df_data1 = df_data[:5000]
    train_data, test_data, train_loader, test_loader = data_processing(df_data)

    # Move the model to the specified device
    model = model.to(device)

    # Train the network
    best_auc, best_auc_train, best_epoch, roc_fpr_train, roc_tpr_train, roc_fpr_test, roc_tpr_test = train_network(model, train_data, test_data, train_loader, test_loader, loss_fn)

    # Plot the ROC curve
    plt.figure(figsize=(8, 6))
    plt.plot(roc_fpr_train, roc_tpr_train, label='Train ROC curve (area = %0.3f)' % best_auc_train, color='darkorange')
    plt.plot(roc_fpr_test, roc_tpr_test, label='Test ROC curve (area = %0.3f)' % best_auc, color='blue')
    plt.plot([0, 1], [0, 1], linestyle='--', lw=2, color='black')
    plt.xlabel('False Positive Rate (FPR)')
    plt.ylabel('True Positive Rate (TPR)')
    plt.title('PROTEINS Dataset')
    plt.legend(loc="lower right")
    plt.tight_layout()

    # Save ROC curve plot as an image
    plt.savefig("PROTEINS_ROC.png")
    plt.show()

    print("Best AUC: {:.3f} at Epoch: {} and train AUC: {:.3f}".format(best_auc, best_epoch + 1, best_auc_train))


In [None]:
# device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device = "cuda"
# Initialize the model and move it to the specified device
model = Model().to(device)
# Loss function
loss_fn = nn.BCEWithLogitsLoss()

train_data_path = '../data/proteins.txt'  # Path to the training data

train_and_evaluate_model(train_data_path, model, loss_fn)  # Call the function to train and evaluate the model
