### 根据李沐老师的d2l的网站，写的Transformer

我还用这个Transformer去替换调包Transformer实现caption任务的那个ipynb。

后来发现这个Transformer特别容易过拟合，推理效果很差，要么就是一个单词经常出现，要么就是只会输出1-2的单词。

所以我不知道，这个模型代码写的对不对。但仍然可以作为自己实现Transformer的参考。

In [26]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import dataloader, dataset

import math
import pandas as pd

In [20]:
class PositionalEncoding(nn.Module):
    # 位置编码
    def __init__(self, num_hiddens, drouput, max_len= 100):
        super().__init__()
        self.dropout = nn.Dropout(drouput)
        self.P = torch.zeros((1, max_len, num_hiddens))
        X = torch.arange(max_len, dtype=torch.float32).reshape(-1, 1) / torch.pow(
            10000, torch.arange(0, num_hiddens, 2, dtype=torch.float32) / num_hiddens
        )
        self.P[:, :, 0::2] = torch.sin(X)
        self.P[:, :, 1::2] = torch.cos(X)

    def forward(self, X):
        #X: [batch_size, num_steps, dim]
        X = X + self.P[:, :X.shape[1], :].to(X.device)
        return self.dropout(X) # 使用dropout来避免模型对P（位置编码）太敏感

In [21]:
def sequence_mask(X, valid_lens, value=0):
    """
    在序列中屏蔽不相关的项
    """
    maxlen = X.size(1)
    mask = torch.arange((maxlen), dtype=torch.float32,
                        device=X.device)[None, :] < valid_lens[:, None]
    X[~mask] = value
    return X



def masked_softmax(X, valid_lens=None):
    """
    在某些情况下，并非所有的值都应该被纳入到注意力池化中。例如文本序列中的<PAD>。
    通过最后一个轴上掩蔽元素来执行softmax操作
    """
    if valid_lens is None:
        return F.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1]) # 如果valid_lens是标量，那么就认为所有的句子长度都是这个标量，故填充成为1维向量。
        else:
            valid_lens = valid_lens.reshape(-1)  # flatten成为1维度向量
        # 最后一个轴上被掩蔽的元素使用一个非常大的负值替换，从而其softmax的结果为0
        X = sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=1e-6)
        return F.softmax(X.reshape(shape), dim=-1)



class DotProductAttention(nn.Module):
    def __init__(self, dropout):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, valid_lens=None):
        """
        queries: [batch, num_steps, hidden_dim]
        keys: [batch, m_keys, hidden_dim]
        values: [batch, m_keys, vocab_dim]
        """
        attention_scores = torch.bmm(queries, keys.transpose(2, 1)) / math.sqrt(queries.shape[-1])
        alpha = masked_softmax(attention_scores, valid_lens) # [batch, num_steps, m_keys]
        return torch.bmm(self.dropout(alpha), values) # [batch, num_steps, vocab_dim]

In [22]:
# 了解一下这个sequence_mask函数
X = torch.tensor([[1, 2, 3], [4, 5, 6]])
sequence_mask(X, torch.tensor([1, 2]))

tensor([[1, 0, 0],
        [4, 5, 0]])

In [None]:
def transpose_qkv(X, num_heads):
    # 虽然多头注意力模型写的是每个之前都有一个全连接，但是我们可以只写一个全连接，之后分成不同的头
    # X: [batch, num_keys, num_hiddens] => [batch*num_heads, num_keys， num_hiddens/num_heads]
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
    X = X.permute(0, 2, 1, 3)
    return X.reshape(-1, X.shape[2], X.shape[3])

def transpose_output(X, num_heads):
    # X: [batch*num_heads, 查询的个数， num_hiddens/num_heads] => [batch, 查询的个数， num_hiddens]
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)


class MultiHeadAttention(nn.Module):
    def __init__(self, query_size, key_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs):
        super().__init__()
        # self.dropout = nn.Dropout(dropout)
        self.num_heads = num_heads
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias) # q, k, v首先需要先做投影，之后再做自注意力池化
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)
        self.attention = DotProductAttention(dropout)

    
    def forward(self, queries, keys, values, valid_lens):
        """
        queries, keys, values的形状 #[batch, 查询或者“键-值”的个数， num_hiddens]
        valid_lens 的形状： [batch] 或者 [batch, 查询的个数]
        经过变换之后：
        queries, keys, values的形状 #[batch*num_heads, 查询或者“键-值”的个数， num_hiddens/num_heads]
        """
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)

        if valid_lens is not None:
            valid_lens = torch.repeat_interleave(
                valid_lens, repeats=self.num_heads, dim=0
            )

        # output: [batch*num_heads, 查询的个数， num_hiddens/num_heads]
        output = self.attention(queries, keys, values, valid_lens) # head和batch合并之后，就可以直接用到一个attention里面了，就是用一个bmm就可以计算了。

        output_concat = transpose_output(output, self.num_heads)

        return self.W_o(output_concat)
    
    @property
    def attention_weights(self):
        return None  #TODO： 这里添加一下

