In [1]:
import math
from sklearn import metrics
from sklearn import preprocessing
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import re
import time
import datetime
import random
random.seed(1234)
from scipy import interp
import warnings
warnings.filterwarnings("ignore")
from Bio.Align import substitution_matrices
from collections import Counter
from functools import reduce
from tqdm import tqdm, trange
from copy import deepcopy
import pickle
from sklearn.metrics import confusion_matrix
from sklearn.metrics import roc_auc_score, auc
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import classification_report
from sklearn.utils import class_weight
from multiprocessing import Pool
import os
import torch
import torch.nn as nn
from torch.nn import functional as F
import torch.optim as optim
import torch.utils.data as Data
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [2]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        '''
        x: [seq_len, batch_size, d_model]
        '''
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

def get_attn_pad_mask(seq_q, seq_k):
    '''
    seq_q: [batch_size, seq_len]
    seq_k: [batch_size, seq_len]
    seq_len could be src_len or it could be tgt_len
    seq_len in seq_q and seq_len in seq_k maybe not equal
    '''
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()
    # eq(zero) is PAD token
    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)  # [batch_size, 1, len_k], False is masked
    return pad_attn_mask.expand(batch_size, len_q, len_k)  # [batch_size, len_q, len_k]

class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()

    def forward(self, Q, K, V, attn_mask):
        '''
        Q: [batch_size, n_heads, len_q, d_k]
        K: [batch_size, n_heads, len_k, d_k]
        V: [batch_size, n_heads, len_v(=len_k), d_v]
        attn_mask: [batch_size, n_heads, seq_len, seq_len]
        '''
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k) # scores : [batch_size, n_heads, len_q, len_k]
        scores.masked_fill_(attn_mask, -1e9) # Fills elements of self tensor with value where mask is True.
#         attn = scores
        attn = nn.Softmax(dim=-1)(scores)
        context = torch.matmul(attn, V) # [batch_size, n_heads, len_q, d_v]
        return context, attn




class MultiHeadAttention(nn.Module):
    def __init__(self):
        super(MultiHeadAttention, self).__init__()
        self.use_cuda = use_cuda
        device = torch.device("cuda" if self.use_cuda else "cpu")
        self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=False)
        self.W_K = nn.Linear(d_model, d_k * n_heads, bias=False)
        self.W_V = nn.Linear(d_model, d_v * n_heads, bias=False)
        self.fc = nn.Linear(n_heads * d_v, d_model, bias=False)
    def forward(self, input_Q, input_K, input_V, attn_mask):
        '''
        input_Q: [batch_size, len_q, d_model]
        input_K: [batch_size, len_k, d_model]
        input_V: [batch_size, len_v(=len_k), d_model]
        attn_mask: [batch_size, seq_len, seq_len]
        '''
        residual, batch_size = input_Q, input_Q.size(0)
        # (B, S, D) -proj-> (B, S, D_new) -split-> (B, S, H, W) -trans-> (B, H, S, W)
        Q = self.W_Q(input_Q).view(batch_size, -1, n_heads, d_k).transpose(1,2)  # Q: [batch_size, n_heads, len_q, d_k]
        K = self.W_K(input_K).view(batch_size, -1, n_heads, d_k).transpose(1,2)  # K: [batch_size, n_heads, len_k, d_k]
        V = self.W_V(input_V).view(batch_size, -1, n_heads, d_v).transpose(1,2)  # V: [batch_size, n_heads, len_v(=len_k), d_v]

        attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1) # attn_mask : [batch_size, n_heads, seq_len, seq_len]

        # context: [batch_size, n_heads, len_q, d_v], attn: [batch_size, n_heads, len_q, len_k]
        context, attn = ScaledDotProductAttention()(Q, K, V, attn_mask)
        context = context.transpose(1, 2).reshape(batch_size, -1, n_heads * d_v) # context: [batch_size, len_q, n_heads * d_v]
        output = self.fc(context) # [batch_size, len_q, d_model]
        return nn.LayerNorm(d_model).to(device)(output + residual), attn




class PoswiseFeedForwardNet(nn.Module):
    def __init__(self):
        super(PoswiseFeedForwardNet, self).__init__()
        self.use_cuda = use_cuda
        device = torch.device("cuda" if self.use_cuda else "cpu")
        self.fc = nn.Sequential(
            nn.Linear(d_model, d_ff, bias=False),
            nn.ReLU(),
            nn.Linear(d_ff, d_model, bias=False)
        )
    def forward(self, inputs):
        '''
        inputs: [batch_size, seq_len, d_model]
        '''
        residual = inputs
        output = self.fc(inputs)
        return nn.LayerNorm(d_model).to(device)(output + residual) # [batch_size, seq_len, d_model]




class EncoderLayer(nn.Module):
    def __init__(self):
        super(EncoderLayer, self).__init__()
        self.enc_self_attn = MultiHeadAttention()
        self.pos_ffn = PoswiseFeedForwardNet()

    def forward(self, enc_inputs, enc_self_attn_mask):
        '''
        enc_inputs: [batch_size, src_len, d_model]
        enc_self_attn_mask: [batch_size, src_len, src_len]
        '''
        # enc_outputs: [batch_size, src_len, d_model], attn: [batch_size, n_heads, src_len, src_len]
        enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask) # enc_inputs to same Q,K,V
        enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size, src_len, d_model]
        return enc_outputs, attn




class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.src_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = PositionalEncoding(d_model)
        self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])

    def forward(self, enc_inputs):
        '''
        enc_inputs: [batch_size, src_len]
        '''
        enc_outputs = self.src_emb(enc_inputs) # [batch_size, src_len, d_model]
        enc_outputs = self.pos_emb(enc_outputs.transpose(0, 1)).transpose(0, 1) # [batch_size, src_len, d_model]
        enc_self_attn_mask = get_attn_pad_mask(enc_inputs, enc_inputs) # [batch_size, src_len, src_len]
        enc_self_attns = []
        for layer in self.layers:
            # enc_outputs: [batch_size, src_len, d_model], enc_self_attn: [batch_size, n_heads, src_len, src_len]
            enc_outputs, enc_self_attn = layer(enc_outputs, enc_self_attn_mask)
            enc_self_attns.append(enc_self_attn)
        return enc_outputs, enc_self_attns


# ### Decoder



class DecoderLayer(nn.Module):
    def __init__(self):
        super(DecoderLayer, self).__init__()
        self.dec_self_attn = MultiHeadAttention()
        self.pos_ffn = PoswiseFeedForwardNet()

    def forward(self, dec_inputs, dec_self_attn_mask): # dec_inputs = enc_outputs
        '''
        dec_inputs: [batch_size, tgt_len, d_model]
        enc_outputs: [batch_size, src_len, d_model]
        dec_self_attn_mask: [batch_size, tgt_len, tgt_len]
        '''
        # dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len]
        dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask)
        dec_outputs = self.pos_ffn(dec_outputs) # [batch_size, tgt_len, d_model]
        return dec_outputs, dec_self_attn


class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
#         self.tgt_emb = nn.Embedding(d_model * 2, d_model)
        self.use_cuda = use_cuda
        device = torch.device("cuda" if self.use_cuda else "cpu")
        self.pos_emb = PositionalEncoding(d_model)
        self.layers = nn.ModuleList([DecoderLayer() for _ in range(n_layers)])
        self.tgt_len = tgt_len
        
    def forward(self, dec_inputs): # dec_inputs = enc_outputs (batch_size, peptide_hla_maxlen_sum, d_model)
        '''
        dec_inputs: [batch_size, tgt_len]
        enc_intpus: [batch_size, src_len]
        enc_outputs: [batsh_size, src_len, d_model]
        '''
#         dec_outputs = self.tgt_emb(dec_inputs) # [batch_size, tgt_len, d_model]
        dec_outputs = self.pos_emb(dec_inputs.transpose(0, 1)).transpose(0, 1).to(device) # [batch_size, tgt_len, d_model]
#         dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs).cuda() # [batch_size, tgt_len, tgt_len]
        dec_self_attn_pad_mask = torch.LongTensor(np.zeros((dec_inputs.shape[0], tgt_len, tgt_len))).bool().to(device)
 
        dec_self_attns = []
        for layer in self.layers:
            # dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len], dec_enc_attn: [batch_size, h_heads, tgt_len, src_len]
            dec_outputs, dec_self_attn = layer(dec_outputs, dec_self_attn_pad_mask)
            dec_self_attns.append(dec_self_attn)
            
        return dec_outputs, dec_self_attns


# ### Transformer

class Transformer(nn.Module):
    def __init__(self):
        super(Transformer, self).__init__()
        self.use_cuda = use_cuda
        device = torch.device("cuda" if use_cuda else "cpu")
        self.pep_encoder = Encoder().to(device)
        self.hla_encoder = Encoder().to(device)
        self.decoder = Decoder().to(device)
        self.tgt_len = tgt_len
        self.projection = nn.Sequential(
                                        nn.Linear(tgt_len * d_model, 256),
                                        nn.ReLU(True),

                                        nn.BatchNorm1d(256),
                                        nn.Linear(256, 64),
                                        nn.ReLU(True),

                                        #output layer
                                        nn.Linear(64, 2)
                                        ).to(device)
        
    def forward(self, hla_inputs,pep_inputs):
        '''
        pep_inputs: [batch_size, pep_len]
        hla_inputs: [batch_size, hla_len]
        '''
        # tensor to store decoder outputs
        # outputs = torch.zeros(batch_size, tgt_len, tgt_vocab_size).to(self.device)
        
        # enc_outputs: [batch_size, src_len, d_model], enc_self_attns: [n_layers, batch_size, n_heads, src_len, src_len]
        hla_enc_outputs, hla_enc_self_attns = self.hla_encoder(hla_inputs)
        pep_enc_outputs, pep_enc_self_attns = self.pep_encoder(pep_inputs)
        
#         print(hla_enc_outputs)
        enc_outputs = torch.cat((hla_enc_outputs,pep_enc_outputs), 1) # concat pep & hla embedding
        ## reverse ##
#         enc_outputs = pep_enc_outputs*hla_enc_outputs
        
        ## end ##
        # dec_outpus: [batch_size, tgt_len, d_model], dec_self_attns: [n_layers, batch_size, n_heads, tgt_len, tgt_len], dec_enc_attn: [n_layers, batch_size, tgt_len, src_len]
        dec_outputs, dec_self_attns = self.decoder(enc_outputs)
        dec_outputs = dec_outputs.view(dec_outputs.shape[0], -1) # Flatten [batch_size, tgt_len * d_model]
        dec_logits = self.projection(dec_outputs) # dec_logits: [batch_size, tgt_len, tgt_vocab_size]

        return dec_logits.view(-1, dec_logits.size(-1)), pep_enc_self_attns, hla_enc_self_attns, dec_self_attns


In [3]:
pep_max_len = 12 # peptide; enc_input max sequence length
hla_max_len = 20 # hla; dec_input(=dec_output) max sequence length
tgt_len = pep_max_len + hla_max_len
pep_max_len, hla_max_len
vocab_size = 21
d_model=64 # Embedding Size
d_ff = 256 # FeedForward dimension
d_k = d_v = 64  # dimension of K(=Q), V
n_layers = 1  # number of Encoder of Decoder Layer
n_heads = 9
batch_size = 1024
# batch_size = 5000
epochs = 150
threshold = 0.5
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

