In [1]:
import math
from sklearn import metrics
from sklearn import preprocessing
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import re
import time
import datetime
import random
random.seed(1234)

from scipy import interp
import warnings
warnings.filterwarnings("ignore")

from collections import Counter
from functools import reduce
from tqdm import tqdm, trange
from copy import deepcopy

from sklearn.metrics import confusion_matrix
from sklearn.metrics import roc_auc_score, auc
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import classification_report
from sklearn.utils import class_weight

import os
import torch
import torch.nn as nn
from torch.nn import functional as F
import torch.optim as optim
import torch.utils.data as Data
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [2]:
seed = 2022
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

In [3]:
hla_sequence = pd.read_csv('../data/other/HLAI_pseudosequences_34mer.csv')

In [79]:
def make_data(data):
    pep_inputs, hla_inputs, labels = [], [], []
    pep_lens = []
    for pep, hla, label in zip(data.peptide, data.HLA_sequence, data.label):
#         pep_lens.append(len(pep)+34)
        pep_lens.append(49)
        pep, hla = pep.ljust(pep_max_len, '-'), hla.ljust(hla_max_len, '-')
        pep_input = [[vocab[n] for n in pep]] # [[1, 2, 3, 4, 0], [1, 2, 3, 5, 0]]
        hla_input = [[vocab[n] for n in hla]]
        pep_inputs.extend(pep_input)
        hla_inputs.extend(hla_input)
        labels.append(label)
    return torch.LongTensor(pep_inputs), torch.LongTensor(hla_inputs), torch.LongTensor(labels), torch.LongTensor(pep_lens)

class MyDataSet(Data.Dataset):
    def __init__(self, pep_inputs, hla_inputs, labels, pep_lens):
        super(MyDataSet, self).__init__()
        self.pep_inputs = pep_inputs
        self.hla_inputs = hla_inputs
        self.labels = labels
        self.pep_lens = pep_lens

    def __len__(self): # 样本数
        return self.pep_inputs.shape[0] # 改成hla_inputs也可以哦！

    def __getitem__(self, idx):
#         return self.pep_inputs[idx], self.hla_inputs[idx], self.labels[idx],self.pep_lens[idx]
        return torch.cat((self.hla_inputs[idx],self.pep_inputs[idx]), dim=0), self.labels[idx], self.pep_lens[idx]

def seq_len_to_mask(seq_len, max_len=49): #50
    r"""
    将一个表示sequence length的一维数组转换为二维的mask，不包含的位置为0。
    转变 1-d seq_len到2-d mask.
    .. code-block::
    
        >>> seq_len = torch.arange(2, 16)
        >>> mask = seq_len_to_mask(seq_len)
        >>> print(mask.size())
        torch.Size([14, 15])
        >>> seq_len = np.arange(2, 16)
        >>> mask = seq_len_to_mask(seq_len)
        >>> print(mask.shape)
        (14, 15)
        >>> seq_len = torch.arange(2, 16)
        >>> mask = seq_len_to_mask(seq_len, max_len=100)
        >>>print(mask.size())
        torch.Size([14, 100])
    :param np.ndarray,torch.LongTensor seq_len: shape将是(B,)
    :param int max_len: 将长度pad到这个长度。默认(None)使用的是seq_len中最长的长度。但在nn.DataParallel的场景下可能不同卡的seq_len会有
        区别，所以需要传入一个max_len使得mask的长度是pad到该长度。
    :return: np.ndarray, torch.Tensor 。shape将是(B, max_length)， 元素类似为bool或torch.uint8
    """
    if isinstance(seq_len, np.ndarray):
        assert len(np.shape(seq_len)) == 1, f"seq_len can only have one dimension, got {len(np.shape(seq_len))}."
        max_len = int(max_len) if max_len else int(seq_len.max())
        broad_cast_seq_len = np.tile(np.arange(max_len), (len(seq_len), 1))
        mask = broad_cast_seq_len < seq_len.reshape(-1, 1)

    elif isinstance(seq_len, torch.Tensor):
        assert seq_len.dim() == 1, f"seq_len can only have one dimension, got {seq_len.dim() == 1}."
        batch_size = seq_len.size(0)
        max_len = int(max_len) if max_len else seq_len.max().long()
        broad_cast_seq_len = torch.arange(max_len).expand(batch_size, -1).to(seq_len)
        mask = broad_cast_seq_len.lt(seq_len.unsqueeze(1))
    else:
        raise TypeError("Only support 1-d numpy.ndarray or 1-d torch.Tensor.")

    return mask

def get_embeddings(init_embed, padding_idx=None):
    r"""
    根据输入的init_embed返回Embedding对象。如果输入是tuple, 则随机初始化一个nn.Embedding; 如果输入是numpy.ndarray, 则按照ndarray
    的值将nn.Embedding初始化; 如果输入是torch.Tensor, 则按该值初始化nn.Embedding; 如果输入是fastNLP中的embedding将不做处理
    返回原对象。
    :param init_embed: 可以是 tuple:(num_embedings, embedding_dim), 即embedding的大小和每个词的维度;也可以传入
        nn.Embedding 对象, 此时就以传入的对象作为embedding; 传入np.ndarray也行，将使用传入的ndarray作为作为Embedding初始化;
        传入torch.Tensor, 将使用传入的值作为Embedding初始化。
    :param padding_idx: 当传入tuple时，padding_idx有效
    :return nn.Embedding:  embeddings
    """
    if isinstance(init_embed, tuple):
        res = nn.Embedding(num_embeddings=init_embed[0], embedding_dim=init_embed[1], padding_idx=padding_idx)
#         nn.init.uniform_(res.weight.data, a=-np.sqrt(3 / res.weight.data.size(1)),
#                          b=np.sqrt(3 / res.weight.data.size(1)))
    elif isinstance(init_embed, nn.Module):
        res = init_embed
    elif isinstance(init_embed, torch.Tensor):
        res = nn.Embedding.from_pretrained(init_embed, freeze=False)
    elif isinstance(init_embed, np.ndarray):
        init_embed = torch.tensor(init_embed, dtype=torch.float32)
        res = nn.Embedding.from_pretrained(init_embed, freeze=False)
    else:
        raise TypeError(
            'invalid init_embed type: {}'.format((type(init_embed))))
    return res

In [139]:

class StarTransformer(nn.Module):
    r"""
    Star-Transformer 的encoder部分。 输入3d的文本输入, 返回相同长度的文本编码
    paper: https://arxiv.org/abs/1902.09113
    """

    def __init__(self, hidden_size, num_layers, num_head, head_dim, dropout=0.1, max_len=None):
        r"""
        
        :param int hidden_size: 输入维度的大小。同时也是输出维度的大小。
        :param int num_layers: star-transformer的层数
        :param int num_head: head的数量。
        :param int head_dim: 每个head的维度大小。
        :param float dropout: dropout 概率. Default: 0.1
        :param int max_len: int or None, 如果为int，输入序列的最大长度，
            模型会为输入序列加上position embedding。
            若为`None`，忽略加上position embedding的步骤. Default: `None`
        """
        super(StarTransformer, self).__init__()
        self.iters = num_layers

        self.norm = nn.ModuleList([nn.LayerNorm(hidden_size, eps=1e-6) for _ in range(self.iters)])
        # self.emb_fc = nn.Conv2d(hidden_size, hidden_size, 1)
        self.emb_drop = nn.Dropout(dropout)
        self.ring_att = nn.ModuleList(
            [_MSA1(hidden_size, nhead=num_head, head_dim=head_dim, dropout=0.0)
             for _ in range(self.iters)])
        self.star_att = nn.ModuleList(
            [_MSA2(hidden_size, nhead=num_head, head_dim=head_dim, dropout=0.0)
             for _ in range(self.iters)])

        if max_len is not None:
            self.pos_emb = nn.Embedding(max_len, hidden_size)
        else:
            self.pos_emb = None

    def forward(self, data, mask):
        r"""
        :param FloatTensor data: [batch, length, hidden] 输入的序列
        :param ByteTensor mask: [batch, length] 输入序列的padding mask, 在没有内容(padding 部分) 为 0,
            否则为 1
        :return: [batch, length, hidden] 编码后的输出序列
                [batch, hidden] 全局 relay 节点, 详见论文
        """

        def norm_func(f, x):
            # B, H, L, 1
            return f(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)

        B, L, H = data.size()
        mask = (mask.eq(False))  # flip the mask for masked_fill_
        smask = torch.cat([torch.zeros(B, 1, ).byte().to(mask), mask], 1)

        embs = data.permute(0, 2, 1)[:, :, :, None]  # B H L 1
        if self.pos_emb:
            P = self.pos_emb(torch.arange(L, dtype=torch.long, device=embs.device) \
                             .view(1, L)).permute(0, 2, 1).contiguous()[:, :, :, None]  # 1 H L 1
            embs = embs + P
        embs = norm_func(self.emb_drop, embs)
        nodes = embs
        relay = embs.mean(2, keepdim=True)
        ex_mask = mask[:, None, :, None].expand(B, H, L, 1)
        r_embs = embs.view(B, H, 1, L)
#         nodes_attns = []
#         relays_attns = []
        for i in range(self.iters):
            ax = torch.cat([r_embs, relay.expand(B, H, 1, L)], 2)
            nodes = F.leaky_relu(self.ring_att[i](norm_func(self.norm[i], nodes), ax=ax))
            # nodes = F.leaky_relu(self.ring_att[i](nodes, ax=ax))
#             nodes_attns.append(nodes_att)
            relay = F.leaky_relu(self.star_att[i](relay, torch.cat([relay, nodes], 2), smask))
#             relays_attns.append(relay_att)
            nodes = nodes.masked_fill_(ex_mask, 0)

        nodes = nodes.view(B, H, L).permute(0, 2, 1)

        return nodes, relay.view(B, H)#, nodes_attns, relays_attns


