## D2L版本

In [None]:
import math
import pandas as pd
import torch
from torch import nn
import torch.nn.functional as F

class DotProd_Attention(nn.Module):
    def __init__(self, dropout, **kwargs) -> None:
        super(DotProd_Attention, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)

    def forward(self, querys, keys, values, valid_lens=None):
        d = querys.shape[-1]
        # 设置transpose_b=True为了交换keys的最后两个维度
        scores = torch.bmm(querys, keys.transpose(1, 2)) / math.sqrt(d)
        self.attention_weights = masked_softmax(scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), values)
    
def masked_softmax(x, valid_lens):   # valid_lens指定我们要保留的有效长度，其余部分作掩码处理，置为一个很大的负数，可以用numpy中的inf
    if valid_lens == None:    # 没指定的话，就直接softmax即可
        return F.softmax(x, dim=-1)      
        # dim=-1表示对最后一维作softmax计算，把取值范围限制在[0,1]，这里是每一行输出值的和应为1
    else:
        shape = x.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])   # 如果你只是输入了一个维度的valid_lens显然要把它复制到每一行上
            # 对于不同的样本也可以分别指定有效长度，这就是做repeat操作的目的
        else:
            valid_lens = valid_lens.reshape(-1)     # 不是1维的话，拉成1维的
        # 在最后的轴上，被遮蔽的元素使⽤⼀个⾮常⼤的负值替换，从⽽其softmax (指数)输出为0
        x = d2l.sequence_mask(x.reshape(-1, shape[-1]), valid_lens, value=-1e6)
        return F.softmax(x.reshape(shape), dim=-1)
    
def transpose_qkv(X, num_heads):
    """为了多注意⼒头的并⾏计算⽽变换形状"""
    # 输⼊X的形状:(batch_size，查询或者“键－值”对的个数，num_hiddens)
    # 输出X的形状:(batch_size，查询或者“键－值”对的个数，num_heads, num_hiddens/num_heads)
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1) 
    
    # 输出X的形状:(batch_size，num_heads，查询或者“键－值”对的个数, num_hiddens/num_heads)
    X = X.permute(0, 2, 1, 3) 
    
    # 最终输出的形状:(batch_size*num_heads,查询或者“键－值”对的个数, num_hiddens/num_heads)
    return X.reshape(-1, X.shape[2], X.shape[3])

def transpose_output(X, num_heads):
    """逆转transpose_qkv函数的操作"""
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)

class Multihead_attention(nn.Module):
    def __init__(self, num_heads, dropout, query_size, key_size, value_size, 
    num_hiddens, bias=False, **kwargs):
        super(Multihead_attention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = DotProd_Attention(dropout)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)
    def forward(self, queries, keys, values, valid_lens):
        # queries，keys，values的形状:
        # (batch_size，查询或者“键－值”对的个数，num_hiddens)
        # valid_lens 的形状:
        # (batch_size，)或(batch_size，查询的个数) # 经过变换后，输出的queries，keys，values 的形状:
        # (batch_size*num_heads，查询或者“键－值”对的个数，num_hiddens/num_heads)
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)
        
        if valid_lens is not None: 
            # 在轴0，将第⼀项（标量或者⽮量）复制num_heads次，
            # 然后如此复制第⼆项，然后诸如此类。
            valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)

        # output的形状:(batch_size*num_heads，查询的个数，num_hiddens/num_heads)
        output = self.attention(queries, keys, values, valid_lens)
        # output_concat的形状:(batch_size，查询的个数，num_hiddens)
        output_concat = transpose_output(output, self.num_heads)
        return self.W_o(output_concat)

# 基于位置的前馈网络（FFN，其实就是线性层，名字叫的好听点）
class PositionWiseFFN(nn.Module):
    """基于位置的前馈⽹络"""
    # 因为⽤同⼀个多层感知机对所有位置上的输⼊进⾏变换，所以当所有这些位置的输⼊相同时，它们的输出也是相同的
    def __init__(self, ffn_num_input, ffn_num_hiddens, ffn_num_outputs,
    **kwargs):
        super(PositionWiseFFN, self).__init__(**kwargs)
        self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens)
        self.relu = nn.ReLU()
        self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)
    def forward(self, X):
        return self.dense2(self.relu(self.dense1(X)))
