## Transformer 实现

In [28]:
import torch
from torch import nn
import math
from d2l import torch as d2l

### 1. FNN 基于位置的前馈神经网络

In [14]:
class PositionWiseFNN(nn.Module):
  """基于位置的前馈神经网络"""
  def __init__(self, ffn_num_input, ffn_num_hidden, ffn_num_outputs, **kwargs):
    super(PositionWiseFNN, self).__init__(**kwargs)

    self.dense1 = nn.Linear(ffn_num_input, ffn_num_hidden)
    self.relu = nn.ReLU()
    self.dense2 = nn.Linear(ffn_num_hidden, ffn_num_outputs)

  def forward(self, X):
    return self.dense2(self.relu(self.dense1(X)))

In [None]:
## FFN Test case
ffn = PositionWiseFNN(4, 4, 8)
ffn.eval()
test_input = torch.ones((2,3,4))
ffn(test_input)
# ffn.dense1(test_input), ffn.relu(ffn.dense1(test_input)), ffn.dense2(ffn.relu(ffn.dense1(test_input)))

### 2. Add & Norm block

In [16]:
class AddNorm(nn.Module):
    """残差连接后进行层归一化"""
    def __init__(self, normalized_shape, dropout, **kwargs):
        super().__init__(**kwargs)

        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(normalized_shape)

    def forward(self, X, Y):
        """X means block input, Y means blocks output"""
        return self.ln(self.dropout(Y) + X)
  

In [None]:
## Test case
add_norm = AddNorm([3, 4], 0.1)
add_norm.eval()
# add_norm(torch.ones(2,3,4), torch.ones(2,3,4))
input = torch.ones(2,3,4)
add_norm.ln(add_norm.dropout(input) + input)

### 3. MultiHeadAttention

In [47]:
def sequence_mask(X, valid_lens, value=0):
    """mask a sequence with mask value"""
    # X: 2D dim(batch, feature), valid_lens 
    mask_len = X.size(1)
    mask = torch.arange((mask_len), dtype=torch.float32, device=X.device)[None, :] < valid_lens[:, None]
    X[~mask] = value
    return X

In [52]:
def masked_softmax(X, valid_lens):
  """通过在最后一个轴上掩蔽元素来执行softmax操作"""
    # X:3D张量，valid_lens:1D或2D张量
  if valid_lens is None:
      return nn.functional.softmax(X, dim=-1)
  else:
      shape = X.shape
      if valid_lens.dim() == 1:
          valid_lens = torch.repeat_interleave(valid_lens, shape[1])
      else:
          valid_lens = valid_lens.reshape(-1)
      # 最后一轴上被掩蔽的元素使用一个非常大的负值替换，从而其softmax输出为0
      X = sequence_mask(X.reshape(-1, shape[-1]), valid_lens,
                            value=-1e6)
      return nn.functional.softmax(X.reshape(shape), dim=-1)

# masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3]))