class _MSA1(nn.Module):
    def __init__(self, nhid, nhead=10, head_dim=10, dropout=0.1):
        super(_MSA1, self).__init__()
        # Multi-head Self Attention Case 1, doing self-attention for small regions
        # Due to the architecture of GPU, using hadamard production and summation are faster than dot production when unfold_size is very small
        self.WQ = nn.Conv2d(nhid, nhead * head_dim, 1)
        self.WK = nn.Conv2d(nhid, nhead * head_dim, 1)
        self.WV = nn.Conv2d(nhid, nhead * head_dim, 1)
        self.WO = nn.Conv2d(nhead * head_dim, nhid, 1)

        self.drop = nn.Dropout(dropout)

        self.nhid, self.nhead, self.head_dim, self.unfold_size = nhid, nhead, head_dim, 3

    def forward(self, x, ax=None):
        # x: B, H, L, 1, ax : B, H, X, L append features
        nhid, nhead, head_dim, unfold_size = self.nhid, self.nhead, self.head_dim, self.unfold_size
        B, H, L, _ = x.shape

        q, k, v = self.WQ(x), self.WK(x), self.WV(x)  # x: (B,H,L,1)

        if ax is not None:
            aL = ax.shape[2]
            ak = self.WK(ax).view(B, nhead, head_dim, aL, L)
            av = self.WV(ax).view(B, nhead, head_dim, aL, L)
        q = q.view(B, nhead, head_dim, 1, L)
        k = F.unfold(k.view(B, nhead * head_dim, L, 1), (unfold_size, 1), padding=(unfold_size // 2, 0)) \
            .view(B, nhead, head_dim, unfold_size, L)
        v = F.unfold(v.view(B, nhead * head_dim, L, 1), (unfold_size, 1), padding=(unfold_size // 2, 0)) \
            .view(B, nhead, head_dim, unfold_size, L)
        if ax is not None:
            k = torch.cat([k, ak], 3)
            v = torch.cat([v, av], 3)

        alphas = self.drop(F.softmax((q * k).sum(2, keepdim=True) / np.sqrt(head_dim), 3))  # B N L 1 U
        att = (alphas * v).sum(3).view(B, nhead * head_dim, L, 1)

        ret = self.WO(att)

        return ret #,alphas


class _MSA2(nn.Module):
    def __init__(self, nhid, nhead=10, head_dim=10, dropout=0.1):
        # Multi-head Self Attention Case 2, a broadcastable query for a sequence key and value
        super(_MSA2, self).__init__()
        self.WQ = nn.Conv2d(nhid, nhead * head_dim, 1)
        self.WK = nn.Conv2d(nhid, nhead * head_dim, 1)
        self.WV = nn.Conv2d(nhid, nhead * head_dim, 1)
        self.WO = nn.Conv2d(nhead * head_dim, nhid, 1)

        self.drop = nn.Dropout(dropout)

        self.nhid, self.nhead, self.head_dim, self.unfold_size = nhid, nhead, head_dim, 3
    def forward(self, x, y, mask=None):
        # x: B, H, 1, 1, 1 y: B H L 1
        nhid, nhead, head_dim, unfold_size = self.nhid, self.nhead, self.head_dim, self.unfold_size
        B, H, L, _ = y.shape

        q, k, v = self.WQ(x), self.WK(y), self.WV(y)

        q = q.view(B, nhead, 1, head_dim)  # B, H, 1, 1 -> B, N, 1, h
        k = k.view(B, nhead, head_dim, L)  # B, H, L, 1 -> B, N, h, L
        v = v.view(B, nhead, head_dim, L).permute(0, 1, 3, 2)  # B, H, L, 1 -> B, N, L, h
        pre_a = torch.matmul(q, k) / np.sqrt(head_dim)
        if mask is not None:
            pre_a = pre_a.masked_fill(mask[:, None, None, :], -float('inf'))
        alphas = self.drop(F.softmax(pre_a, 3))  # B, N, 1, L
        att = torch.matmul(alphas, v).view(B, -1, 1, 1)  # B, N, 1, h -> B, N*h, 1, 1
        return self.WO(att) #,alphas
    
class StarTransEnc(nn.Module):
    r"""
    带word embedding的Star-Transformer Encoder
    """

    def __init__(self, embed,
                 hidden_size,
                 num_layers,
                 num_head,
                 head_dim,
                 max_len,
                 emb_dropout,
                 dropout):
        r"""
        
        :param embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即
            embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象,此时就以传入的对象作为embedding
        :param hidden_size: 模型中特征维度.
        :param num_layers: 模型层数.
        :param num_head: 模型中multi-head的head个数.
        :param head_dim: 模型中multi-head中每个head特征维度.
        :param max_len: 模型能接受的最大输入长度.
        :param emb_dropout: 词嵌入的dropout概率.
        :param dropout: 模型除词嵌入外的dropout概率.
        """
        super(StarTransEnc, self).__init__()
        self.embedding = get_embeddings(embed,padding_idx=0)
        emb_dim = self.embedding.embedding_dim
        self.emb_fc = nn.Linear(emb_dim, hidden_size)
        # self.emb_drop = nn.Dropout(emb_dropout)
        self.encoder = StarTransformer(hidden_size=hidden_size,
                                       num_layers=num_layers,
                                       num_head=num_head,
                                       head_dim=head_dim,
                                       dropout=dropout,
                                       max_len=max_len)
        
#         conv_block_klass = ConvBlock
# #         Embedding Layer
#         self.stem = nn.Sequential(
#         #             Rearrange('b n d -> b d n'),
# #             Dynamic_conv1d(49, 49, 3,padding = 1),
#             Residual(conv_block_klass(49)),
# #             AttentionPool(49, pool_size = 2)
            
#         )
#         self.stem2 = nn.Sequential(
#         #             Rearrange('b n d -> b d n'),
#             nn.Conv1d(34, 34, 3,padding = 1),
#             Residual(conv_block_klass(34)),
#             AttentionPool(34, pool_size = 2)
#         )

    def forward(self, x, mask):
        r"""
        :param FloatTensor x: [batch, length, hidden] 输入的序列
        :param ByteTensor mask: [batch, length] 输入序列的padding mask, 在没有内容(padding 部分) 为 0,
            否则为 1
        :return: [batch, length, hidden] 编码后的输出序列
                [batch, hidden] 全局 relay 节点, 详见论文
        """
        x = self.embedding(x)
        x = self.emb_fc(x)
#         x = self.stem(x)
        #nodes, relay, nodes_attns, relays_attns = self.encoder(x3, mask3)
        nodes, relay = self.encoder(x, mask)
        return nodes, relay, #nodes_attns, relays_attns


class _Cls(nn.Module):
    def __init__(self, in_dim, num_cls, hid_dim, dropout=0.1):
        super(_Cls, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(in_dim, hid_dim),
            nn.LeakyReLU(),
            nn.Dropout(dropout),
            nn.Linear(hid_dim, num_cls),
        )

    def forward(self, x):
        h = self.fc(x)
        return h

class STSeqCls(nn.Module):
    r"""
    用于分类任务的Star-Transformer
    """

    def __init__(self, embed, num_cls=2,
                 hidden_size=300,
                 num_layers=1,
                 num_head=9,
                 head_dim=32,
                 max_len=512,
                 cls_hidden_size=600,
                 emb_dropout=0.1,
                 dropout=0.1):
        r"""
        
        :param embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即
            embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象, 此时就以传入的对象作为embedding
        :param num_cls: 输出类别个数
        :param hidden_size: 模型中特征维度. Default: 300
        :param num_layers: 模型层数. Default: 4
        :param num_head: 模型中multi-head的head个数. Default: 8
        :param head_dim: 模型中multi-head中每个head特征维度. Default: 32
        :param max_len: 模型能接受的最大输入长度. Default: 512
        :param cls_hidden_size: 分类器隐层维度. Default: 600
        :param emb_dropout: 词嵌入的dropout概率. Default: 0.1
        :param dropout: 模型除词嵌入外的dropout概率. Default: 0.1
        """
        super(STSeqCls, self).__init__()
        self.enc = StarTransEnc(embed=embed,
                                hidden_size=hidden_size,
                                num_layers=num_layers,
                                num_head=num_head,
                                head_dim=head_dim,
                                max_len=max_len,
                                emb_dropout=emb_dropout,
                                dropout=dropout)
        self.cls = _Cls(hidden_size, num_cls, cls_hidden_size, dropout=dropout)


    def forward(self, words, seq_len):
        r"""
        :param words: [batch, seq_len] 输入序列
        :param seq_len: [batch,] 输入序列的长度
        :return output: [batch, num_cls] 输出序列的分类的概率
        """
        mask = seq_len_to_mask(seq_len,max_len=49).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
#         mask2 = seq_len_to_mask(torch.tensor([34]*len(seq_len))).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
        nodes, relay = self.enc(words, mask)
        y = 0.5 * (relay + nodes.max(1)[0])
#         y = torch.cat([relay, torch.sort(nodes,dim=1)[0][:,-1,:], torch.sort(nodes,dim=1)[0][:,-2,:]],1)
        
        output = self.cls(y)  # [bsz, n_cls]
        return output#, nodes_attns, relays_attns

In [106]:
def performances(y_true, y_pred, y_prob, print_ = True):
    
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels = [0, 1]).ravel().tolist()
    accuracy = (tp+tn)/(tn+fp+fn+tp)
    try:
        mcc = ((tp*tn) - (fn*fp)) / np.sqrt(np.float((tp+fn)*(tn+fp)*(tp+fp)*(tn+fn)))
    except:
        print('MCC Error: ', (tp+fn)*(tn+fp)*(tp+fp)*(tn+fn))
        mcc = np.nan
    sensitivity = tp/(tp+fn)
    specificity = tn/(tn+fp)
    
    try:
        recall = tp / (tp+fn)
    except:
        recall = np.nan
        
    try:
        precision = tp / (tp+fp)
    except:
        precision = np.nan
        
    try: 
        f1 = 2*precision*recall / (precision+recall)
    except:
        f1 = np.nan
        
    roc_auc = roc_auc_score(y_true, y_prob)
    prec, reca, _ = precision_recall_curve(y_true, y_prob)
    aupr = auc(reca, prec)
    
    if print_:
        print('tn = {}, fp = {}, fn = {}, tp = {}'.format(tn, fp, fn, tp))
        print('y_pred: 0 = {} | 1 = {}'.format(Counter(y_pred)[0], Counter(y_pred)[1]))
        print('y_true: 0 = {} | 1 = {}'.format(Counter(y_true)[0], Counter(y_true)[1]))
        print('auc={:.4f}|sensitivity={:.4f}|specificity={:.4f}|acc={:.4f}|mcc={:.4f}'.format(roc_auc, sensitivity, specificity, accuracy, mcc))
        print('precision={:.4f}|recall={:.4f}|f1={:.4f}|aupr={:.4f}'.format(precision, recall, f1, aupr))
    
    return (roc_auc, accuracy, mcc, f1, sensitivity, specificity, precision, recall, aupr)


# In[25]:


def transfer(y_prob, threshold = 0.5):
    return np.array([[0, 1][x > threshold] for x in y_prob])


# In[26]:


f_mean = lambda l: sum(l)/len(l)


# In[28]:


def performances_to_pd(performances_list):
    metrics_name = ['roc_auc', 'accuracy', 'mcc', 'f1', 'sensitivity', 'specificity', 'precision', 'recall', 'aupr']

    performances_pd = pd.DataFrame(performances_list, columns = metrics_name)
    performances_pd.loc['mean'] = performances_pd.mean(axis = 0)
    performances_pd.loc['std'] = performances_pd.std(axis = 0)
    
    return performances_pd

In [107]:
def train_step(model, train_loader, fold, epoch, epochs, use_cuda = True):
    device = torch.device("cuda" if use_cuda else "cpu")
    
    time_train_ep = 0
    model.train()
    y_true_train_list, y_prob_train_list = [], []
    loss_train_list = []
    for train_pep_inputs, train_pep_lens, train_labels in tqdm(train_loader):
        '''
        pep_inputs: [batch_size, pep_len]
        hla_inputs: [batch_size, hla_len]
        train_outputs: [batch_size, 2]
        '''
        train_pep_inputs, train_labels = train_pep_inputs.to(device), train_labels.to(device)
        train_pep_lens = train_pep_lens.to(device)
        t1 = time.time()
        train_outputs = model(train_pep_inputs, train_pep_lens)
        train_loss = criterion(train_outputs, train_labels)
        time_train_ep += time.time() - t1

        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()

        y_true_train = train_labels.cpu().numpy()
        y_prob_train = nn.Softmax(dim = 1)(train_outputs)[:, 1].cpu().detach().numpy()
        
        y_true_train_list.extend(y_true_train)
        y_prob_train_list.extend(y_prob_train)
        loss_train_list.append(train_loss.item())
#         relays_attns_list.append(relays_attns)
#         nodes_attns_list.append(nodes_attns)
        
    y_pred_train_list = transfer(y_prob_train_list, threshold)
    ys_train = (y_true_train_list, y_pred_train_list, y_prob_train_list)
    
    print('Fold-{}****Train (Ep avg): Epoch-{}/{} | Loss = {:.4f} | Time = {:.4f} sec'.format(fold, epoch, epochs, f_mean(loss_train_list), time_train_ep))
    metrics_train = performances(y_true_train_list, y_pred_train_list, y_prob_train_list, print_ = True)
    
    return ys_train, loss_train_list, metrics_train, time_train_ep#, relays_attns_list, nodes_attns_list


# In[30]:


def eval_step(model, val_loader, fold, epoch, epochs, use_cuda = True):
    device = torch.device("cuda" if use_cuda else "cpu")
    
    model.eval()
    torch.manual_seed(2022)
    torch.cuda.manual_seed(2022)
    with torch.no_grad():
        loss_val_list = []
        y_true_val_list, y_prob_val_list = [], []
        for val_pep_inputs, val_pep_lens, val_labels in tqdm(val_loader):
            val_pep_inputs, val_labels = val_pep_inputs.to(device), val_labels.to(device)
            val_pep_lens = val_pep_lens.to(device)
            val_outputs = model(val_pep_inputs, val_pep_lens)
            val_loss = criterion(val_outputs, val_labels)

            y_true_val = val_labels.cpu().numpy()
            y_prob_val = nn.Softmax(dim = 1)(val_outputs)[:, 1].cpu().detach().numpy()

            y_true_val_list.extend(y_true_val)
            y_prob_val_list.extend(y_prob_val)
            loss_val_list.append(val_loss.item())

            
        y_pred_val_list = transfer(y_prob_val_list, threshold)
        ys_val = (y_true_val_list, y_pred_val_list, y_prob_val_list)
        
        print('Fold-{} ****Test  Epoch-{}/{}: Loss = {:.6f}'.format(fold, epoch, epochs, f_mean(loss_val_list)))
        metrics_val = performances(y_true_val_list, y_pred_val_list, y_prob_val_list, print_ = True)
    return ys_val, loss_val_list, metrics_val


In [108]:
pep_max_len = 15 # peptide; enc_input max sequence length
hla_max_len = 34 # hla; dec_input(=dec_output) max sequence length
tgt_len = pep_max_len + hla_max_len
pep_max_len, hla_max_len

# vocab = np.load('./vocab_dict.npy', allow_pickle = True).item()
vocab = {'-': 0,
 'Y': 1,
 'A': 2,
 'T': 3,
 'V': 4,
 'L': 5,
 'D': 6,
 'E': 7,
 'G': 8,
 'R': 9,
 'H': 10,
 'I': 11,
 'W': 12,
 'Q': 13,
 'K': 14,
 'M': 15,
 'F': 16,
 'N': 17,
 'S': 18,
 'P': 19,
 'C': 20}
vocab_size = len(vocab)


n_layers = 1  # number of Encoder of Decoder Layer
n_heads = 8

batch_size = 1024
epochs = 25
threshold = 0.5

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

In [109]:
#G:TransPHLA-AOMP/Dataset/
def data_with_loader(type_ = 'train',fold = None,  batch_size = 128):
    if type_ != 'train' and type_ != 'val':
        data = pd.read_csv('../data/test_set/{}_set.csv'.format(type_), index_col = 0)
    elif type_ == 'train':
        data = pd.read_csv('../data/train_set/NetMHCpan4.1/train_data_fold{}.csv'.format(fold)) #, index_col = 0
    elif type_ == 'val':
        data = pd.read_csv('../data/train_set/NetMHCpan4.1/val_data_fold{}.csv'.format(fold)) #, index_col = 0
        
    pep_inputs, hla_inputs, labels, pep_lens = make_data(data)
    loader = Data.DataLoader(MyDataSet(pep_inputs, hla_inputs, pep_lens, labels), batch_size, shuffle = False, num_workers = 0)
    
    return data, pep_inputs, hla_inputs, pep_lens, labels, loader

In [110]:
independent_data, independent_pep_inputs, independent_hla_inputs, independent_pep_lens, independent_labels, independent_loader = data_with_loader(type_ = 'independent',fold = None,  batch_size = batch_size)
external_data, external_pep_inputs, external_hla_inputs, external_pep_lens, external_labels, external_loader = data_with_loader(type_ = 'external',fold = None,  batch_size = batch_size)

In [111]:
ys_train_fold_dict, ys_val_fold_dict = {}, {}
train_fold_metrics_list, val_fold_metrics_list = [], []
independent_fold_metrics_list, external_fold_metrics_list, ys_independent_fold_dict, ys_external_fold_dict = [], [], {}, {}
attns_train_fold_dict, attns_val_fold_dict, attns_independent_fold_dict, attns_external_fold_dict = {}, {}, {}, {}
loss_train_fold_dict, loss_val_fold_dict, loss_independent_fold_dict, loss_external_fold_dict = {}, {}, {}, {}

for fold in range(3,4):
    print('=====Fold-{}====='.format(fold))
    print('-----Generate data loader-----')
    train_data, train_pep_inputs, train_hla_inputs, train_pep_lens, train_labels, train_loader = data_with_loader(type_ = 'train', fold = fold,  batch_size = batch_size)
    val_data, val_pep_inputs, val_hla_inputs, val_pep_lens, val_labels, val_loader = data_with_loader(type_ = 'val', fold = fold,  batch_size = batch_size)
    print('Fold-{} Label info: Train = {} | Val = {}'.format(fold, Counter(train_data.label), Counter(val_data.label)))

    print('-----Compile model-----')
    model = STSeqCls((21, 100), num_cls=2, hidden_size=300, num_layers=1, num_head=8, max_len=49,cls_hidden_size=600,dropout=0.1,head_dim=32).to(device)
    print(model)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr = 1e-3)#, momentum = 0.99)
#     optimizer = ScheduledOptim(optim.Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-09), 2, 64, 10)

    print('-----Train-----')
    dir_saver = 'G:TransPHLA-AOMP/model/STformer/'
    path_saver = 'G:TransPHLA-AOMP/model/STformer/netmhcpan/st_layer{}_multihead{}_fold{}_netmhcpan.pkl'.format(n_layers, n_heads, fold)
    print('dir_saver: ', dir_saver)
    print('path_saver: ', path_saver)

    metric_best, ep_best = 0, -1
    time_train = 0
    for epoch in range(1, epochs + 1):

        ys_train, loss_train_list, metrics_train, time_train_ep = train_step(model, train_loader, fold, epoch, epochs, use_cuda) # , dec_attns_train
        ys_val, loss_val_list, metrics_val = eval_step(model, val_loader, fold, epoch, epochs, use_cuda) #, dec_attns_val

        metrics_ep_avg = sum(metrics_val[:4])/4
        if metrics_ep_avg > metric_best: 
            metric_best, ep_best = metrics_ep_avg, epoch
            if not os.path.exists(dir_saver):
                os.makedirs(dir_saver)
            print('****Saving model: Best epoch = {} | 5metrics_Best_avg = {:.4f}'.format(ep_best, metric_best))
            print('*****Path saver: ', path_saver)
            torch.save(model.eval().state_dict(), path_saver)

        time_train += time_train_ep

    print('-----Optimization Finished!-----')
    print('-----Evaluate Results-----')
    if ep_best >= 0:
        print('*****Path saver: ', path_saver)
        model.load_state_dict(torch.load(path_saver))
        model_eval = model.eval()

        ys_res_train, loss_res_train_list, metrics_res_train = eval_step(model_eval, train_loader, fold, ep_best, epochs, use_cuda) # , train_res_attns
        ys_res_val, loss_res_val_list, metrics_res_val = eval_step(model_eval, val_loader, fold, ep_best, epochs, use_cuda) # , val_res_attns
        ys_res_independent, loss_res_independent_list, metrics_res_independent = eval_step(model_eval, independent_loader, fold, ep_best, epochs, use_cuda) # , independent_res_attns
        ys_res_external, loss_res_external_list, metrics_res_external = eval_step(model_eval, external_loader, fold, ep_best, epochs, use_cuda) # , external_res_attns

        train_fold_metrics_list.append(metrics_res_train)
        val_fold_metrics_list.append(metrics_res_val)
        independent_fold_metrics_list.append(metrics_res_independent)
        external_fold_metrics_list.append(metrics_res_external)

        ys_train_fold_dict[fold], ys_val_fold_dict[fold], ys_independent_fold_dict[fold], ys_external_fold_dict[fold] = ys_res_train, ys_res_val, ys_res_independent, ys_res_external    
        loss_train_fold_dict[fold], loss_val_fold_dict[fold], loss_independent_fold_dict[fold], loss_external_fold_dict[fold] = loss_res_train_list, loss_res_val_list, loss_res_independent_list, loss_res_external_list  

    print("Total training time: {:6.2f} sec".format(time_train))

=====Fold-0=====
-----Generate data loader-----
Fold-0 Label info: Train = Counter({1: 277903, 0: 277552}) | Val = Counter({0: 69607, 1: 69256})
-----Compile model-----
STSeqCls(
  (enc): StarTransEnc(
    (embedding): Embedding(21, 100, padding_idx=0)
    (emb_fc): Linear(in_features=100, out_features=300, bias=True)
    (encoder): StarTransformer(
      (norm): ModuleList(
        (0): LayerNorm((300,), eps=1e-06, elementwise_affine=True)
      )
      (emb_drop): Dropout(p=0.1, inplace=False)
      (ring_att): ModuleList(
        (0): _MSA1(
          (WQ): Conv2d(300, 256, kernel_size=(1, 1), stride=(1, 1))
          (WK): Conv2d(300, 256, kernel_size=(1, 1), stride=(1, 1))
          (WV): Conv2d(300, 256, kernel_size=(1, 1), stride=(1, 1))
          (WO): Conv2d(256, 300, kernel_size=(1, 1), stride=(1, 1))
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
      (star_att): ModuleList(
        (0): _MSA2(
          (WQ): Conv2d(300, 256, kernel_size=(1, 1), stride=(

100%|████████████████████████████████████████████████████████████████████████████████| 543/543 [00:53<00:00, 10.13it/s]


Fold-0****Train (Ep avg): Epoch-1/25 | Loss = 0.2670 | Time = 5.5459 sec
tn = 247570, fp = 29982, fn = 32362, tp = 245541
y_pred: 0 = 279932 | 1 = 275523
y_true: 0 = 277552 | 1 = 277903
auc=0.9563|sensitivity=0.8835|specificity=0.8920|acc=0.8878|mcc=0.7756
precision=0.8912|recall=0.8835|f1=0.8873|aupr=0.9582


100%|████████████████████████████████████████████████████████████████████████████████| 136/136 [00:06<00:00, 19.92it/s]


Fold-0 ****Test  Epoch-1/25: Loss = 0.180448
tn = 66866, fp = 2741, fn = 6586, tp = 62670
y_pred: 0 = 73452 | 1 = 65411
y_true: 0 = 69607 | 1 = 69256
auc=0.9804|sensitivity=0.9049|specificity=0.9606|acc=0.9328|mcc=0.8670
precision=0.9581|recall=0.9049|f1=0.9307|aupr=0.9819
****Saving model: Best epoch = 1 | 5metrics_Best_avg = 0.9277
*****Path saver:  G:TransPHLA-AOMP/model/STformer/netmhcpan/st_layer1_multihead8_fold0_netmhcpan.pkl


100%|████████████████████████████████████████████████████████████████████████████████| 543/543 [00:55<00:00,  9.80it/s]


Fold-0****Train (Ep avg): Epoch-2/25 | Loss = 0.1679 | Time = 5.5942 sec
tn = 262559, fp = 14993, fn = 19681, tp = 258222
y_pred: 0 = 282240 | 1 = 273215
y_true: 0 = 277552 | 1 = 277903
auc=0.9817|sensitivity=0.9292|specificity=0.9460|acc=0.9376|mcc=0.8753
precision=0.9451|recall=0.9292|f1=0.9371|aupr=0.9831


100%|████████████████████████████████████████████████████████████████████████████████| 136/136 [00:07<00:00, 18.87it/s]


Fold-0 ****Test  Epoch-2/25: Loss = 0.158444
tn = 67079, fp = 2528, fn = 5632, tp = 63624
y_pred: 0 = 72711 | 1 = 66152
y_true: 0 = 69607 | 1 = 69256
auc=0.9843|sensitivity=0.9187|specificity=0.9637|acc=0.9412|mcc=0.8833
precision=0.9618|recall=0.9187|f1=0.9397|aupr=0.9855
****Saving model: Best epoch = 2 | 5metrics_Best_avg = 0.9372
*****Path saver:  G:TransPHLA-AOMP/model/STformer/netmhcpan/st_layer1_multihead8_fold0_netmhcpan.pkl


100%|████████████████████████████████████████████████████████████████████████████████| 543/543 [00:57<00:00,  9.46it/s]


Fold-0****Train (Ep avg): Epoch-3/25 | Loss = 0.1508 | Time = 5.6924 sec
tn = 264445, fp = 13107, fn = 17852, tp = 260051
y_pred: 0 = 282297 | 1 = 273158
y_true: 0 = 277552 | 1 = 277903
auc=0.9851|sensitivity=0.9358|specificity=0.9528|acc=0.9443|mcc=0.8887
precision=0.9520|recall=0.9358|f1=0.9438|aupr=0.9863


100%|████████████████████████████████████████████████████████████████████████████████| 136/136 [00:07<00:00, 19.41it/s]


Fold-0 ****Test  Epoch-3/25: Loss = 0.146688
tn = 66683, fp = 2924, fn = 4572, tp = 64684
y_pred: 0 = 71255 | 1 = 67608
y_true: 0 = 69607 | 1 = 69256
auc=0.9858|sensitivity=0.9340|specificity=0.9580|acc=0.9460|mcc=0.8923
precision=0.9568|recall=0.9340|f1=0.9452|aupr=0.9870
****Saving model: Best epoch = 3 | 5metrics_Best_avg = 0.9423
*****Path saver:  G:TransPHLA-AOMP/model/STformer/netmhcpan/st_layer1_multihead8_fold0_netmhcpan.pkl


100%|████████████████████████████████████████████████████████████████████████████████| 543/543 [00:57<00:00,  9.46it/s]


Fold-0****Train (Ep avg): Epoch-4/25 | Loss = 0.1417 | Time = 5.7554 sec
tn = 265287, fp = 12265, fn = 16797, tp = 261106
y_pred: 0 = 282084 | 1 = 273371
y_true: 0 = 277552 | 1 = 277903
auc=0.9868|sensitivity=0.9396|specificity=0.9558|acc=0.9477|mcc=0.8955
precision=0.9551|recall=0.9396|f1=0.9473|aupr=0.9879


100%|████████████████████████████████████████████████████████████████████████████████| 136/136 [00:07<00:00, 18.40it/s]


Fold-0 ****Test  Epoch-4/25: Loss = 0.142897
tn = 66144, fp = 3463, fn = 3838, tp = 65418
y_pred: 0 = 69982 | 1 = 68881
y_true: 0 = 69607 | 1 = 69256
auc=0.9867|sensitivity=0.9446|specificity=0.9502|acc=0.9474|mcc=0.8949
precision=0.9497|recall=0.9446|f1=0.9471|aupr=0.9877
****Saving model: Best epoch = 4 | 5metrics_Best_avg = 0.9440
*****Path saver:  G:TransPHLA-AOMP/model/STformer/netmhcpan/st_layer1_multihead8_fold0_netmhcpan.pkl


100%|████████████████████████████████████████████████████████████████████████████████| 543/543 [00:57<00:00,  9.48it/s]


Fold-0****Train (Ep avg): Epoch-5/25 | Loss = 0.1342 | Time = 5.6583 sec
tn = 266011, fp = 11541, fn = 15956, tp = 261947
y_pred: 0 = 281967 | 1 = 273488
y_true: 0 = 277552 | 1 = 277903
auc=0.9881|sensitivity=0.9426|specificity=0.9584|acc=0.9505|mcc=0.9011
precision=0.9578|recall=0.9426|f1=0.9501|aupr=0.9891


100%|████████████████████████████████████████████████████████████████████████████████| 136/136 [00:07<00:00, 18.68it/s]


Fold-0 ****Test  Epoch-5/25: Loss = 0.145129
tn = 66167, fp = 3440, fn = 4038, tp = 65218
y_pred: 0 = 70205 | 1 = 68658
y_true: 0 = 69607 | 1 = 69256
auc=0.9864|sensitivity=0.9417|specificity=0.9506|acc=0.9461|mcc=0.8923
precision=0.9499|recall=0.9417|f1=0.9458|aupr=0.9874


100%|████████████████████████████████████████████████████████████████████████████████| 543/543 [00:57<00:00,  9.47it/s]


Fold-0****Train (Ep avg): Epoch-6/25 | Loss = 0.1300 | Time = 5.6970 sec
tn = 266427, fp = 11125, fn = 15497, tp = 262406
y_pred: 0 = 281924 | 1 = 273531
y_true: 0 = 277552 | 1 = 277903
auc=0.9889|sensitivity=0.9442|specificity=0.9599|acc=0.9521|mcc=0.9043
precision=0.9593|recall=0.9442|f1=0.9517|aupr=0.9898


100%|████████████████████████████████████████████████████████████████████████████████| 136/136 [00:07<00:00, 18.83it/s]


Fold-0 ****Test  Epoch-6/25: Loss = 0.141861
tn = 66732, fp = 2875, fn = 4358, tp = 64898
y_pred: 0 = 71090 | 1 = 67773
y_true: 0 = 69607 | 1 = 69256
auc=0.9869|sensitivity=0.9371|specificity=0.9587|acc=0.9479|mcc=0.8960
precision=0.9576|recall=0.9371|f1=0.9472|aupr=0.9879
****Saving model: Best epoch = 6 | 5metrics_Best_avg = 0.9445
*****Path saver:  G:TransPHLA-AOMP/model/STformer/netmhcpan/st_layer1_multihead8_fold0_netmhcpan.pkl


100%|████████████████████████████████████████████████████████████████████████████████| 543/543 [00:57<00:00,  9.49it/s]


Fold-0****Train (Ep avg): Epoch-7/25 | Loss = 0.1253 | Time = 5.6758 sec
tn = 266748, fp = 10804, fn = 14950, tp = 262953
y_pred: 0 = 281698 | 1 = 273757
y_true: 0 = 277552 | 1 = 277903
auc=0.9896|sensitivity=0.9462|specificity=0.9611|acc=0.9536|mcc=0.9074
precision=0.9605|recall=0.9462|f1=0.9533|aupr=0.9905


100%|████████████████████████████████████████████████████████████████████████████████| 136/136 [00:07<00:00, 19.43it/s]


Fold-0 ****Test  Epoch-7/25: Loss = 0.140283
tn = 66795, fp = 2812, fn = 4286, tp = 64970
y_pred: 0 = 71081 | 1 = 67782
y_true: 0 = 69607 | 1 = 69256
auc=0.9872|sensitivity=0.9381|specificity=0.9596|acc=0.9489|mcc=0.8980
precision=0.9585|recall=0.9381|f1=0.9482|aupr=0.9881
****Saving model: Best epoch = 7 | 5metrics_Best_avg = 0.9456
*****Path saver:  G:TransPHLA-AOMP/model/STformer/netmhcpan/st_layer1_multihead8_fold0_netmhcpan.pkl


100%|████████████████████████████████████████████████████████████████████████████████| 543/543 [00:57<00:00,  9.52it/s]


Fold-0****Train (Ep avg): Epoch-8/25 | Loss = 0.1221 | Time = 5.7208 sec
tn = 267076, fp = 10476, fn = 14619, tp = 263284
y_pred: 0 = 281695 | 1 = 273760
y_true: 0 = 277552 | 1 = 277903
auc=0.9902|sensitivity=0.9474|specificity=0.9623|acc=0.9548|mcc=0.9097
precision=0.9617|recall=0.9474|f1=0.9545|aupr=0.9910


100%|████████████████████████████████████████████████████████████████████████████████| 136/136 [00:07<00:00, 18.69it/s]


Fold-0 ****Test  Epoch-8/25: Loss = 0.136898
tn = 66494, fp = 3113, fn = 3787, tp = 65469
y_pred: 0 = 70281 | 1 = 68582
y_true: 0 = 69607 | 1 = 69256
auc=0.9878|sensitivity=0.9453|specificity=0.9553|acc=0.9503|mcc=0.9007
precision=0.9546|recall=0.9453|f1=0.9499|aupr=0.9886
****Saving model: Best epoch = 8 | 5metrics_Best_avg = 0.9472
*****Path saver:  G:TransPHLA-AOMP/model/STformer/netmhcpan/st_layer1_multihead8_fold0_netmhcpan.pkl


100%|████████████████████████████████████████████████████████████████████████████████| 543/543 [00:57<00:00,  9.46it/s]


Fold-0****Train (Ep avg): Epoch-9/25 | Loss = 0.1185 | Time = 5.6671 sec
tn = 267306, fp = 10246, fn = 14170, tp = 263733
y_pred: 0 = 281476 | 1 = 273979
y_true: 0 = 277552 | 1 = 277903
auc=0.9908|sensitivity=0.9490|specificity=0.9631|acc=0.9560|mcc=0.9122
precision=0.9626|recall=0.9490|f1=0.9558|aupr=0.9915


100%|████████████████████████████████████████████████████████████████████████████████| 136/136 [00:07<00:00, 18.85it/s]


Fold-0 ****Test  Epoch-9/25: Loss = 0.139460
tn = 66875, fp = 2732, fn = 4165, tp = 65091
y_pred: 0 = 71040 | 1 = 67823
y_true: 0 = 69607 | 1 = 69256
auc=0.9877|sensitivity=0.9399|specificity=0.9608|acc=0.9503|mcc=0.9009
precision=0.9597|recall=0.9399|f1=0.9497|aupr=0.9886


100%|████████████████████████████████████████████████████████████████████████████████| 543/543 [00:56<00:00,  9.59it/s]


Fold-0****Train (Ep avg): Epoch-10/25 | Loss = 0.1153 | Time = 5.6505 sec
tn = 267537, fp = 10015, fn = 13781, tp = 264122
y_pred: 0 = 281318 | 1 = 274137
y_true: 0 = 277552 | 1 = 277903
auc=0.9913|sensitivity=0.9504|specificity=0.9639|acc=0.9572|mcc=0.9144
precision=0.9635|recall=0.9504|f1=0.9569|aupr=0.9920


100%|████████████████████████████████████████████████████████████████████████████████| 136/136 [00:06<00:00, 19.99it/s]


Fold-0 ****Test  Epoch-10/25: Loss = 0.143469
tn = 66997, fp = 2610, fn = 4394, tp = 64862
y_pred: 0 = 71391 | 1 = 67472
y_true: 0 = 69607 | 1 = 69256
auc=0.9875|sensitivity=0.9366|specificity=0.9625|acc=0.9496|mcc=0.8994
precision=0.9613|recall=0.9366|f1=0.9488|aupr=0.9883


100%|████████████████████████████████████████████████████████████████████████████████| 543/543 [00:53<00:00, 10.06it/s]


Fold-0****Train (Ep avg): Epoch-11/25 | Loss = 0.1109 | Time = 5.5015 sec
tn = 267879, fp = 9673, fn = 13263, tp = 264640
y_pred: 0 = 281142 | 1 = 274313
y_true: 0 = 277552 | 1 = 277903
auc=0.9919|sensitivity=0.9523|specificity=0.9651|acc=0.9587|mcc=0.9175
precision=0.9647|recall=0.9523|f1=0.9585|aupr=0.9926


100%|████████████████████████████████████████████████████████████████████████████████| 136/136 [00:06<00:00, 20.54it/s]


Fold-0 ****Test  Epoch-11/25: Loss = 0.142255
tn = 67040, fp = 2567, fn = 4350, tp = 64906
y_pred: 0 = 71390 | 1 = 67473
y_true: 0 = 69607 | 1 = 69256
auc=0.9877|sensitivity=0.9372|specificity=0.9631|acc=0.9502|mcc=0.9007
precision=0.9620|recall=0.9372|f1=0.9494|aupr=0.9884


100%|████████████████████████████████████████████████████████████████████████████████| 543/543 [00:53<00:00, 10.06it/s]


Fold-0****Train (Ep avg): Epoch-12/25 | Loss = 0.1083 | Time = 5.5350 sec
tn = 268028, fp = 9524, fn = 12980, tp = 264923
y_pred: 0 = 281008 | 1 = 274447
y_true: 0 = 277552 | 1 = 277903
auc=0.9923|sensitivity=0.9533|specificity=0.9657|acc=0.9595|mcc=0.9190
precision=0.9653|recall=0.9533|f1=0.9593|aupr=0.9929


100%|████████████████████████████████████████████████████████████████████████████████| 136/136 [00:06<00:00, 19.88it/s]


Fold-0 ****Test  Epoch-12/25: Loss = 0.144316
tn = 66721, fp = 2886, fn = 4058, tp = 65198
y_pred: 0 = 70779 | 1 = 68084
y_true: 0 = 69607 | 1 = 69256
auc=0.9875|sensitivity=0.9414|specificity=0.9585|acc=0.9500|mcc=0.9001
precision=0.9576|recall=0.9414|f1=0.9494|aupr=0.9882


100%|████████████████████████████████████████████████████████████████████████████████| 543/543 [00:54<00:00, 10.00it/s]


Fold-0****Train (Ep avg): Epoch-13/25 | Loss = 0.1057 | Time = 5.5714 sec
tn = 268234, fp = 9318, fn = 12660, tp = 265243
y_pred: 0 = 280894 | 1 = 274561
y_true: 0 = 277552 | 1 = 277903
auc=0.9927|sensitivity=0.9544|specificity=0.9664|acc=0.9604|mcc=0.9209
precision=0.9661|recall=0.9544|f1=0.9602|aupr=0.9933


100%|████████████████████████████████████████████████████████████████████████████████| 136/136 [00:06<00:00, 19.87it/s]


Fold-0 ****Test  Epoch-13/25: Loss = 0.147516
tn = 66550, fp = 3057, fn = 3952, tp = 65304
y_pred: 0 = 70502 | 1 = 68361
y_true: 0 = 69607 | 1 = 69256
auc=0.9874|sensitivity=0.9429|specificity=0.9561|acc=0.9495|mcc=0.8991
precision=0.9553|recall=0.9429|f1=0.9491|aupr=0.9881


100%|████████████████████████████████████████████████████████████████████████████████| 543/543 [00:54<00:00, 10.03it/s]


Fold-0****Train (Ep avg): Epoch-14/25 | Loss = 0.1034 | Time = 5.5838 sec
tn = 268349, fp = 9203, fn = 12456, tp = 265447
y_pred: 0 = 280805 | 1 = 274650
y_true: 0 = 277552 | 1 = 277903
auc=0.9930|sensitivity=0.9552|specificity=0.9668|acc=0.9610|mcc=0.9221
precision=0.9665|recall=0.9552|f1=0.9608|aupr=0.9935


100%|████████████████████████████████████████████████████████████████████████████████| 136/136 [00:06<00:00, 19.79it/s]


Fold-0 ****Test  Epoch-14/25: Loss = 0.148162
tn = 66010, fp = 3597, fn = 3340, tp = 65916
y_pred: 0 = 69350 | 1 = 69513
y_true: 0 = 69607 | 1 = 69256
auc=0.9876|sensitivity=0.9518|specificity=0.9483|acc=0.9500|mcc=0.9001
precision=0.9483|recall=0.9518|f1=0.9500|aupr=0.9883


  4%|███▍                                                                             | 23/543 [00:02<00:52,  9.92it/s]


KeyboardInterrupt: 

In [141]:
path_saver = '../model/st_layer1_multihead8_fold3_netmhcpan.pkl'
model = STSeqCls((21, 100), num_cls=2, hidden_size=300, num_layers=1, num_head=8, max_len=49,cls_hidden_size=600,dropout=0.1,head_dim=32).to(device)
model.load_state_dict(torch.load(path_saver))

<All keys matched successfully>

In [142]:
# model = model_eval
criterion = nn.CrossEntropyLoss()
fold = 3
ep_best = 7

model_eval = model.eval()
ys_res_independent, loss_res_independent_list, metrics_res_independent = eval_step(model_eval, independent_loader, fold, ep_best, epochs, use_cuda) # , independent_res_attns
ys_res_external, loss_res_external_list, metrics_res_external = eval_step(model_eval, external_loader, fold, ep_best, epochs, use_cuda) # , external_res_attns


100%|████████████████████████████████████████████████████████████████████████████████| 168/168 [00:08<00:00, 19.81it/s]


Fold-3 ****Test  Epoch-7/25: Loss = 0.367765
tn = 74424, fp = 11138, fn = 9313, tp = 76563
y_pred: 0 = 83737 | 1 = 87701
y_true: 0 = 85562 | 1 = 85876
auc=0.9444|sensitivity=0.8916|specificity=0.8698|acc=0.8807|mcc=0.7616
precision=0.8730|recall=0.8916|f1=0.8822|aupr=0.9448


100%|████████████████████████████████████████████████████████████████████████████████| 102/102 [00:05<00:00, 19.71it/s]


Fold-3 ****Test  Epoch-7/25: Loss = 0.336248
tn = 46839, fp = 5042, fn = 7528, tp = 44456
y_pred: 0 = 54367 | 1 = 49498
y_true: 0 = 51881 | 1 = 51984
auc=0.9410|sensitivity=0.8552|specificity=0.9028|acc=0.8790|mcc=0.7588
precision=0.8981|recall=0.8552|f1=0.8761|aupr=0.9472


In [143]:
class StarTransformer(nn.Module):
    r"""
    Star-Transformer 的encoder部分。 输入3d的文本输入, 返回相同长度的文本编码
    paper: https://arxiv.org/abs/1902.09113
    """

    def __init__(self, hidden_size, num_layers, num_head, head_dim, dropout=0.1, max_len=None):
        r"""
        
        :param int hidden_size: 输入维度的大小。同时也是输出维度的大小。
        :param int num_layers: star-transformer的层数
        :param int num_head: head的数量。
        :param int head_dim: 每个head的维度大小。
        :param float dropout: dropout 概率. Default: 0.1
        :param int max_len: int or None, 如果为int，输入序列的最大长度，
            模型会为输入序列加上position embedding。
            若为`None`，忽略加上position embedding的步骤. Default: `None`
        """
        super(StarTransformer, self).__init__()
        self.iters = num_layers

        self.norm = nn.ModuleList([nn.LayerNorm(hidden_size, eps=1e-6) for _ in range(self.iters)])
        # self.emb_fc = nn.Conv2d(hidden_size, hidden_size, 1)
        self.emb_drop = nn.Dropout(dropout)
        self.ring_att = nn.ModuleList(
            [_MSA1(hidden_size, nhead=num_head, head_dim=head_dim, dropout=0.0)
             for _ in range(self.iters)])
        self.star_att = nn.ModuleList(
            [_MSA2(hidden_size, nhead=num_head, head_dim=head_dim, dropout=0.0)
             for _ in range(self.iters)])

        if max_len is not None:
            self.pos_emb = nn.Embedding(max_len, hidden_size)
        else:
            self.pos_emb = None

    def forward(self, data, mask):
        r"""
        :param FloatTensor data: [batch, length, hidden] 输入的序列
        :param ByteTensor mask: [batch, length] 输入序列的padding mask, 在没有内容(padding 部分) 为 0,
            否则为 1
        :return: [batch, length, hidden] 编码后的输出序列
                [batch, hidden] 全局 relay 节点, 详见论文
        """

        def norm_func(f, x):
            # B, H, L, 1
            return f(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)

        B, L, H = data.size()
        mask = (mask.eq(False))  # flip the mask for masked_fill_
        smask = torch.cat([torch.zeros(B, 1, ).byte().to(mask), mask], 1)

        embs = data.permute(0, 2, 1)[:, :, :, None]  # B H L 1
        if self.pos_emb:
            P = self.pos_emb(torch.arange(L, dtype=torch.long, device=embs.device) \
                             .view(1, L)).permute(0, 2, 1).contiguous()[:, :, :, None]  # 1 H L 1
            embs = embs + P
        embs = norm_func(self.emb_drop, embs)
        nodes = embs
        relay = embs.mean(2, keepdim=True)
        ex_mask = mask[:, None, :, None].expand(B, H, L, 1)
        r_embs = embs.view(B, H, 1, L)
#         nodes_attns = []
        relays_attns = []
        for i in range(self.iters):
            ax = torch.cat([r_embs, relay.expand(B, H, 1, L)], 2)
#             nodes, nodes_att = self.ring_att[i](norm_func(self.norm[i], nodes), ax=ax)
#             nodes = F.leaky_relu(nodes)
            nodes = F.leaky_relu(self.ring_att[i](nodes, ax=ax))
            relay, relay_att = self.star_att[i](relay, torch.cat([relay, nodes], 2), smask)
            relay = F.leaky_relu(relay)
            relays_attns.append(relay_att)
#             nodes_attns.append(nodes_att)
            nodes = nodes.masked_fill_(ex_mask, 0)

        nodes = nodes.view(B, H, L).permute(0, 2, 1)

        return nodes, relay.view(B, H), relays_attns


class _MSA1(nn.Module):
    def __init__(self, nhid, nhead=10, head_dim=10, dropout=0.1):
        super(_MSA1, self).__init__()
        # Multi-head Self Attention Case 1, doing self-attention for small regions
        # Due to the architecture of GPU, using hadamard production and summation are faster than dot production when unfold_size is very small
        self.WQ = nn.Conv2d(nhid, nhead * head_dim, 1)
        self.WK = nn.Conv2d(nhid, nhead * head_dim, 1)
        self.WV = nn.Conv2d(nhid, nhead * head_dim, 1)
        self.WO = nn.Conv2d(nhead * head_dim, nhid, 1)

        self.drop = nn.Dropout(dropout)

        self.nhid, self.nhead, self.head_dim, self.unfold_size = nhid, nhead, head_dim, 3

    def forward(self, x, ax=None):
        # x: B, H, L, 1, ax : B, H, X, L append features
        nhid, nhead, head_dim, unfold_size = self.nhid, self.nhead, self.head_dim, self.unfold_size
        B, H, L, _ = x.shape
        q, k, v = self.WQ(x), self.WK(x), self.WV(x)  # x: (B,H,L,1)

        if ax is not None:
            aL = ax.shape[2]
            ak = self.WK(ax).view(B, nhead, head_dim, aL, L)
            av = self.WV(ax).view(B, nhead, head_dim, aL, L)
        q = q.view(B, nhead, head_dim, 1, L)
        k = F.unfold(k.view(B, nhead * head_dim, L, 1), (unfold_size, 1), padding=(unfold_size // 2, 0)) \
            .view(B, nhead, head_dim, unfold_size, L)
        v = F.unfold(v.view(B, nhead * head_dim, L, 1), (unfold_size, 1), padding=(unfold_size // 2, 0)) \
            .view(B, nhead, head_dim, unfold_size, L)
        if ax is not None:
            k = torch.cat([k, ak], 3)
            v = torch.cat([v, av], 3)
        alphas = self.drop(F.softmax((q * k).sum(2, keepdim=True) / np.sqrt(head_dim), 3))  # B N L 1 U
        #print('alphas shape',alphas.shape) #[1024, 8, 1, 5, 49]
        att = (alphas * v).sum(3).view(B, nhead * head_dim, L, 1)
        ret = self.WO(att)

        return ret


class _MSA2(nn.Module):
    def __init__(self, nhid, nhead=10, head_dim=10, dropout=0.1):
        # Multi-head Self Attention Case 2, a broadcastable query for a sequence key and value
        super(_MSA2, self).__init__()
        self.WQ = nn.Conv2d(nhid, nhead * head_dim, 1)
        self.WK = nn.Conv2d(nhid, nhead * head_dim, 1)
        self.WV = nn.Conv2d(nhid, nhead * head_dim, 1)
        self.WO = nn.Conv2d(nhead * head_dim, nhid, 1)

        self.drop = nn.Dropout(dropout)

        self.nhid, self.nhead, self.head_dim, self.unfold_size = nhid, nhead, head_dim, 3
    def forward(self, x, y, mask=None):
        # x: B, H, 1, 1, 1 y: B H L 1
        nhid, nhead, head_dim, unfold_size = self.nhid, self.nhead, self.head_dim, self.unfold_size
        B, H, L, _ = y.shape

        q, k, v = self.WQ(x), self.WK(y), self.WV(y)

        q = q.view(B, nhead, 1, head_dim)  # B, H, 1, 1 -> B, N, 1, h
        k = k.view(B, nhead, head_dim, L)  # B, H, L, 1 -> B, N, h, L
        v = v.view(B, nhead, head_dim, L).permute(0, 1, 3, 2)  # B, H, L, 1 -> B, N, L, h
        pre_a = torch.matmul(q, k) / np.sqrt(head_dim)
        
        if mask is not None:
            pre_a = pre_a.masked_fill(mask[:, None, None, :], -float('inf'))
        alphas = self.drop(F.softmax(pre_a, 3))  # B, N, 1, L
        att = torch.matmul(alphas, v).view(B, -1, 1, 1)  # B, N, 1, h -> B, N*h, 1, 1
        return self.WO(att) ,alphas
    
class StarTransEnc(nn.Module):
    r"""
    带word embedding的Star-Transformer Encoder
    """

    def __init__(self, embed,
                 hidden_size,
                 num_layers,
                 num_head,
                 head_dim,
                 max_len,
                 emb_dropout,
                 dropout):
        r"""
        
        :param embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即
            embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象,此时就以传入的对象作为embedding
        :param hidden_size: 模型中特征维度.
        :param num_layers: 模型层数.
        :param num_head: 模型中multi-head的head个数.
        :param head_dim: 模型中multi-head中每个head特征维度.
        :param max_len: 模型能接受的最大输入长度.
        :param emb_dropout: 词嵌入的dropout概率.
        :param dropout: 模型除词嵌入外的dropout概率.
        """
        super(StarTransEnc, self).__init__()
        self.embedding = get_embeddings(embed,padding_idx=0)
        emb_dim = self.embedding.embedding_dim
        self.emb_fc = nn.Linear(emb_dim, hidden_size)
        # self.emb_drop = nn.Dropout(emb_dropout)
        self.encoder = StarTransformer(hidden_size=hidden_size,
                                       num_layers=num_layers,
                                       num_head=num_head,
                                       head_dim=head_dim,
                                       dropout=dropout,
                                       max_len=max_len)

    def forward(self, x, mask):
        r"""
        :param FloatTensor x: [batch, length, hidden] 输入的序列
        :param ByteTensor mask: [batch, length] 输入序列的padding mask, 在没有内容(padding 部分) 为 0,
            否则为 1
        :return: [batch, length, hidden] 编码后的输出序列
                [batch, hidden] 全局 relay 节点, 详见论文
        """
        x = self.embedding(x)
        x = self.emb_fc(x)
        nodes, relay, relays_attns = self.encoder(x, mask)
        return nodes, relay, relays_attns


class _Cls(nn.Module):
    def __init__(self, in_dim, num_cls, hid_dim, dropout=0.1):
        super(_Cls, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(in_dim, hid_dim),
            nn.LeakyReLU(),
            nn.Dropout(dropout),
            nn.Linear(hid_dim, num_cls),
        )

    def forward(self, x):
        h = self.fc(x)
        return h
    
class STSeqCls(nn.Module):
    r"""
    用于分类任务的Star-Transformer
    """

    def __init__(self, embed, num_cls=2,
                 hidden_size=300,
                 num_layers=1,
                 num_head=9,
                 head_dim=32,
                 max_len=512,
                 cls_hidden_size=600,
                 emb_dropout=0.1,
                 dropout=0.1):
        r"""
        
        :param embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即
            embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象, 此时就以传入的对象作为embedding
        :param num_cls: 输出类别个数
        :param hidden_size: 模型中特征维度. Default: 300
        :param num_layers: 模型层数. Default: 4
        :param num_head: 模型中multi-head的head个数. Default: 8
        :param head_dim: 模型中multi-head中每个head特征维度. Default: 32
        :param max_len: 模型能接受的最大输入长度. Default: 512
        :param cls_hidden_size: 分类器隐层维度. Default: 600
        :param emb_dropout: 词嵌入的dropout概率. Default: 0.1
        :param dropout: 模型除词嵌入外的dropout概率. Default: 0.1
        """
        super(STSeqCls, self).__init__()
        self.enc = StarTransEnc(embed=embed,
                                hidden_size=hidden_size,
                                num_layers=num_layers,
                                num_head=num_head,
                                head_dim=head_dim,
                                max_len=max_len,
                                emb_dropout=emb_dropout,
                                dropout=dropout)
        self.cls = _Cls(hidden_size, num_cls, cls_hidden_size, dropout=dropout)

    def forward(self, words, seq_len):
        r"""
        :param words: [batch, seq_len] 输入序列
        :param seq_len: [batch,] 输入序列的长度
        :return output: [batch, num_cls] 输出序列的分类的概率
        """
        mask = seq_len_to_mask(seq_len,max_len=49).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
        nodes, relay, relays_attns = self.enc(words, mask)
        y = 0.5 * (relay + nodes.max(1)[0])
        output = self.cls(y)  # [bsz, n_cls]
        return output, relays_attns#, nodes_attns, relays_attns
    

                     

In [144]:
path_saver = '../model/st_layer1_multihead8_fold3_netmhcpan.pkl'
model_eval = STSeqCls((21, 100), num_cls=2, hidden_size=300, num_layers=1, num_head=8, max_len=49,cls_hidden_size=600,dropout=0.1,head_dim=32).to(device)
model_eval.load_state_dict(torch.load(path_saver))

<All keys matched successfully>

In [145]:
threshold = 0.5
def transfer(y_prob, threshold = 0.5):
    return np.array([[0, 1][x > threshold] for x in y_prob])
def eval_step_corrected(model, val_loader, use_cuda = False, save_ = False):
    device = torch.device("cuda" if use_cuda else "cpu")
    
    model.eval()
#     torch.manual_seed(19961231)
#     torch.cuda.manual_seed(19961231)
    with torch.no_grad():
        y_true_val_list, y_prob_val_list, dec_attns_val_list = [], [], []
        for train_pep_inputs, train_pep_lens, train_labels in tqdm(val_loader):
            '''
            pep_inputs: [batch_size, pep_len]
            hla_inputs: [batch_size, hla_len]
            train_outputs: [batch_size, 2]
            '''
            train_pep_inputs, train_labels = train_pep_inputs.to(device), train_labels.to(device)
            train_pep_lens = train_pep_lens.to(device)
#             t1 = time.time()
#             train_outputs = model(train_pep_inputs, train_pep_lens)
            val_outputs, val_dec_self_attns = model(train_pep_inputs, train_pep_lens)
            val_loss = criterion(val_outputs, train_labels)

            y_true_val = train_labels.cpu().numpy()
            y_prob_val = nn.Softmax(dim = 1)(val_outputs)[:, 1].cpu().detach().numpy()

            y_true_val_list.extend(y_true_val)
            y_prob_val_list.extend(y_prob_val)
            
            if save_:
#                 print(len(val_dec_self_attns))
#                 print(val_dec_self_attns[0].shape)
#                 dec_attns_val_list.extend(val_dec_self_attns[0][:, :, :, 34:]) # 只要（34,15）行HLA，列peptide
#                 print(val_dec_self_attns[0].shape)
                dec_attns_val_list.extend(val_dec_self_attns[0][:, :, :, 35:]) # 只要（34,15）行HLA，列peptide
                
#         assert (labels.numpy() == y_true_val_list).all()    
        y_pred_val_list = transfer(y_prob_val_list, threshold)
        ys_val = (y_true_val_list, y_pred_val_list, y_prob_val_list)
#         metrics_val = performances(y_true_val_list, y_pred_val_list, y_prob_val_list, print_ = True)
        
        if save_: 
            return ys_val, dec_attns_val_list
        else:
            return ys_val, None

In [146]:
train_data, train_pep_inputs, train_hla_inputs, train_pep_lens, train_labels, train_loader = data_with_loader(type_ = 'train', fold = 3,  batch_size = batch_size)

In [147]:
# type_ = 'all_corrected'

save_ = True
use_cuda = True
device = torch.device("cuda" if use_cuda else "cpu")

attns_dict = {}
criterion = nn.CrossEntropyLoss()
_, attn_res = eval_step_corrected(model_eval, train_loader, use_cuda, save_)


100%|████████████████████████████████████████████████████████████████████████████████| 543/543 [00:26<00:00, 20.15it/s]


In [148]:
attn_res[0][0]

tensor([[7.3125e-29, 0.0000e+00, 1.5869e-08, 4.8617e-27, 2.4951e-34, 4.4545e-33,
         1.0823e-38, 6.7814e-41, 0.0000e+00, 6.7683e-41, 2.3480e-36, 7.1038e-37,
         1.3530e-33, 1.0855e-31, 1.3035e-31]], device='cuda:0')

In [149]:
def attn_sumhead_peplength_pepposition(data, attn_data, label = None):
    SUM_length_head_dict = {}
    for l in range(8, 15):
        print('Length = ', str(l))
        SUM_length_head_dict[l] = []
        
        if label == None:
            length_index = np.array(data[data.length == l].index)
        elif label == 1:
            length_index = np.array(data[data.label == 1][data.length == l].index)
        elif label == 0:
            length_index = np.array(data[data.label == 0][data.length == l].index)
            
        length_data_num = len(length_index)
        print(length_data_num, length_index)

        for head in trange(8):
            idx_0 = length_index[0]
            temp_length_head = deepcopy(nn.Softmax(dim = -1)(attn_data[idx_0][head][:, :l].float())) # Shape = (34, length), 行是HLA，列是peptide，由行查列

            for idx in length_index[1:]:
                temp_length_head += nn.Softmax(dim = -1)(attn_data[idx][head][:, :l].float())

            temp_length_head = np.array(nn.Softmax(dim = -1)(temp_length_head.sum(axis = 0)).cpu()) # 把这一列的数据相加，shape = （length，）
            SUM_length_head_dict[l].append(temp_length_head)
            
    #############################
    SUM_length_head_sum = []
    for l in range(8, 15):
        print(l)
        temp = pd.DataFrame(SUM_length_head_dict[l], columns = range(1, l+1)).round(4)
        temp.loc['sum'] = temp.sum(axis = 0)
        SUM_length_head_sum.append(list(temp.loc['sum']))
        print(l, temp.loc['sum'].sort_values(ascending = False).index)
        
    return SUM_length_head_dict, SUM_length_head_sum

In [150]:
df_data = pd.read_csv('../data/train_set/NetMHCpan4.1/train_data_fold3.csv')
df_data['length'] = df_data['peptide'].map(lambda x:len(x))

In [151]:
df_data

Unnamed: 0,peptide,HLA,label,HLA_sequence,length
0,LSDKSLSIL,HLA-C16:01,1,YYAGYREKYRQTDVSNLYLWYDSYTWAAQAYTWY,9
1,LLPGGPGPSPEAE,HLA-B07:02,0,YYSEYRNIYAQTDESNLYLSYDYYTWAERAYEWY,13
2,IVTNGLRTTCAS,HLA-B27:09,0,YHTEYREICAKTDEDTLYLNYHHYTWAVLAYEWY,12
3,HPLGDIVAF,HLA-B35:08,1,YYATYRNIFTNTYESNLYIRYDSYTWAVRAYLWY,9
4,VPVEPVLTV,HLA-B51:01,1,YYATYRNIFTNTYENIAYWTYNYYTWAELAYLWH,9
...,...,...,...,...,...
555450,GKALPFGQNDLRQF,HLA-B18:01,0,YHSTYRNISTNTYESNLYLRYDSYTWAVLAYTWH,14
555451,QEQADSLERSL,HLA-B44:02,1,YYTKYREISTNTYENTAYIRYDDYTWAVDAYLSY,11
555452,MPKMDQDSL,HLA-B07:02,1,YYSEYRNIYAQTDESNLYLSYDYYTWAERAYEWY,9
555453,RVGDVYIPR,HLA-A03:01,1,YFAMYQENVAQTDVDTLYIIYRDYTWAELAYTWY,9


In [152]:
attn_res[100].shape

torch.Size([8, 1, 15])

In [153]:
# 正样本
positive_sum_peplength_pepposition, positive_sum_peplength_pepposition_headsum = attn_sumhead_peplength_pepposition(df_data, attn_res, label = 1)
positive_sum_peplength_pepposition_headsum
# 负样本
negative_sum_peplength_pepposition, negative_sum_peplength_pepposition_headsum = attn_sumhead_peplength_pepposition(df_data, attn_res, label = 0)
negative_sum_peplength_pepposition_headsum
# 全部样本
sum_peplength_pepposition, sum_peplength_pepposition_headsum = attn_sumhead_peplength_pepposition(df_data, attn_res)
sum_peplength_pepposition_headsum

Length =  8
19350 [    60     73    166 ... 555435 555441 555447]


100%|████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:06<00:00,  1.15it/s]


Length =  9
154097 [     0      3      4 ... 555448 555452 555453]


100%|████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:55<00:00,  6.92s/it]


Length =  10
50074 [     7     10     39 ... 555395 555404 555427]


100%|████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:18<00:00,  2.26s/it]


Length =  11
32804 [    43     45     61 ... 555386 555418 555451]


100%|████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:11<00:00,  1.47s/it]


Length =  12
11308 [    20     51     70 ... 555311 555349 555367]


100%|████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:04<00:00,  1.89it/s]


Length =  13
6308 [   244    448    592 ... 554741 554907 555333]


100%|████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:02<00:00,  3.53it/s]


Length =  14
3765 [    52    133    144 ... 554513 554678 555223]


100%|████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:01<00:00,  5.81it/s]


8
8 Int64Index([8, 4, 3, 2, 1, 6, 7, 5], dtype='int64')
9
9 Int64Index([9, 1, 2, 3, 4, 5, 6, 7, 8], dtype='int64')
10
10 Int64Index([10, 9, 7, 2, 3, 4, 1, 6, 5, 8], dtype='int64')
11
11 Int64Index([11, 2, 4, 7, 3, 1, 9, 10, 6, 5, 8], dtype='int64')
12
12 Int64Index([12, 2, 7, 3, 4, 1, 11, 9, 6, 10, 8, 5], dtype='int64')
13
13 Int64Index([13, 2, 3, 4, 7, 1, 11, 9, 12, 6, 10, 5, 8], dtype='int64')
14
14 Int64Index([14, 7, 4, 2, 3, 10, 1, 13, 11, 12, 9, 6, 5, 8], dtype='int64')
Length =  8
42840 [    19     28     42 ... 555431 555442 555445]


100%|████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:15<00:00,  1.94s/it]


Length =  9
42583 [    31     40     57 ... 555433 555438 555439]


100%|████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:15<00:00,  1.92s/it]


Length =  10
42676 [    38     44     46 ... 555422 555426 555428]


100%|████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:15<00:00,  1.93s/it]


Length =  11
42381 [    16     25     29 ... 555378 555394 555408]


100%|████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:15<00:00,  1.90s/it]


Length =  12
36832 [     2      6     11 ... 555423 555424 555425]


100%|████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:13<00:00,  1.68s/it]


Length =  13
36437 [     1     26     34 ... 555393 555412 555449]


100%|████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:13<00:00,  1.63s/it]


Length =  14
33098 [     5     15     21 ... 555436 555450 555454]


100%|████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:12<00:00,  1.50s/it]


8
8 Int64Index([8, 2, 3, 4, 7, 1, 5, 6], dtype='int64')
9
9 Int64Index([9, 2, 3, 4, 7, 1, 8, 5, 6], dtype='int64')
10
10 Int64Index([10, 1, 2, 3, 4, 7, 5, 6, 8, 9], dtype='int64')
11
11 Int64Index([11, 2, 7, 3, 4, 1, 9, 6, 5, 10, 8], dtype='int64')
12
12 Int64Index([12, 2, 3, 4, 7, 1, 11, 9, 10, 5, 6, 8], dtype='int64')
13
13 Int64Index([13, 4, 2, 3, 7, 1, 11, 12, 6, 9, 5, 8, 10], dtype='int64')
14
14 Int64Index([14, 12, 3, 4, 2, 7, 1, 13, 11, 9, 10, 6, 5, 8], dtype='int64')
Length =  8
62190 [    19     28     42 ... 555442 555445 555447]


100%|████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:22<00:00,  2.82s/it]