# 测试下前馈网络
# ffn = PositionWiseFFN(4, 4, 8)
# ffn.eval()
# print(ffn(torch.ones((2, 3, 4)))[0])  # 可以观察到同一位置的输出值相等

# layerNorm和batchNorm的区别在于说，layernorm是针对于一个样本的所有特征来做归一化的，使得从一个样本上看过去是均值为0方差为1
# 而batchnorm则是对当前一个batch内所有样本的同一列特征来做归一化，也就是说两者处理的维度不同

# 残差连接与layernorm实现
class AddNorm(nn.Module):
    """残差连接后进⾏层规范化"""
    def __init__(self, normalized_shape, dropout, **kwargs):
        super(AddNorm, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(normalized_shape)
    def forward(self, X, Y):
        return self.ln(self.dropout(Y) + X)
# 测试下残差连接层，两个输入维度要一致
# add_norm = AddNorm([3, 4], 0.5)
# add_norm.eval()
# print(add_norm(torch.ones((2, 3, 4)), torch.ones((2, 3, 4))).shape)

# 编码器block的实现，transformer是要叠好几个encoderblock和decoderblock
class EncoderBlock(nn.Module):
    """transformer编码器块"""
    def __init__(self, key_size, query_size, value_size, num_hiddens,
    norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
    dropout, use_bias=False, **kwargs):
        super(EncoderBlock, self).__init__(**kwargs)
        self.attention = Multihead_attention(
            num_heads, dropout, query_size, key_size, value_size, 
            num_hiddens)
        self.addnorm1 = AddNorm(norm_shape, dropout)
        self.ffn = PositionWiseFFN(
        ffn_num_input, ffn_num_hiddens, num_hiddens)
        self.addnorm2 = AddNorm(norm_shape, dropout)
    def forward(self, X, valid_lens):
        Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))
        return self.addnorm2(Y, self.ffn(Y))
# 可以看到，transformer编码器中的任何层都不会改变其输⼊的形状
# 测试编码器block
# X = torch.ones((2, 100, 24))
# valid_lens = torch.tensor([3, 2])   # 分别设置两个样本的有效长度
# encoder_blk = EncoderBlock(24, 24, 24, 24, [100, 24], 24, 48, 8, 0.5)
# encoder_blk.eval()
# print(encoder_blk(X, valid_lens).shape)

