In [None]:
!pip install plotly

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px
from plotly.offline import iplot


## plot

In [None]:
compare = pd.read_csv('../pic/netmhcpan41_test_data.csv')
compare=compare.sort_values(by='allele')
compare_A = compare[:]
tmp_df1 = pd.DataFrame({'HLA':compare_A['allele'].to_list(),'AUC':compare_A['transformer'].to_list()})
tmp_df2 = pd.DataFrame({'HLA':compare_A['allele'].to_list(),'AUC':compare_A['st'].to_list()})
tmp_df3 = pd.DataFrame({'HLA':compare_A['allele'].to_list(),'AUC':compare_A['st_lstm'].to_list()})
tmp_df4 = pd.DataFrame({'HLA':compare_A['allele'].to_list(),'AUC':compare_A['st_cnn'].to_list()})
df_cat = pd.concat([tmp_df2,tmp_df3,tmp_df4,tmp_df1])
df_cat['Method'] = ['Star-Transformer']*36+ ['Star-Transformer-LSTM']*36 + ['Star-Transformer-CNN']*36 + ['Transformer Encoder']*36
df_cat['Type'] = df_cat['HLA'].map(lambda x: 'HLA-'+x[0])


fig = px.scatter(df_cat,  # 数据集
             x="HLA",  # x轴
             y="AUC",  # y轴
             color="Method",  # 指定颜色
             symbol='Method',
                 facet_col="Type",
                 color_discrete_sequence=px.colors.qualitative.Bold,
                 facet_col_spacing=0.01
            )
fig.update_xaxes(tickangle=90, tickfont=dict(family='black', size=12))
fig.update_xaxes(matches=None)
fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))  

for axis in fig.layout:
    if type(fig.layout[axis]) == go.layout.XAxis:
        fig.layout[axis].title.text = ''

fig.update_traces(marker=dict(size=12),
                  selector=dict(mode='markers'))
fig.update_xaxes(showline=True, linewidth=1, linecolor='black', mirror=False)
fig.update_yaxes(showline=True, linewidth=1, linecolor='black', mirror=False)

fig.update_layout(width=1000,height=500,
                      paper_bgcolor='rgba(0,0,0,0)',
    plot_bgcolor='rgba(0,0,0,0)',
        font=dict(
        family="black",
        size=18))
fig.update_layout(
    legend=dict(
        title_font_family="black",  # 图例标题字体
        font=dict(  # 图例字体
            family="black",
            size=15,
        ),
        itemwidth=30,
        tracegroupgap=0,
        
    )
)
fig.update_layout(legend=dict(
    orientation="h",  # 开启水平显示
    yanchor="bottom",
    y=1.12,
    xanchor="right",
    x=1
))
fig.update_layout(showlegend=True,   # 隐藏图例，默认是True
                  legend_title_text=''   # 修改图例的名称
                 )
fig.show()

In [None]:
data = pd.read_csv('../pic/external_set_compare.csv')

In [None]:
fig = px.scatter(data, x="nt_f1", y="st_f1",
                 color="HLA", color_discrete_sequence=px.colors.qualitative.Dark24)
fig.add_trace(
        go.Scatter(
            x = [0.5,1],
            y = [0.5,1],
            mode = "lines",
            line = {'dash':'dash','color' : "gray"},
            showlegend=False
        )
)
fig.update_traces(marker=dict(size=12),
                  selector=dict(mode='markers'))
fig.update_xaxes(showline=True, linewidth=1, linecolor='black', mirror=False)
fig.update_yaxes(showline=True, linewidth=1, linecolor='black', mirror=False)

fig.update_layout(width=500,height=500,
                      paper_bgcolor='rgba(0,0,0,0)',
    plot_bgcolor='rgba(0,0,0,0)',
        font=dict(
        family="black",
        size=18))
fig.update_traces(showlegend=False)
fig.update_yaxes(range = [0.5,1])
fig.update_xaxes(range = [0.5,1])
fig.update_layout(
    title="F1",
    xaxis_title="NetMHCpan4.1",
    yaxis_title="Star-Transformer",
    font=dict(
        family="black",
        size=18,
    )
)
fig.show()

In [None]:
iplot(fig, image='svg', filename='external_set_f1', image_width=500, image_height=500)

In [None]:
color_bar = ['#2E91E5',
 '#E15F99',
 '#1CA71C',
 '#FB0D0D',
 '#DA16FF',
 '#222A2A',
 '#B68100',
 '#750D86',
 '#EB663B',
 '#511CFB',
 '#00A08B',
 '#FB00D1',
 '#FC0080',
 '#B2828D',
 '#6C7C32',
 '#778AAE',
 '#862A16',
 '#A777F1',
 '#620042',
 '#1616A7',
 '#DA60CA',
 '#6C4516',
 '#0D2A63',
 '#AF0038',
'#3366CC',
 '#DC3912',
 '#FF9900',
 '#109618',
 '#990099',
 '#0099C6',
 '#DD4477',
 '#66AA00',
 '#B82E2E',
 '#316395']
