In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
from d2l import torch as d2l

### 1. 获取掩码
- 在encoder和encoder-decoder的self-attention中，掩码和序列长度有关
- 在decoder_layer中的第一层self-attention中，掩码是下三角矩阵
- decoder输出掩码

#### 1.1 encoder和encoder-decoder的self-attention

In [68]:
# test
# encoder/encoder-decoder
torch.manual_seed(42)
"""
输入矩阵：
    batch_size * max_seq_len * dim
    batch_size=2, max_seq_len=4, dim=8
有效长度矩阵：
    batch_size *
    batch_size=2
"""
X = torch.rand((2, 4, 8))
valid_len = torch.randint(1, 5, (2, ))
print(X, X.shape)
print(valid_len, valid_len.shape)
print()

max_len = X.shape[1]
# 获取长度矩阵
print(torch.arange(max_len)[None, :])
# 对valid_len扩容
print(valid_len.repeat_interleave(max_len)[:, None])
mask = torch.arange(max_len)[None, :] < valid_len.repeat_interleave(max_len)[:, None]
print(mask)
print()

# test
Q, K, V = X, X, X
score = torch.bmm(Q, K.permute(0, 2, 1)).reshape(-1, score.shape[-1])
print(score, score.shape)

print(score.masked_fill_(~mask, -1e6))
F.softmax(score, -1)

tensor([[[0.8823, 0.9150, 0.3829, 0.9593, 0.3904, 0.6009, 0.2566, 0.7936],
         [0.9408, 0.1332, 0.9346, 0.5936, 0.8694, 0.5677, 0.7411, 0.4294],
         [0.8854, 0.5739, 0.2666, 0.6274, 0.2696, 0.4414, 0.2969, 0.8317],
         [0.1053, 0.2695, 0.3588, 0.1994, 0.5472, 0.0062, 0.9516, 0.0753]],

        [[0.8860, 0.5832, 0.3376, 0.8090, 0.5779, 0.9040, 0.5547, 0.3423],
         [0.6343, 0.3644, 0.7104, 0.9464, 0.7890, 0.2814, 0.7886, 0.5895],
         [0.7539, 0.1952, 0.0050, 0.3068, 0.1165, 0.9103, 0.6440, 0.7071],
         [0.6581, 0.4913, 0.8913, 0.1447, 0.5315, 0.1587, 0.6542, 0.3278]]]) torch.Size([2, 4, 8])
tensor([3, 2]) torch.Size([2])

tensor([[0, 1, 2, 3]])
tensor([[3],
        [3],
        [3],
        [3],
        [2],
        [2],
        [2],
        [2]])