class PositionalEncoding(nn.Module):
    """位置编码"""
    def __init__(self, num_hiddens, dropout, max_len=1000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dropout)
        # 创建⼀个⾜够⻓的P
        self.P = torch.zeros((1, max_len, num_hiddens))
        X = torch.arange(max_len, dtype=torch.float32).reshape(
        -1, 1) / torch.pow(10000, torch.arange(0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
        self.P[:, :, 0::2] = torch.sin(X)
        self.P[:, :, 1::2] = torch.cos(X)
    def forward(self, X):
        X = X + self.P[:, :X.shape[1], :].to(X.device)
        return self.dropout(X)

# 叠加encoder_block
# Transformer编码器输出的形状是（批量⼤⼩，时间步数⽬，num_hiddens）
class TransformerEncoder(nn.Module):
    """transformer编码器"""
    def __init__(self, vocab_size, key_size, query_size, value_size,
    num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens,
    num_heads, num_layers, dropout, use_bias=False, **kwargs):
        super(TransformerEncoder, self).__init__(**kwargs)
        self.num_hiddens = num_hiddens
        self.embedding = nn.Embedding(vocab_size, num_hiddens)
        self.pos_encoding = PositionalEncoding(num_hiddens, dropout)
        self.blks = nn.Sequential()
        for i in range(num_layers):
            self.blks.add_module("block"+str(i),
            EncoderBlock(key_size, query_size, value_size, num_hiddens,
            norm_shape, ffn_num_input, ffn_num_hiddens,
            num_heads, dropout, use_bias))
    def forward(self, X, valid_lens):
    # 因为位置编码值在-1和1之间，
    # 因此嵌⼊值乘以嵌⼊维度的平⽅根进⾏缩放，
    # 然后再与位置编码相加。
        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
        self.attention_weights = [None] * len(self.blks)
        for i, blk in enumerate(self.blks):
            X = blk(X, valid_lens)
            self.attention_weights[i] = blk.attention.attention.attention_weights
        return X

# 测试Transformer_Encoder，两个block
# encoder = TransformerEncoder(200, 24, 24, 24, 24, [100, 24], 24, 48, 8, 2, 0.5)
# encoder.eval()
# valid_lens = torch.tensor([3, 2])
# print(encoder(torch.ones((2, 100), dtype=torch.long), valid_lens).shape)

# transformer解码器也是由多个相同的层组成。在DecoderBlock类中实现的每个层包含了三个⼦层:
# 解码器⾃注意⼒、“编码器-解码器”注意⼒和基于位置的前馈⽹络。这些⼦层也都被和紧随的layernorm围绕
# 在掩蔽多头解码器⾃注意⼒层（第⼀个⼦层）中，查询、键和值都来⾃上⼀个解码器层的输出
# 为了在解码器中保留⾃回归的属性，其掩蔽⾃注意⼒设定了参数dec_valid_lens，以便任何查询都只会与解码器中所有已经⽣成词元的位置（即直到该查询位置为⽌）进⾏注意⼒计算
class DecoderBlock(nn.Module):
    """解码器中第i个块"""
    def __init__(self, key_size, query_size, value_size, num_hiddens,
    norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
    dropout, i, **kwargs):
        super(DecoderBlock, self).__init__(**kwargs)
        self.i = i
        
        self.attention1 = Multihead_attention(num_heads, dropout, query_size, key_size, value_size, num_hiddens)
        self.addnorm1 = AddNorm(norm_shape, dropout)
        self.attention2 = Multihead_attention(num_heads, dropout, query_size, key_size, value_size, num_hiddens)
        self.addnorm2 = AddNorm(norm_shape, dropout)
        self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens)
        self.addnorm3 = AddNorm(norm_shape, dropout)
    def forward(self, X, state):
        enc_outputs, enc_valid_lens = state[0], state[1]
        # 训练阶段，输出序列的所有词元都在同⼀时间处理，
        # 因此state[2][self.i]初始化为None。
        # 预测阶段，输出序列是通过词元⼀个接着⼀个解码的，
        # 因此state[2][self.i]包含着直到当前时间步第i个块解码的输出表⽰
        if state[2][self.i] is None:
            key_values = X
        else:
            key_values = torch.cat((state[2][self.i], X), axis=1)
        state[2][self.i] = key_values
        if self.training:
            batch_size, num_steps, _ = X.shape
            # dec_valid_lens的开头:(batch_size,num_steps),
            # 其中每⼀⾏是[1,2,...,num_steps]
            dec_valid_lens = torch.arange(
                1, num_steps + 1, device=X.device).repeat(batch_size, 1)
        else:
            dec_valid_lens = None
        # ⾃注意⼒
        X2 = self.attention1(X, key_values, key_values, dec_valid_lens)
        Y = self.addnorm1(X, X2)
        # 编码器－解码器注意⼒。
        # enc_outputs的开头:(batch_size,num_steps,num_hiddens)
        Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens)
        Z = self.addnorm2(Y, Y2)
        return self.addnorm3(Z, self.ffn(Z)), state




