In [29]:
import numpy as np
import torch
import torch.nn as nn

In [6]:
transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
src = torch.rand((10, 32, 512))
tgt = torch.rand((20, 32, 512))
out = transformer_model(src, tgt) # 没有实现position embedding ，也需要自己实现mask机制。否则不是你想象的transformer

# 实现mask
torch.nn.Transformer的forward函数实现了mask<br/>
- Examples: output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)
    
Docstring:
Take in and process masked source/target sequences.

Args:<br/> 
    - src: the sequence to the encoder (required).
    - tgt: the sequence to the decoder (required).
    - src_mask: the additive mask for the src sequence (optional).
    - tgt_mask: the additive mask for the tgt sequence (optional).
    - memory_mask: the additive mask for the encoder output (optional).
    - src_key_padding_mask: the ByteTensor mask for src keys per batch (optional).
    - tgt_key_padding_mask: the ByteTensor mask for tgt keys per batch (optional).
    - memory_key_padding_mask: the ByteTensor mask for memory keys per batch (optional).

Shape:
    - src: :math:`(S, N, E)`.
    - tgt: :math:`(T, N, E)`.
    - src_mask: :math:`(S, S)`.
    - tgt_mask: :math:`(T, T)`.
    - memory_mask: :math:`(T, S)`.
    - src_key_padding_mask: :math:`(N, S)`.
    - tgt_key_padding_mask: :math:`(N, T)`.
    - memory_key_padding_mask: :math:`(N, S)`.

    Note: [src/tgt/memory]_mask ensures that position i is allowed to attend the unmasked
    positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
    while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
    are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
    is provided, it will be added to the attention weight. 
    [src/tgt/memory]_key_padding_mask provides specified elements in the key to be ignored by
    the attention. If a ByteTensor is provided, the non-zero positions will be ignored while the zero
    positions will be unchanged. If a BoolTensor is provided, the positions with the
    value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.

    - output: :math:`(T, N, E)`.

    # decoder的输出与输入一样
    Note: Due to the multi-head attention architecture in the transformer model,
    the output sequence length of a transformer is same as the input sequence
    (i.e. target) length of the decode.

    where S is the source sequence length, T is the target sequence length, N is the
    batch size, E is the feature number

src指encoder，tgt指decoder，memory指decoder每个block的第二层和encoder做cross attention的时候。<br/>
\*_key_padding_mask: 用来遮蔽<PAD>以避免pad token的embedding输入。<br/>
src_mask/tgt_mask/memory_mask: 附加的mask，用来避免指定位置的embedding输入。tgt_mask做生成时可以实现tgt_mask。<br/>
    
\*_mask 对应的API是attn_mask，\*_key_padding_mask对应的API是key_padding_mask(在torch.nn.modules.activation.MultiheadAttention.forward中)<br/>

def forward(self, query, key, value, key_padding_mask=None,
                need_weights=True, attn_mask=None):
        # type: (Tensor, Tensor, Tensor, Optional[Tensor], bool, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]]
        r"""
    Args:
        query, key, value: map a query and a set of key-value pairs to an output.
            See "Attention Is All You Need" for more details.
        key_padding_mask: if provided, specified padding elements in the key will
            be ignored by the attention. When given a binary mask and a value is True,
            the corresponding value on the attention layer will be ignored. When given
            a byte mask and a value is non-zero, the corresponding value on the attention
            layer will be ignored
        need_weights: output attn_output_weights.
        attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
            the batches while a 3D mask allows to specify a different mask for the entries of each batch.
    Shape:
        - Inputs:
        - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
          the embedding dimension.
        - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
          the embedding dimension.
        - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
          the embedding dimension.
        - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
          If a ByteTensor is provided, the non-zero positions will be ignored while the position
          with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
          value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
        - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
          3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
          S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
          positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
          while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
          is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
          is provided, it will be added to the attention weight.

        - Outputs:
        - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
          E is the embedding dimension.
        - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
          L is the target sequence length, S is the source sequence length.
        """

## \*_key_padding_mask

举个例子：batch_size是3，sequence_length是4.
```python
[
    [‘a’,'b','c','<PAD>'],
    [‘a’,'b','c','d'],
    [‘a’,'b','<PAD>','<PAD>']
]
```
如果你做self-attention的计算（可以是encoder，也可以是decoder),并不希望a、b、c等使用到PAD的信息，所以可以使用key_padding_mask遮住它们。<br/>
key_padding_mask的小是(batch_size, sequence_length),对应这个例子：
```python

[
    [False, False, False, True],
    [False, False, False, False],
    [False, False, True, True]
]
```
key_padding_mask本质上是遮住key这个位置的值（置0），但是PAD token本身，也是会做qkv的计算的，以第三行数据的第三个位置为例，它的q是<PAD>的embedding，k和v分别各是第一个的‘a’和第二个的‘b’，它也会输出一个embedding。softmax之后仍然有值。<font color='red'> 所以你的模型训练在transformer最后的output计算loss的时候，还需要指定ignore_index=pad_index(被pad掉的index)。 </font>以第三行数据为例，它的监督信号是[3205,1890,0,0]，pad_index=0 。如此一来，即便位于<PAD>的transformer会疯狂的和有意义的position做qkv，也会输出embedding，但是我们不算它的loss，任凭它各种作妖。

In [24]:
import torch
import torch.nn.functional as F
pred = []
pred.append([0.9, 0.1])
pred.append([0.8, 0.2])
pred = torch.Tensor(pred).view(-1,  2)

# 这里输出类别为0或1，-1表示不参与计算loss。
# 且计算平均loss的时候，reduction只计算实际参与计算的个数，这里相当于batchsize=2，但其中第index=1行为-1不参与计算loss。
label = torch.LongTensor([[1], [-1]])  

# out = F.cross_entropy(pred.view(-1, 2), label.view(-1, )) 
out = F.cross_entropy(pred.view(-1, 2), label.view(-1, ), ignore_index=-1) 
print(out)


tensor(1.1711)


## \*_mask对应attn_mask

还是上面那个例子，以第一行数据`['a','b','c','<PAD>']`,为例（假设我们在用decoder做生成，研究<font color='red'>block 的第一层layer </font>也就是self-attention），此时：
- 'a'可以看到'a'
- 'b'可以看到'a','b'
- 'c'可以看到'a','b','c'
- `'<PAD>'`理论上不应该看到什么，但是只要它头顶的监督信号是ignore_index，那就没有关系，所以让他看到'a','b','c','<PAD>'

2维的时候是（L,S），3维的时候是（N*num_heads, L, S）。此时，由于decoder第一层layer的qkv都是同一个序列，所以L=S。这是个正方形。<br/>
torrch.nn.Transformer.generate_square_subsequent_mask实现了此mask。

In [27]:
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask
generate_square_subsequent_mask(4)

tensor([[0., -inf, -inf, -inf],
        [0., 0., -inf, -inf],
        [0., 0., 0., -inf],
        [0., 0., 0., 0.]])

In [54]:
transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
src = torch.rand((10, 32, 512))
tgt = torch.rand((20, 32, 512))

# 1是<PAD>
src_key_padding_mask = torch.tensor(np.zeros((32, 10)))
src_key_padding_mask[:, -1:] =  1
src_key_padding_mask[3:7, -5:-1] = 1

tgt_mask = torch.nn.Transformer.generate_square_subsequent_mask(transformer_model, 20)

out = transformer_model(src, tgt, tgt_mask=tgt_mask, 
                        src_key_padding_mask = src_key_padding_mask) # 实现了mask

In [56]:
out.size()

torch.Size([20, 32, 512])