In [1]:
# multi-head attention
import torch
from torch import nn
import torch.functional as F
import math
from torch import Tensor

In [2]:
X = torch.randn(128, 64, 512) 
# Batch Time（Sequence Length） Dimension(Embedding Dimention)(编码后的维度，想embeding映射到多少维度)
print(X.shape)

torch.Size([128, 64, 512])


In [3]:
d_model = 512 # 映射到QKV的多少维度
n_head = 8 # 头数

In [4]:
'''
class A:
    def method(self):
        print("A.method")

class B(A):
    def method(self):
        print("B.method")
        super().method()  # 调用A.method

class C(B):
    def method(self):
        print("C.method")
        super(B, self).method()  # 直接跳过B，调用A.method！
        # 而 super().method() 会调用 B.method

这个为什么会跳过B？
这个跳过B的原因是 super(B, self) 明确指定了从B之后开始查找MRO链。
MRO（方法解析顺序）C -> B -> A
1. super().method() 在C类中
class C(B):
    def method(self):
        print("C.method")
        super().method()  # 等价于 super(C, self).method()

从C之后开始查找MRO链：[B, A]

找到第一个有method的类：B

所以调用 B.method()

2. super(B, self).method() 在C类中
class C(B):
    def method(self):
        print("C.method")
        super(B, self).method()  # 明确指定跳过B
        
查找过程：

从B之后开始查找MRO链：[A]

直接跳过B

找到下一个有method的类：A

所以调用 A.method()
'''

'\nclass A:\n    def method(self):\n        print("A.method")\n\nclass B(A):\n    def method(self):\n        print("B.method")\n        super().method()  # 调用A.method\n\nclass C(B):\n    def method(self):\n        print("C.method")\n        super(B, self).method()  # 直接跳过B，调用A.method！\n        # 而 super().method() 会调用 B.method\n\n这个为什么会跳过B？\n这个跳过B的原因是 super(B, self) 明确指定了从B之后开始查找MRO链。\nMRO（方法解析顺序）C -> B -> A\n1. super().method() 在C类中\nclass C(B):\n    def method(self):\n        print("C.method")\n        super().method()  # 等价于 super(C, self).method()\n\n从C之后开始查找MRO链：[B, A]\n\n找到第一个有method的类：B\n\n所以调用 B.method()\n\n2. super(B, self).method() 在C类中\nclass C(B):\n    def method(self):\n        print("C.method")\n        super(B, self).method()  # 明确指定跳过B\n        \n查找过程：\n\n从B之后开始查找MRO链：[A]\n\n直接跳过B\n\n找到下一个有method的类：A\n\n所以调用 A.method()\n'

In [5]:
class MultiHearAttention(nn.Module):
    def __init__(self, d_model, n_head):
        super().__init__()
        self.n_head = n_head
        self.d_model = d_model
        # 初始化QKV用于映射向量
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        # 因为是muti-head所以最后要做一个组合映射
        self.w_combine = nn.Linear(d_model, d_model)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, q: torch.Tensor, k: Tensor, v: Tensor, mask = None):
        batch, time, dimension = q.shape
        n_d = self.d_model // self.n_head # 整除，相当于每个头被分到的注意力的维度
        q, k, v = self.w_q(q),self.w_k(k), self.w_v(v)
        # 原始q的形状可能是：(batch, time, d_model)，其中 d_model = n_head * n_d
        # 理解q1，k1是行向量
        q = q.view(batch, time, self.n_head, n_d).permute(0, 2, 1, 3) # head_维不能放在最后两维，因为要用于处理
        k = k.view(batch, time, self.n_head, n_d).permute(0, 2, 1, 3)
        v = v.view(batch, time, self.n_head, n_d).permute(0, 2, 1, 3)

        score = q @ k.transpose(2, 3) / math.sqrt(n_d) # 除以是为了方差更小好归一化
        # 控制方差：点积的方差与 $d_k$ 成正比，除以 $\sqrt{d_k}$ 使方差保持为1，稳定softmax：
        # 防止softmax进入梯度极小的饱和区，改善梯度流：使训练更加稳定，防止梯度消失，
        # 保持注意力分布合理：避免注意力过于集中在少数位置（变成近乎one-hot）
        if mask is not None:
            # mask = torch.tril(torch.ones(time, time, dtype=bool)) # 生成下三角矩阵 这里开始采用传进来的mask
            score = score.masked_fill(mask == 0, float("-inf")) # masked 注意力分数
        
        score_weight: Tensor = self.softmax(score) # softmax 自动对最后一维做，这里刚好需要列做softmax
        output = score_weight @ v
        # 把n_head和time再逆转回去
        output = output.permute(0, 2, 1, 3).contiguous().view(batch, time, dimension)
        # 连续内存（contiguous）
        # view() 要求张量在内存中是连续的 因为它只是改变张量的"视图"，不复制数据
        # reshape()可以替代该功能

        output = self.w_combine(output)
        return output
    
attention = MultiHearAttention(d_model, n_head)
output = attention(X, X, X)
output