In [4]:
def performances(y_true, y_pred, y_prob, print_ = True):
    
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels = [0, 1]).ravel().tolist()
    accuracy = (tp+tn)/(tn+fp+fn+tp)
    try:
        mcc = ((tp*tn) - (fn*fp)) / np.sqrt(np.float((tp+fn)*(tn+fp)*(tp+fp)*(tn+fn)))
    except:
        print('MCC Error: ', (tp+fn)*(tn+fp)*(tp+fp)*(tn+fn))
        mcc = np.nan
    sensitivity = tp/(tp+fn)
    specificity = tn/(tn+fp)
    
    try:
        recall = tp / (tp+fn)
    except:
        recall = np.nan
        
    try:
        precision = tp / (tp+fp)
    except:
        precision = np.nan
        
    try: 
        f1 = 2*precision*recall / (precision+recall)
    except:
        f1 = np.nan
        
    roc_auc = roc_auc_score(y_true, y_prob)
    prec, reca, _ = precision_recall_curve(y_true, y_prob)
    aupr = auc(reca, prec)
    
    if print_:
        print('tn = {}, fp = {}, fn = {}, tp = {}'.format(tn, fp, fn, tp))
        print('y_pred: 0 = {} | 1 = {}'.format(Counter(y_pred)[0], Counter(y_pred)[1]))
        print('y_true: 0 = {} | 1 = {}'.format(Counter(y_true)[0], Counter(y_true)[1]))
        print('auc={:.4f}|sensitivity={:.4f}|specificity={:.4f}|acc={:.4f}|mcc={:.4f}'.format(roc_auc, sensitivity, specificity, accuracy, mcc))
        print('precision={:.4f}|recall={:.4f}|f1={:.4f}|aupr={:.4f}'.format(precision, recall, f1, aupr))
        
    return (roc_auc, accuracy, mcc, f1, sensitivity, specificity, precision, recall, aupr)


# In[25]:


def transfer(y_prob, threshold = 0.5):
    # return np.array([[0, 1][x > threshold] for x in y_prob])
    y_prob = np.array(y_prob)
    return np.where(y_prob > threshold, 1, 0)


# In[26]:


f_mean = lambda l: sum(l)/len(l)


# In[28]:


def performances_to_pd(performances_list):
    metrics_name = ['roc_auc', 'accuracy', 'mcc', 'f1', 'sensitivity', 'specificity', 'precision', 'recall', 'aupr']

    performances_pd = pd.DataFrame(performances_list, columns = metrics_name)
    performances_pd.loc['mean'] = performances_pd.mean(axis = 0)
    performances_pd.loc['std'] = performances_pd.std(axis = 0)
    
    return performances_pd

In [5]:
def train_step(model, train_loader, fold, epoch, epochs, use_cuda = True):
    device = torch.device("cuda" if use_cuda else "cpu")
    
    time_train_ep = 0
    model.train()
    y_true_train_list, y_prob_train_list = [], []
    loss_train_list, dec_attns_train_list = [], []
    for train_pep_inputs, train_hla_inputs, train_labels in tqdm(train_loader):
        '''
        pep_inputs: [batch_size, pep_len]
        hla_inputs: [batch_size, hla_len]
        train_outputs: [batch_size, 2]
        '''
        train_pep_inputs, train_hla_inputs, train_labels = train_pep_inputs.to(device), train_hla_inputs.to(device), train_labels.to(device)
#         print(train_pep_inputs.shape,train_hla_inputs.shape)
        t1 = time.time()
        train_outputs, _, _, train_dec_self_attns = model(train_hla_inputs, train_pep_inputs)
        train_loss = criterion(train_outputs, train_labels)
        time_train_ep += time.time() - t1

        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()

        y_true_train = train_labels.cpu().numpy()
        y_prob_train = nn.Softmax(dim = 1)(train_outputs)[:, 1].cpu().detach().numpy()
        
        y_true_train_list.extend(y_true_train)
        y_prob_train_list.extend(y_prob_train)
        loss_train_list.append(train_loss)
#         dec_attns_train_list.append(train_dec_self_attns)
        
    y_pred_train_list = transfer(y_prob_train_list, threshold)
    ys_train = (y_true_train_list, y_pred_train_list, y_prob_train_list)
    
    print('Fold-{}****Train (Ep avg): Epoch-{}/{} | Loss = {:.4f} | Time = {:.4f} sec'.format(fold, epoch, epochs, f_mean(loss_train_list), time_train_ep))
    metrics_train = performances(y_true_train_list, y_pred_train_list, y_prob_train_list, print_ = True)
    
    return ys_train, loss_train_list, metrics_train, time_train_ep#, dec_attns_train_list


# In[30]:


def eval_step(model, val_loader, fold, epoch, epochs, use_cuda = True):
    device = torch.device("cuda" if use_cuda else "cpu")
    
    model.eval()
    torch.manual_seed(19961231)
    torch.cuda.manual_seed(19961231)
    with torch.no_grad():
        loss_val_list, dec_attns_val_list = [], []
        y_true_val_list, y_prob_val_list = [], []
        for val_pep_inputs, val_hla_inputs, val_labels in tqdm(val_loader):
            val_pep_inputs, val_hla_inputs, val_labels = val_pep_inputs.to(device), val_hla_inputs.to(device), val_labels.to(device)
            val_outputs, _, _, val_dec_self_attns = model(val_hla_inputs,val_pep_inputs)
            val_loss = criterion(val_outputs, val_labels)

            y_true_val = val_labels.cpu().numpy()
            y_prob_val = nn.Softmax(dim = 1)(val_outputs)[:, 1].cpu().detach().numpy()

            y_true_val_list.extend(y_true_val)
            y_prob_val_list.extend(y_prob_val)
            loss_val_list.append(val_loss)
#             dec_attns_val_list.append(val_dec_self_attns)
            
        y_pred_val_list = transfer(y_prob_val_list, threshold)
        ys_val = (y_true_val_list, y_pred_val_list, y_prob_val_list)
        
        print('Fold-{} ****Test  Epoch-{}/{}: Loss = {:.6f}'.format(fold, epoch, epochs, f_mean(loss_val_list)))
        metrics_val = performances(y_true_val_list, y_pred_val_list, y_prob_val_list, print_ = True)
    return ys_val, loss_val_list, metrics_val#, dec_attns_val_list


In [6]:
def make_data(data):
#     labels = []
    cdr3 = data['CDR3'].values
    epitope = data['Epitope'].values
    labels = data['label'].values
    mat = Tokenizer() 
    hla_inputs = encode_cdr3(cdr3, mat)
#     print(hla_inputs)
    pep_inputs = encode_epi(epitope,mat)
#     epi_encoder = PretrainedEncoder(mat)
#     pep_inputs, epi_vec = epi_encoder.encode_pretrained_epi(epitope)
#     labels.append(label)
    return torch.LongTensor(pep_inputs), torch.LongTensor(hla_inputs), torch.LongTensor(labels)

class MyDataSet(Data.Dataset):
    def __init__(self, pep_inputs, hla_inputs,labels):
        super(MyDataSet, self).__init__()
        self.pep_inputs = pep_inputs
        self.hla_inputs = hla_inputs
        self.labels = labels
        

    def __len__(self): # 样本数
        return self.pep_inputs.shape[0] # 改成hla_inputs也可以哦！

    def __getitem__(self, idx):
#         return self.pep_inputs[idx], self.hla_inputs[idx], self.labels[idx],self.pep_lens[idx]
        return self.pep_inputs[idx],self.hla_inputs[idx], self.labels[idx]



In [7]:

import sys
sys.path.append('.')


def GetBlosumMat(residues_list):
    n_residues = len(residues_list)  # the number of amino acids _ 'X'
    blosum62_mat = np.zeros([n_residues, n_residues])  # plus 1 for gap
    bl_dict = substitution_matrices.load('BLOSUM62')
    for pair, score in bl_dict.items():
        if (pair[0] not in residues_list) or (pair[1] not in residues_list):  # special residues not considered here
            continue
        idx_pair0 = residues_list.index(pair[0])  # index of residues
        idx_pair1 = residues_list.index(pair[1])
        blosum62_mat[idx_pair0, idx_pair1] = score
        blosum62_mat[idx_pair1, idx_pair0] = score
    return blosum62_mat


class Tokenizer:
    def __init__(self,):
        self.res_all = ['G', 'A', 'V', 'L', 'I', 'F', 'W', 'Y', 'D', 'N',
                     'E', 'K', 'Q', 'M', 'S', 'T', 'C', 'P', 'H', 'R'] #+ ['X'] #BJZOU
        self.tokens = ['-'] + self.res_all # '-' for padding encoding

    def tokenize(self, index): # int 2 str
        return self.tokens[index]

    def id(self, token): # str 2 int
        try:
            return self.tokens.index(token.upper())
        except ValueError:
            print('Error letter in the sequences:', token)
            if str.isalpha(token):
                return self.tokens.index('X')

    def tokenize_list(self, seq):
        return [self.tokenize(i) for i in seq]

    def id_list(self, seq):
        return [self.id(s) for s in seq]

    def embedding_mat(self):
        blosum62 = GetBlosumMat(self.res_all)
        mat = np.eye(len(self.tokens))
        mat[1:len(self.res_all) + 1, 1:len(self.res_all) + 1] = blosum62
        return mat

tokenizer = Tokenizer()


def encoding_epi(seqs, max_len=12):
    encoding = np.zeros([len(seqs), max_len], dtype='long')
    for i, seq in tqdm(enumerate(seqs), desc='Encoding epi seqs', total=len(seqs)):
        len_seq = len(seq)
        if len_seq == 8:
            encoding[i, 2:len_seq+2] = tokenizer.id_list(seq)
        elif (len_seq == 9) or (len_seq == 10):
            encoding[i, 1:len_seq+1] = tokenizer.id_list(seq)
        else:
            encoding[i, :len_seq] = tokenizer.id_list(seq)
    return encoding

def encoding_cdr3(seqs, max_len=20):
    encoding = np.zeros([len(seqs), max_len], dtype='long')
    for i, seq in tqdm(enumerate(seqs), desc='Encoding cdr3s', total=len(seqs)):
        len_seq = len(seq)
        i_start =  max_len // 2 - len_seq // 2
        encoding[i, i_start:i_start+len_seq] = tokenizer.id_list(seq)
    return encoding

def encoding_cdr3_single(seq, max_len=20):
    encoding = np.zeros(max_len, dtype='long')
    len_seq = len(seq)
    i_start =  max_len // 2 - len_seq // 2
    encoding[i_start:i_start+len_seq] = tokenizer.id_list(seq)
    return encoding

def encoding_epi_single(seq, max_len=12):
    encoding = np.zeros(max_len, dtype='long')
    len_seq = len(seq)
    if len_seq == 8:
        encoding[2:len_seq+2] = tokenizer.id_list(seq)
    elif (len_seq == 9) or (len_seq == 10):
        encoding[1:len_seq+1] = tokenizer.id_list(seq)
    else:
        encoding[:len_seq] = tokenizer.id_list(seq)
    return encoding


def encoding_dist_mat(mat_list, max_cdr3=20, max_epi=12):
    encoding = np.zeros([len(mat_list), max_cdr3, max_epi], dtype='float32')
    masking = np.zeros([len(mat_list), max_cdr3, max_epi], dtype='bool')
    for i, mat in tqdm(enumerate(mat_list), desc='Encoding dist mat', total=len(mat_list)):
        len_cdr3, len_epi = mat.shape
        i_start_cdr3 = max_cdr3 // 2 - len_cdr3 // 2
        if len_epi == 8:
            i_start_epi = 2
        elif (len_epi == 9) or (len_epi == 10):
            i_start_epi = 1
        else:
            i_start_epi = 0
        encoding[i, i_start_cdr3:i_start_cdr3+len_cdr3, i_start_epi:i_start_epi+len_epi] = mat
        masking[i, i_start_cdr3:i_start_cdr3+len_cdr3, i_start_epi:i_start_epi+len_epi] = True
    return encoding, masking


def decoding_one_mat(mat, len_cdr3, len_epi):
    decoding = np.zeros([len_cdr3, len_epi] + list(mat.shape[2:]), dtype=mat.dtype)
    i_start_cdr3 = 10 - len_cdr3 // 2
    if len_epi == 8:
        i_start_epi = 2
    elif (len_epi == 9) or (len_epi == 10):
        i_start_epi = 1
    else:
        i_start_epi = 0
    decoding = mat[i_start_cdr3:i_start_cdr3+len_cdr3, i_start_epi:i_start_epi+len_epi] 
    return decoding

In [8]:
from Bio.Align import substitution_matrices
def GetBlosumMat(residues_list):
    n_residues = len(residues_list)  # the number of amino acids _ 'X'
    blosum62_mat = np.zeros([n_residues, n_residues])  # plus 1 for gap
    bl_dict = substitution_matrices.load('BLOSUM62')
    for pair, score in bl_dict.items():
        if (pair[0] not in residues_list) or (pair[1] not in residues_list):  # special residues not considered here
            continue
        idx_pair0 = residues_list.index(pair[0])  # index of residues
        idx_pair1 = residues_list.index(pair[1])
        blosum62_mat[idx_pair0, idx_pair1] = score
        blosum62_mat[idx_pair1, idx_pair0] = score
    return blosum62_mat