Length =  9
196680 [     0      3      4 ... 555448 555452 555453]


100%|████████████████████████████████████████████████████████████████████████████████████| 8/8 [01:10<00:00,  8.84s/it]


Length =  10
92750 [     7     10     38 ... 555426 555427 555428]


100%|████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:33<00:00,  4.21s/it]


Length =  11
75185 [    16     25     29 ... 555408 555418 555451]


100%|████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:27<00:00,  3.40s/it]


Length =  12
48140 [     2      6     11 ... 555423 555424 555425]


100%|████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:17<00:00,  2.19s/it]


Length =  13
42745 [     1     26     34 ... 555393 555412 555449]


100%|████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:15<00:00,  1.93s/it]


Length =  14
36863 [     5     15     21 ... 555436 555450 555454]


100%|████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:13<00:00,  1.65s/it]

8
8 Int64Index([8, 1, 2, 3, 4, 7, 5, 6], dtype='int64')
9
9 Int64Index([9, 1, 2, 3, 4, 5, 6, 7, 8], dtype='int64')
10
10 Int64Index([10, 1, 2, 3, 4, 7, 5, 6, 8, 9], dtype='int64')
11
11 Int64Index([11, 1, 2, 3, 4, 7, 5, 6, 8, 9, 10], dtype='int64')
12
12 Int64Index([12, 2, 3, 4, 7, 1, 11, 5, 6, 8, 9, 10], dtype='int64')
13
13 Int64Index([13, 2, 3, 4, 7, 1, 11, 5, 6, 8, 9, 10, 12], dtype='int64')
14
14 Int64Index([14, 12, 2, 3, 4, 7, 1, 13, 11, 5, 6, 8, 9, 10], dtype='int64')