In [51]:
num_hiddens, num_heads= 100, 5
attetion = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, 0.5)
attetion.eval()

MultiHeadAttention(
  (W_q): Linear(in_features=100, out_features=100, bias=False)
  (W_k): Linear(in_features=100, out_features=100, bias=False)
  (W_v): Linear(in_features=100, out_features=100, bias=False)
  (W_o): Linear(in_features=100, out_features=100, bias=False)
  (attention): DotProductAttention(
    (dropout): Dropout(p=0.5, inplace=False)
  )
)

In [52]:
batch_size, num_queries = 2, 4
num_kvpairs, valid_lens = 6, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
attetion(X, Y, Y, valid_lens).shape

torch.Size([2, 4, 100])

In [53]:
class PositionWiseFFN(nn.Module):
    """基于位置的前馈网络"""
    def __init__(self, ffn_num_input, ffn_num_hiddens, ffn_num_outputs, **kwargs):
        super().__init__()
        self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens)
        self.relu = nn.ReLU()
        self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)

    def forward(self, X):
        # X: [batch, num_steps, ffn_num_input] => [batch, num_steps, ffn_num_outputs]
        return self.dense2(self.relu(self.dense1(X)))
    

ffn = PositionWiseFFN(4, 4, 8)
ffn.eval()
ffn(torch.ones((2, 3, 4))).shape


torch.Size([2, 3, 8])

In [54]:
ln = nn.LayerNorm(2)
bn = nn.BatchNorm1d(2)
X = torch.tensor([[1, 2], [2, 3]], dtype=torch.float32)

print('layer norm:', ln(X), '\nbatch norm:', bn(X)) # 可以发现layer norm是在同一个样本上的norm使其均值为0，方差为1， batch norm是多个样本之间同一个数据维度上的norm

layer norm: tensor([[-1.0000,  1.0000],
        [-1.0000,  1.0000]], grad_fn=<NativeLayerNormBackward0>) 
batch norm: tensor([[-1.0000, -1.0000],
        [ 1.0000,  1.0000]], grad_fn=<NativeBatchNormBackward0>)


In [55]:
class AddNorm(nn.Module):
    def __init__(self, normalize_shape, dropout):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(normalize_shape)

    def forward(self, X, Y):
        return self.ln(self.dropout(Y) + X)
    
add_norm = AddNorm([3, 4], 0.5)
add_norm.eval()
add_norm(torch.ones((2, 3, 4)), torch.ones((2, 3, 4))).shape
        

torch.Size([2, 3, 4])

## Encoder

In [56]:
class EncoderBlock(nn.Module):
    def __init__(self, query_size, key_size, value_size, num_hiddens, normalize_shape,
                 ffn_num_input, ffn_num_hiddens, num_heads, dropout, bias=False, **kwargs):
        super().__init__()

        self.attention = MultiHeadAttention(query_size, key_size, value_size, num_hiddens, num_heads, dropout, bias=bias)
        self.addnorm1 = AddNorm(normalize_shape, dropout)
        self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens) # feed forward network 的输出dim也是num_hiddens，这有利于残差相加
        self.addnorm2 = AddNorm(normalize_shape, dropout)

    def forward(self, X, valid_lens):
        Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))
        return self.addnorm2(Y, self.ffn(Y))
    

X = torch.ones((2, 100, 24))
valid_lens = torch.tensor([3, 2])
encoder_blk = EncoderBlock(24, 24, 24, 24, [100, 24], 24, 48, 8, 0.5)
encoder_blk.eval()
encoder_blk(X, valid_lens).shape

torch.Size([2, 100, 24])