class Tokenizer:
    def __init__(self,):
        self.res_all = ['G', 'A', 'V', 'L', 'I', 'F', 'W', 'Y', 'D', 'N',
                     'E', 'K', 'Q', 'M', 'S', 'T', 'C', 'P', 'H', 'R'] #+ ['X'] #BJZOU
        self.tokens = ['-'] + self.res_all # '-' for padding encoding

    def tokenize(self, index): # int 2 str
        return self.tokens[index]

    def id(self, token): # str 2 int
        try:
            return self.tokens.index(token.upper())
        except ValueError:
            print('Error letter in the sequences:', token)
            if str.isalpha(token):
                return self.tokens.index('X')

    def tokenize_list(self, seq):
        return [self.tokenize(i) for i in seq]

    def id_list(self, seq):
        return [self.id(s) for s in seq]

    def embedding_mat(self):
        blosum62 = GetBlosumMat(self.res_all)
        mat = np.eye(len(self.tokens))
        mat[1:len(self.res_all) + 1, 1:len(self.res_all) + 1] = blosum62
        return mat
def encode_cdr3(cdr3, tokenizer):
    len_cdr3 = [len(s) for s in cdr3]
    max_len_cdr3 = np.max(len_cdr3)
    assert max_len_cdr3 <= 20, 'The cdr3 length must <= 20'
    max_len_cdr3 = 20
    
    seqs_al = get_numbering(cdr3)
    num_samples = len(seqs_al)

    # encoding
    encoding_cdr3 = np.zeros([num_samples, max_len_cdr3], dtype='int32')
    for i, seq in enumerate(seqs_al):
        encoding_cdr3[i, ] = tokenizer.id_list(seq)
    return encoding_cdr3
# def encode_epi(epi, tokenizer):
#     # tokenizer = Tokenizer()
#     encoding_epi = np.zeros([12], dtype='int32')
#     len_epi = len(epi)
#     if len_epi == 8:
#         encoding_epi[2:len_epi+2] = tokenizer.id_list(epi)
#     elif (len_epi == 9) or (len_epi == 10):
#         encoding_epi[1:len_epi+1] = tokenizer.id_list(epi)
#     else:
#         encoding_epi[:len_epi] = tokenizer.id_list(epi)
#     return encoding_epi

def encode_epi(epi, tokenizer):
    tokenizer = Tokenizer()
    encoding_epi = np.zeros([len(epi),12], dtype='int32')
    for i, seq in enumerate(epi):
        len_epi = len(seq)
        
        if len_epi == 8:
        
            encoding_epi[i,2:len_epi+2] = tokenizer.id_list(seq)
        elif (len_epi == 9) or (len_epi == 10) or (len_epi ==11):
            
            encoding_epi[i,1:len_epi+1] = tokenizer.id_list(seq)
        else:
            
            encoding_epi[i,:len_epi] = tokenizer.id_list(seq)
    print(encoding_epi)
    return encoding_epi

def get_numbering(seqs, ):
    """
    get the IMGT numbering of CDR3 with ANARCI tool
    """
    template = ['GVTQTPKFQVLKTGQSMTLQCAQDMNHEYMSWYRQDPGMGLRLIHYSVGAGTTDQGEVPNGYNVSRSTIEDFPLRLLSAAPSQTSVYF', 'GEGSRLTVL']
    # # save fake tcr file
    save_path = 'tmp_faketcr.fasta'
    id_list = []
    seqs_uni = np.unique(seqs)
    with open(save_path, 'w+') as f:
        for i, seq in enumerate(seqs_uni):
            f.write('>'+str(i)+'\n')
            id_list.append(i)
            total_seq = ''.join([template[0], seq ,template[1]])
            f.write(str(total_seq))
            f.write('\n')
    print('Save fasta file to '+save_path + '\n Aligning...')
    df_seqs = pd.DataFrame(list(zip(id_list, seqs_uni)), columns=['Id', 'cdr3'])
    
    # # using ANARCI to get numbering file

   # this environment name should be the same as the one you install anarci
    !ANARCI -i ./tmp_faketcr.fasta  -o tmp_align --csv -p 24
#     res = os.system(cmd)
    
    # # parse numbered seqs data
    try:
        df = pd.read_csv('tmp_align_B.csv')
    except FileNotFoundError:
        raise FileNotFoundError('Error: ANARCI failed to align, please check whether ANARCI exists in your environment')
        
    cols = ['104', '105', '106', '107', '108', '109', '110', '111', '111A', '111B', '112C', '112B', '112A', '112', '113', '114', '115', '116', '117', '118']
    seqs_al = []
    for col in cols:
        if col in df.columns:
            seqs_al_curr = df[col].values
            seqs_al.append(seqs_al_curr)
        else:
            seqs_al_curr = np.full([len(df)], '-')
            seqs_al.append(seqs_al_curr)
    seqs_al = [''.join(seq) for seq in np.array(seqs_al).T]
    df_al = df[['Id']]
    df_al['cdr3_align'] = seqs_al
    
    ## merge
    # os.remove('tmp_align_B.csv')
#     os.remove('tmp_faketcr.fasta')
    df = df_seqs.merge(df_al, how='inner', on='Id')
    df = df.set_index('cdr3')
    return df.loc[seqs, 'cdr3_align'].values

In [9]:
class View(nn.Module):
    def __init__(self, *shape):
        super(View, self).__init__()
        self.shape = shape
    def forward(self, input):
        shape = [input.shape[0]] + list(self.shape)
        return input.view(*shape)
def load_ae_model(tokenizer, path='./epi_ae.ckpt',):
    # tokenizer = Tokenizer()
    ## load model
    model_args = dict(
        tokenizer = tokenizer,
        dim_hid = 32,
        len_seq = 12,
    )
    model = AutoEncoder(**model_args)
    model.eval()

    ## load weights
    state_dict = torch.load(path, map_location=device)
    state_dict = {k[6:]:v for k, v in state_dict.items()}
    model.load_state_dict(state_dict)
    return model
class AutoEncoder(nn.Module):
    def __init__(self, 
        tokenizer,
        dim_hid,
        len_seq,
    ):
        super().__init__()
        embedding = tokenizer.embedding_mat()
        vocab_size, dim_emb = embedding.shape
        self.embedding_module = nn.Embedding.from_pretrained(torch.FloatTensor(embedding), padding_idx=0, )
        self.encoder = nn.Sequential(
            nn.Conv1d(dim_emb, dim_hid, 3, padding=1),
            nn.BatchNorm1d(dim_hid),
            nn.ReLU(),
            nn.Conv1d(dim_hid, dim_hid, 3, padding=1),
            nn.BatchNorm1d(dim_hid),
            nn.ReLU(),
        )

        self.seq2vec = nn.Sequential(
            nn.Flatten(),
            nn.Linear(len_seq * dim_hid, dim_hid),
            nn.ReLU()
        )
        self.vec2seq = nn.Sequential(
            nn.Linear(dim_hid, len_seq * dim_hid),
            nn.ReLU(),
            View(dim_hid, len_seq)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose1d(dim_hid, dim_hid, kernel_size=3, padding=1),
            nn.BatchNorm1d(dim_hid),
            nn.ReLU(),
            nn.ConvTranspose1d(dim_hid, dim_hid, kernel_size=3, padding=1),
            nn.BatchNorm1d(dim_hid),
            nn.ReLU(),
        )
        self.out_layer = nn.Linear(dim_hid, vocab_size)

    def forward(self, inputs):
        inputs = inputs.long()
        seq_emb = self.embedding_module(inputs)
        
        seq_enc = self.encoder(seq_emb.transpose(1, 2))
        vec = self.seq2vec(seq_enc)
        seq_repr = self.vec2seq(vec)
        indices = None
        seq_dec = self.decoder(seq_repr)
        out = self.out_layer(seq_dec.transpose(1, 2))
        return out, seq_enc, vec, indices


In [15]:
def data_with_loader(type_ = 'train',fold = None,  batch_size = 128):
    if type_ != 'train' and type_ != 'val':
#         data = pd.read_csv('../data/justina_test.csv')
        data = pd.read_csv('./inputs/inputs_bd.csv')
#         data = pd.read_csv('../data/test/GILGLVFTL.csv')
#         data = pd.read_csv('../data/posi_length.csv')
        
    elif type_ == 'train':
        data = pd.read_csv('./突变负样本/add_10xneg/add_10xneg/train_add10Xneg_{}.csv'.format(fold))

    elif type_ == 'val':
        data = pd.read_csv('./突变负样本/add_10xneg/add_10xneg/eva_add10Xneg_{}.csv'.format(fold))

    pep_inputs, hla_inputs,labels = make_data(data)
#     print(labels)
    loader = Data.DataLoader(MyDataSet(pep_inputs, hla_inputs,labels), batch_size, shuffle = True, num_workers = 0)
    n_samples = len(pep_inputs)
    len_cdr3 = len(hla_inputs[0])
    len_epi = len(pep_inputs[0])
    encoding_mask = np.zeros([n_samples, len_cdr3,len_epi])
    for idx_sample, (enc_cdr3_this, enc_epi_this) in enumerate(zip(hla_inputs, pep_inputs)):
        mask = np.ones([len_cdr3,len_epi])
        zero_cdr3 = (enc_cdr3_this == 0)
        mask[zero_cdr3,:] = 0
        zero_epi = (enc_epi_this == 0)
        mask[:,zero_epi] = 0
#         print(mask.shape)
        encoding_mask[idx_sample] = mask
    return data, pep_inputs, hla_inputs, labels,loader,encoding_mask

In [11]:
def data_with_loader(type_ = 'train',fold = None,  batch_size = 128):
    if type_ != 'train' and type_ != 'val':
#         data = pd.read_csv('../data/justina_test.csv')
        data = pd.read_csv('./inputs/inputs_bd.csv')
#         data = pd.read_csv('../data/test/GILGLVFTL.csv')
#         data = pd.read_csv('../data/posi_length.csv')
        
    elif type_ == 'train':
#         data = pd.read_csv('../new_train/train_VDJ_10X_McPAS_1V1_{}.csv'.format(fold))
        data = pd.read_csv('./使用pMTnet数据训练/final_data/train_mismatch_{}.csv'.format(fold))

    elif type_ == 'val':
#         data = pd.read_csv('../new_train/eva_VDJ_10X_McPAS_1V1_{}.csv'.format(fold))
        data = pd.read_csv('./使用pMTnet数据训练/final_data/eva_mismatch_{}.csv'.format(fold))

    pep_inputs, hla_inputs,labels = make_data(data)
#     print(labels)
    loader = Data.DataLoader(MyDataSet(pep_inputs, hla_inputs,labels), batch_size, shuffle = True, num_workers = 0)
    n_samples = len(pep_inputs)
    len_cdr3 = len(hla_inputs[0])
    len_epi = len(pep_inputs[0])
    encoding_mask = np.zeros([n_samples, len_cdr3,len_epi])
    for idx_sample, (enc_cdr3_this, enc_epi_this) in enumerate(zip(hla_inputs, pep_inputs)):
        mask = np.ones([len_cdr3,len_epi])
        zero_cdr3 = (enc_cdr3_this == 0)
        mask[zero_cdr3,:] = 0
        zero_epi = (enc_epi_this == 0)
        mask[:,zero_epi] = 0
#         print(mask.shape)
        encoding_mask[idx_sample] = mask
    return data, pep_inputs, hla_inputs, labels,loader,encoding_mask