def color(color, text):
    return f"<span style='color:{str(color)}'> {str(text)} </span>"


colors = color_bar[19:34]#[15:19]#['#FD3216', '#00FE35', '#6A76FC', '#FED4C4','#FE00CE','#0DF9FF','#F6F926',]
ticks = y_list
keys = dict(zip(ticks, colors))

fig = go.Figure()

fig.add_trace(go.Bar(
    y=y_list,
    x=x_list_pos,
    name='#Positives',
    orientation='h',
#     width=1,
    marker=dict(
        color='rgba(246, 78, 139, 0.6)',
        line=dict(color='rgba(246, 78, 139, 1.0)', width=3)
    )
))
fig.add_trace(go.Bar(
    y=y_list,
    x=x_list_neg,
    name='#Negatives',
    orientation='h',
#     width=1,
    marker=dict(
        color='rgba(58, 71, 80, 0.6)',
        line=dict(color='rgba(58, 71, 80, 1.0)', width=3)
    )
))


fig.update_layout(width=600,height=600,
                      paper_bgcolor='rgba(0,0,0,0)',
    plot_bgcolor='rgba(0,0,0,0)',
        font=dict(
        family="black",
        size=18))
fig.update_layout(legend=dict(
    orientation="h",  # 开启水平显示
    yanchor="bottom",
    y=1.02,
    xanchor="right",
    x=0.9,
))

fig.update_layout(
    font=dict(
        family="black",
        size=18,
    )
)
fig.update_xaxes(showline=True, linewidth=1, linecolor='black', mirror=False)
fig.update_yaxes(showline=True, linewidth=1, linecolor='black', mirror=False)
fig.update_layout(barmode='stack')#overlay #stack
fig.update_layout(xaxis_type="log")


ticktext = [color(v, k) for k, v in keys.items()]
print(ticktext)
fig.update_layout(
yaxis=dict(tickmode='array', ticktext=ticktext, tickvals=ticks)
)
fig.show()

## Attention plot（ref: TranspHLA）

In [None]:
class StarTransformer(nn.Module):
    r"""
    Star-Transformer 的encoder部分。 输入3d的文本输入, 返回相同长度的文本编码
    paper: https://arxiv.org/abs/1902.09113
    """

    def __init__(self, hidden_size, num_layers, num_head, head_dim, dropout=0.1, max_len=None):
        r"""
        
        :param int hidden_size: 输入维度的大小。同时也是输出维度的大小。
        :param int num_layers: star-transformer的层数
        :param int num_head: head的数量。
        :param int head_dim: 每个head的维度大小。
        :param float dropout: dropout 概率. Default: 0.1
        :param int max_len: int or None, 如果为int，输入序列的最大长度，
            模型会为输入序列加上position embedding。
            若为`None`，忽略加上position embedding的步骤. Default: `None`
        """
        super(StarTransformer, self).__init__()
        self.iters = num_layers

        self.norm = nn.ModuleList([nn.LayerNorm(hidden_size, eps=1e-6) for _ in range(self.iters)])
        # self.emb_fc = nn.Conv2d(hidden_size, hidden_size, 1)
        self.emb_drop = nn.Dropout(dropout)
        self.ring_att = nn.ModuleList(
            [_MSA1(hidden_size, nhead=num_head, head_dim=head_dim, dropout=0.0)
             for _ in range(self.iters)])
        self.star_att = nn.ModuleList(
            [_MSA2(hidden_size, nhead=num_head, head_dim=head_dim, dropout=0.0)
             for _ in range(self.iters)])

        if max_len is not None:
            self.pos_emb = nn.Embedding(max_len, hidden_size)
        else:
            self.pos_emb = None

    def forward(self, data, mask):
        r"""
        :param FloatTensor data: [batch, length, hidden] 输入的序列
        :param ByteTensor mask: [batch, length] 输入序列的padding mask, 在没有内容(padding 部分) 为 0,
            否则为 1
        :return: [batch, length, hidden] 编码后的输出序列
                [batch, hidden] 全局 relay 节点, 详见论文
        """

        def norm_func(f, x):
            # B, H, L, 1
            return f(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)

        B, L, H = data.size()
        mask = (mask.eq(False))  # flip the mask for masked_fill_
        smask = torch.cat([torch.zeros(B, 1, ).byte().to(mask), mask], 1)

        embs = data.permute(0, 2, 1)[:, :, :, None]  # B H L 1
        if self.pos_emb:
            P = self.pos_emb(torch.arange(L, dtype=torch.long, device=embs.device) \
                             .view(1, L)).permute(0, 2, 1).contiguous()[:, :, :, None]  # 1 H L 1
            embs = embs + P
        embs = norm_func(self.emb_drop, embs)
        nodes = embs
        relay = embs.mean(2, keepdim=True)
        ex_mask = mask[:, None, :, None].expand(B, H, L, 1)
        r_embs = embs.view(B, H, 1, L)