[[1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 3.0],
 [1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 4.0],
 [1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 3.0],
 [1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0],
 [0.9998, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0001, 3.0],
 [0.9982,
  1.0001,
  1.0001,
  1.0001,
  0.0001,
  0.0001,
  1.0001,
  0.0001,
  0.0001,
  0.0001,
  0.0003,
  0.0001,
  3.0004999999999997],
 [0.9957,
  1.0002,
  1.0002,
  1.0002,
  0.0002,
  0.0002,
  1.0002,
  0.0002,
  0.0002,
  0.0002,
  0.0006,
  1.0003,
  0.0012,
  2.0004]]

In [124]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go

In [154]:
# 两个基本参数：设置行、列
fig = make_subplots(rows=1, cols=3,#start_cell="bottom-left", # 'bottom-left', 'top-left
                    subplot_titles=["All Samples","Positive Samples","Negative Samples"],shared_yaxes=True)  # 1行2列

# 添加两个数据轨迹，构成两个图形
fig.add_trace(
    go.Heatmap(z=pd.DataFrame(sum_peplength_pepposition_headsum),colorscale='teal',
                   x=['1', '2', '3', '4', '5','6','7','8','9','10','11', '12', '13', '14'],
                   y=['8', '9', '10','11', '12', '13','14'],coloraxis='coloraxis',
                   hoverongaps = False),
    row=1, col=1  # 第一行第一列
)