In [17]:
for n_heads in range(5,6):
    
    ys_train_fold_dict, ys_val_fold_dict = {}, {}
    train_fold_metrics_list, val_fold_metrics_list = [], []
    independent_fold_metrics_list, external_fold_metrics_list, ys_independent_fold_dict, ys_external_fold_dict = [], [], {}, {}
    attns_train_fold_dict, attns_val_fold_dict, attns_independent_fold_dict, attns_external_fold_dict = {}, {}, {}, {}
    loss_train_fold_dict, loss_val_fold_dict, loss_independent_fold_dict, loss_external_fold_dict = {}, {}, {}, {}

    for fold in range(1,6):
        print('=====Fold-{}====='.format(fold))
        print('-----Generate data loader-----')
        train_data, train_pep_inputs, train_hla_inputs, train_labels, train_loader,_ = data_with_loader(type_ = 'train', fold = fold,  batch_size = batch_size)
        val_data, val_pep_inputs, val_hla_inputs, val_labels, val_loader,_ = data_with_loader(type_ = 'val', fold = fold,  batch_size = batch_size)
        print('Fold-{} Label info: Train = {} | Val = {}'.format(fold, Counter(train_data.label), Counter(val_data.label)))

        print('-----Compile model-----')
        model = Transformer().to(device)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr = 1e-3)#, momentum = 0.99)

        print('-----Train-----')
        dir_saver = './model2'
    
        path_saver = './tcr_st_layer{}_multihead{}_fold{}_netmhcpan.pkl'.format(n_layers, n_heads, fold)
        metric_best, ep_best = 0, -1
        time_train = 0
        for epoch in range(1, epochs + 1):

            ys_train, loss_train_list, metrics_train, time_train_ep = train_step(model, train_loader, fold, epoch, epochs, use_cuda) # , dec_attns_train
            ys_val, loss_val_list, metrics_val = eval_step(model, val_loader, fold, epoch, epochs, use_cuda) #, dec_attns_val

            metrics_ep_avg = sum(metrics_val[:4])/4
            if metrics_ep_avg > metric_best: 
                metric_best, ep_best = metrics_ep_avg, epoch
                if not os.path.exists(dir_saver):
                    os.makedirs(dir_saver)
                print('****Saving model: Best epoch = {} | 5metrics_Best_avg = {:.4f}'.format(ep_best, metric_best))
                print('*****Path saver: ', path_saver)
                torch.save(model.eval().state_dict(), path_saver)

            time_train += time_train_ep

        print('-----Optimization Finished!-----')
        print('-----Evaluate Results-----')
        if ep_best >= 0:
            print('*****Path saver: ', path_saver)
            model.load_state_dict(torch.load(path_saver))
            model_eval = model.eval()

            ys_res_train, loss_res_train_list, metrics_res_train = eval_step(model_eval, train_loader, fold, ep_best, epochs, use_cuda) # , train_res_attns
            ys_res_val, loss_res_val_list, metrics_res_val = eval_step(model_eval, val_loader, fold, ep_best, epochs, use_cuda) # , val_res_attns
#             ys_res_independent, loss_res_independent_list, metrics_res_independent = eval_step(model_eval, independent_loader, fold, ep_best, epochs, use_cuda) # , independent_res_attns
#             ys_res_external, loss_res_external_list, metrics_res_external = eval_step(model_eval, external_loader, fold, ep_best, epochs, use_cuda) # , external_res_attns

            train_fold_metrics_list.append(metrics_res_train)
            val_fold_metrics_list.append(metrics_res_val)
#             independent_fold_metrics_list.append(metrics_res_independent)
#             external_fold_metrics_list.append(metrics_res_external)

#             ys_train_fold_dict[fold], ys_val_fold_dict[fold], ys_independent_fold_dict[fold], ys_external_fold_dict[fold] = ys_res_train, ys_res_val, ys_res_independent, ys_res_external    
#             attns_train_fold_dict[fold], attns_val_fold_dict[fold], attns_independent_fold_dict[fold], attns_external_fold_dict[fold] = train_res_attns, val_res_attns, independent_res_attns, external_res_attns   
#             loss_train_fold_dict[fold], loss_val_fold_dict[fold], loss_independent_fold_dict[fold], loss_external_fold_dict[fold] = loss_res_train_list, loss_res_val_list, loss_res_independent_list, loss_res_external_list  

        print("Total training time: {:6.2f} sec".format(time_train))




=====Fold-1=====
-----Generate data loader-----
Save fasta file to tmp_faketcr.fasta
 Aligning...
zsh:1: command not found: ANARCI
[[ 0 12  4 ... 12  0  0]
 [ 0  1  3 ... 16  0  0]
 [ 0 12  4 ... 12  0  0]
 ...
 [ 0 10  4 ...  3  0  0]
 [ 0 12  4 ...  3  0  0]
 [ 0  6 18 ...  2  0  0]]
Save fasta file to tmp_faketcr.fasta
 Aligning...
zsh:1: command not found: ANARCI
[[ 0  6  4 ...  4  0  0]
 [ 0 12  4 ...  3  0  0]
 [ 0  0 20 ...  4  0  0]
 ...
 [ 0 12  4 ...  3  0  0]
 [ 0  1  4 ...  4  0  0]
 [ 0 12  4 ... 12  0  0]]
Fold-1 Label info: Train = Counter({0: 56508, 1: 39168}) | Val = Counter({0: 14122, 1: 9798})
-----Compile model-----
-----Train-----


100%|██████████| 94/94 [00:06<00:00, 15.05it/s]


Fold-1****Train (Ep avg): Epoch-1/150 | Loss = 0.2527 | Time = 1.5878 sec
tn = 50304, fp = 6204, fn = 3896, tp = 35272
y_pred: 0 = 54200 | 1 = 41476
y_true: 0 = 56508 | 1 = 39168
auc=0.9598|sensitivity=0.9005|specificity=0.8902|acc=0.8944|mcc=0.7846
precision=0.8504|recall=0.9005|f1=0.8748|aupr=0.9383


100%|██████████| 24/24 [00:00<00:00, 39.67it/s]


Fold-1 ****Test  Epoch-1/150: Loss = 1.931552
tn = 9881, fp = 4241, fn = 5210, tp = 4588
y_pred: 0 = 15091 | 1 = 8829
y_true: 0 = 14122 | 1 = 9798
auc=0.6029|sensitivity=0.4683|specificity=0.6997|acc=0.6049|mcc=0.1711
precision=0.5197|recall=0.4683|f1=0.4926|aupr=0.5216
****Saving model: Best epoch = 1 | 5metrics_Best_avg = 0.4679
*****Path saver:  ./tcr_st_layer1_multihead5_fold1_netmhcpan.pkl


100%|██████████| 94/94 [00:06<00:00, 15.60it/s]


Fold-1****Train (Ep avg): Epoch-2/150 | Loss = 0.1744 | Time = 1.5507 sec
tn = 52787, fp = 3721, fn = 2545, tp = 36623
y_pred: 0 = 55332 | 1 = 40344
y_true: 0 = 56508 | 1 = 39168
auc=0.9798|sensitivity=0.9350|specificity=0.9342|acc=0.9345|mcc=0.8655
precision=0.9078|recall=0.9350|f1=0.9212|aupr=0.9678


100%|██████████| 24/24 [00:00<00:00, 37.65it/s]


Fold-1 ****Test  Epoch-2/150: Loss = 1.784450
tn = 9946, fp = 4176, fn = 5207, tp = 4591
y_pred: 0 = 15153 | 1 = 8767
y_true: 0 = 14122 | 1 = 9798
auc=0.6111|sensitivity=0.4686|specificity=0.7043|acc=0.6077|mcc=0.1764
precision=0.5237|recall=0.4686|f1=0.4946|aupr=0.5257
****Saving model: Best epoch = 2 | 5metrics_Best_avg = 0.4725
*****Path saver:  ./tcr_st_layer1_multihead5_fold1_netmhcpan.pkl


100%|██████████| 94/94 [00:05<00:00, 15.72it/s]


Fold-1****Train (Ep avg): Epoch-3/150 | Loss = 0.1515 | Time = 1.5563 sec
tn = 53109, fp = 3399, fn = 2143, tp = 37025
y_pred: 0 = 55252 | 1 = 40424
y_true: 0 = 56508 | 1 = 39168
auc=0.9847|sensitivity=0.9453|specificity=0.9398|acc=0.9421|mcc=0.8811
precision=0.9159|recall=0.9453|f1=0.9304|aupr=0.9759


100%|██████████| 24/24 [00:00<00:00, 30.31it/s]


Fold-1 ****Test  Epoch-3/150: Loss = 2.103628
tn = 9803, fp = 4319, fn = 5183, tp = 4615
y_pred: 0 = 14986 | 1 = 8934
y_true: 0 = 14122 | 1 = 9798
auc=0.6031|sensitivity=0.4710|specificity=0.6942|acc=0.6028|mcc=0.1679
precision=0.5166|recall=0.4710|f1=0.4927|aupr=0.5201


100%|██████████| 94/94 [00:05<00:00, 15.71it/s]


Fold-1****Train (Ep avg): Epoch-4/150 | Loss = 0.1379 | Time = 1.5416 sec
tn = 53424, fp = 3084, fn = 1951, tp = 37217
y_pred: 0 = 55375 | 1 = 40301
y_true: 0 = 56508 | 1 = 39168
auc=0.9873|sensitivity=0.9502|specificity=0.9454|acc=0.9474|mcc=0.8919
precision=0.9235|recall=0.9502|f1=0.9366|aupr=0.9800


100%|██████████| 24/24 [00:00<00:00, 35.18it/s]


Fold-1 ****Test  Epoch-4/150: Loss = 2.476070
tn = 9916, fp = 4206, fn = 5239, tp = 4559
y_pred: 0 = 15155 | 1 = 8765
y_true: 0 = 14122 | 1 = 9798
auc=0.5980|sensitivity=0.4653|specificity=0.7022|acc=0.6051|mcc=0.1709
precision=0.5201|recall=0.4653|f1=0.4912|aupr=0.5133


100%|██████████| 94/94 [00:05<00:00, 15.71it/s]


Fold-1****Train (Ep avg): Epoch-5/150 | Loss = 0.1244 | Time = 1.5340 sec
tn = 53707, fp = 2801, fn = 1794, tp = 37374
y_pred: 0 = 55501 | 1 = 40175
y_true: 0 = 56508 | 1 = 39168
auc=0.9897|sensitivity=0.9542|specificity=0.9504|acc=0.9520|mcc=0.9013
precision=0.9303|recall=0.9542|f1=0.9421|aupr=0.9838


100%|██████████| 24/24 [00:00<00:00, 37.07it/s]


Fold-1 ****Test  Epoch-5/150: Loss = 2.930256
tn = 9748, fp = 4374, fn = 5198, tp = 4600
y_pred: 0 = 14946 | 1 = 8974
y_true: 0 = 14122 | 1 = 9798
auc=0.5943|sensitivity=0.4695|specificity=0.6903|acc=0.5998|mcc=0.1623
precision=0.5126|recall=0.4695|f1=0.4901|aupr=0.5057


100%|██████████| 94/94 [00:05<00:00, 15.67it/s]


Fold-1****Train (Ep avg): Epoch-6/150 | Loss = 0.1114 | Time = 1.5503 sec
tn = 53921, fp = 2587, fn = 1647, tp = 37521
y_pred: 0 = 55568 | 1 = 40108
y_true: 0 = 56508 | 1 = 39168
auc=0.9917|sensitivity=0.9580|specificity=0.9542|acc=0.9557|mcc=0.9090
precision=0.9355|recall=0.9580|f1=0.9466|aupr=0.9871


100%|██████████| 24/24 [00:00<00:00, 37.50it/s]


Fold-1 ****Test  Epoch-6/150: Loss = 3.484137
tn = 9788, fp = 4334, fn = 5226, tp = 4572
y_pred: 0 = 15014 | 1 = 8906
y_true: 0 = 14122 | 1 = 9798
auc=0.5941|sensitivity=0.4666|specificity=0.6931|acc=0.6003|mcc=0.1625
precision=0.5134|recall=0.4666|f1=0.4889|aupr=0.5092


100%|██████████| 94/94 [00:05<00:00, 15.69it/s]


Fold-1****Train (Ep avg): Epoch-7/150 | Loss = 0.1015 | Time = 1.5376 sec
tn = 54131, fp = 2377, fn = 1484, tp = 37684
y_pred: 0 = 55615 | 1 = 40061
y_true: 0 = 56508 | 1 = 39168
auc=0.9932|sensitivity=0.9621|specificity=0.9579|acc=0.9596|mcc=0.9170
precision=0.9407|recall=0.9621|f1=0.9513|aupr=0.9896