# 构建transformer解码器，还有后面全连接层输出
class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, key_size, query_size, value_size,
                num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens,
                num_heads, num_layers, dropout, **kwargs):
        super(TransformerDecoder, self).__init__(**kwargs)
        self.num_hiddens = num_hiddens
        self.num_layers = num_layers
        self.embedding = nn.Embedding(vocab_size, num_hiddens)
        self.pos_encoding = PositionalEncoding(num_hiddens, dropout)
        self.blks = nn.Sequential()
        for i in range(num_layers):
            self.blks.add_module("block"+str(i),
                    DecoderBlock(key_size, query_size, value_size, num_hiddens, 
                    norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, dropout, i))
        self.dense = nn.Linear(num_hiddens, vocab_size)

    def init_state(self, enc_outputs, enc_valid_lens, *args):
        return [enc_outputs, enc_valid_lens, [None] * self.num_layers]

    def forward(self, X, state):
        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
        self._attention_weights = [[None] * len(self.blks) for _ in range (2)]
        for i, blk in enumerate(self.blks):
            X, state = blk(X, state)
            # 解码器⾃注意⼒权重
            self._attention_weights[0][i] = blk.attention1.attention.attention_weights
            # “编码器－解码器”⾃注意⼒权重
            self._attention_weights[1][i] = blk.attention2.attention.attention_weights
        return self.dense(X), state

    @property
    def attention_weights(self):
        return self._attention_weights

## 自写学习版

In [1]:
import torch
import numpy as ny
import torch.nn as nn
import torch.nn.functional as F


# 关于word embedding, 以序列建模为例
# 考虑source_sentence和target_sentence
# 构建序列，序列的字符以其在词表中的索引的形式表示

# 定义下单词表的最大长度
max_num_src_words = 8    # 意味着最多一个句子有8个单词组成
max_num_tgt_words = 8
model_dim = 8    
# 在论文中维度是512，这个为了测试方便写小点，相当于输出模型的特征维度

src_len = torch.Tensor([2, 4]).to(torch.int32)
tgt_len = torch.Tensor([4, 3]).to(torch.int32)

# 定义序列的最大长度
max_src_seq_len = 5
max_tgt_seq_len = 5
max_pos_len = 5   
# 位置编码的长度显然要跟序列长度一致，因为实际上就是矩阵加法，维度要一致


# 单词索引构成的句子，我们需要做一个padding操作，把整序列的长度是一致的，空余地方补0
# 同时把形式构造成能够进行小批量训练的形式
src_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, max_num_src_words, (L,)), (0, max(src_len) - L)), 0)  
                            for L in src_len])
tgt_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, max_num_tgt_words, (L,)), (0, max(tgt_len) - L)), 0)  
                            for L in tgt_len])

# 构造embedding，对于词向量来说必不可少，包括在ViT模型中处理图像，个人理解相当于一个编码的过程
# 利用nn.Embedding类，相当于对句子中的每一个单词做一个编码索引，事实上就是把前面seq的值当做index
src_embedding_table = nn.Embedding(max_num_src_words+1, model_dim)   # 这里要加1是因为要把第0个位置让个padding的做索引
tgt_embedding_table = nn.Embedding(max_num_tgt_words+1, model_dim)
src_embedding = src_embedding_table(src_seq)   # 调用forward方法传入句子，得到embedding结果
tgt_embedding = tgt_embedding_table(tgt_seq)

print(src_seq)
print(src_embedding_table.weight)
print(src_embedding)    # 打印一下可以直观的发现编码规律


tensor([[7, 1, 0, 0],
        [6, 2, 1, 3]])
Parameter containing:
tensor([[-0.1941,  0.3929, -0.7884,  1.0630,  0.2310, -0.6385, -0.3659, -0.6311],
        [ 0.0708,  0.6229,  0.0506,  0.3370, -1.7553, -1.0314,  0.5341,  1.8713],
        [ 0.2468,  0.4240, -1.3750, -0.0041,  0.1440,  0.4418,  1.4496, -0.8339],
        [ 0.1772, -0.6893, -0.5072,  1.7150, -0.9951, -1.5135, -1.0859,  0.1859],
        [ 0.1687, -0.6958, -0.5035,  1.1720, -0.0776,  1.0870, -0.3335,  0.0664],
        [-1.1862, -0.9110, -1.5377,  1.8045, -0.2976,  2.1323, -0.4915,  0.7331],
        [-1.4168, -2.2570, -2.2182, -0.0821,  0.4044,  0.9650,  1.1716,  0.7415],
        [ 0.0207, -0.1584,  1.1953,  0.9947, -0.6578, -1.7685,  1.2537,  0.9954],
        [ 0.1161, -0.2006, -0.9234,  0.2654, -0.7097,  0.4820,  0.5363,  0.4558]],
       requires_grad=True)
