In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [2]:
d_model = 16
d_head = 4
content_length = 16
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# 1 模型定义

## 1.1 FFN

In [3]:
class ffn(nn.Module):
    def __init__(self, d_model, dropout=0.1):
        super(ffn, self).__init__()
        self.linear1 = nn.Linear(d_model, 4 * d_model)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(4 * d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.linear1(x) # (batch, seq_len, d_model) -> (batch, seq_len, 4 * d_model)
        x = self.relu(x)    # (batch, seq_len, 4 * d_model)
        x = self.linear2(x) # (batch, seq_len, 4 * d_model) -> (batch, seq_len, d_model)
        x = self.dropout(x) # (batch, seq_len, d_model)
        return x

## 1.2 Attention

In [4]:
class Attention(nn.Module):
    def __init__(self, d_model, d_head, dropout=0.1):
        super(Attention, self).__init__()
        self.d_model = d_model
        self.d_head = d_head
        # 这里使用的另一种高效的实现，不是计算后再切分，而是直接进过线性层变换维度为d_head，最后直接拼接
        self.wq = nn.Linear(d_model, d_head)
        self.wk = nn.Linear(d_model, d_head)
        self.wv = nn.Linear(d_model, d_head)
        self.register_buffer('mask', torch.tril(torch.ones(content_length, content_length)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, q, k, v):
        B, T, D = k.shape
        q, k, v = self.wq(q), self.wk(k), self.wv(v)
        output = (q @ k.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.d_head))
        output = output.masked_fill(self.mask[:T, :T] == 0, float('-inf'))
        output = F.softmax(output, dim=-1)
        # output = self.dropout(output)
        output = output @ v
        return output

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d_model, n_head, dropout = 0.1):
        super(MultiHeadSelfAttention, self).__init__()
        self.d_model = d_model
        self.n_head = n_head
        self.d_head = d_model // n_head
        self.attns = nn.ModuleList([Attention(self.d_model, self.d_head, dropout) for _ in range(n_head)])
        self.linear = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = torch.cat([attn(x, x, x) for attn in self.attns], dim=-1)
        x = self.linear(x)
        x = self.dropout(x)
        return x


In [5]:
x = torch.randn(1, content_length, d_model)
attn = MultiHeadSelfAttention(d_model, d_head)
attn(x)

tensor([[[-0.4184,  0.3011,  0.2741, -0.3853,  0.0799,  0.1975,  0.0000,
          -0.5492, -0.4010, -0.0000,  0.0000,  0.2386, -0.4940,  0.4856,
          -0.1477,  0.0871],
         [-0.2510, -0.0513,  0.0573, -0.0000,  0.4182,  0.0616,  0.1208,
          -0.6447, -0.4309, -0.1921,  0.2582,  0.1586, -0.0000,  0.1088,
          -0.0524,  0.1730],
         [-0.2748,  0.0510, -0.3411, -0.3255,  0.0705,  0.1340,  0.0234,
          -0.2968, -0.3339, -0.0000, -0.1511, -0.0755, -0.0826,  0.0950,
          -0.0640,  0.0668],
         [-0.2100, -0.0920, -0.2085, -0.3110,  0.0000,  0.2229, -0.0199,
          -0.3194, -0.0000, -0.1728, -0.0101, -0.0200, -0.0992,  0.2268,
          -0.0783,  0.0793],
         [-0.2838, -0.0073, -0.1365, -0.2856,  0.3233,  0.2559,  0.0061,
          -0.4334, -0.3212, -0.1337,  0.0867, -0.1098, -0.2197,  0.1594,
          -0.0000,  0.0928],
         [-0.3385,  0.0000,  0.1276, -0.1334,  0.3864,  0.2349,  0.0918,
          -0.4840, -0.2200, -0.2542,  0.0410, -0.117