In [14]:
## 导包和初始化
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
from IPython.display import Image
# default: 100
mpl.rcParams['figure.dpi'] = 150
torch.manual_seed(42)

---
# 1. 编码器

In [None]:
## 初始化参数
d_model = 4  # 模型维度
nhead = 2    # 多头注意力中的头数
dim_feedforward = 8  # 前馈网络的维度
batch_size = 1
seq_len = 3

## 保证多头可以整除
assert d_model % nhead == 0

In [15]:
## 配置输入
encoder_input = torch.randn(seq_len, batch_size, d_model) 
print(encoder_input)
print(encoder_input.shape)

tensor([[[ 0.3367,  0.1288,  0.2345,  0.2303]],

        [[-1.1229, -0.1863,  2.2082, -0.6380]],

        [[ 0.4617,  0.2674,  0.5349,  0.8094]]])
torch.Size([3, 1, 4])


In [16]:
## 利用已封装的编码器类实现推理
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead,
                                           dim_feedforward=dim_feedforward, dropout=0.0)
memory = encoder_layer(encoder_input)  # 编码器输出

print(memory)
print(memory.shape)

tensor([[[-1.0328, -0.9185,  0.6710,  1.2804]],

        [[-1.4175, -0.1948,  1.3775,  0.2347]],

        [[-1.0022, -0.8035,  0.3029,  1.5028]]],
       grad_fn=<NativeLayerNormBackward0>)
torch.Size([3, 1, 4])


In [17]:
## 手写编码器——1.实现多头注意力机制
### 1.1. 展平输入
X = encoder_input
X_flat = X.contiguous().view(-1, d_model)  # [T * B, d_model] -> [3, 4]
X_flat.shape

### 1.2. 获取权重
self_attn = encoder_layer.self_attn
W_in = self_attn.in_proj_weight
b_in = self_attn.in_proj_bias
W_out = self_attn.out_proj.weight
b_out = self_attn.out_proj.bias

### 1.3. 计算Q、K、V
QKV = F.linear(X_flat, W_in, b_in)  # [3, 3*d_model]
Q, K, V = QKV.split(d_model, dim=1)  # 每个维度为[3, d_model], 实际就是[TxB，d_model]

### 1.4. 分开为多头注意力
head_dim = d_model // nhead  # 每个头的维度
def reshape_for_heads(x):
    return x.contiguous().view(seq_len, batch_size, nhead, head_dim).permute(1, 2, 0, 3).reshape(batch_size * nhead, seq_len, head_dim)
Q = reshape_for_heads(Q)
K = reshape_for_heads(K)
V = reshape_for_heads(V)

### 1.5. 计算注意力分数
scores = torch.bmm(Q, K.transpose(1, 2)) / (head_dim ** 0.5)  # [batch_size * nhead, seq_len, seq_len]
attn_weights = F.softmax(scores, dim=-1)  # [batch_size * nhead, seq_len, seq_len]  应用softmax

### 1.6. 计算注意力输值
attn_output = torch.bmm(attn_weights, V)  # [batch_size * nhead, seq_len, head_dim]

### 1.7. 合并多头
attn_output = attn_output.view(batch_size, nhead, seq_len, head_dim).permute(2, 0, 1, 3).contiguous()
attn_output = attn_output.view(seq_len, batch_size, d_model)  # [seq_len, batch_size, d_model]

### 1.8. 输出投影
attn_output = F.linear(attn_output.view(-1, d_model), W_out, b_out)  # [seq_len * batch_size, d_model]
attn_output = attn_output.view(seq_len, batch_size, d_model)

In [18]:
## 手写编码器——2.残差连接和层归一化
### 2.1. 获取层归一化权重和偏置
norm1 = encoder_layer.norm1
print(norm1.weight,norm1.bias)

### 2.2. 残差连接
residual = X + attn_output  # [seq_len, batch_size, d_model]
normalized = F.layer_norm(residual, (d_model,), weight=norm1.weight, bias=norm1.bias)  # [seq_len, batch_size, d_model]
print(normalized) # 实际上就是最后一维，某个batch和序列的某个token的特征向量归一化了