fig.add_trace(
    go.Heatmap(z=pd.DataFrame(positive_sum_peplength_pepposition_headsum),colorscale='teal',
                   x=['1', '2', '3', '4', '5','6','7','8','9','10','11', '12', '13', '14'],
                   y=['8', '9', '10','11', '12', '13','14'],coloraxis='coloraxis',
                   hoverongaps = False),
    row=1, col=2  # 第一行第二列
)

fig.add_trace(
    go.Heatmap(z=pd.DataFrame(negative_sum_peplength_pepposition_headsum),colorscale='teal',
                   x=['1', '2', '3', '4', '5','6','7','8','9','10','11', '12', '13', '14'],
                   y=['8', '9', '10','11', '12', '13','14'],coloraxis='coloraxis',
                   hoverongaps = False),
    row=1, col=3  # 第一行第二列
)
fig.update_layout(width=1000,height=350,
                      paper_bgcolor='rgba(0,0,0,0)',
    plot_bgcolor='rgba(0,0,0,0)',
        font=dict(
        family="black",
        size=13))
fig.update_xaxes(tickangle=0, tickfont=dict(family='black', size=12))
fig.update_xaxes(title='peptide position')
fig.update_yaxes(title_text="peptide length", row=1, col=1)
fig.update_coloraxes(colorbar_thickness=5)
fig.update_coloraxes(colorscale='teal')
# fig.update_yaxes(title='peptide length')
# # 设置图形的宽高和标题
# fig.update_layout(height=300, 
#                   width=800, 
#                   title_text="子图制作")
fig.show()