tensor([[[ 0.0207, -0.1584,  1.1953,  0.9947, -0.6578, -1.7685,  1.2537,
           0.9954],
         [ 0.0708,  0.6229,  0.0506,  0.3370, -1.7553, -1.0314,  0.5341,
 

  from .autonotebook import tqdm as notebook_tqdm


### embedding

In [2]:
# 构造position embedding 这里用到了pytorch矩阵相乘的广播机制，具体公式看论文
# 其中公式中pos代表行，i代表列即词向量的维度  位置编码，保证模型在深度加深时不会丢失位置信息
# 实际上在ViT中也要位置编码，不管是NLP还是CV我们都要保证前后语义的关联信息是不变的

pos_mat = torch.arange(max_pos_len).reshape(-1,1)
i_mat = torch.pow(10000, torch.arange(0, model_dim, 2).reshape(1,-1) / model_dim)
pos_embedding_table = torch.zeros(max_pos_len, model_dim)
pos_embedding_table[:, 0::2] = torch.sin(pos_mat / i_mat)  # 偶数列用sin
pos_embedding_table[:, 1::2] = torch.cos(pos_mat / i_mat)  # 奇数列用cos

pe_embedding = nn.Embedding(max_pos_len, model_dim)   # 把table传入Embedding类构建层
pe_embedding.weight = nn.Parameter(pos_embedding_table, requires_grad=False)

src_pos = torch.cat([torch.unsqueeze(torch.arange(max(src_len)),0) for _ in src_len]).to(torch.int32)
tgt_pos = torch.cat([torch.unsqueeze(torch.arange(max(tgt_len)),0) for _ in tgt_len]).to(torch.int32)

src_pe_embedding = pe_embedding(src_pos)
tgt_pe_embedding = pe_embedding(tgt_pos)
print(src_pe_embedding)
print(tgt_pe_embedding)

tensor([[[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,
           1.0000e+00,  0.0000e+00,  1.0000e+00],
         [ 8.4147e-01,  5.4030e-01,  9.9833e-02,  9.9500e-01,  9.9998e-03,
           9.9995e-01,  1.0000e-03,  1.0000e+00],
         [ 9.0930e-01, -4.1615e-01,  1.9867e-01,  9.8007e-01,  1.9999e-02,
           9.9980e-01,  2.0000e-03,  1.0000e+00],
         [ 1.4112e-01, -9.8999e-01,  2.9552e-01,  9.5534e-01,  2.9995e-02,
           9.9955e-01,  3.0000e-03,  1.0000e+00]],

        [[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,
           1.0000e+00,  0.0000e+00,  1.0000e+00],
         [ 8.4147e-01,  5.4030e-01,  9.9833e-02,  9.9500e-01,  9.9998e-03,
           9.9995e-01,  1.0000e-03,  1.0000e+00],
         [ 9.0930e-01, -4.1615e-01,  1.9867e-01,  9.8007e-01,  1.9999e-02,
           9.9980e-01,  2.0000e-03,  1.0000e+00],
         [ 1.4112e-01, -9.8999e-01,  2.9552e-01,  9.5534e-01,  2.9995e-02,
           9.9955e-01,  3.0000e-03,  1.0000e+00]

### Masked

In [3]:
# 构造encoder中的self-attention mask

# 先得到有效矩阵
valid_encode_pos = torch.unsqueeze(torch.cat([torch.unsqueeze(F.pad(torch.ones(L), (0, max(src_len)-L)), 0) for L in src_len]), 2) 
# (0, max(src_len)-L))表示左侧padding0列，右侧padding max(src_len)-L)列
# 由于在注意力公式里有Q*K的转置实际上会得到一个k*k的方阵，因此我们这里先构造出一个邻接矩阵即得到要进行softmax的位置，值为1
valid_matrix = torch.bmm(valid_encode_pos, valid_encode_pos.transpose(1,2))
invalid_matrix = 1- valid_matrix     # 1减得到无效区域
invalid_matrix = invalid_matrix.to(torch.bool)   # 转成bool类型方便后面利用mask_filled API，值为True的位置即需要mask处理

score = torch.randn(2, max(src_len), max(src_len))   # 构造一个矩阵测试一下
masked_score = score.masked_fill(invalid_matrix, -1e9)  # 需要mask的位置给它赋一个负无穷的数，这样经过softmax之后就为0
prob = F.softmax(masked_score, -1)
print(masked_score)
print(prob)

tensor([[[-6.6827e-01, -6.2055e-01, -1.0000e+09, -1.0000e+09],
         [-1.6404e-01,  5.9722e-01, -1.0000e+09, -1.0000e+09],
         [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
         [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09]],

        [[-2.3849e-01,  5.4275e-01, -5.8056e-01,  3.5201e-02],
         [-8.8779e-01, -3.4283e-02,  7.3794e-01, -1.0497e+00],
         [ 6.6987e-01,  9.8241e-01, -3.8788e-01,  1.8632e+00],
         [ 1.3347e+00, -1.6914e+00, -4.0881e-01,  5.8134e-01]]])
tensor([[[0.4881, 0.5119, 0.0000, 0.0000],
         [0.3184, 0.6816, 0.0000, 0.0000],
         [0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500]],

        [[0.1920, 0.4193, 0.1364, 0.2524],
         [0.1078, 0.2530, 0.5476, 0.0916],
         [0.1663, 0.2274, 0.0578, 0.5486],
         [0.5902, 0.0286, 0.1032, 0.2779]]])


In [18]:
# 构造decoder中的self-attention mask
# 因为目标序列训练的时候是已知的，为了保证预测的合理性，所以我们要把当前预测单词之后的序列遮住
# 因此需要在attention之前做masked处理
# (0, max(tgt_len)-L, 0, max(tgt_len)-L)   分别表示上、下、左、右要不要padding
valid_decoder_tri_matrix = torch.cat([torch.unsqueeze(F.pad(torch.tril(torch.ones(L, L)), (0, max(tgt_len)-L, 0, max(tgt_len)-L)), 0) for L in tgt_len])    # 先生成上三角矩阵，很直观的符合我们需求，1表示当前预测位置，每次都只传入之前的单词
# 同时我们要把每个batch的样本pad成维度一致的形状，不然显然无法做矩阵运算
# 后面的步骤几乎和上一个masked一模一样了
invalid_decoder_tri_matrix = 1 - valid_decoder_tri_matrix
invalid_decoder_tri_matrix = invalid_decoder_tri_matrix.to(torch.bool)
print(invalid_decoder_tri_matrix)
score = torch.randn(2, max(tgt_len), max(tgt_len))
masked_score = score.masked_fill(invalid_decoder_tri_matrix, -1e9)
F.softmax(masked_score, -1)

tensor([[[False,  True,  True,  True],
         [False, False,  True,  True],
         [False, False, False,  True],
         [False, False, False, False]],

        [[False,  True,  True,  True],
         [False, False,  True,  True],
         [False, False, False,  True],
         [ True,  True,  True,  True]]])


tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
         [0.7647, 0.2353, 0.0000, 0.0000],
         [0.8152, 0.1500, 0.0348, 0.0000],
         [0.5327, 0.2886, 0.0371, 0.1417]],

        [[1.0000, 0.0000, 0.0000, 0.0000],
         [0.8665, 0.1335, 0.0000, 0.0000],
         [0.3606, 0.4226, 0.2168, 0.0000],
         [0.2500, 0.2500, 0.2500, 0.2500]]])