In [1]:
import os
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import KFold
from myconfig.Config import *
from trans1utils.DataGenerater import *
from transformer.Transformer import Transformer

In [2]:
vocab,vocab_size = vocab_config()

hla_max_len,pep_max_len,tcr_max_len,hla_pep_concat_len,pep_tcr_concat_len = data_config()

d_model, d_ff, d_k, d_v, _, n_heads, epochs, batch_size, threshold, dropout_rate, max_len, device \
    = model_config()

seed = run_config()

# train_hla_pep_loader,test_hla_pep_loader \
#     = hla_pep_data_loader('hla_pep_dataset',vocab,hla_max_len,pep_max_len,batch_size,0.9)

train_pep_tcr_loader,test_pep_tcr_loader \
    = pep_tcr_data_loader('pep_tcr_dataset',vocab,pep_max_len,tcr_max_len,batch_size,0.9)

In [15]:
train_hla_pep_loader,test_hla_pep_loader \
    = hla_pep_data_loader('independent_set',vocab,hla_max_len,pep_max_len,batch_size,0.9)

In [None]:
for n_layers in range(1,10):
    for n_heads in range(1,6):
        for fold in range(5):
            
            train_loader = data_with_loader(vocab,'train',fold,hla_max_len,pep_max_len,batch_size)
            val_loader = data_with_loader(vocab,'val',fold,hla_max_len,pep_max_len,batch_size)

            model = Transformer(vocab_size,n_enc_layers=n_layers,n_enc_heads=n_heads,n_dec_layers=n_layers,n_dec_heads=n_heads,hla_pep_concat_len=hla_pep_concat_len,pep_tcr_concat_len=pep_tcr_concat_len).to(device)

            criterion = nn.CrossEntropyLoss()
            optimizer = optim.Adam(model.parameters(), lr = 1e-3)

            dir_saver = './model/'
            path_saver = './model/model_hp_layer{}_head{}_fold{}.pth'.format(n_layers, n_heads,fold)

            metric_best, ep_best,time_train = 0, -1, 0
            for epoch in range(1,epochs+1):
                _,_,_ \
                    = model.train_per_epoch(train_loader,epoch,epochs,criterion,optimizer,num_group=1,threshold=threshold,metrics_print_=True)

                _, _, metrics_val \
                    = model.eval_per_epoch(val_loader, epoch, epochs,criterion,num_group=1,threshold=threshold,metrics_print_=True)

                metrics_ep_avg = sum(metrics_val[:4])/4
                
                if metrics_ep_avg > metric_best:
                    metric_best, ep_best = metrics_ep_avg, epoch
                    if not os.path.exists(dir_saver):
                        os.makedirs(dir_saver)
                    print('Best epoch = {} | Best metrics_ep_avg = {:.4f}--------'.format(ep_best, metric_best))
                    print('Saving model Path saver: {} --------'.format(path_saver))
                    torch.save(model.state_dict(), path_saver)