In [37]:
from plotly.offline import iplot

In [74]:
iplot(fig, image='svg', filename='fold3_new', image_width=1000, image_height=350)

In [75]:
def attn_HLA_length_aatype_position_num(data, attn_data, hla = 'HLA-A*11:01', label = None, length = 9, show_num = False):
    aatype_position = dict()
    if label == None:
        length_index = np.array(data[data.length == length][data.HLA == hla].index)
    else:
        length_index = np.array(data[data.length == length][data.HLA == hla][data.label == label].index)

    length_data_num = len(length_index)
    print(length_data_num)

    for head in trange(8):
        for idx in length_index:
            temp_peptide = data.iloc[idx].peptide
            temp_length_head = deepcopy(nn.Softmax(dim=-1)(attn_data[idx][head][:, :length].float())) # Shape = (34, 9), 行是HLA，列是peptide，由行查列
            temp_length_head = nn.Softmax(dim=-1)(temp_length_head.sum(axis = 0)) # 把这一列的数据相加，shape = （9，）

            for i, aa in enumerate(temp_peptide): 
                aatype_position.setdefault(aa, {})
                aatype_position[aa].setdefault(i, 0)
                aatype_position[aa][i] += temp_length_head[i] 
    
    if show_num:
        aatype_position_num = dict()
        for idx in length_index:
            temp_peptide = data.iloc[idx].peptide
            for i, aa in enumerate(temp_peptide):
                aatype_position_num.setdefault(aa, {})
                aatype_position_num[aa].setdefault(i, 0)
                aatype_position_num[aa][i] += 1
             
        return aatype_position, aatype_position_num
    else:
        return aatype_position