100%|██████████| 24/24 [00:00<00:00, 30.86it/s]


Fold-1 ****Test  Epoch-7/150: Loss = 3.954656
tn = 9781, fp = 4341, fn = 5270, tp = 4528
y_pred: 0 = 15051 | 1 = 8869
y_true: 0 = 14122 | 1 = 9798
auc=0.5824|sensitivity=0.4621|specificity=0.6926|acc=0.5982|mcc=0.1575
precision=0.5105|recall=0.4621|f1=0.4851|aupr=0.4974


100%|██████████| 94/94 [00:05<00:00, 16.36it/s]


Fold-1****Train (Ep avg): Epoch-8/150 | Loss = 0.0942 | Time = 1.4852 sec
tn = 54344, fp = 2164, fn = 1437, tp = 37731
y_pred: 0 = 55781 | 1 = 39895
y_true: 0 = 56508 | 1 = 39168
auc=0.9941|sensitivity=0.9633|specificity=0.9617|acc=0.9624|mcc=0.9225
precision=0.9458|recall=0.9633|f1=0.9545|aupr=0.9911


100%|██████████| 24/24 [00:00<00:00, 37.79it/s]


Fold-1 ****Test  Epoch-8/150: Loss = 4.202820
tn = 9554, fp = 4568, fn = 5216, tp = 4582
y_pred: 0 = 14770 | 1 = 9150
y_true: 0 = 14122 | 1 = 9798
auc=0.5762|sensitivity=0.4676|specificity=0.6765|acc=0.5910|mcc=0.1459
precision=0.5008|recall=0.4676|f1=0.4836|aupr=0.4902


100%|██████████| 94/94 [00:05<00:00, 16.21it/s]


Fold-1****Train (Ep avg): Epoch-9/150 | Loss = 0.0846 | Time = 1.4949 sec
tn = 54562, fp = 1946, fn = 1331, tp = 37837
y_pred: 0 = 55893 | 1 = 39783
y_true: 0 = 56508 | 1 = 39168
auc=0.9953|sensitivity=0.9660|specificity=0.9656|acc=0.9657|mcc=0.9294
precision=0.9511|recall=0.9660|f1=0.9585|aupr=0.9930


100%|██████████| 24/24 [00:00<00:00, 37.69it/s]


Fold-1 ****Test  Epoch-9/150: Loss = 4.075807
tn = 9459, fp = 4663, fn = 5087, tp = 4711
y_pred: 0 = 14546 | 1 = 9374
y_true: 0 = 14122 | 1 = 9798
auc=0.5835|sensitivity=0.4808|specificity=0.6698|acc=0.5924|mcc=0.1517
precision=0.5026|recall=0.4808|f1=0.4914|aupr=0.4938


100%|██████████| 94/94 [00:06<00:00, 15.50it/s]


Fold-1****Train (Ep avg): Epoch-10/150 | Loss = 0.0799 | Time = 1.5566 sec
tn = 54660, fp = 1848, fn = 1262, tp = 37906
y_pred: 0 = 55922 | 1 = 39754
y_true: 0 = 56508 | 1 = 39168
auc=0.9958|sensitivity=0.9678|specificity=0.9673|acc=0.9675|mcc=0.9330
precision=0.9535|recall=0.9678|f1=0.9606|aupr=0.9937


100%|██████████| 24/24 [00:00<00:00, 30.50it/s]


Fold-1 ****Test  Epoch-10/150: Loss = 4.868225
tn = 9963, fp = 4159, fn = 5307, tp = 4491
y_pred: 0 = 15270 | 1 = 8650
y_true: 0 = 14122 | 1 = 9798
auc=0.5789|sensitivity=0.4584|specificity=0.7055|acc=0.6043|mcc=0.1677
precision=0.5192|recall=0.4584|f1=0.4869|aupr=0.4918


100%|██████████| 94/94 [00:05<00:00, 16.50it/s]


Fold-1****Train (Ep avg): Epoch-11/150 | Loss = 0.0700 | Time = 1.4711 sec
tn = 54883, fp = 1625, fn = 1076, tp = 38092
y_pred: 0 = 55959 | 1 = 39717
y_true: 0 = 56508 | 1 = 39168
auc=0.9968|sensitivity=0.9725|specificity=0.9712|acc=0.9718|mcc=0.9418
precision=0.9591|recall=0.9725|f1=0.9658|aupr=0.9953


100%|██████████| 24/24 [00:00<00:00, 36.92it/s]


Fold-1 ****Test  Epoch-11/150: Loss = 5.217054
tn = 10049, fp = 4073, fn = 5402, tp = 4396
y_pred: 0 = 15451 | 1 = 8469
y_true: 0 = 14122 | 1 = 9798
auc=0.5809|sensitivity=0.4487|specificity=0.7116|acc=0.6039|mcc=0.1648
precision=0.5191|recall=0.4487|f1=0.4813|aupr=0.4932


100%|██████████| 94/94 [00:05<00:00, 16.02it/s]


Fold-1****Train (Ep avg): Epoch-12/150 | Loss = 0.0608 | Time = 1.5130 sec
tn = 55123, fp = 1385, fn = 937, tp = 38231
y_pred: 0 = 56060 | 1 = 39616
y_true: 0 = 56508 | 1 = 39168
auc=0.9976|sensitivity=0.9761|specificity=0.9755|acc=0.9757|mcc=0.9499
precision=0.9650|recall=0.9761|f1=0.9705|aupr=0.9965


100%|██████████| 24/24 [00:00<00:00, 38.98it/s]


Fold-1 ****Test  Epoch-12/150: Loss = 5.287415
tn = 9828, fp = 4294, fn = 5293, tp = 4505
y_pred: 0 = 15121 | 1 = 8799
y_true: 0 = 14122 | 1 = 9798
auc=0.5861|sensitivity=0.4598|specificity=0.6959|acc=0.5992|mcc=0.1588
precision=0.5120|recall=0.4598|f1=0.4845|aupr=0.5025


100%|██████████| 94/94 [00:05<00:00, 15.76it/s]


Fold-1****Train (Ep avg): Epoch-13/150 | Loss = 0.0551 | Time = 1.5372 sec
tn = 55289, fp = 1219, fn = 897, tp = 38271
y_pred: 0 = 56186 | 1 = 39490
y_true: 0 = 56508 | 1 = 39168
auc=0.9980|sensitivity=0.9771|specificity=0.9784|acc=0.9779|mcc=0.9543
precision=0.9691|recall=0.9771|f1=0.9731|aupr=0.9972


100%|██████████| 24/24 [00:00<00:00, 34.87it/s]


Fold-1 ****Test  Epoch-13/150: Loss = 5.566410
tn = 9903, fp = 4219, fn = 5227, tp = 4571
y_pred: 0 = 15130 | 1 = 8790
y_true: 0 = 14122 | 1 = 9798
auc=0.5934|sensitivity=0.4665|specificity=0.7012|acc=0.6051|mcc=0.1711
precision=0.5200|recall=0.4665|f1=0.4918|aupr=0.5147


100%|██████████| 94/94 [00:06<00:00, 15.62it/s]


Fold-1****Train (Ep avg): Epoch-14/150 | Loss = 0.0504 | Time = 1.5489 sec
tn = 55396, fp = 1112, fn = 833, tp = 38335
y_pred: 0 = 56229 | 1 = 39447
y_true: 0 = 56508 | 1 = 39168
auc=0.9984|sensitivity=0.9787|specificity=0.9803|acc=0.9797|mcc=0.9580
precision=0.9718|recall=0.9787|f1=0.9753|aupr=0.9977


100%|██████████| 24/24 [00:00<00:00, 28.66it/s]


Fold-1 ****Test  Epoch-14/150: Loss = 6.298601
tn = 9907, fp = 4215, fn = 5216, tp = 4582
y_pred: 0 = 15123 | 1 = 8797
y_true: 0 = 14122 | 1 = 9798
auc=0.5895|sensitivity=0.4676|specificity=0.7015|acc=0.6057|mcc=0.1725
precision=0.5209|recall=0.4676|f1=0.4928|aupr=0.5252


100%|██████████| 94/94 [00:06<00:00, 15.28it/s]


Fold-1****Train (Ep avg): Epoch-15/150 | Loss = 0.0462 | Time = 1.5723 sec
tn = 55548, fp = 960, fn = 763, tp = 38405
y_pred: 0 = 56311 | 1 = 39365
y_true: 0 = 56508 | 1 = 39168
auc=0.9986|sensitivity=0.9805|specificity=0.9830|acc=0.9820|mcc=0.9628
precision=0.9756|recall=0.9805|f1=0.9781|aupr=0.9980


100%|██████████| 24/24 [00:00<00:00, 37.11it/s]


Fold-1 ****Test  Epoch-15/150: Loss = 6.339417
tn = 10171, fp = 3951, fn = 5354, tp = 4444
y_pred: 0 = 15525 | 1 = 8395
y_true: 0 = 14122 | 1 = 9798
auc=0.5898|sensitivity=0.4536|specificity=0.7202|acc=0.6110|mcc=0.1791
precision=0.5294|recall=0.4536|f1=0.4885|aupr=0.5087


100%|██████████| 94/94 [00:05<00:00, 15.71it/s]


Fold-1****Train (Ep avg): Epoch-16/150 | Loss = 0.0402 | Time = 1.5469 sec
tn = 55677, fp = 831, fn = 679, tp = 38489
y_pred: 0 = 56356 | 1 = 39320
y_true: 0 = 56508 | 1 = 39168
auc=0.9990|sensitivity=0.9827|specificity=0.9853|acc=0.9842|mcc=0.9674
precision=0.9789|recall=0.9827|f1=0.9808|aupr=0.9985


100%|██████████| 24/24 [00:00<00:00, 36.84it/s]


Fold-1 ****Test  Epoch-16/150: Loss = 6.091587
tn = 10285, fp = 3837, fn = 5311, tp = 4487
y_pred: 0 = 15596 | 1 = 8324
y_true: 0 = 14122 | 1 = 9798
auc=0.5993|sensitivity=0.4580|specificity=0.7283|acc=0.6176|mcc=0.1923
precision=0.5390|recall=0.4580|f1=0.4952|aupr=0.5215
****Saving model: Best epoch = 16 | 5metrics_Best_avg = 0.4761
*****Path saver:  ./tcr_st_layer1_multihead5_fold1_netmhcpan.pkl


100%|██████████| 94/94 [00:06<00:00, 15.44it/s]


Fold-1****Train (Ep avg): Epoch-17/150 | Loss = 0.0367 | Time = 1.5583 sec
tn = 55801, fp = 707, fn = 605, tp = 38563
y_pred: 0 = 56406 | 1 = 39270
y_true: 0 = 56508 | 1 = 39168
auc=0.9991|sensitivity=0.9846|specificity=0.9875|acc=0.9863|mcc=0.9717
precision=0.9820|recall=0.9846|f1=0.9833|aupr=0.9988


100%|██████████| 24/24 [00:00<00:00, 38.88it/s]


Fold-1 ****Test  Epoch-17/150: Loss = 5.768851
tn = 10105, fp = 4017, fn = 5270, tp = 4528
y_pred: 0 = 15375 | 1 = 8545
y_true: 0 = 14122 | 1 = 9798
auc=0.5959|sensitivity=0.4621|specificity=0.7156|acc=0.6117|mcc=0.1823
precision=0.5299|recall=0.4621|f1=0.4937|aupr=0.5201


100%|██████████| 94/94 [00:06<00:00, 15.61it/s]


Fold-1****Train (Ep avg): Epoch-18/150 | Loss = 0.0330 | Time = 1.5485 sec
tn = 55830, fp = 678, fn = 565, tp = 38603
y_pred: 0 = 56395 | 1 = 39281
y_true: 0 = 56508 | 1 = 39168
auc=0.9993|sensitivity=0.9856|specificity=0.9880|acc=0.9870|mcc=0.9731
precision=0.9827|recall=0.9856|f1=0.9842|aupr=0.9990


100%|██████████| 24/24 [00:00<00:00, 30.33it/s]