tensor([[[ 1.1749e-02, -3.0374e-02,  4.8668e-02,  ...,  4.9120e-02,
          -7.1959e-02,  6.0836e-02],
         [ 1.3038e-02, -1.5902e-03, -3.2350e-04,  ...,  3.4419e-02,
          -8.1254e-02,  8.3808e-02],
         [ 1.8813e-02, -1.3788e-02, -6.8165e-03,  ...,  4.0832e-02,
          -8.4661e-02,  8.4138e-02],
         ...,
         [-1.4644e-02, -1.3101e-02,  8.0912e-03,  ...,  5.2369e-02,
          -1.2297e-01,  6.4357e-02],
         [ 1.6878e-02,  4.8485e-03,  1.2748e-02,  ...,  3.9039e-02,
          -9.2022e-02,  5.5626e-02],
         [ 1.7730e-03,  2.4077e-02,  2.6180e-02,  ...,  3.8056e-02,
          -6.2393e-02,  8.5839e-02]],

        [[ 6.2945e-02,  9.3349e-02, -4.4470e-02,  ...,  5.7741e-02,
          -6.5188e-02,  1.2842e-03],
         [ 2.9000e-02,  5.9636e-02, -3.2421e-02,  ...,  2.1193e-02,
          -4.1963e-02, -1.4982e-02],
         [ 1.9922e-02,  8.0316e-02, -5.2340e-02,  ...,  4.8672e-02,
          -3.3632e-02, -1.3564e-02],
         ...,
         [ 5.6347e-02,  6

In [6]:
# Token and position embedding

class TokenEmbedding(nn.Embedding):
    def __init__(self, vocab_size, d_model):
        super().__init__(vocab_size, d_model, padding_idx=1)
        

In [7]:
# Position Embedding

class PositionEmbedding(nn.Module):
        def __init__(self, d_model, maxlen, device):
            super().__init__()
            self.encoding = torch.zeros(maxlen, d_model, device)
            # 因为这个编码不需要梯度
            self.encoding.requires_grad(False)

            pos = torch.arange(0, maxlen, device=device)
            pos = pos.float().unsqueeze(1) # 增加一个维度
            # 生成2i的序列
            _2i = torch.arange(0, d_model, step=2, device=device)
            # self.encoding[:, start:stop:step]
            self.encoding[:, 0::2] = torch.sin(pos / (10000 ** (_2i / d_model))) 
            self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / d_model)))
            # 相当是对位置从0-> Maxlen的位置进行编码，每个位置都是当前pos=i的一个512维的向量

        def forward(self, x: Tensor):
            seq_len = x.shape[1]
            return self.encoding[:seq_len] # (seq_len, d_ model)

In [8]:
# LayerNorm (图像一般是BatchNorm)
# 最核心的是可以减小显存的用量，比batch要的少所以可以减少显存用量
# 相当于在做归一化的操作，需要参数
class LayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-10):
        super().__init__()
        # 在做归一化的操作，需要参数，下列操作为默认
        # 为什么需要gamma和beta？如果不加这两个参数，归一化会破坏网络学到的特征表示
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.eps = eps

    def forward(self, x: Tensor):
        mean = x.mean(-1, keepdim=True) # 是对最后一个维度 变成（batch_size, Sequence_len, d_model->1）
        var = x.var(-1, unbiased=False, keepdim=True)
        out = (x - mean) / torch.sqrt(var + self.eps)
        # 所以相当于对d_model维度做归一化
        out = self.gamma * out + self.beta
        return out
    
# 归一化是必须的，但完全标准化的分布不一定是最优的。所以可以学习这个分布调整方式

In [9]:
# FFN Relu(xW1+b1)W2+b2
class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, hidden, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(d_model, hidden)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: Tensor):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

In [10]:
# Total Embedding
class TransformerEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model, max_len, drop_prob, device):
        super().__init__()
        self.tok_emb = TokenEmbedding(vocab_size, d_model)
        self.pos_emb = PositionEmbedding(d_model, max_len, device)
        self.dropout = nn.Dropout(drop_prob)

    def forward(self, x):
        pos_emb = self.pos_emb(x)
        tok_emb = self.tok_emb(x)
        return self.dropout(pos_emb+tok_emb)

In [11]:
# Encoder Layer
class EncoderLayer(nn.Module):
    def __init__(self, d_model, ffn_hidden, n_head, drop_prob):
        super().__init__()
        self.attention = MultiHearAttention(d_model, n_head)
        self.norm1 = LayerNorm(d_model)
        self.dropout1 = nn.Dropout(drop_prob)

        self.ffn = PositionWiseFeedForward(d_model, ffn_hidden, drop_prob)
        self.norm2 = LayerNorm(d_model)
        self.dropout2 = nn.Dropout(drop_prob)

    def forward(self, x, mask = None):
        _x = x
        x = self.attention(x, x, x, mask)

        x = self.dropout1(x)
        x = self.norm1(x + _x)

        _x = x
        x = self.ffn(x)
        x = self.dropout2(x)
        x = self.norm2(x + _x)
        return x

In [12]:
# Decoder
# decoder layer vs encoder layer
# 带掩码的attention
# cross-attention
# encoder提供的是key, value, decoder提供querry