def attn_HLA_length_aatype_position_pd(HLA_length_aatype_position, length = 9, softmax = True, unsoftmax = True):
        
    HLA_length_aatype_position_pd = np.zeros((20, length))
    
    aai, aa_indexs = 0, []
    for aa, aa_posi in HLA_length_aatype_position.items():
        aa_indexs.append(aa)
        for posi, v in aa_posi.items():
            HLA_length_aatype_position_pd[aai, posi] = v
        aai += 1
    
    if len(aa_indexs) != 20: 
        aatype_sorts = list('YATVLDEGRHIWQKMFNSPC')
        abscent_aa = list(set(aatype_sorts).difference(set(aa_indexs)))
        aa_indexs += abscent_aa
    
    if softmax and not unsoftmax: 
        HLA_length_aatype_position_softmax_pd = deepcopy(nn.Softmax(dim = -1)(torch.Tensor(HLA_length_aatype_position_pd)))
        HLA_length_aatype_position_softmax_pd = np.array(HLA_length_aatype_position_softmax_pd)
        HLA_length_aatype_position_softmax_pd = pd.DataFrame(HLA_length_aatype_position_softmax_pd, 
                                                             index = aa_indexs, columns = range(1, length + 1))
        return HLA_length_aatype_position_softmax_pd
    
    elif unsoftmax and not softmax:
        HLA_length_aatype_position_unsoftmax_pd = pd.DataFrame(HLA_length_aatype_position_pd,
                                                               index = aa_indexs, columns = range(1, length + 1))
        return HLA_length_aatype_position_unsoftmax_pd
    
    elif softmax and unsoftmax:
        HLA_length_aatype_position_softmax_pd = deepcopy(nn.Softmax(dim = -1)(torch.Tensor(HLA_length_aatype_position_pd)))
        HLA_length_aatype_position_softmax_pd = np.array(HLA_length_aatype_position_softmax_pd)
        HLA_length_aatype_position_softmax_pd = pd.DataFrame(HLA_length_aatype_position_softmax_pd, 
                                                             index = aa_indexs, columns = range(1, length + 1))
        
        HLA_length_aatype_position_unsoftmax_pd = pd.DataFrame(HLA_length_aatype_position_pd,
                                                               index = aa_indexs, columns = range(1, length + 1))
        return HLA_length_aatype_position_softmax_pd, HLA_length_aatype_position_unsoftmax_pd
def draw_hla_length_aatype_position(data, attn_data, hla = 'HLA-B*27:05', label = None, length = 9, 
                                    show = True, softmax = True, unsoftmax = True):
    
    HLA_length_aatype_position = attn_HLA_length_aatype_position_num(data, attn_data, hla, label, length, show_num = False)
    print(HLA_length_aatype_position)
    
    if softmax and unsoftmax:
        HLA_length_aatype_position_softmax_pd, HLA_length_aatype_position_unsoftmax_pd = attn_HLA_length_aatype_position_pd(
                                                                                     HLA_length_aatype_position, 
                                                                                     length, 
                                                                                     softmax,
                                                                                     unsoftmax)
        HLA_length_aatype_position_softmax_pd = sort_aatype(HLA_length_aatype_position_softmax_pd)
        HLA_length_aatype_position_unsoftmax_pd = sort_aatype(HLA_length_aatype_position_unsoftmax_pd)
        
        if show:
            fig, axes = plt.subplots(nrows = 1, ncols = 2, figsize = (10, 8))
            sns.heatmap(HLA_length_aatype_position_softmax_pd,
                        ax = axes[0], cmap = 'YlGn', square = True)

            sns.heatmap(HLA_length_aatype_position_unsoftmax_pd,
                        ax = axes[1], cmap = 'YlGn', square = True)

            axes[0].set_title(hla + ' Softmax Normalization')
            axes[1].set_title(hla + ' UnNormalization')
            plt.show()

        return HLA_length_aatype_position_softmax_pd, HLA_length_aatype_position_unsoftmax_pd
    
    else:
        HLA_length_aatype_position_pd = attn_HLA_length_aatype_position_pd(HLA_length_aatype_position, 
                                                                           length, 
                                                                           softmax,
                                                                           unsoftmax)
        HLA_length_aatype_position_pd = sort_aatype(HLA_length_aatype_position_pd)
        return HLA_length_aatype_position_pd
def sort_aatype(df):
    aatype_sorts = list('YATVLDEGRHIWQKMFNSPC')
    df.reset_index(inplace = True)
    df['index'] = df['index'].astype('category')
    df['index'].cat.reorder_categories(aatype_sorts, inplace=True)
    df.sort_values('index', inplace=True)
    df.rename(columns = {'index':''}, inplace = True)
    df = df.set_index('')
    return df


In [76]:
A0101_length9_positive_aatype_position_unsoftmax_pd = draw_hla_length_aatype_position(df_data, attn_res, 'HLA-A01:01', label = 1, length = 9, show = False, softmax = False, unsoftmax = True)
A0201_length9_positive_aatype_position_unsoftmax_pd = draw_hla_length_aatype_position(df_data, attn_res, 'HLA-A02:01', label = 1, length = 9, show = False, softmax = False, unsoftmax = True)
A0301_length9_positive_aatype_position_unsoftmax_pd = draw_hla_length_aatype_position(df_data, attn_res, 'HLA-A03:01', label = 1, length = 9, show = False, softmax = False, unsoftmax = True)
B0702_length9_positive_aatype_position_unsoftmax_pd = draw_hla_length_aatype_position(df_data, attn_res, 'HLA-B07:02', label = 1, length = 9, show = False, softmax = False, unsoftmax = True)
B2705_length9_positive_aatype_position_unsoftmax_pd = draw_hla_length_aatype_position(df_data, attn_res, 'HLA-B27:05', label = 1, length = 9, show = False, softmax = False, unsoftmax = True)
B5701_length9_positive_aatype_position_unsoftmax_pd = draw_hla_length_aatype_position(df_data, attn_res, 'HLA-B57:01', label = 1, length = 9, show = False, softmax = False, unsoftmax = True)

2939


100%|████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:09<00:00,  1.13s/it]


