# 手写 Transformer

In [1]:
import torch
from torch import nn
import torch.functional as F

import math

In [2]:
X = torch.randn(128, 64, 512) # [batch, time, dimension]

In [3]:
print(X.shape)

torch.Size([128, 64, 512])


## 多头注意力机制

In [4]:
d_model = 512 # q,k,v 权重dim 
n_head = 8 # 注意力头

In [5]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_head) -> None:
        super(MultiHeadAttention, self).__init__()

        self.n_head = n_head
        self.d_model = d_model
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_combine = nn.Linear(d_model, d_model)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, q, k, v, mask = None):
        batch, time, dimension = q.shape
        n_d = self.d_model // self.n_head
        
        q = self.w_q(q)
        k = self.w_k(k) 
        v = self.w_v(v)

        q = q.view(batch, time, self.n_head, n_d).permute(0, 2, 1, 3)
        k = k.view(batch, time, self.n_head, n_d).permute(0, 2, 1, 3)
        v = v.view(batch, time, self.n_head, n_d).permute(0, 2, 1, 3)
        
        score = q @ k.transpose(2,3) / math.sqrt(n_d)
        if mask is not None:
            mask = torch.tril(torch.ones(time, time, dtype=bool))
            score = score.masked_fill(mask == 0, float("-inf")) # 做 softmax 操作时，e 的负五无穷等于 0
        score = self.softmax(score) @ v

        score = score.permute(0, 2, 1, 3).contiguous().view(batch, time, dimension)

        output = self.w_combine(score)
        return output

In [6]:
attention = MultiHeadAttention(d_model, n_head)
output = attention(X,X,X, mask = 1)
print(output)
print(output.shape)

tensor([[[-3.2379e-02, -5.5683e-01,  1.2646e-02,  ..., -1.0508e-01,
           3.6439e-04, -2.5211e-01],
         [ 1.6632e-01, -2.8211e-01, -3.3325e-02,  ...,  5.9400e-02,
          -7.0265e-02, -3.6296e-01],
         [ 3.0871e-01, -3.2886e-01, -4.4546e-02,  ...,  2.4793e-01,
           8.0088e-02, -6.1949e-02],
         ...,
         [ 8.5811e-02,  8.9556e-02,  2.0718e-02,  ..., -6.6410e-02,
          -3.5111e-03, -2.6802e-02],
         [ 9.7252e-02,  6.8795e-02,  1.9015e-02,  ..., -4.4383e-02,
          -9.6786e-03, -2.2903e-02],
         [ 1.0361e-01,  8.1382e-02,  1.6885e-02,  ..., -6.1121e-02,
           1.3742e-02, -3.8663e-02]],

        [[ 2.4396e-01,  3.7242e-01,  3.1582e-01,  ...,  3.5661e-01,
           1.6890e-01,  6.2801e-01],
         [ 4.7332e-01,  1.0259e-02,  2.5907e-01,  ...,  2.6994e-01,
           2.8141e-01,  7.4819e-02],
         [ 5.9286e-01, -1.8706e-01,  1.7782e-01,  ...,  5.6798e-02,
           1.1468e-01,  1.8903e-02],
         ...,
         [ 3.0853e-02,  1