## 2 Transformer架构解析

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import math
import warnings
warnings.filterwarnings("ignore")

#### 2.2.1 Embedding

In [4]:
class Embeddings(nn.Module):
    def __init__(self, d_model, vocab) -> None:
        # d_model: 词嵌入维度
        # vocab: 词汇总数
        super(Embeddings, self).__init__()
        self.lut = nn.Embedding(vocab, d_model)
        self.d_model = d_model
    
    def forward(self, x):
        return self.lut(x) * math.sqrt(self.d_model)

In [5]:
d_model = 512
vocab = 1000
x = Variable(torch.LongTensor([[100, 2, 421, 508],[491, 998, 1, 221]]))
emb = Embeddings(d_model, vocab)
emb_result = emb(x)
print("Embedding Result", emb_result)
print("Embedding Result Shape", emb_result.shape)

Embedding Result tensor([[[ 2.8251e+01,  1.2709e+01, -2.3858e+01,  ...,  5.2831e+01,
           1.1292e+01, -2.4750e+00],
         [ 5.5775e+01, -3.6318e+01, -9.9836e+00,  ...,  6.6456e+00,
          -4.0601e+01, -6.6463e+00],
         [-2.7960e+01,  2.2780e+00, -5.9441e+00,  ...,  1.1343e+00,
           8.0652e+00, -2.5514e+01],
         [-1.8473e+01,  5.5878e+00,  3.0744e+01,  ...,  8.9195e-01,
          -1.6658e+01, -4.1357e+01]],

        [[-3.6581e+01,  1.6614e+01,  1.9827e+01,  ...,  1.8057e+01,
          -9.4189e+00, -6.0719e+00],
         [ 4.2156e+01,  1.1574e+01,  2.1524e+00,  ..., -3.9088e-03,
           1.4521e+01, -2.7413e+01],
         [ 3.5479e+01,  2.8121e+01,  1.6331e+00,  ..., -2.0129e+01,
           3.3834e+00,  8.4812e+00],
         [-1.1242e+01, -1.4735e+01,  4.1057e+01,  ...,  6.9943e+00,
           2.2428e+00, -2.8003e+01]]], grad_fn=<MulBackward0>)
Embedding Result Shape torch.Size([2, 4, 512])


#### 2.2.2 PositionalEncoding

In [8]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, p_dropout=0.1, max_len=5000) -> None:
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2)*(-math.log(10000.0)/d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)
        self.dropout = nn.Dropout(p=p_dropout)

    def forward(self, x):
        x = x + Variable(self.pe[:, :x.size(1)], requires_grad = False)
        return self.dropout(x)

In [9]:
d_model = 512
dropout = 0.1
max_len = 60
x = emb_result
pe = PositionalEncoding(d_model, dropout, max_len)
pe_result = pe(x)
print("PE Result", pe_result)
print("PE Result Shape", pe_result.shape)

PE Result tensor([[[ 31.3905,  15.2318, -26.5090,  ...,  59.8119,   0.0000,  -1.6388],
         [ 62.9070, -39.7534, -10.1797,  ...,   8.4951, -45.1120,  -6.2737],
         [-30.0562,   2.0687,  -5.5641,  ...,   2.3714,   8.9616, -27.2383],
         [-20.3684,   5.1087,  34.4318,  ...,   2.1022, -18.5091, -44.8415]],

        [[-40.6450,  19.5714,  22.0305,  ...,  21.1747, -10.4654,  -5.6355],
         [ 47.7747,  13.4607,   3.3047,  ...,   1.1068,  16.1348, -29.3475],
         [ 40.4311,  30.7834,   2.8550,  ..., -21.2543,   0.0000,   0.0000],
         [ -0.0000, -17.4720,  45.8916,  ...,   8.8826,   2.4923, -30.0030]]],
       grad_fn=<MulBackward0>)
PE Result Shape torch.Size([2, 4, 512])


#### 2.3.1 掩码张量

In [13]:
def subsequent_mask(size):
    attn_shape = (1, size, size)
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype("uint8")
    return torch.from_numpy(1 - subsequent_mask)

In [14]:
size = 5
sm = subsequent_mask(size)
print("SM", sm)
print("SM Shape", sm.shape)

SM tensor([[[1, 0, 0, 0, 0],
         [1, 1, 0, 0, 0],
         [1, 1, 1, 0, 0],
         [1, 1, 1, 1, 0],
         [1, 1, 1, 1, 1]]], dtype=torch.uint8)
SM Shape torch.Size([1, 5, 5])


#### 2.3.2 注意力机制

In [18]:
import torch
import torch.nn.functional as F

def attention(query, key, value, mask=None, dropout=None):
    d_model = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2,-1)/math.sqrt(d_model))
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = F.softmax(scores, dim = -1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn

- 学习小结
1. 什么是注意力计算规则

#### 2.3.3 多头注意力机制

多头注意力机制：只使用一组线性变换层，对三个变换张量Q,K,V分别进行线性变换

In [2]:
import copy  # 深度copy工具包
import torch.nn as nn

def clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

class MultiHeadedAttention(nn.Module):
    def __init__(self, head, embedding_dim, p_dropout=0.1) -> None:
        super(MultiHeadedAttention, self).__init__()
        assert embedding_dim % head == 0
        self.d_k = embedding_dim // head
        self.head = head
        self.linears = clones(nn.Linear(embedding_dim, embedding_dim), 4)
        self.attn = None
        self.dropout = nn.Dropout(p=p_dropout)

    def forward(self, query, key, value, mask=None):
        if mask is not None:
            mask = mask.unsqueeze(1)
        batch_size = query.size(0)
        query, key, value = [model(x).view(batch_size, -1, self.head, self.d_k).transpose(1,2) 
                             for model, x in zip(self.linears, (query, key, value))]
        x, self.attn = attention(query, key, value, mask, self.dropout)
        x = x.transpose(1,2).contiguous().view(batch_size, -1, self.head*self.d_k)
        return self.linears[-1](x)

In [17]:
head = 8 
embedding_dim = 512
dropout = 0.2

query = key = value = pe_result
mask = subsequent_mask(4)
mha =MultiHeadedAttention(head, embedding_dim, dropout)
mha_result = mha(query, key, value, mask)
print(mha_result)
print(mha_result.shape)

tensor([[[ -2.5778,   7.2659,  -5.3006,  ...,   5.6536,   0.3294,  -5.5139],
         [ -7.7403,  -1.4697, -10.4636,  ...,  -1.4353,  -4.3027,   3.6523],
         [ 16.5337,  -5.5156,   6.1483,  ..., -17.6196,  -3.8547,  10.5025],
         [ 16.2973,   6.1084,  -0.1795,  ...,  -2.1778,   4.2438,  -3.7886]],

        [[  0.1977,   5.6523,  -6.5110,  ...,   2.0864,  -4.6978,   5.9171],
         [  2.6465,  18.9712,   4.1097,  ...,  -2.9097,  -9.6876,  10.7208],
         [ -0.8884,  17.1208,   6.4450,  ...,   3.0732,  -6.1221,  -8.1069],
         [  2.1331,   6.3038,   0.7080,  ...,   1.0429,  -2.7919,  -6.0905]]],
       grad_fn=<ViewBackward0>)
torch.Size([2, 4, 512])