{'L': {0: tensor(333.3163, device='cuda:0'), 1: tensor(362.3422, device='cuda:0'), 7: tensor(439.9067, device='cuda:0'), 6: tensor(584.8994, device='cuda:0'), 5: tensor(219.5494, device='cuda:0'), 4: tensor(238.2338, device='cuda:0'), 3: tensor(225.7533, device='cuda:0'), 2: tensor(43.4134, device='cuda:0'), 8: tensor(77.9455, device='cuda:0')}, 'S': {1: tensor(727.3334, device='cuda:0'), 4: tensor(185.3352, device='cuda:0'), 6: tensor(181.6656, device='cuda:0'), 5: tensor(197.5648, device='cuda:0'), 3: tensor(206.3036, device='cuda:0'), 0: tensor(168.1044, device='cuda:0'), 2: tensor(120.5066, device='cuda:0'), 7: tensor(190.1796, device='cuda:0'), 8: tensor(26.7682, device='cuda:0')}, 'E': {2: tensor(603.5454, device='cuda:0'), 7: tensor(172.6453, device='cuda:0'), 5: tensor(88.5871, device='cuda:0'), 0: tensor(131.7071, device='cuda:0'), 1: tensor(24.0152, device='cuda:0'), 6: tensor(230.4143, device='cuda:0'), 4: tensor(180.1410, device='cuda:0'), 3: tensor(223.5154, device='cuda:0

100%|████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:26<00:00,  3.35s/it]


{'A': {0: tensor(879.4124, device='cuda:0'), 5: tensor(295.3068, device='cuda:0'), 4: tensor(482.5713, device='cuda:0'), 7: tensor(593.9031, device='cuda:0'), 6: tensor(550.7026, device='cuda:0'), 2: tensor(869.0291, device='cuda:0'), 3: tensor(394.5589, device='cuda:0'), 8: tensor(647.1201, device='cuda:0'), 1: tensor(144.8120, device='cuda:0')}, 'L': {1: tensor(5363.2920, device='cuda:0'), 4: tensor(702.3273, device='cuda:0'), 6: tensor(735.6033, device='cuda:0'), 8: tensor(2936.6729, device='cuda:0'), 3: tensor(184.8746, device='cuda:0'), 2: tensor(1192.6862, device='cuda:0'), 0: tensor(605.6728, device='cuda:0'), 5: tensor(1459.5144, device='cuda:0'), 7: tensor(538.1249, device='cuda:0')}, 'S': {2: tensor(571.4554, device='cuda:0'), 5: tensor(429.6005, device='cuda:0'), 4: tensor(476.5929, device='cuda:0'), 0: tensor(899.6631, device='cuda:0'), 6: tensor(342.3687, device='cuda:0'), 7: tensor(803.3450, device='cuda:0'), 3: tensor(655.0251, device='cuda:0'), 1: tensor(69.3470, device

100%|████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:17<00:00,  2.16s/it]


{'F': {0: tensor(35.3646, device='cuda:0'), 7: tensor(180.4139, device='cuda:0'), 5: tensor(337.2780, device='cuda:0'), 2: tensor(489.8577, device='cuda:0'), 6: tensor(227.8229, device='cuda:0'), 3: tensor(114.4189, device='cuda:0'), 4: tensor(172.5659, device='cuda:0'), 1: tensor(21.6399, device='cuda:0'), 8: tensor(3.5784, device='cuda:0')}, 'I': {1: tensor(615.5153, device='cuda:0'), 6: tensor(523.2121, device='cuda:0'), 0: tensor(207.8266, device='cuda:0'), 5: tensor(459.6172, device='cuda:0'), 4: tensor(233.1544, device='cuda:0'), 2: tensor(128.2034, device='cuda:0'), 7: tensor(91.9593, device='cuda:0'), 3: tensor(110.1196, device='cuda:0'), 8: tensor(4.4222, device='cuda:0')}, 'N': {2: tensor(328.4175, device='cuda:0'), 3: tensor(215.6404, device='cuda:0'), 5: tensor(122.2688, device='cuda:0'), 4: tensor(212.1309, device='cuda:0'), 6: tensor(103.1891, device='cuda:0'), 7: tensor(305.1306, device='cuda:0'), 1: tensor(12.3995, device='cuda:0'), 0: tensor(28.3030, device='cuda:0'), 

100%|████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:19<00:00,  2.47s/it]


{'G': {0: tensor(264.5667, device='cuda:0'), 3: tensor(483.0464, device='cuda:0'), 7: tensor(276.9437, device='cuda:0'), 4: tensor(418.7203, device='cuda:0'), 5: tensor(639.7430, device='cuda:0'), 6: tensor(263.5501, device='cuda:0'), 2: tensor(228.9206, device='cuda:0'), 1: tensor(21.4868, device='cuda:0'), 8: tensor(9.9143, device='cuda:0')}, 'P': {1: tensor(4909.9604, device='cuda:0'), 3: tensor(860.9327, device='cuda:0'), 6: tensor(514.1602, device='cuda:0'), 4: tensor(1016.9318, device='cuda:0'), 5: tensor(489.6921, device='cuda:0'), 0: tensor(39.8013, device='cuda:0'), 7: tensor(272.1208, device='cuda:0'), 8: tensor(81.1988, device='cuda:0'), 2: tensor(113.1826, device='cuda:0')}, 'R': {2: tensor(1235.1449, device='cuda:0'), 5: tensor(742.2578, device='cuda:0'), 4: tensor(508.4388, device='cuda:0'), 6: tensor(300.2628, device='cuda:0'), 0: tensor(842.3100, device='cuda:0'), 3: tensor(233.8139, device='cuda:0'), 7: tensor(172.1710, device='cuda:0'), 8: tensor(20.6929, device='cuda

100%|████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:11<00:00,  1.49s/it]


{'G': {0: tensor(527.3246, device='cuda:0'), 5: tensor(142.3102, device='cuda:0'), 7: tensor(283.7634, device='cuda:0'), 3: tensor(290.6187, device='cuda:0'), 4: tensor(230.5927, device='cuda:0'), 2: tensor(63.6170, device='cuda:0'), 6: tensor(30.2774, device='cuda:0'), 1: tensor(1.7694, device='cuda:0'), 8: tensor(0.9199, device='cuda:0')}, 'R': {1: tensor(3393.9658, device='cuda:0'), 3: tensor(160.0770, device='cuda:0'), 5: tensor(184.3609, device='cuda:0'), 4: tensor(128.7104, device='cuda:0'), 0: tensor(606.1377, device='cuda:0'), 8: tensor(436.1251, device='cuda:0'), 6: tensor(71.7368, device='cuda:0'), 2: tensor(58.5739, device='cuda:0'), 7: tensor(235.5918, device='cuda:0')}, 'V': {2: tensor(239.8559, device='cuda:0'), 6: tensor(382.0424, device='cuda:0'), 8: tensor(197.8715, device='cuda:0'), 4: tensor(357.2553, device='cuda:0'), 5: tensor(367.1438, device='cuda:0'), 7: tensor(185.6136, device='cuda:0'), 0: tensor(90.1147, device='cuda:0'), 3: tensor(130.3526, device='cuda:0'),

100%|████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:14<00:00,  1.82s/it]

{'L': {0: tensor(527.4846, device='cuda:0'), 2: tensor(671.2939, device='cuda:0'), 8: tensor(289.6098, device='cuda:0'), 6: tensor(350.6541, device='cuda:0'), 5: tensor(639.7095, device='cuda:0'), 4: tensor(662.9524, device='cuda:0'), 1: tensor(193.2834, device='cuda:0'), 3: tensor(303.2322, device='cuda:0'), 7: tensor(635.4509, device='cuda:0')}, 'S': {1: tensor(1376.2479, device='cuda:0'), 5: tensor(207.1985, device='cuda:0'), 3: tensor(245.6723, device='cuda:0'), 7: tensor(378.4660, device='cuda:0'), 2: tensor(240.9603, device='cuda:0'), 0: tensor(253.8992, device='cuda:0'), 6: tensor(206.5729, device='cuda:0'), 4: tensor(162.1341, device='cuda:0'), 8: tensor(0.8855, device='cuda:0')}, 'F': {2: tensor(310.0356, device='cuda:0'), 3: tensor(112.8059, device='cuda:0'), 8: tensor(1231.5442, device='cuda:0'), 4: tensor(225.3648, device='cuda:0'), 6: tensor(247.9592, device='cuda:0'), 5: tensor(198.2885, device='cuda:0'), 7: tensor(216.7805, device='cuda:0'), 0: tensor(151.8981, device='c




In [77]:
# 两个基本参数：设置行、列
fig = make_subplots(rows=1, cols=6,horizontal_spacing=0.02,x_title='Peptide position',#start_cell="bottom-left", # 'bottom-left', 'top-left
                    subplot_titles=["HLA-A01:01","HLA-A02:01","HLA-A03:01",'HLA-B07:02','HLA-B27:05','HLA-B57:01'],shared_yaxes=True,shared_xaxes=True)  # 1行2列

# 添加两个数据轨迹，构成两个图形
fig.add_trace(
    go.Heatmap(z=pd.DataFrame(A0101_length9_positive_aatype_position_unsoftmax_pd),colorscale='teal',
                   x=['1', '2', '3', '4', '5','6','7','8','9'],
                   y=['Y', 'A', 'T', 'V', 'L','D','E','G','R', 'H','I', 'W', 'Q','K', 'M','F', 'N', 'S','P','C'],coloraxis='coloraxis',
                   hoverongaps = False),
    row=1, col=1  # 第一行第一列
)

fig.add_trace(
    go.Heatmap(z=pd.DataFrame(A0201_length9_positive_aatype_position_unsoftmax_pd),colorscale='teal',
                   x=['1', '2', '3', '4', '5','6','7','8','9'],
                   y=['Y', 'A', 'T', 'V', 'L','D','E','G','R', 'H','I', 'W', 'Q','K', 'M','F', 'N', 'S','P','C'],coloraxis='coloraxis',
                   hoverongaps = False),
    row=1, col=2  # 第一行第二列
)

fig.add_trace(
    go.Heatmap(z=pd.DataFrame(A0301_length9_positive_aatype_position_unsoftmax_pd),colorscale='teal',
                    x=['1', '2', '3', '4', '5','6','7','8','9'],
                   y=['Y', 'A', 'T', 'V', 'L','D','E','G','R', 'H','I', 'W', 'Q','K', 'M','F', 'N', 'S','P','C'],coloraxis='coloraxis',
                   hoverongaps = False),
    row=1, col=3  # 第一行第二列
)
fig.add_trace(
    go.Heatmap(z=pd.DataFrame(B0702_length9_positive_aatype_position_unsoftmax_pd),colorscale='teal',
                    x=['1', '2', '3', '4', '5','6','7','8','9'],
                   y=['Y', 'A', 'T', 'V', 'L','D','E','G','R', 'H','I', 'W', 'Q','K', 'M','F', 'N', 'S','P','C'],coloraxis='coloraxis',
                   hoverongaps = False),
    row=1, col=4  # 第一行第二列
)

fig.add_trace(
    go.Heatmap(z=pd.DataFrame(B2705_length9_positive_aatype_position_unsoftmax_pd),colorscale='teal',
                    x=['1', '2', '3', '4', '5','6','7','8','9'],
                   y=['Y', 'A', 'T', 'V', 'L','D','E','G','R', 'H','I', 'W', 'Q','K', 'M','F', 'N', 'S','P','C'],coloraxis='coloraxis',
                   hoverongaps = False),
    row=1, col=5  # 第一行第二列
)

fig.add_trace(
    go.Heatmap(z=pd.DataFrame(B5701_length9_positive_aatype_position_unsoftmax_pd),colorscale='teal',
                    x=['1', '2', '3', '4', '5','6','7','8','9'],
                   y=['Y', 'A', 'T', 'V', 'L','D','E','G','R', 'H','I', 'W', 'Q','K', 'M','F', 'N', 'S','P','C'],coloraxis='coloraxis',
                   hoverongaps = False),
    row=1, col=6  # 第一行第二列
)
fig.update_layout(width=1000,height=500,
                      paper_bgcolor='rgba(0,0,0,0)',
    plot_bgcolor='rgba(0,0,0,0)',
        font=dict(
        family="black",
        size=13))
fig.update_xaxes(tickangle=0, tickfont=dict(family='black', size=12))
# fig.update_xaxes(title='Peptide position', row=1, col=3)
fig.update_yaxes(title_text="Amino acid type", row=1, col=1)
fig.update_coloraxes(colorbar_thickness=8)
fig.update_coloraxes(colorscale='teal')
# fig.update_yaxes(title='peptide length')
# # 设置图形的宽高和标题
# fig.update_layout(height=300, 
#                   width=800, 
#                   title_text="子图制作")
fig.show()

In [78]:
iplot(fig, image='svg', filename='fold3_other_new', image_width=1000, image_height=500)