Fold-1 ****Test  Epoch-18/150: Loss = 6.344634
tn = 9988, fp = 4134, fn = 5229, tp = 4569
y_pred: 0 = 15217 | 1 = 8703
y_true: 0 = 14122 | 1 = 9798
auc=0.5950|sensitivity=0.4663|specificity=0.7073|acc=0.6086|mcc=0.1774
precision=0.5250|recall=0.4663|f1=0.4939|aupr=0.5253


100%|██████████| 94/94 [00:06<00:00, 15.52it/s]


Fold-1****Train (Ep avg): Epoch-19/150 | Loss = 0.0275 | Time = 1.5498 sec
tn = 55957, fp = 551, fn = 460, tp = 38708
y_pred: 0 = 56417 | 1 = 39259
y_true: 0 = 56508 | 1 = 39168
auc=0.9995|sensitivity=0.9883|specificity=0.9902|acc=0.9894|mcc=0.9782
precision=0.9860|recall=0.9883|f1=0.9871|aupr=0.9993


100%|██████████| 24/24 [00:00<00:00, 36.96it/s]


Fold-1 ****Test  Epoch-19/150: Loss = 7.112957
tn = 9740, fp = 4382, fn = 5111, tp = 4687
y_pred: 0 = 14851 | 1 = 9069
y_true: 0 = 14122 | 1 = 9798
auc=0.5906|sensitivity=0.4784|specificity=0.6897|acc=0.6031|mcc=0.1703
precision=0.5168|recall=0.4784|f1=0.4968|aupr=0.5243


100%|██████████| 94/94 [00:05<00:00, 15.95it/s]


Fold-1****Train (Ep avg): Epoch-20/150 | Loss = 0.0255 | Time = 1.5105 sec
tn = 56011, fp = 497, fn = 424, tp = 38744
y_pred: 0 = 56435 | 1 = 39241
y_true: 0 = 56508 | 1 = 39168
auc=0.9996|sensitivity=0.9892|specificity=0.9912|acc=0.9904|mcc=0.9801
precision=0.9873|recall=0.9892|f1=0.9883|aupr=0.9994


100%|██████████| 24/24 [00:00<00:00, 36.40it/s]


Fold-1 ****Test  Epoch-20/150: Loss = 7.905073
tn = 9662, fp = 4460, fn = 5140, tp = 4658
y_pred: 0 = 14802 | 1 = 9118
y_true: 0 = 14122 | 1 = 9798
auc=0.5885|sensitivity=0.4754|specificity=0.6842|acc=0.5987|mcc=0.1616
precision=0.5109|recall=0.4754|f1=0.4925|aupr=0.5382


100%|██████████| 94/94 [00:06<00:00, 15.54it/s]


Fold-1****Train (Ep avg): Epoch-21/150 | Loss = 0.0225 | Time = 1.5516 sec
tn = 56087, fp = 421, fn = 348, tp = 38820
y_pred: 0 = 56435 | 1 = 39241
y_true: 0 = 56508 | 1 = 39168
auc=0.9997|sensitivity=0.9911|specificity=0.9925|acc=0.9920|mcc=0.9834
precision=0.9893|recall=0.9911|f1=0.9902|aupr=0.9995


100%|██████████| 24/24 [00:00<00:00, 29.85it/s]


Fold-1 ****Test  Epoch-21/150: Loss = 8.223892
tn = 9551, fp = 4571, fn = 5071, tp = 4727
y_pred: 0 = 14622 | 1 = 9298
y_true: 0 = 14122 | 1 = 9798
auc=0.5823|sensitivity=0.4824|specificity=0.6763|acc=0.5969|mcc=0.1602
precision=0.5084|recall=0.4824|f1=0.4951|aupr=0.5295


100%|██████████| 94/94 [00:06<00:00, 15.60it/s]


Fold-1****Train (Ep avg): Epoch-22/150 | Loss = 0.0187 | Time = 1.5620 sec
tn = 56158, fp = 350, fn = 290, tp = 38878
y_pred: 0 = 56448 | 1 = 39228
y_true: 0 = 56508 | 1 = 39168
auc=0.9998|sensitivity=0.9926|specificity=0.9938|acc=0.9933|mcc=0.9862
precision=0.9911|recall=0.9926|f1=0.9918|aupr=0.9997


100%|██████████| 24/24 [00:00<00:00, 37.72it/s]


Fold-1 ****Test  Epoch-22/150: Loss = 8.301046
tn = 9639, fp = 4483, fn = 5121, tp = 4677
y_pred: 0 = 14760 | 1 = 9160
y_true: 0 = 14122 | 1 = 9798
auc=0.5921|sensitivity=0.4773|specificity=0.6826|acc=0.5985|mcc=0.1618
precision=0.5106|recall=0.4773|f1=0.4934|aupr=0.5361


100%|██████████| 94/94 [00:06<00:00, 15.62it/s]


Fold-1****Train (Ep avg): Epoch-23/150 | Loss = 0.0200 | Time = 1.5472 sec
tn = 56150, fp = 358, fn = 316, tp = 38852
y_pred: 0 = 56466 | 1 = 39210
y_true: 0 = 56508 | 1 = 39168
auc=0.9997|sensitivity=0.9919|specificity=0.9937|acc=0.9930|mcc=0.9854
precision=0.9909|recall=0.9919|f1=0.9914|aupr=0.9996


100%|██████████| 24/24 [00:00<00:00, 34.95it/s]


Fold-1 ****Test  Epoch-23/150: Loss = 8.527172
tn = 9804, fp = 4318, fn = 5145, tp = 4653
y_pred: 0 = 14949 | 1 = 8971
y_true: 0 = 14122 | 1 = 9798
auc=0.5876|sensitivity=0.4749|specificity=0.6942|acc=0.6044|mcc=0.1718
precision=0.5187|recall=0.4749|f1=0.4958|aupr=0.5349


100%|██████████| 94/94 [00:05<00:00, 16.58it/s]


Fold-1****Train (Ep avg): Epoch-24/150 | Loss = 0.0183 | Time = 1.4888 sec
tn = 56174, fp = 334, fn = 278, tp = 38890
y_pred: 0 = 56452 | 1 = 39224
y_true: 0 = 56508 | 1 = 39168
auc=0.9998|sensitivity=0.9929|specificity=0.9941|acc=0.9936|mcc=0.9868
precision=0.9915|recall=0.9929|f1=0.9922|aupr=0.9996


100%|██████████| 24/24 [00:00<00:00, 37.22it/s]


Fold-1 ****Test  Epoch-24/150: Loss = 8.173071
tn = 9914, fp = 4208, fn = 5259, tp = 4539
y_pred: 0 = 15173 | 1 = 8747
y_true: 0 = 14122 | 1 = 9798
auc=0.5932|sensitivity=0.4633|specificity=0.7020|acc=0.6042|mcc=0.1688
precision=0.5189|recall=0.4633|f1=0.4895|aupr=0.5326


100%|██████████| 94/94 [00:05<00:00, 15.91it/s]


Fold-1****Train (Ep avg): Epoch-25/150 | Loss = 0.0171 | Time = 1.5290 sec
tn = 56186, fp = 322, fn = 274, tp = 38894
y_pred: 0 = 56460 | 1 = 39216
y_true: 0 = 56508 | 1 = 39168
auc=0.9998|sensitivity=0.9930|specificity=0.9943|acc=0.9938|mcc=0.9871
precision=0.9918|recall=0.9930|f1=0.9924|aupr=0.9997


100%|██████████| 24/24 [00:00<00:00, 31.24it/s]


Fold-1 ****Test  Epoch-25/150: Loss = 8.823400
tn = 10026, fp = 4096, fn = 5349, tp = 4449
y_pred: 0 = 15375 | 1 = 8545
y_true: 0 = 14122 | 1 = 9798
auc=0.5894|sensitivity=0.4541|specificity=0.7100|acc=0.6051|mcc=0.1683
precision=0.5207|recall=0.4541|f1=0.4851|aupr=0.5327


100%|██████████| 94/94 [00:05<00:00, 15.88it/s]


Fold-1****Train (Ep avg): Epoch-26/150 | Loss = 0.0147 | Time = 1.5185 sec
tn = 56251, fp = 257, fn = 236, tp = 38932
y_pred: 0 = 56487 | 1 = 39189
y_true: 0 = 56508 | 1 = 39168
auc=0.9999|sensitivity=0.9940|specificity=0.9955|acc=0.9948|mcc=0.9893
precision=0.9934|recall=0.9940|f1=0.9937|aupr=0.9998


100%|██████████| 24/24 [00:00<00:00, 36.67it/s]


Fold-1 ****Test  Epoch-26/150: Loss = 8.556540
tn = 9902, fp = 4220, fn = 5297, tp = 4501
y_pred: 0 = 15199 | 1 = 8721
y_true: 0 = 14122 | 1 = 9798
auc=0.5780|sensitivity=0.4594|specificity=0.7012|acc=0.6021|mcc=0.1640
precision=0.5161|recall=0.4594|f1=0.4861|aupr=0.5264


100%|██████████| 94/94 [00:06<00:00, 15.44it/s]


Fold-1****Train (Ep avg): Epoch-27/150 | Loss = 0.0144 | Time = 1.5631 sec
tn = 56254, fp = 254, fn = 223, tp = 38945
y_pred: 0 = 56477 | 1 = 39199
y_true: 0 = 56508 | 1 = 39168
auc=0.9999|sensitivity=0.9943|specificity=0.9955|acc=0.9950|mcc=0.9897
precision=0.9935|recall=0.9943|f1=0.9939|aupr=0.9998


100%|██████████| 24/24 [00:00<00:00, 36.27it/s]


Fold-1 ****Test  Epoch-27/150: Loss = 9.245920
tn = 10018, fp = 4104, fn = 5325, tp = 4473
y_pred: 0 = 15343 | 1 = 8577
y_true: 0 = 14122 | 1 = 9798
auc=0.5906|sensitivity=0.4565|specificity=0.7094|acc=0.6058|mcc=0.1701
precision=0.5215|recall=0.4565|f1=0.4869|aupr=0.5337


100%|██████████| 94/94 [00:06<00:00, 15.52it/s]


Fold-1****Train (Ep avg): Epoch-28/150 | Loss = 0.0117 | Time = 1.5498 sec
tn = 56322, fp = 186, fn = 180, tp = 38988
y_pred: 0 = 56502 | 1 = 39174
y_true: 0 = 56508 | 1 = 39168
auc=0.9999|sensitivity=0.9954|specificity=0.9967|acc=0.9962|mcc=0.9921
precision=0.9953|recall=0.9954|f1=0.9953|aupr=0.9999


100%|██████████| 24/24 [00:00<00:00, 36.13it/s]


Fold-1 ****Test  Epoch-28/150: Loss = 9.153191
tn = 9827, fp = 4295, fn = 5314, tp = 4484
y_pred: 0 = 15141 | 1 = 8779
y_true: 0 = 14122 | 1 = 9798
auc=0.5894|sensitivity=0.4576|specificity=0.6959|acc=0.5983|mcc=0.1566
precision=0.5108|recall=0.4576|f1=0.4827|aupr=0.5306


100%|██████████| 94/94 [00:06<00:00, 15.44it/s]


Fold-1****Train (Ep avg): Epoch-29/150 | Loss = 0.0106 | Time = 1.5653 sec
tn = 56334, fp = 174, fn = 150, tp = 39018
y_pred: 0 = 56484 | 1 = 39192
y_true: 0 = 56508 | 1 = 39168
auc=0.9999|sensitivity=0.9962|specificity=0.9969|acc=0.9966|mcc=0.9930
precision=0.9956|recall=0.9962|f1=0.9959|aupr=0.9999


100%|██████████| 24/24 [00:00<00:00, 28.95it/s]


Fold-1 ****Test  Epoch-29/150: Loss = 9.324862
tn = 9777, fp = 4345, fn = 5192, tp = 4606
y_pred: 0 = 14969 | 1 = 8951
y_true: 0 = 14122 | 1 = 9798
auc=0.5909|sensitivity=0.4701|specificity=0.6923|acc=0.6013|mcc=0.1651
precision=0.5146|recall=0.4701|f1=0.4913|aupr=0.5488


100%|██████████| 94/94 [00:05<00:00, 16.51it/s]