In [None]:
# kf = KFold(n_splits=5)
# for n_layers in range(1,10):
#     for n_heads in range(1,6):
#         for fold,(train_index,val_index) in enumerate(kf.split(train_pep_tcr_loader.dataset)):
# 
#             train_pep_tcr_set = train_pep_tcr_loader.dataset[train_index]
#             val_pep_tcr_set = train_pep_tcr_loader.dataset[val_index]
# 
#             train_pep_tcr \
#                 = Data.DataLoader(PEP_TCR_DataSet(train_pep_tcr_set[0], train_pep_tcr_set[1],train_pep_tcr_set[2]), batch_size,shuffle=False, num_workers=0)
#             val_pep_tcr \
#                 = Data.DataLoader(PEP_TCR_DataSet(val_pep_tcr_set[0], val_pep_tcr_set[1],    val_pep_tcr_set[2]), batch_size,shuffle=False, num_workers=0)
# 
#             model \
#                 = Transformer(vocab_size,n_enc_layers=n_layers,n_enc_heads=n_heads,n_dec_layers=n_layers,n_dec_heads=n_heads,hla_pep_concat_len=hla_pep_concat_len,pep_tcr_concat_len=pep_tcr_concat_len).to(device)
# 
#             criterion = nn.CrossEntropyLoss()
#             optimizer = optim.Adam(model.parameters(), lr = 1e-3)
# 
#             dir_saver = './model/'
#             path_saver = './model/model_pt_layer{}_head{}_fold{}.pth'.format(n_layers, n_heads,fold)
# 
#             metric_best, ep_best,time_train = 0, -1, 0
#             for epoch in range(1,epochs+1):
#                 _,_,_ \
#                     = model.train_per_epoch(train_pep_tcr,epoch,epochs,criterion,optimizer,num_group=2,threshold=threshold,metrics_print_=True)
# 
#                 _, _, metrics_val \
#                     = model.eval_per_epoch(val_pep_tcr, epoch, epochs,criterion,num_group=2,threshold=threshold,metrics_print_=True)
# 
# 
#                 metrics_ep_avg = metrics_val[0] # auc
# 
#                 if metrics_ep_avg > metric_best:
#                     metric_best, ep_best = metrics_ep_avg, epoch
#                     if not os.path.exists(dir_saver):
#                         os.makedirs(dir_saver)
#                     print('Best epoch = {} | Best metrics_ep_avg = {:.4f}--------'.format(ep_best, metric_best))
#                     print('Saving model Path saver: {} --------'.format(path_saver))
#                     torch.save(model.state_dict(), path_saver)

In [5]:
model = Transformer(vocab_size,n_enc_layers=9,n_enc_heads=5,n_dec_layers=9,n_dec_heads=5,hla_pep_concat_len=hla_pep_concat_len,pep_tcr_concat_len=pep_tcr_concat_len,device=device).to(device)

In [6]:
model