#         nodes_attns = []
        relays_attns = []
        for i in range(self.iters):
            ax = torch.cat([r_embs, relay.expand(B, H, 1, L)], 2)
            nodes, nodes_att = self.ring_att[i](norm_func(self.norm[i], nodes), ax=ax)
            nodes = F.leaky_relu(nodes)
            # nodes = F.leaky_relu(self.ring_att[i](nodes, ax=ax))
            relay, relay_att = self.star_att[i](relay, torch.cat([relay, nodes], 2), smask)
            relay = F.leaky_relu(relay)
            relays_attns.append(relay_att)
#             nodes_attns.append(nodes_att)
            nodes = nodes.masked_fill_(ex_mask, 0)

        nodes = nodes.view(B, H, L).permute(0, 2, 1)

        return nodes, relay.view(B, H), relays_attns


class _MSA1(nn.Module):
    def __init__(self, nhid, nhead=10, head_dim=10, dropout=0.1):
        super(_MSA1, self).__init__()
        # Multi-head Self Attention Case 1, doing self-attention for small regions
        # Due to the architecture of GPU, using hadamard production and summation are faster than dot production when unfold_size is very small
        self.WQ = nn.Conv2d(nhid, nhead * head_dim, 1)
        self.WK = nn.Conv2d(nhid, nhead * head_dim, 1)
        self.WV = nn.Conv2d(nhid, nhead * head_dim, 1)
        self.WO = nn.Conv2d(nhead * head_dim, nhid, 1)

        self.drop = nn.Dropout(dropout)

        self.nhid, self.nhead, self.head_dim, self.unfold_size = nhid, nhead, head_dim, 3
        self.attn = None

    def forward(self, x, ax=None):
        # x: B, H, L, 1, ax : B, H, X, L append features
        nhid, nhead, head_dim, unfold_size = self.nhid, self.nhead, self.head_dim, self.unfold_size
        B, H, L, _ = x.shape
        q, k, v = self.WQ(x), self.WK(x), self.WV(x)  # x: (B,H,L,1)

        if ax is not None:
            aL = ax.shape[2]
            ak = self.WK(ax).view(B, nhead, head_dim, aL, L)
            av = self.WV(ax).view(B, nhead, head_dim, aL, L)
        q = q.view(B, nhead, head_dim, 1, L)
        k = F.unfold(k.view(B, nhead * head_dim, L, 1), (unfold_size, 1), padding=(unfold_size // 2, 0)) \
            .view(B, nhead, head_dim, unfold_size, L)
        v = F.unfold(v.view(B, nhead * head_dim, L, 1), (unfold_size, 1), padding=(unfold_size // 2, 0)) \
            .view(B, nhead, head_dim, unfold_size, L)
        if ax is not None:
            k = torch.cat([k, ak], 3)
            v = torch.cat([v, av], 3)
        alphas = self.drop(F.softmax((q * k).sum(2, keepdim=True) / np.sqrt(head_dim), 3))  # B N L 1 U
        print('alphas shape',alphas.shape) #[1024, 8, 1, 5, 49]
        att = (alphas * v).sum(3).view(B, nhead * head_dim, L, 1)
        ret = self.WO(att)

        return ret ,alphas


class _MSA2(nn.Module):
    def __init__(self, nhid, nhead=10, head_dim=10, dropout=0.1):
        # Multi-head Self Attention Case 2, a broadcastable query for a sequence key and value
        super(_MSA2, self).__init__()
        self.WQ = nn.Conv2d(nhid, nhead * head_dim, 1)
        self.WK = nn.Conv2d(nhid, nhead * head_dim, 1)
        self.WV = nn.Conv2d(nhid, nhead * head_dim, 1)
        self.WO = nn.Conv2d(nhead * head_dim, nhid, 1)

        self.drop = nn.Dropout(dropout)

        self.nhid, self.nhead, self.head_dim, self.unfold_size = nhid, nhead, head_dim, 3
    def forward(self, x, y, mask=None):
        # x: B, H, 1, 1, 1 y: B H L 1
        nhid, nhead, head_dim, unfold_size = self.nhid, self.nhead, self.head_dim, self.unfold_size
        B, H, L, _ = y.shape

        q, k, v = self.WQ(x), self.WK(y), self.WV(y)

        q = q.view(B, nhead, 1, head_dim)  # B, H, 1, 1 -> B, N, 1, h
        k = k.view(B, nhead, head_dim, L)  # B, H, L, 1 -> B, N, h, L
        v = v.view(B, nhead, head_dim, L).permute(0, 1, 3, 2)  # B, H, L, 1 -> B, N, L, h
        pre_a = torch.matmul(q, k) / np.sqrt(head_dim)
        
        if mask is not None:
            pre_a = pre_a.masked_fill(mask[:, None, None, :], -float('inf'))
        alphas = self.drop(F.softmax(pre_a, 3))  # B, N, 1, L
        att = torch.matmul(alphas, v).view(B, -1, 1, 1)  # B, N, 1, h -> B, N*h, 1, 1
        return self.WO(att) ,alphas
    
class StarTransEnc(nn.Module):
    r"""
    带word embedding的Star-Transformer Encoder
    """

    def __init__(self, embed,
                 hidden_size,
                 num_layers,
                 num_head,
                 head_dim,
                 max_len,
                 emb_dropout,
                 dropout):
        r"""
        
        :param embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即
            embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象,此时就以传入的对象作为embedding
        :param hidden_size: 模型中特征维度.
        :param num_layers: 模型层数.
        :param num_head: 模型中multi-head的head个数.
        :param head_dim: 模型中multi-head中每个head特征维度.
        :param max_len: 模型能接受的最大输入长度.
        :param emb_dropout: 词嵌入的dropout概率.
        :param dropout: 模型除词嵌入外的dropout概率.
        """
        super(StarTransEnc, self).__init__()
        self.embedding = get_embeddings(embed,padding_idx=0)
        emb_dim = self.embedding.embedding_dim
        self.emb_fc = nn.Linear(emb_dim, hidden_size)
        # self.emb_drop = nn.Dropout(emb_dropout)
        self.encoder = StarTransformer(hidden_size=hidden_size,
                                       num_layers=num_layers,
                                       num_head=num_head,
                                       head_dim=head_dim,
                                       dropout=dropout,
                                       max_len=max_len)

    def forward(self, x, mask):
        r"""
        :param FloatTensor x: [batch, length, hidden] 输入的序列
        :param ByteTensor mask: [batch, length] 输入序列的padding mask, 在没有内容(padding 部分) 为 0,
            否则为 1
        :return: [batch, length, hidden] 编码后的输出序列
                [batch, hidden] 全局 relay 节点, 详见论文
        """
        x = self.embedding(x)
        x = self.emb_fc(x)
        nodes, relay, relays_attns = self.encoder(x, mask)
        return nodes, relay, relays_attns


class _Cls(nn.Module):
    def __init__(self, in_dim, num_cls, hid_dim, dropout=0.1):
        super(_Cls, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(in_dim, hid_dim),
            nn.LeakyReLU(),
            nn.Dropout(dropout),
            nn.Linear(hid_dim, num_cls),
        )

    def forward(self, x):
        h = self.fc(x)
        return h
    
class STSeqCls(nn.Module):
    r"""
    用于分类任务的Star-Transformer
    """

    def __init__(self, embed, num_cls=2,
                 hidden_size=300,
                 num_layers=1,
                 num_head=9,
                 head_dim=32,
                 max_len=512,
                 cls_hidden_size=600,
                 emb_dropout=0.1,
                 dropout=0.1):
        r"""
        
        :param embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即
            embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象, 此时就以传入的对象作为embedding
        :param num_cls: 输出类别个数
        :param hidden_size: 模型中特征维度. Default: 300
        :param num_layers: 模型层数. Default: 4
        :param num_head: 模型中multi-head的head个数. Default: 8
        :param head_dim: 模型中multi-head中每个head特征维度. Default: 32
        :param max_len: 模型能接受的最大输入长度. Default: 512
        :param cls_hidden_size: 分类器隐层维度. Default: 600
        :param emb_dropout: 词嵌入的dropout概率. Default: 0.1
        :param dropout: 模型除词嵌入外的dropout概率. Default: 0.1
        """
        super(STSeqCls, self).__init__()
        self.enc = StarTransEnc(embed=embed,
                                hidden_size=hidden_size,
                                num_layers=num_layers,
                                num_head=num_head,
                                head_dim=head_dim,
                                max_len=max_len,
                                emb_dropout=emb_dropout,
                                dropout=dropout)
        self.cls = _Cls(hidden_size, num_cls, cls_hidden_size, dropout=dropout)

    def forward(self, words, seq_len):
        r"""
        :param words: [batch, seq_len] 输入序列
        :param seq_len: [batch,] 输入序列的长度
        :return output: [batch, num_cls] 输出序列的分类的概率
        """
        mask = seq_len_to_mask(seq_len,max_len=49).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
        nodes, relay, relays_attns = self.enc(words, mask)
        y = 0.5 * (relay + nodes.max(1)[0])
        output = self.cls(y)  # [bsz, n_cls]
        return output, relays_attns#, nodes_attns, relays_attns
    

                     

In [None]:
model_eval = STSeqCls((21, 100), num_cls=2, hidden_size=300, num_layers=1, num_head=8, max_len=49,cls_hidden_size=600,dropout=0.1,head_dim=32).to(device)

In [None]:
threshold = 0.5
def transfer(y_prob, threshold = 0.5):
    return np.array([[0, 1][x > threshold] for x in y_prob])
def eval_step_corrected(model, val_loader, use_cuda = False, save_ = False):
    device = torch.device("cuda" if use_cuda else "cpu")
    
    model.eval()
    with torch.no_grad():
        y_true_val_list, y_prob_val_list, dec_attns_val_list = [], [], []
        for train_pep_inputs, train_pep_lens, train_labels in tqdm(val_loader):
            '''
            pep_inputs: [batch_size, pep_len]
            hla_inputs: [batch_size, hla_len]
            train_outputs: [batch_size, 2]
            '''
            train_pep_inputs, train_labels = train_pep_inputs.to(device), train_labels.to(device)
            train_pep_lens = train_pep_lens.to(device)
            val_outputs, val_dec_self_attns = model(train_pep_inputs, train_pep_lens)

            y_true_val = train_labels.cpu().numpy()
            y_prob_val = nn.Softmax(dim = 1)(val_outputs)[:, 1].cpu().detach().numpy()

            y_true_val_list.extend(y_true_val)
            y_prob_val_list.extend(y_prob_val)
            
            if save_:
                dec_attns_val_list.extend(val_dec_self_attns[0][:, :, :, 34:]) # 只要（34,15）行HLA，列peptide
                
#         assert (labels.numpy() == y_true_val_list).all()    
        y_pred_val_list = transfer(y_prob_val_list, threshold)
        ys_val = (y_true_val_list, y_pred_val_list, y_prob_val_list)
#         metrics_val = performances(y_true_val_list, y_pred_val_list, y_prob_val_list, print_ = True)
        
        if save_: 
            return ys_val, dec_attns_val_list
        else:
            return ys_val, None

In [None]:
type_ = 'all_corrected'

# model_file = 'model_layer1_multihead9_fold4.pkl'

save_ = True
use_cuda = True
device = torch.device("cuda" if use_cuda else "cpu")
_, attn_res = eval_step_corrected(model_eval, train_loader, use_cuda, save_)


In [None]:
def attn_sumhead_peplength_pepposition(data, attn_data, label = None):
    SUM_length_head_dict = {}
    for l in range(8, 15):
        print('Length = ', str(l))
        SUM_length_head_dict[l] = []
        
        if label == None:
            length_index = np.array(data[data.length == l].index)
        elif label == 1:
            length_index = np.array(data[data.label == 1][data.length == l].index)
        elif label == 0:
            length_index = np.array(data[data.label == 0][data.length == l].index)
            
        length_data_num = len(length_index)
        print(length_data_num, length_index)

        for head in trange(8):
            idx_0 = length_index[0]
            temp_length_head = deepcopy(nn.Softmax(dim = -1)(attn_data[idx_0][head][:,:,:l].float())) # Shape = (34, length), 行是HLA，列是peptide，由行查列

            for idx in length_index[1:]:
                temp_length_head += nn.Softmax(dim = -1)(attn_data[idx][head][:,:,:l].float())

            temp_length_head = np.array(nn.Softmax(dim = -1)(temp_length_head.sum(axis = 1)[0]).cpu()) # 把这一列的数据相加，shape = （length，）
            print(temp_length_head.shape)
            SUM_length_head_dict[l].append(temp_length_head)
            
    #############################
    SUM_length_head_sum = []
    for l in range(8, 15):
        print(l)
        temp = pd.DataFrame(SUM_length_head_dict[l], columns = range(1, l+1)).round(4)
        temp.loc['sum'] = temp.sum(axis = 0)
        SUM_length_head_sum.append(list(temp.loc['sum']))
        print(l, temp.loc['sum'].sort_values(ascending = False).index)
        
    return SUM_length_head_dict, SUM_length_head_sum

In [None]:
# 正样本
positive_sum_peplength_pepposition, positive_sum_peplength_pepposition_headsum = attn_sumhead_peplength_pepposition(df_data, attn_res, label = 1)
positive_sum_peplength_pepposition_headsum
# 负样本
negative_sum_peplength_pepposition, negative_sum_peplength_pepposition_headsum = attn_sumhead_peplength_pepposition(df_data, attn_res, label = 0)
negative_sum_peplength_pepposition_headsum
# 全部样本
sum_peplength_pepposition, sum_peplength_pepposition_headsum = attn_sumhead_peplength_pepposition(df_data, attn_res, label = 0)
sum_peplength_pepposition_headsum

In [None]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go

In [None]:
fig = make_subplots(rows=1, cols=3,#start_cell="bottom-left", # 'bottom-left', 'top-left
                    subplot_titles=["All Samples","Positive Samples","Negative Samples"],shared_yaxes=True)  # 1行2列

fig.add_trace(
    go.Heatmap(z=pd.DataFrame(sum_peplength_pepposition_headsum),colorscale='teal',
                   x=['1', '2', '3', '4', '5','6','7','8','9','10','11', '12', '13', '14'],
                   y=['8', '9', '10','11', '12', '13','14'],coloraxis='coloraxis',
                   hoverongaps = False),
    row=1, col=1
)

fig.add_trace(
    go.Heatmap(z=pd.DataFrame(positive_sum_peplength_pepposition_headsum),colorscale='teal',
                   x=['1', '2', '3', '4', '5','6','7','8','9','10','11', '12', '13', '14'],
                   y=['8', '9', '10','11', '12', '13','14'],coloraxis='coloraxis',
                   hoverongaps = False),
    row=1, col=2
)

fig.add_trace(
    go.Heatmap(z=pd.DataFrame(negative_sum_peplength_pepposition_headsum),colorscale='teal',
                   x=['1', '2', '3', '4', '5','6','7','8','9','10','11', '12', '13', '14'],
                   y=['8', '9', '10','11', '12', '13','14'],coloraxis='coloraxis',
                   hoverongaps = False),
    row=1, col=3
)
fig.update_layout(width=1000,height=350,
                      paper_bgcolor='rgba(0,0,0,0)',
    plot_bgcolor='rgba(0,0,0,0)',
        font=dict(
        family="black",
        size=13))
fig.update_xaxes(tickangle=0, tickfont=dict(family='black', size=12))
fig.update_xaxes(title='peptide position')
fig.update_yaxes(title_text="peptide length", row=1, col=1)
fig.update_coloraxes(colorbar_thickness=5)
fig.update_coloraxes(colorscale='teal')
fig.show()

In [None]:
iplot(fig, image='svg', filename='attention', image_width=1000, image_height=500)

In [None]:
def attn_HLA_length_aatype_position_num(data, attn_data, hla = 'HLA-A*11:01', label = None, length = 9, show_num = False):
    aatype_position = dict()
    if label == None:
        length_index = np.array(data[data.length == length][data.HLA == hla].index)
    else:
        length_index = np.array(data[data.length == length][data.HLA == hla][data.label == label].index)

    length_data_num = len(length_index)
    print(length_data_num)

    for head in trange(8):
        for idx in length_index:
            temp_peptide = data.iloc[idx].peptide
            temp_length_head = deepcopy(nn.Softmax(dim=-1)(attn_data[idx][head][:, :length].float())) # Shape = (34, 9), 行是HLA，列是peptide，由行查列
            temp_length_head = nn.Softmax(dim=-1)(temp_length_head.sum(axis = 0)) # 把这一列的数据相加，shape = （9，）

            for i, aa in enumerate(temp_peptide): 
                aatype_position.setdefault(aa, {})
                aatype_position[aa].setdefault(i, 0)
                aatype_position[aa][i] += temp_length_head[i] 
    
    if show_num:
        aatype_position_num = dict()
        for idx in length_index:
            temp_peptide = data.iloc[idx].peptide
            for i, aa in enumerate(temp_peptide):
                aatype_position_num.setdefault(aa, {})
                aatype_position_num[aa].setdefault(i, 0)
                aatype_position_num[aa][i] += 1
             
        return aatype_position, aatype_position_num
    else:
        return aatype_position
def attn_HLA_length_aatype_position_pd(HLA_length_aatype_position, length = 9, softmax = True, unsoftmax = True):
        
    HLA_length_aatype_position_pd = np.zeros((20, length))
    
    aai, aa_indexs = 0, []
    for aa, aa_posi in HLA_length_aatype_position.items():
        aa_indexs.append(aa)
        for posi, v in aa_posi.items():
            HLA_length_aatype_position_pd[aai, posi] = v
        aai += 1
    
    if len(aa_indexs) != 20: 
        aatype_sorts = list('YATVLDEGRHIWQKMFNSPC')
        abscent_aa = list(set(aatype_sorts).difference(set(aa_indexs)))
        aa_indexs += abscent_aa
    
    if softmax and not unsoftmax: 
        HLA_length_aatype_position_softmax_pd = deepcopy(nn.Softmax(dim = -1)(torch.Tensor(HLA_length_aatype_position_pd)))
        HLA_length_aatype_position_softmax_pd = np.array(HLA_length_aatype_position_softmax_pd)
        HLA_length_aatype_position_softmax_pd = pd.DataFrame(HLA_length_aatype_position_softmax_pd, 
                                                             index = aa_indexs, columns = range(1, length + 1))
        return HLA_length_aatype_position_softmax_pd
    
    elif unsoftmax and not softmax:
        HLA_length_aatype_position_unsoftmax_pd = pd.DataFrame(HLA_length_aatype_position_pd,
                                                               index = aa_indexs, columns = range(1, length + 1))
        return HLA_length_aatype_position_unsoftmax_pd
    
    elif softmax and unsoftmax:
        HLA_length_aatype_position_softmax_pd = deepcopy(nn.Softmax(dim = -1)(torch.Tensor(HLA_length_aatype_position_pd)))
        HLA_length_aatype_position_softmax_pd = np.array(HLA_length_aatype_position_softmax_pd)
        HLA_length_aatype_position_softmax_pd = pd.DataFrame(HLA_length_aatype_position_softmax_pd, 
                                                             index = aa_indexs, columns = range(1, length + 1))
        
        HLA_length_aatype_position_unsoftmax_pd = pd.DataFrame(HLA_length_aatype_position_pd,
                                                               index = aa_indexs, columns = range(1, length + 1))
        return HLA_length_aatype_position_softmax_pd, HLA_length_aatype_position_unsoftmax_pd
def draw_hla_length_aatype_position(data, attn_data, hla = 'HLA-B*27:05', label = None, length = 9, 
                                    show = True, softmax = True, unsoftmax = True):
    
    HLA_length_aatype_position = attn_HLA_length_aatype_position_num(data, attn_data, hla, label, length, show_num = False)
    print(HLA_length_aatype_position)
    
    if softmax and unsoftmax:
        HLA_length_aatype_position_softmax_pd, HLA_length_aatype_position_unsoftmax_pd = attn_HLA_length_aatype_position_pd(
                                                                                     HLA_length_aatype_position, 
                                                                                     length, 
                                                                                     softmax,
                                                                                     unsoftmax)
        HLA_length_aatype_position_softmax_pd = sort_aatype(HLA_length_aatype_position_softmax_pd)
        HLA_length_aatype_position_unsoftmax_pd = sort_aatype(HLA_length_aatype_position_unsoftmax_pd)
        
        if show:
            fig, axes = plt.subplots(nrows = 1, ncols = 2, figsize = (10, 8))
            sns.heatmap(HLA_length_aatype_position_softmax_pd,
                        ax = axes[0], cmap = 'YlGn', square = True)

            sns.heatmap(HLA_length_aatype_position_unsoftmax_pd,
                        ax = axes[1], cmap = 'YlGn', square = True)

            axes[0].set_title(hla + ' Softmax Normalization')
            axes[1].set_title(hla + ' UnNormalization')
            plt.show()

        return HLA_length_aatype_position_softmax_pd, HLA_length_aatype_position_unsoftmax_pd
    
    else:
        HLA_length_aatype_position_pd = attn_HLA_length_aatype_position_pd(HLA_length_aatype_position, 
                                                                           length, 
                                                                           softmax,
                                                                           unsoftmax)
        HLA_length_aatype_position_pd = sort_aatype(HLA_length_aatype_position_pd)
        return HLA_length_aatype_position_pd
def sort_aatype(df):
    aatype_sorts = list('YATVLDEGRHIWQKMFNSPC')
    df.reset_index(inplace = True)
    df['index'] = df['index'].astype('category')
    df['index'].cat.reorder_categories(aatype_sorts, inplace=True)
    df.sort_values('index', inplace=True)
    df.rename(columns = {'index':''}, inplace = True)
    df = df.set_index('')
    return df


In [None]:
A0101_length9_positive_aatype_position_unsoftmax_pd = draw_hla_length_aatype_position(df_data, attn_res, 'HLA-A01:01', label = 1, length = 9, show = False, softmax = False, unsoftmax = True)
A0201_length9_positive_aatype_position_unsoftmax_pd = draw_hla_length_aatype_position(df_data, attn_res, 'HLA-A02:01', label = 1, length = 9, show = False, softmax = False, unsoftmax = True)
A0301_length9_positive_aatype_position_unsoftmax_pd = draw_hla_length_aatype_position(df_data, attn_res, 'HLA-A03:01', label = 1, length = 9, show = False, softmax = False, unsoftmax = True)
B0702_length9_positive_aatype_position_unsoftmax_pd = draw_hla_length_aatype_position(df_data, attn_res, 'HLA-B07:02', label = 1, length = 9, show = False, softmax = False, unsoftmax = True)
B2705_length9_positive_aatype_position_unsoftmax_pd = draw_hla_length_aatype_position(df_data, attn_res, 'HLA-B27:05', label = 1, length = 9, show = False, softmax = False, unsoftmax = True)
B5701_length9_positive_aatype_position_unsoftmax_pd = draw_hla_length_aatype_position(df_data, attn_res, 'HLA-B57:01', label = 1, length = 9, show = False, softmax = False, unsoftmax = True)

In [None]:
# 两个基本参数：设置行、列
fig = make_subplots(rows=1, cols=6,horizontal_spacing=0.02,x_title='Peptide position',#start_cell="bottom-left", # 'bottom-left', 'top-left
                    subplot_titles=["HLA-A01:01","HLA-A02:01","HLA-A03:01",'HLA-B07:02','HLA-B27:05','HLA-B57:01'],shared_yaxes=True,shared_xaxes=True)  # 1行2列

# 添加两个数据轨迹，构成两个图形
fig.add_trace(
    go.Heatmap(z=pd.DataFrame(A0101_length9_positive_aatype_position_unsoftmax_pd),colorscale='teal',
                   x=['1', '2', '3', '4', '5','6','7','8','9'],
                   y=['Y', 'A', 'T', 'V', 'L','D','E','G','R', 'H','I', 'W', 'Q','K', 'M','F', 'N', 'S','P','C'],coloraxis='coloraxis',
                   hoverongaps = False),
    row=1, col=1  # 第一行第一列
)

fig.add_trace(
    go.Heatmap(z=pd.DataFrame(A0201_length9_positive_aatype_position_unsoftmax_pd),colorscale='teal',
                   x=['1', '2', '3', '4', '5','6','7','8','9'],
                   y=['Y', 'A', 'T', 'V', 'L','D','E','G','R', 'H','I', 'W', 'Q','K', 'M','F', 'N', 'S','P','C'],coloraxis='coloraxis',
                   hoverongaps = False),
    row=1, col=2  # 第一行第二列
)

fig.add_trace(
    go.Heatmap(z=pd.DataFrame(A0301_length9_positive_aatype_position_unsoftmax_pd),colorscale='teal',
                    x=['1', '2', '3', '4', '5','6','7','8','9'],
                   y=['Y', 'A', 'T', 'V', 'L','D','E','G','R', 'H','I', 'W', 'Q','K', 'M','F', 'N', 'S','P','C'],coloraxis='coloraxis',
                   hoverongaps = False),
    row=1, col=3  # 第一行第二列
)
fig.add_trace(
    go.Heatmap(z=pd.DataFrame(B0702_length9_positive_aatype_position_unsoftmax_pd),colorscale='teal',
                    x=['1', '2', '3', '4', '5','6','7','8','9'],
                   y=['Y', 'A', 'T', 'V', 'L','D','E','G','R', 'H','I', 'W', 'Q','K', 'M','F', 'N', 'S','P','C'],coloraxis='coloraxis',
                   hoverongaps = False),
    row=1, col=4  # 第一行第二列
)

fig.add_trace(
    go.Heatmap(z=pd.DataFrame(B2705_length9_positive_aatype_position_unsoftmax_pd),colorscale='teal',
                    x=['1', '2', '3', '4', '5','6','7','8','9'],
                   y=['Y', 'A', 'T', 'V', 'L','D','E','G','R', 'H','I', 'W', 'Q','K', 'M','F', 'N', 'S','P','C'],coloraxis='coloraxis',
                   hoverongaps = False),
    row=1, col=5  # 第一行第二列
)

fig.add_trace(
    go.Heatmap(z=pd.DataFrame(B5701_length9_positive_aatype_position_unsoftmax_pd),colorscale='teal',
                    x=['1', '2', '3', '4', '5','6','7','8','9'],
                   y=['Y', 'A', 'T', 'V', 'L','D','E','G','R', 'H','I', 'W', 'Q','K', 'M','F', 'N', 'S','P','C'],coloraxis='coloraxis',
                   hoverongaps = False),
    row=1, col=6  # 第一行第二列
)
fig.update_layout(width=1000,height=500,
                      paper_bgcolor='rgba(0,0,0,0)',
    plot_bgcolor='rgba(0,0,0,0)',
        font=dict(
        family="black",
        size=13))
fig.update_xaxes(tickangle=0, tickfont=dict(family='black', size=12))
# fig.update_xaxes(title='Peptide position', row=1, col=3)
fig.update_yaxes(title_text="Amino acid type", row=1, col=1)
fig.update_coloraxes(colorbar_thickness=8)
fig.update_coloraxes(colorscale='teal')
# fig.update_yaxes(title='peptide length')
# # 设置图形的宽高和标题
# fig.update_layout(height=300, 
#                   width=800, 
#                   title_text="子图制作")
fig.show()

In [None]:
iplot(fig, image='svg', filename='attention', image_width=1000, image_height=500)