Fold-1****Train (Ep avg): Epoch-30/150 | Loss = 0.0096 | Time = 1.4775 sec
tn = 56344, fp = 164, fn = 149, tp = 39019
y_pred: 0 = 56493 | 1 = 39183
y_true: 0 = 56508 | 1 = 39168
auc=0.9999|sensitivity=0.9962|specificity=0.9971|acc=0.9967|mcc=0.9932
precision=0.9958|recall=0.9962|f1=0.9960|aupr=0.9999


100%|██████████| 24/24 [00:00<00:00, 38.91it/s]


Fold-1 ****Test  Epoch-30/150: Loss = 10.199240
tn = 9921, fp = 4201, fn = 5362, tp = 4436
y_pred: 0 = 15283 | 1 = 8637
y_true: 0 = 14122 | 1 = 9798
auc=0.5899|sensitivity=0.4527|specificity=0.7025|acc=0.6002|mcc=0.1590
precision=0.5136|recall=0.4527|f1=0.4813|aupr=0.5416


100%|██████████| 94/94 [00:05<00:00, 15.81it/s]


Fold-1****Train (Ep avg): Epoch-31/150 | Loss = 0.0107 | Time = 1.5207 sec
tn = 56326, fp = 182, fn = 167, tp = 39001
y_pred: 0 = 56493 | 1 = 39183
y_true: 0 = 56508 | 1 = 39168
auc=0.9999|sensitivity=0.9957|specificity=0.9968|acc=0.9964|mcc=0.9925
precision=0.9954|recall=0.9957|f1=0.9955|aupr=0.9999


100%|██████████| 24/24 [00:00<00:00, 36.61it/s]


Fold-1 ****Test  Epoch-31/150: Loss = 9.476794
tn = 9880, fp = 4242, fn = 5257, tp = 4541
y_pred: 0 = 15137 | 1 = 8783
y_true: 0 = 14122 | 1 = 9798
auc=0.5902|sensitivity=0.4635|specificity=0.6996|acc=0.6029|mcc=0.1664
precision=0.5170|recall=0.4635|f1=0.4888|aupr=0.5423


100%|██████████| 94/94 [00:05<00:00, 16.56it/s]


Fold-1****Train (Ep avg): Epoch-32/150 | Loss = 0.0098 | Time = 1.4679 sec
tn = 56327, fp = 181, fn = 151, tp = 39017
y_pred: 0 = 56478 | 1 = 39198
y_true: 0 = 56508 | 1 = 39168
auc=0.9999|sensitivity=0.9961|specificity=0.9968|acc=0.9965|mcc=0.9928
precision=0.9954|recall=0.9961|f1=0.9958|aupr=0.9999


100%|██████████| 24/24 [00:00<00:00, 30.70it/s]


Fold-1 ****Test  Epoch-32/150: Loss = 9.346182
tn = 10052, fp = 4070, fn = 5209, tp = 4589
y_pred: 0 = 15261 | 1 = 8659
y_true: 0 = 14122 | 1 = 9798
auc=0.6083|sensitivity=0.4684|specificity=0.7118|acc=0.6121|mcc=0.1844
precision=0.5300|recall=0.4684|f1=0.4973|aupr=0.5553


100%|██████████| 94/94 [00:06<00:00, 15.50it/s]


Fold-1****Train (Ep avg): Epoch-33/150 | Loss = 0.0101 | Time = 1.5588 sec
tn = 56336, fp = 172, fn = 170, tp = 38998
y_pred: 0 = 56506 | 1 = 39170
y_true: 0 = 56508 | 1 = 39168
auc=0.9999|sensitivity=0.9957|specificity=0.9970|acc=0.9964|mcc=0.9926
precision=0.9956|recall=0.9957|f1=0.9956|aupr=0.9999


100%|██████████| 24/24 [00:00<00:00, 38.87it/s]


Fold-1 ****Test  Epoch-33/150: Loss = 8.921778
tn = 9949, fp = 4173, fn = 5251, tp = 4547
y_pred: 0 = 15200 | 1 = 8720
y_true: 0 = 14122 | 1 = 9798
auc=0.5971|sensitivity=0.4641|specificity=0.7045|acc=0.6060|mcc=0.1722
precision=0.5214|recall=0.4641|f1=0.4911|aupr=0.5449


100%|██████████| 94/94 [00:05<00:00, 15.79it/s]


Fold-1****Train (Ep avg): Epoch-34/150 | Loss = 0.0102 | Time = 1.5248 sec
tn = 56322, fp = 186, fn = 172, tp = 38996
y_pred: 0 = 56494 | 1 = 39182
y_true: 0 = 56508 | 1 = 39168
auc=0.9999|sensitivity=0.9956|specificity=0.9967|acc=0.9963|mcc=0.9923
precision=0.9953|recall=0.9956|f1=0.9954|aupr=0.9999


100%|██████████| 24/24 [00:00<00:00, 36.47it/s]


Fold-1 ****Test  Epoch-34/150: Loss = 10.269807
tn = 9849, fp = 4273, fn = 5278, tp = 4520
y_pred: 0 = 15127 | 1 = 8793
y_true: 0 = 14122 | 1 = 9798
auc=0.5938|sensitivity=0.4613|specificity=0.6974|acc=0.6007|mcc=0.1619
precision=0.5140|recall=0.4613|f1=0.4863|aupr=0.5480


100%|██████████| 94/94 [00:06<00:00, 15.55it/s]


Fold-1****Train (Ep avg): Epoch-35/150 | Loss = 0.0101 | Time = 1.5443 sec
tn = 56323, fp = 185, fn = 165, tp = 39003
y_pred: 0 = 56488 | 1 = 39188
y_true: 0 = 56508 | 1 = 39168
auc=0.9999|sensitivity=0.9958|specificity=0.9967|acc=0.9963|mcc=0.9924
precision=0.9953|recall=0.9958|f1=0.9955|aupr=0.9999


100%|██████████| 24/24 [00:00<00:00, 37.15it/s]


Fold-1 ****Test  Epoch-35/150: Loss = 10.566199
tn = 9842, fp = 4280, fn = 5277, tp = 4521
y_pred: 0 = 15119 | 1 = 8801
y_true: 0 = 14122 | 1 = 9798
auc=0.5884|sensitivity=0.4614|specificity=0.6969|acc=0.6005|mcc=0.1615
precision=0.5137|recall=0.4614|f1=0.4862|aupr=0.5411


100%|██████████| 94/94 [00:05<00:00, 16.29it/s]


Fold-1****Train (Ep avg): Epoch-36/150 | Loss = 0.0117 | Time = 1.4945 sec
tn = 56287, fp = 221, fn = 183, tp = 38985
y_pred: 0 = 56470 | 1 = 39206
y_true: 0 = 56508 | 1 = 39168
auc=0.9999|sensitivity=0.9953|specificity=0.9961|acc=0.9958|mcc=0.9913
precision=0.9944|recall=0.9953|f1=0.9948|aupr=0.9999


100%|██████████| 24/24 [00:00<00:00, 31.30it/s]


Fold-1 ****Test  Epoch-36/150: Loss = 9.796679
tn = 10013, fp = 4109, fn = 5420, tp = 4378
y_pred: 0 = 15433 | 1 = 8487
y_true: 0 = 14122 | 1 = 9798
auc=0.5871|sensitivity=0.4468|specificity=0.7090|acc=0.6016|mcc=0.1602
precision=0.5158|recall=0.4468|f1=0.4789|aupr=0.5296


100%|██████████| 94/94 [00:05<00:00, 16.07it/s]


Fold-1****Train (Ep avg): Epoch-37/150 | Loss = 0.0140 | Time = 1.4998 sec
tn = 56240, fp = 268, fn = 222, tp = 38946
y_pred: 0 = 56462 | 1 = 39214
y_true: 0 = 56508 | 1 = 39168
auc=0.9999|sensitivity=0.9943|specificity=0.9953|acc=0.9949|mcc=0.9894
precision=0.9932|recall=0.9943|f1=0.9937|aupr=0.9998


100%|██████████| 24/24 [00:00<00:00, 37.09it/s]


Fold-1 ****Test  Epoch-37/150: Loss = 9.734117
tn = 10064, fp = 4058, fn = 5479, tp = 4319
y_pred: 0 = 15543 | 1 = 8377
y_true: 0 = 14122 | 1 = 9798
auc=0.5962|sensitivity=0.4408|specificity=0.7126|acc=0.6013|mcc=0.1582
precision=0.5156|recall=0.4408|f1=0.4753|aupr=0.5353


100%|██████████| 94/94 [00:05<00:00, 16.42it/s]


Fold-1****Train (Ep avg): Epoch-38/150 | Loss = 0.0123 | Time = 1.4675 sec
tn = 56278, fp = 230, fn = 196, tp = 38972
y_pred: 0 = 56474 | 1 = 39202
y_true: 0 = 56508 | 1 = 39168
auc=0.9999|sensitivity=0.9950|specificity=0.9959|acc=0.9955|mcc=0.9908
precision=0.9941|recall=0.9950|f1=0.9946|aupr=0.9998


100%|██████████| 24/24 [00:00<00:00, 35.61it/s]


Fold-1 ****Test  Epoch-38/150: Loss = 10.324287
tn = 9941, fp = 4181, fn = 5179, tp = 4619
y_pred: 0 = 15120 | 1 = 8800
y_true: 0 = 14122 | 1 = 9798
auc=0.6006|sensitivity=0.4714|specificity=0.7039|acc=0.6087|mcc=0.1788
precision=0.5249|recall=0.4714|f1=0.4967|aupr=0.5626


100%|██████████| 94/94 [00:06<00:00, 15.51it/s]


Fold-1****Train (Ep avg): Epoch-39/150 | Loss = 0.0095 | Time = 1.5439 sec
tn = 56342, fp = 166, fn = 147, tp = 39021
y_pred: 0 = 56489 | 1 = 39187
y_true: 0 = 56508 | 1 = 39168
auc=0.9999|sensitivity=0.9962|specificity=0.9971|acc=0.9967|mcc=0.9932
precision=0.9958|recall=0.9962|f1=0.9960|aupr=0.9999


100%|██████████| 24/24 [00:00<00:00, 36.27it/s]


Fold-1 ****Test  Epoch-39/150: Loss = 10.444735
tn = 9923, fp = 4199, fn = 5235, tp = 4563
y_pred: 0 = 15158 | 1 = 8762
y_true: 0 = 14122 | 1 = 9798
auc=0.5928|sensitivity=0.4657|specificity=0.7027|acc=0.6056|mcc=0.1719
precision=0.5208|recall=0.4657|f1=0.4917|aupr=0.5593


100%|██████████| 94/94 [00:06<00:00, 15.31it/s]


Fold-1****Train (Ep avg): Epoch-40/150 | Loss = 0.0080 | Time = 1.5368 sec
tn = 56367, fp = 141, fn = 130, tp = 39038
y_pred: 0 = 56497 | 1 = 39179
y_true: 0 = 56508 | 1 = 39168
auc=1.0000|sensitivity=0.9967|specificity=0.9975|acc=0.9972|mcc=0.9941
precision=0.9964|recall=0.9967|f1=0.9965|aupr=0.9999


100%|██████████| 24/24 [00:00<00:00, 37.45it/s]


Fold-1 ****Test  Epoch-40/150: Loss = 10.625496
tn = 9746, fp = 4376, fn = 5141, tp = 4657
y_pred: 0 = 14887 | 1 = 9033
y_true: 0 = 14122 | 1 = 9798
auc=0.6041|sensitivity=0.4753|specificity=0.6901|acc=0.6021|mcc=0.1678
precision=0.5156|recall=0.4753|f1=0.4946|aupr=0.5692


100%|██████████| 94/94 [00:05<00:00, 15.90it/s]


Fold-1****Train (Ep avg): Epoch-41/150 | Loss = 0.0063 | Time = 1.5230 sec
tn = 56390, fp = 118, fn = 86, tp = 39082
y_pred: 0 = 56476 | 1 = 39200
y_true: 0 = 56508 | 1 = 39168
auc=1.0000|sensitivity=0.9978|specificity=0.9979|acc=0.9979|mcc=0.9956
precision=0.9970|recall=0.9978|f1=0.9974|aupr=1.0000