Parameter containing:
tensor([1., 1., 1., 1.], requires_grad=True) Parameter containing:
tensor([0., 0., 0., 0.], requires_grad=True)
tensor([[[-0.9493, -1.0434,  1.1045,  0.8881]],

        [[-1.0025, -0.1531,  1.6511, -0.4955]],

        [[-1.0129, -0.9286,  0.6342,  1.3073]]],
       grad_fn=<NativeLayerNormBackward0>)


In [19]:
## 手写编码器——3.实现前馈网络
### 3.1. 获取前馈网络权重和偏置
W_1 = encoder_layer.linear1.weight
b_1 = encoder_layer.linear1.bias
W_2 = encoder_layer.linear2.weight
b_2 = encoder_layer.linear2.bias

### 3.2. 第一层线性变换
ffn_output = F.linear(normalized.view(-1, d_model), W_1, b_1)  # [seq_len * batch_size, dim_feedforward]
ffn_output = F.relu(ffn_output)  # [seq_len * batch_size, dim_feedforward]

### 3.3. 第二层线性变换
ffn_output = F.linear(ffn_output, W_2, b_2)  # [seq_len * batch_size, d_model]
ffn_output = ffn_output.view(seq_len, batch_size, d_model)  # [seq_len, batch_size, d_model]

In [20]:
## 手写编码器——4.残差连接和层归一化（第二层）
### 4.1. 获取层归一化权重和偏置
norm2 = encoder_layer.norm2
print(norm2.weight,norm2.bias)

### 4.2. 残差连接和层归一化
residual2 = normalized + ffn_output  # [seq_len, batch_size, d_model]
normalized2 = F.layer_norm(residual2, (d_model,), weight=norm2.weight, bias=norm2.bias)  # [seq_len, batch_size, d_model]

Parameter containing:
tensor([1., 1., 1., 1.], requires_grad=True) Parameter containing:
tensor([0., 0., 0., 0.], requires_grad=True)


In [22]:
## 打印最终输出：结果一致
### 手写编码器输出
print(normalized2)
### 封装编码器输出
print(memory)

tensor([[[-1.0328, -0.9185,  0.6710,  1.2804]],

        [[-1.4175, -0.1948,  1.3775,  0.2347]],

        [[-1.0022, -0.8035,  0.3029,  1.5028]]],
       grad_fn=<NativeLayerNormBackward0>)
tensor([[[-1.0328, -0.9185,  0.6710,  1.2804]],

        [[-1.4175, -0.1948,  1.3775,  0.2347]],

        [[-1.0022, -0.8035,  0.3029,  1.5028]]],
       grad_fn=<NativeLayerNormBackward0>)


---
# 2. 解码器

In [24]:
## 初始化参数
# 定义参数
d_model = 4  # 模型维度
nhead = 2    # 多头注意力中的头数
dim_feedforward = 8  # 前馈网络的维度
batch_size = 1

src_seq_len = 3
trg_seq_len = 5

In [25]:
# 构造输入
# [T, B, d]
encoder_input = torch.randn(src_seq_len, batch_size, d_model)  
decoder_input = torch.randn(trg_seq_len, batch_size, d_model)

In [27]:
## 利用已封装的解码器类实现推理
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, 
                                           nhead=nhead,
                                           dim_feedforward=dim_feedforward, 
                                           dropout=0.0)
decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, 
                                           nhead=nhead,
                                           dim_feedforward=dim_feedforward, 
                                           dropout=0.0)
memory = encoder_layer(encoder_input)  # 编码器输出
output = decoder_layer(decoder_input, memory)  # 解码器输出
print(output)
print(encoder_input.shape, memory.shape, output.shape)

tensor([[[-0.1402, -0.7298, -0.8038,  1.6738]],

        [[-0.7154,  1.6404, -0.8922, -0.0327]],

        [[ 0.0890,  0.6641, -1.6547,  0.9016]],

        [[ 0.2811, -1.0920,  1.5008, -0.6900]],

        [[-0.4675,  1.5222, -1.2012,  0.1465]]],
       grad_fn=<NativeLayerNormBackward0>)
torch.Size([3, 1, 4]) torch.Size([3, 1, 4]) torch.Size([5, 1, 4])


In [None]:
## 手写解码器——1.实现多头注意力机制