tensor([[ True,  True,  True, False],
        [ True,  True,  True, False],
        [ True,  True,  True, False],
        [ True,  True,  True, False],
        [ True,  True, False, False],
        [ True,  True, F

tensor([[0.5236, 0.2350, 0.2413, 0.0000],
        [0.2534, 0.5926, 0.1541, 0.0000],
        [0.4537, 0.2687, 0.2775, 0.0000],
        [0.2740, 0.5073, 0.2187, 0.0000],
        [0.5842, 0.4158, 0.0000, 0.0000],
        [0.3829, 0.6171, 0.0000, 0.0000],
        [0.5999, 0.4001, 0.0000, 0.0000],
        [0.4191, 0.5809, 0.0000, 0.0000]])

#### 1.2 decoder_layer中的第一层self-attention
例如序列长度为3, 最大长度为4，下三角矩阵:
$$
\begin{pmatrix}
1 & 0 & 0 & 0 \\
1 & 1 & 0 & 0 \\
1 & 1 & 1 & 0 \\
1 & 1 & 1 & 1
\end{pmatrix}
$$
因为序列长度为3，实际的mask为:
$$
\begin{pmatrix}
1 & 0 & 0 & 0 \\
1 & 1 & 0 & 0 \\
1 & 1 & 1 & 0 \\
0 & 0 & 0 & 0
\end{pmatrix}
$$
但这里其实没有必要，最后计算loss的时候会有一个loss掩码消除，后续会说明

In [72]:
# test
torch.manual_seed(46)
"""
输入矩阵：
    batch_size * max_seq_len * dim
    batch_size=2, max_seq_len=4, dim=8
有效长度矩阵：
    batch_size *
    batch_size=2
"""
X = torch.rand((2, 4, 8))
valid_len = torch.randint(1, 5, (2, ))
print(X, X.shape)
print(valid_len, valid_len.shape)
print()

batch_size = X.shape[0]
max_len = X.shape[1]
dec_valid_len = torch.arange(1, max_len + 1).repeat(batch_size)
print(dec_valid_len)
mask = torch.arange(max_len)[None, :] < dec_valid_len[:, None]
print(mask)
print()

# test
Q, K, V = X, X, X
score = torch.bmm(Q, K.permute(0, 2, 1)).reshape(-1, score.shape[-1])
# print(score, score.shape)

print(score.masked_fill_(~mask, -1e6))
F.softmax(score, -1)

tensor([[[0.6611, 0.0600, 0.5174, 0.1596, 0.7550, 0.8390, 0.0674, 0.4631],
         [0.1477, 0.3597, 0.9328, 0.0170, 0.9736, 0.4108, 0.8620, 0.8799],
         [0.6569, 0.8152, 0.4810, 0.7388, 0.0312, 0.7049, 0.7364, 0.1079],
         [0.1455, 0.2633, 0.9035, 0.6618, 0.9728, 0.9471, 0.8585, 0.9694]],

        [[0.6430, 0.4919, 0.3397, 0.7519, 0.0770, 0.1563, 0.7086, 0.5063],
         [0.2131, 0.3311, 0.7764, 0.2493, 0.1992, 0.9874, 0.2860, 0.0898],
         [0.1783, 0.0602, 0.5747, 0.9875, 0.1572, 0.1534, 0.7301, 0.7916],
         [0.6019, 0.7746, 0.4704, 0.7769, 0.8160, 0.4427, 0.1632, 0.6080]]]) torch.Size([2, 4, 8])
tensor([3, 4]) torch.Size([2])

tensor([1, 2, 3, 4, 1, 2, 3, 4])
tensor([[ True, False, False, False],
        [ True,  True, False, False],
        [ True,  True,  True, False],
        [ True,  True,  True,  True],
        [ True, False, False, False],
        [ True,  True, False, False],
        [ True,  True,  True, False],
        [ True,  True,  True,  True]])

ten

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.1816, 0.8184, 0.0000, 0.0000],
        [0.1588, 0.2223, 0.6189, 0.0000],
        [0.0728, 0.2366, 0.0691, 0.6214],
        [1.0000, 0.0000, 0.0000, 0.0000],
        [0.3196, 0.6804, 0.0000, 0.0000],
        [0.3217, 0.1412, 0.5370, 0.0000],
        [0.1909, 0.1251, 0.1764, 0.5076]])

#### decoder输出掩码

In [95]:
# test
torch.manual_seed(46)
"""
输入矩阵：
    batch_size * max_seq_len 
    batch_size=2, max_seq_len=4, dim=8
有效长度矩阵：
    batch_size *
    batch_size=2
"""
y_hat = torch.rand((2, 4, 8))
y = torch.randint(0, 4, (2, 4))
out_valid_len = torch.randint(1, 5, (2, ))
print(y_hat, y_hat.shape)
print(y, y.shape)
print(out_valid_len, out_valid_len.shape)
print()

loss = nn.CrossEntropyLoss(reduction="none")
l = loss(y_hat.permute(0, 2, 1), y)
print(l)
print()

max_len = y_hat.shape[1]
mask = torch.arange(max_len)[None, :] < out_valid_len[:, None]
print(mask)
mask_one = torch.ones_like(y)
mask_one[~mask] = 0
l = l * mask_one
print(l, l.mean())

tensor([[[0.6611, 0.0600, 0.5174, 0.1596, 0.7550, 0.8390, 0.0674, 0.4631],
         [0.1477, 0.3597, 0.9328, 0.0170, 0.9736, 0.4108, 0.8620, 0.8799],
         [0.6569, 0.8152, 0.4810, 0.7388, 0.0312, 0.7049, 0.7364, 0.1079],
         [0.1455, 0.2633, 0.9035, 0.6618, 0.9728, 0.9471, 0.8585, 0.9694]],

        [[0.6430, 0.4919, 0.3397, 0.7519, 0.0770, 0.1563, 0.7086, 0.5063],
         [0.2131, 0.3311, 0.7764, 0.2493, 0.1992, 0.9874, 0.2860, 0.0898],
         [0.1783, 0.0602, 0.5747, 0.9875, 0.1572, 0.1534, 0.7301, 0.7916],
         [0.6019, 0.7746, 0.4704, 0.7769, 0.8160, 0.4427, 0.1632, 0.6080]]]) torch.Size([2, 4, 8])
tensor([[2, 3, 1, 0],
        [2, 3, 2, 1]]) torch.Size([2, 4])
tensor([2, 4]) torch.Size([2])

tensor([[2.0437, 2.6967, 1.8350, 2.6926],
        [2.2255, 2.2697, 2.0158, 1.9066]])

tensor([[ True,  True, False, False],
        [ True,  True,  True,  True]])
tensor([[2.0437, 2.6967, 0.0000, 0.0000],
        [2.2255, 2.2697, 2.0158, 1.9066]]) tensor(1.6447)


#### 1.4 上述掩码整合并测试