100%|██████████| 24/24 [00:00<00:00, 39.02it/s]


Fold-1 ****Test  Epoch-41/150: Loss = 11.666096
tn = 9788, fp = 4334, fn = 5163, tp = 4635
y_pred: 0 = 14951 | 1 = 8969
y_true: 0 = 14122 | 1 = 9798
auc=0.5947|sensitivity=0.4731|specificity=0.6931|acc=0.6030|mcc=0.1688
precision=0.5168|recall=0.4731|f1=0.4940|aupr=0.5669


100%|██████████| 94/94 [00:05<00:00, 16.14it/s]


Fold-1****Train (Ep avg): Epoch-42/150 | Loss = 0.0057 | Time = 1.5085 sec
tn = 56404, fp = 104, fn = 88, tp = 39080
y_pred: 0 = 56492 | 1 = 39184
y_true: 0 = 56508 | 1 = 39168
auc=1.0000|sensitivity=0.9978|specificity=0.9982|acc=0.9980|mcc=0.9959
precision=0.9973|recall=0.9978|f1=0.9975|aupr=1.0000


100%|██████████| 24/24 [00:00<00:00, 35.90it/s]


Fold-1 ****Test  Epoch-42/150: Loss = 11.001197
tn = 9961, fp = 4161, fn = 5246, tp = 4552
y_pred: 0 = 15207 | 1 = 8713
y_true: 0 = 14122 | 1 = 9798
auc=0.5953|sensitivity=0.4646|specificity=0.7054|acc=0.6067|mcc=0.1737
precision=0.5224|recall=0.4646|f1=0.4918|aupr=0.5605


100%|██████████| 94/94 [00:05<00:00, 15.80it/s]


Fold-1****Train (Ep avg): Epoch-43/150 | Loss = 0.0070 | Time = 1.5298 sec
tn = 56382, fp = 126, fn = 120, tp = 39048
y_pred: 0 = 56502 | 1 = 39174
y_true: 0 = 56508 | 1 = 39168
auc=1.0000|sensitivity=0.9969|specificity=0.9978|acc=0.9974|mcc=0.9947
precision=0.9968|recall=0.9969|f1=0.9969|aupr=0.9999


100%|██████████| 24/24 [00:00<00:00, 30.21it/s]


Fold-1 ****Test  Epoch-43/150: Loss = 10.228483
tn = 9807, fp = 4315, fn = 5182, tp = 4616
y_pred: 0 = 14989 | 1 = 8931
y_true: 0 = 14122 | 1 = 9798
auc=0.6075|sensitivity=0.4711|specificity=0.6944|acc=0.6030|mcc=0.1683
precision=0.5169|recall=0.4711|f1=0.4929|aupr=0.5616


100%|██████████| 94/94 [00:06<00:00, 15.42it/s]


Fold-1****Train (Ep avg): Epoch-44/150 | Loss = 0.0078 | Time = 1.5583 sec
tn = 56368, fp = 140, fn = 121, tp = 39047
y_pred: 0 = 56489 | 1 = 39187
y_true: 0 = 56508 | 1 = 39168
auc=1.0000|sensitivity=0.9969|specificity=0.9975|acc=0.9973|mcc=0.9944
precision=0.9964|recall=0.9969|f1=0.9967|aupr=0.9999


100%|██████████| 24/24 [00:00<00:00, 36.82it/s]


Fold-1 ****Test  Epoch-44/150: Loss = 12.157547
tn = 9903, fp = 4219, fn = 5241, tp = 4557
y_pred: 0 = 15144 | 1 = 8776
y_true: 0 = 14122 | 1 = 9798
auc=0.5930|sensitivity=0.4651|specificity=0.7012|acc=0.6045|mcc=0.1697
precision=0.5193|recall=0.4651|f1=0.4907|aupr=0.5606


100%|██████████| 94/94 [00:06<00:00, 15.42it/s]


Fold-1****Train (Ep avg): Epoch-45/150 | Loss = 0.0091 | Time = 1.5504 sec
tn = 56364, fp = 144, fn = 143, tp = 39025
y_pred: 0 = 56507 | 1 = 39169
y_true: 0 = 56508 | 1 = 39168
auc=0.9999|sensitivity=0.9963|specificity=0.9975|acc=0.9970|mcc=0.9938
precision=0.9963|recall=0.9963|f1=0.9963|aupr=0.9999


100%|██████████| 24/24 [00:00<00:00, 36.36it/s]


Fold-1 ****Test  Epoch-45/150: Loss = 10.641571
tn = 9800, fp = 4322, fn = 5155, tp = 4643
y_pred: 0 = 14955 | 1 = 8965
y_true: 0 = 14122 | 1 = 9798
auc=0.5999|sensitivity=0.4739|specificity=0.6940|acc=0.6038|mcc=0.1705
precision=0.5179|recall=0.4739|f1=0.4949|aupr=0.5641


 88%|████████▊ | 83/94 [00:05<00:00, 16.26it/s]


KeyboardInterrupt: 

In [None]:
for n_heads in range(5,10):
    for fold in range(1,6):

        path_saver = './VDJ_10X_McPAS_随机错配_1V5/tcr_st_layer1_multihead{}_fold{}_netmhcpan.pkl'.format(n_heads,fold)
#         path_saver = 'model/tcr_st_layer1_multihead8_fold0_netmhcpan.pkl'
# model = STSeqCls((21, 100), num_cls=2, hidden_size=300, num_layers=1, num_head=8, max_len=29,cls_hidden_size=600,dropout=0.1,head_dim=32).to(device)
        model = Transformer().to(device)
        model.load_state_dict(torch.load(path_saver))
# model_eval = model.eval()
        type_ = 'val'
        save_ = False
        use_cuda = True
        device = torch.device("cuda" if use_cuda else "cpu")
        criterion = nn.CrossEntropyLoss()
        # fold = 0
        ep_best = None
        print("n_head is:"+str(n_heads))

        data, pep_inputs, hla_inputs, labels, loader,_ = data_with_loader(type_,fold = fold,  batch_size = batch_size)
        independent_metrics_res, independent_ys_res, independent_attn_res = eval_step(model, loader, fold, ep_best, epochs, use_cuda)



In [12]:
def data_with_loader(type_ = 'train',fold = None,  batch_size = 128):
    if type_ != 'train' and type_ != 'val':
#         data = pd.read_csv('../data/justina_test.csv')
        data = pd.read_csv('./inputs/inputs_bd.csv')
#         data = pd.read_csv('../data/test/GILGLVFTL.csv')
#         data = pd.read_csv('../data/posi_length.csv')
        
    elif type_ == 'train':
        data = pd.read_csv('./Combine_10Xneg/Com_train_B.csv')

    elif type_ == 'val':
        data = pd.read_csv('./Combine_10Xneg/Com_eva_test_B.csv')

    pep_inputs, hla_inputs,labels = make_data(data)
#     print(labels)
    loader = Data.DataLoader(MyDataSet(pep_inputs, hla_inputs,labels), batch_size, shuffle = True, num_workers = 0)
    n_samples = len(pep_inputs)
    len_cdr3 = len(hla_inputs[0])
    len_epi = len(pep_inputs[0])
    encoding_mask = np.zeros([n_samples, len_cdr3,len_epi])
    for idx_sample, (enc_cdr3_this, enc_epi_this) in enumerate(zip(hla_inputs, pep_inputs)):
        mask = np.ones([len_cdr3,len_epi])
        zero_cdr3 = (enc_cdr3_this == 0)
        mask[zero_cdr3,:] = 0
        zero_epi = (enc_epi_this == 0)
        mask[:,zero_epi] = 0
#         print(mask.shape)
        encoding_mask[idx_sample] = mask
    return data, pep_inputs, hla_inputs, labels,loader,encoding_mask

In [13]:
for n_heads in range(5,6):
    
    ys_train_fold_dict, ys_val_fold_dict = {}, {}
    train_fold_metrics_list, val_fold_metrics_list = [], []
    independent_fold_metrics_list, external_fold_metrics_list, ys_independent_fold_dict, ys_external_fold_dict = [], [], {}, {}
    attns_train_fold_dict, attns_val_fold_dict, attns_independent_fold_dict, attns_external_fold_dict = {}, {}, {}, {}
    loss_train_fold_dict, loss_val_fold_dict, loss_independent_fold_dict, loss_external_fold_dict = {}, {}, {}, {}

    for fold in range(1,6):
        print('=====Fold-{}====='.format(fold))
        print('-----Generate data loader-----')
        train_data, train_pep_inputs, train_hla_inputs, train_labels, train_loader,_ = data_with_loader(type_ = 'train', fold = fold,  batch_size = batch_size)
        val_data, val_pep_inputs, val_hla_inputs, val_labels, val_loader,_ = data_with_loader(type_ = 'val', fold = fold,  batch_size = batch_size)
        print('Fold-{} Label info: Train = {} | Val = {}'.format(fold, Counter(train_data.label), Counter(val_data.label)))

        print('-----Compile model-----')
        model = Transformer().to(device)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr = 1e-3)#, momentum = 0.99)

        print('-----Train-----')
        dir_saver = './model2'
    
        path_saver = './model2/tcr_st_layer{}_multihead{}_fold{}_netmhcpan.pkl'.format(n_layers, n_heads, fold)
        metric_best, ep_best = 0, -1
        time_train = 0
        for epoch in range(1, epochs + 1):

            ys_train, loss_train_list, metrics_train, time_train_ep = train_step(model, train_loader, fold, epoch, epochs, use_cuda) # , dec_attns_train
            ys_val, loss_val_list, metrics_val = eval_step(model, val_loader, fold, epoch, epochs, use_cuda) #, dec_attns_val

            metrics_ep_avg = sum(metrics_val[:4])/4
            if metrics_ep_avg > metric_best: 
                metric_best, ep_best = metrics_ep_avg, epoch
                if not os.path.exists(dir_saver):
                    os.makedirs(dir_saver)
                print('****Saving model: Best epoch = {} | 5metrics_Best_avg = {:.4f}'.format(ep_best, metric_best))
                print('*****Path saver: ', path_saver)
                torch.save(model.eval().state_dict(), path_saver)

            time_train += time_train_ep

        print('-----Optimization Finished!-----')
        print('-----Evaluate Results-----')
        if ep_best >= 0:
            print('*****Path saver: ', path_saver)
            model.load_state_dict(torch.load(path_saver))
            model_eval = model.eval()

            ys_res_train, loss_res_train_list, metrics_res_train = eval_step(model_eval, train_loader, fold, ep_best, epochs, use_cuda) # , train_res_attns
            ys_res_val, loss_res_val_list, metrics_res_val = eval_step(model_eval, val_loader, fold, ep_best, epochs, use_cuda) # , val_res_attns
#             ys_res_independent, loss_res_independent_list, metrics_res_independent = eval_step(model_eval, independent_loader, fold, ep_best, epochs, use_cuda) # , independent_res_attns
#             ys_res_external, loss_res_external_list, metrics_res_external = eval_step(model_eval, external_loader, fold, ep_best, epochs, use_cuda) # , external_res_attns

            train_fold_metrics_list.append(metrics_res_train)
            val_fold_metrics_list.append(metrics_res_val)
#             independent_fold_metrics_list.append(metrics_res_independent)
#             external_fold_metrics_list.append(metrics_res_external)

#             ys_train_fold_dict[fold], ys_val_fold_dict[fold], ys_independent_fold_dict[fold], ys_external_fold_dict[fold] = ys_res_train, ys_res_val, ys_res_independent, ys_res_external    
#             attns_train_fold_dict[fold], attns_val_fold_dict[fold], attns_independent_fold_dict[fold], attns_external_fold_dict[fold] = train_res_attns, val_res_attns, independent_res_attns, external_res_attns   
#             loss_train_fold_dict[fold], loss_val_fold_dict[fold], loss_independent_fold_dict[fold], loss_external_fold_dict[fold] = loss_res_train_list, loss_res_val_list, loss_res_independent_list, loss_res_external_list  

        print("Total training time: {:6.2f} sec".format(time_train))




=====Fold-1=====
-----Generate data loader-----


AssertionError: The cdr3 length must <= 20