## Transformer

transformer架构

![](https://zh-v2.d2l.ai/_images/transformer.svg)

In [1]:
import sys
sys.path.append('..')
import math
import pandas as pd
import torch
from torch import nn
import d2l

In [2]:
class PositionWiseFFN(nn.Module):
  '''基于位置的前馈网络'''
  def __init__(self, ffn_num_input, ffn_num_hiddens, ffn_num_outputs):
    super().__init__()
    self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens)
    self.relu = nn.ReLU()
    self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)
  
  def forward(self, X):
    return self.dense2(self.relu(self.dense1(X)))


In [3]:
ffn = PositionWiseFFN(4, 4, 8)
ffn.eval()
ffn(torch.ones((2, 3, 4)))[0]

tensor([[-0.0956, -0.1658, -0.8312, -0.0336,  0.7260,  0.4568, -0.3645, -0.3293],
        [-0.0956, -0.1658, -0.8312, -0.0336,  0.7260,  0.4568, -0.3645, -0.3293],
        [-0.0956, -0.1658, -0.8312, -0.0336,  0.7260,  0.4568, -0.3645, -0.3293]],
       grad_fn=<SelectBackward0>)

### 残差连接和层规范化

层规范化是基于特征维度进行规范化。在自然语言处理任务中（输入通常是变长序列）层规范化的效果好

In [4]:
ln = nn.LayerNorm(2)
bn = nn.BatchNorm1d(2)
X = torch.tensor([[1, 2], [2, 3]], dtype=torch.float32)
# 在训练模式下计算X的均值和方差
print('layer notrm:', ln(X), '\nbatch norm:', bn(X))

layer notrm: tensor([[-1.0000,  1.0000],
        [-1.0000,  1.0000]], grad_fn=<NativeLayerNormBackward0>) 
batch norm: tensor([[-1.0000, -1.0000],
        [ 1.0000,  1.0000]], grad_fn=<NativeBatchNormBackward0>)


In [6]:
class AddNorm(nn.Module):
  '''残差连接后进行层规范化'''
  def __init__(self, normalized_shape, dropout):
    super().__init__()
    self.dropout = nn.Dropout(dropout)
    self.ln = nn.LayerNorm(normalized_shape)
  
  def forward(self, X, Y):
    return self.ln(self.dropout(Y) + X)

残差连接要求两个输入的形状相同，以便加法操作后输出张量的形状相同。

In [7]:
add_norm = AddNorm([3, 4], 0.5)
add_norm.eval()
add_norm(torch.ones((2, 3, 4)), torch.ones((2, 3, 4))).shape

torch.Size([2, 3, 4])

### 编码器

In [None]:
class EncoderBlock(nn.Module):
  '''Transformer编码器块'''
  def __init__(self, key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, dropout, use_bias=False):
    super().__init__()
    self.attention = d2l.MultiHeadAttention(
      key_size, query_size, value_size, num_hiddens, num_heads, dropout, use_bias
    )
    self.add