In [None]:
class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, query_size, key_size, value_size, num_hiddens, normalize_shape,
                 ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, bias=False, **kwargs):
        super().__init__()
        self.num_hiddens = num_hiddens
        self.embedding = nn.Embedding(vocab_size, num_hiddens)
        self.pos_encoding = PositionalEncoding(num_hiddens, dropout)
        self.blks = nn.Sequential()
        for i in range(num_layers):
            self.blks.add_module("block"+str(i),
                                 EncoderBlock(query_size, key_size, value_size, num_hiddens, normalize_shape,
                                              ffn_num_input, ffn_num_hiddens, num_heads, dropout, bias))
            
    def forward(self, X, valid_lens, *args):
        # 因为位置编码值在-1和1之间
        # 因此嵌入值乘以嵌入维度的平方根进行缩放
        # 然后再与位置编码相加
        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens)) # embedding层dim越大，每个值就越小，所以这里对postion encoding进行了缩放去匹配embedding的大小。
        self.attention_weights = [None] * len(self.blks)
        for i, blk in enumerate(self.blks):
            X = blk(X, valid_lens)
            self.attention_weights[i] = blk.attention.attention_weights
        return X
    

encoder = TransformerEncoder(200, 24, 24, 24, 24, [100, 24], 24, 48, 8, 2, 0.5)
encoder.eval()
encoder(torch.ones((2, 100), dtype=torch.long), valid_lens).shape

torch.Size([2, 100, 24])

## Decoder

In [59]:
class DecoderBlock(nn.Module):
    def __init__(self, query_size, key_size, value_size, num_hiddens, normalize_shape,
                 ffn_num_input, ffn_num_hiddens, num_heads, dropout, i, bias=False, **kwargs):
        super().__init__()
        self.i = i #解码器中第i个块
        self.masked_attention = MultiHeadAttention(query_size, key_size, value_size, num_hiddens, num_heads, dropout, bias)
        self.addnorm1 = AddNorm(normalize_shape, dropout)
        self.attention = MultiHeadAttention(query_size, key_size, value_size, num_hiddens, num_heads, dropout, bias)
        self.addnorm2 = AddNorm(normalize_shape, dropout)
        self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens)
        self.addnorm3 = AddNorm(normalize_shape, dropout)

    def forward(self, X, state):
        enc_outputs, enc_valid_lens = state[0], state[1]
        # 训练阶段，输出序列的所有词元都在同一时间处理
        # 因此state[2][self.i]初始化为None
        # 预测阶段，输出序列是通过词元一个接着一个解码的
        # 因此state[2][self.i]包含着直到当前时间不第i个块解码的输出表示。
        if state[2][self.i] is None:
            key_values = X # 训练
        else:
            key_values = torch.cat((state[2][self.i], X), axis=1)
        state[2][self.i] = key_values
        if self.training:
            batch_size, num_steps, _ = X.shape
            dec_valid_lens = torch.arange(1, num_steps+1, device=X.device).repeat(batch_size, 1)
        else:
            dec_valid_lens = None
        
        Y = self.addnorm1(X, self.masked_attention(X, key_values, key_values, dec_valid_lens))
        Z = self.addnorm2(Y, self.attention(Y, enc_outputs, enc_outputs, enc_valid_lens))
        return self.addnorm3(Z, self.ffn(Z)), state
    

decoder_blk = DecoderBlock(24, 24, 24, 24, [100, 24], 24, 48, 8, 0.5, 0, bias=False)
decoder_blk.eval()
X = torch.ones((2, 100, 24))
state = [encoder_blk(X, valid_lens), valid_lens, [None]]
decoder_blk(X, state)[0].shape

torch.Size([2, 100, 24])

In [60]:
class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, query_size, key_size, value_size, num_hiddens, normalize_shape,
                 ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, bias=False, **kwargs):
        super().__init__()
        self.num_hiddens = num_hiddens
        self.num_layers = num_layers
        self.embedding = nn.Embedding(vocab_size, num_hiddens)
        self.pos_encoding = PositionalEncoding(num_hiddens, dropout)
        self.blks = nn.Sequential()
        for i in range(num_layers):
            self.blks.add_module("block"+str(i),
                                 DecoderBlock(query_size, key_size, value_size, num_hiddens, normalize_shape, 
                                              ffn_num_input, ffn_num_hiddens, num_heads, dropout, i, bias))
        self.dense = nn.Linear(num_hiddens, vocab_size)

    def init_state(self, enc_outputs, enc_valid_lens, *args):
        return [enc_outputs, enc_valid_lens, [None] * self.num_layers]

    def forward(self, X, state):
        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
        self._attention_weights = [[None] * len(self.blks) for _ in range(2)]
        for i, blk in enumerate(self.blks):
            X, state = blk(X, state)
            # 解码器自注意力权重
            self._attention_weights[0][i] = blk.masked_attention.attention.attention_weights
            # "编码器-解码器"自注意力权重
            self._attention_weights[1][i] = blk.attention.attention.attention_weights

        return self.dense(X), state
    
    @property
    def attetion_weights(self):
        return self._attention_weights