Transformer(
  (hla_embedding): Embedding(
    (src_emb): Embedding(27, 64)
    (pos_emb): PositionEmbedding(
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (pep_embedding): Embedding(
    (src_emb): Embedding(27, 64)
    (pos_emb): PositionEmbedding(
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (tcr_embedding): Embedding(
    (src_emb): Embedding(27, 64)
    (pos_emb): PositionEmbedding(
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (hla_encoder): Encoder(
    (layers): ModuleList(
      (0): EncoderLayer(
        (multi_head_attention): MultiHeadAttention(
          (W_Q): Linear(in_features=64, out_features=320, bias=False)
          (W_K): Linear(in_features=64, out_features=320, bias=False)
          (W_V): Linear(in_features=64, out_features=320, bias=False)
          (FC): Linear(in_features=320, out_features=64, bias=False)
        )
        (pos_wise_feed_forward_net): PosWiseFeedForwardNet(
          (FC): Sequential(
            (0): Li

In [4]:
model.load_state_dict(
    torch.load('D:\ProjectsSTC\pytorchProject\model\model_pt_layer9_head5_fold4.pth')
)

NameError: name 'model' is not defined

In [3]:
criterion = nn.CrossEntropyLoss()
model.eval_per_epoch(train_pep_tcr_loader,None,None,criterion,2,0.5,True)

NameError: name 'model' is not defined

In [13]:
model.eval_per_epoch(test_pep_tcr_loader,None,None,criterion,2,0.5,True)

******开始验证******


100%|██████████| 12/12 [00:05<00:00,  2.12it/s]


以下是评估得分:
MCC Error:  921653598363345
y_true: 0 = 7691 | 1 = 3841
y_pred: 0 = 7197 | 1 = 4335
tn = 7053, fp = 638, fn = 144, tp = 3697
auc=0.9657|sensitivity=0.9625|specificity=0.9170|acc=0.9322|mcc=nan
precision=0.8528|recall=0.9625|f1=0.9044|ap=0.8913
******结束验证: Loss = 0.191216******


(([0,
   0,
   1,
   0,
   0,
   1,
   0,
   0,
   1,
   0,
   0,
   1,
   1,
   0,
   0,
   0,
   0,
   0,
   1,
   1,
   1,
   0,
   0,
   0,
   1,
   0,
   0,
   0,
   0,
   1,
   0,
   1,
   0,
   1,
   0,
   1,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   1,
   1,
   0,
   0,
   0,
   0,
   1,
   0,
   1,
   1,
   0,
   0,
   1,
   1,
   0,
   0,
   0,
   1,
   0,
   0,
   0,
   1,
   0,
   0,
   1,
   0,
   0,
   0,
   0,
   1,
   0,
   0,
   0,
   1,
   0,
   0,
   1,
   1,
   0,
   0,
   1,
   1,
   1,
   0,
   1,
   0,
   1,
   0,
   0,
   0,
   0,
   1,
   0,
   0,
   0,
   0,
   0,
   1,
   1,
   0,
   0,
   1,
   0,
   0,
   1,
   1,
   1,
   0,
   0,
   1,
   1,
   0,
   1,
   0,
   0,
   1,
   0,
   0,
   1,
   1,
   1,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   1,
   0,
   0,
   0,
   1,
   0,
   1,
   0,
   0,
   1,
   0,
   1,
   1,
   0,
   1,
   0,
   1,
   0,
   0,
   0,
   0,
   0,
   0,
   1,
   1,
   1,
   0,
   0,
   1,
   1,
   1,
   0,
   0,
   0

In [19]:
model.load_state_dict(
    torch.load('D:\ProjectsSTC\pytorchProject\model\model_hp_layer9_head5_fold4.pth')
)

<All keys matched successfully>

In [20]:
model

Transformer(
  (hla_embedding): Embedding(
    (src_emb): Embedding(27, 64)
    (pos_emb): PositionEmbedding(
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (pep_embedding): Embedding(
    (src_emb): Embedding(27, 64)
    (pos_emb): PositionEmbedding(
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (tcr_embedding): Embedding(
    (src_emb): Embedding(27, 64)
    (pos_emb): PositionEmbedding(
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (hla_encoder): Encoder(
    (layers): ModuleList(
      (0): EncoderLayer(
        (multi_head_attention): MultiHeadAttention(
          (W_Q): Linear(in_features=64, out_features=320, bias=False)
          (W_K): Linear(in_features=64, out_features=320, bias=False)
          (W_V): Linear(in_features=64, out_features=320, bias=False)
          (FC): Linear(in_features=320, out_features=64, bias=False)
        )
        (pos_wise_feed_forward_net): PosWiseFeedForwardNet(
          (FC): Sequential(
            (0): Li

In [21]:
model.eval_per_epoch(train_hla_pep_loader,None,None,criterion,1,0.5,True)

******开始验证******


100%|██████████| 151/151 [01:12<00:00,  2.08it/s]


以下是评估得分:
MCC Error:  35395917891480367104
y_true: 0 = 76918 | 1 = 77376
y_pred: 0 = 75056 | 1 = 79238
tn = 69431, fp = 7487, fn = 5625, tp = 71751
auc=0.9705|sensitivity=0.9273|specificity=0.9027|acc=0.9150|mcc=nan
precision=0.9055|recall=0.9273|f1=0.9163|ap=0.9688
******结束验证: Loss = 0.219511******


(([1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1

In [22]:
model.eval_per_epoch(test_hla_pep_loader,None,None,criterion,1,0.5,True)

******开始验证******


100%|██████████| 17/17 [00:08<00:00,  2.00it/s]

以下是评估得分:
MCC Error:  5395691622552000
y_true: 0 = 8644 | 1 = 8500
y_pred: 0 = 8366 | 1 = 8778
tn = 8054, fp = 590, fn = 312, tp = 8188
auc=0.9868|sensitivity=0.9633|specificity=0.9317|acc=0.9474|mcc=nan
precision=0.9328|recall=0.9633|f1=0.9478|ap=0.9854
******结束验证: Loss = 0.144655******





(([1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0