In [13]:
class DecoderLayer(nn.Module):
    def __init__(self, d_module, ffn_hidden, n_head, drop_prob):
        super().__init__()
        self.attention = MultiHearAttention(d_model, n_head)
        self.norm1 = LayerNorm(d_model)
        self.dropout1 = nn.Dropout(drop_prob)

        self.cross_attention = MultiHearAttention(d_model, n_head)
        self.norm2 = LayerNorm(d_model)
        self.dropout2 = nn.Dropout(drop_prob)

        self.ffn = PositionWiseFeedForward(d_model, ffn_hidden, drop_prob)
        self.norm3 = LayerNorm(d_model)
        self.dropout3 = nn.Dropout(drop_prob)

    def forward(self, dec, enc, t_mask, s_mask):
        # 两个掩码，一个是padding的掩码，一个是对未来信息的掩码
        # t_mask下三角掩码, t_mask因果关系的掩码
        # s_mask未知的掩码，不需要关注padding的信息
        _x = dec
        x = self.attention(dec, dec, dec, t_mask) #下三角掩码, t_mask因果关系的掩码

        x = self.dropout1(x)
        x = self.norm1(x + _x)

        if enc is not None:
            _x = x
            x = self.cross_attention(x, enc, enc, s_mask)
            x = self.dropout2(x)
            x = self.norm2(x + _x)

        _x = x
        x = self.ffn(x)

        x = self.dropout3(x)
        x = self.norm3(x + _x)
        return x

In [14]:
class Encoder(nn.Module):
    def __init__(self, env_voc_size, max_len, d_model, ffn_hidden, n_head, n_layer, drop_prob, device):
        super().__init__()
        self.embedding = TransformerEmbedding(env_voc_size, d_model, max_len, drop_prob, device)
        self.layers = nn.ModuleList(
            [EncoderLayer(d_model, ffn_hidden, n_head, drop_prob) for _ in range(n_layer)]
        )

    def forward(self, x, s_mask):
        x = self.embedding(x)
        for layer in self.layers:
            x = layer(x, s_mask)
        return x

In [15]:
class Decoder(nn.Module):
    def __init__(self, dec_voc_size, max_len, d_model, ffn_hidden, n_head, n_layer, drop_prob, device):
        super().__init__()
        self.embedding = TransformerEmbedding(dec_voc_size, d_model, max_len, drop_prob, device)
        self.layers = nn.ModuleList(
            [DecoderLayer(d_model, ffn_hidden, n_head, drop_prob) for _ in range(n_layer)]
        )
        self.fc = nn.Linear(d_model, dec_voc_size)

    def forward(self, dec, enc, t_mask, s_mask):
        dec = self.embedding(dec)
        for layer in self.layers:
            dec = layer(dec, enc, t_mask, s_mask)
        dec = self.fc(dec)
        return dec

In [16]:
class Transformer(nn.Module):
    def __init__(self, src_pad_idx, trg_pad_idx, env_voc_size, dec_voc_size, max_len, d_model, n_heads, ffn_hidden, n_layers, drop_prob, device):
        # 两个pad是输入的pad和decoder的pad的标示符的记录
        super().__init__()

        self.encoder = Encoder(env_voc_size, max_len, d_model, ffn_hidden, n_heads, n_layers, drop_prob, device)
        self.decoder = Decoder(dec_voc_size, max_len, d_model, ffn_hidden, n_heads, n_layers, drop_prob, device)

        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device

    def make_pad_mask(self, q: Tensor, k: Tensor, pad_idx_q, pad_idx_k):
        len_q, len_k = q.shape[1], k.shape[1]
        # (Batch, Time, len_q, len_k) ne, not equal, 不等于pading符时为true
        # q 应该是形状为 (batch_size, seq_len_q) 的序列索引, 而不是
        # 这是序列索引经过嵌入层转换后的连续向量表示，每个整数索引被映射为一个d_model维的向量
        # 形状是 (batch_size, sequence_length, d_model)

        # (Batch, Time, len_q, len_k)
        q = q.ne(pad_idx_q).unsqueeze(1).unsqueeze(3) # 填充到四维
        q = q.repeat(1, 1, 1, len_k) #

        k = k.ne(pad_idx_k).unsqueeze(1).unsqueeze(2)
        k = k.repeat(1, 1, len_q, 1) #

        mask = q & k
        return mask


    def make_casual_mask(self, q: Tensor, k: Tensor):
        len_q, len_k = q.shape[1], k.shape[1]
        mask = torch.tril(torch.ones(len_q, len_k)).type(torch.bool).to(self.device)
        return mask

    def forward(self, src, trg):
        src_mask = self.make_pad_mask(src, src, self.src_pad_idx, self.src_pad_idx)
        # decoder自己的因果mask
        trg_mask = self.make_pad_mask(trg, trg, self.trg_pad_idx, self.trg_pad_idx) * self.make_casual_mask(trg, trg)
        src_trg_mask = self.make_pad_mask(trg, src, self.trg_pad_idx, self.src_pad_idx)

        enc = self.encoder(src, src_mask)
        output = self.decoder(trg, enc, trg_mask, src_trg_mask)
        return output
