In [21]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [22]:
d_model = 16
d_head = 4
content_length = 16
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# 1 模型定义

## 1.1 FFN

In [23]:
class ffn(nn.Module):
    def __init__(self, d_model, dropout=0.1):
        super(ffn, self).__init__()
        self.linear1 = nn.Linear(d_model, 4 * d_model)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(4 * d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.linear1(x) # (batch, seq_len, d_model) -> (batch, seq_len, 4 * d_model)
        x = self.relu(x)    # (batch, seq_len, 4 * d_model)
        x = self.linear2(x) # (batch, seq_len, 4 * d_model) -> (batch, seq_len, d_model)
        x = self.dropout(x) # (batch, seq_len, d_model)
        return x

## 1.2 Attention

In [24]:
class Attention(nn.Module):
    def __init__(self, d_model, d_head, dropout=0.1):
        super(Attention, self).__init__()
        self.d_model = d_model
        self.d_head = d_head
        self.wq = nn.Linear(d_model, d_head)
        self.wk = nn.Linear(d_model, d_head)
        self.wv = nn.Linear(d_model, d_head)
        self.register_buffer('mask', torch.tril(torch.ones(content_length, content_length)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, q, k, v):
        B, T, D = k.shape
        q, k, v = self.wq(q), self.wk(k), self.wv(v)
        output = (q @ k.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.d_head))
        output = output.masked_fill(self.mask[:T, :T] == 0, float('-inf'))
        output = F.softmax(output, dim=-1)
        # output = self.dropout(output)
        output = output @ v
        return output

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_head, dropout = 0.1):
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.n_head = n_head
        self.d_head = d_model // n_head
        self.attns = nn.ModuleList([Attention(self.d_model, self.d_head, dropout) for _ in range(n_head)])
        self.linear = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, q, k, v):
        x = torch.cat([attn(q, k, v) for attn in self.attns], dim=-1)
        x = self.linear(x)
        x = self.dropout(x)
        return x


In [25]:
x = torch.randn(1, content_length, d_model)
attn = MultiHeadAttention(d_model, d_head)
attn(x, x, x)

tensor([[[ 0.0000,  0.5958,  0.1640, -0.0561, -0.0000, -0.2798,  1.1937,
          -0.1650, -0.3092, -1.2333, -0.1858, -0.1201,  0.1896, -0.1828,
          -0.5437,  0.2806],
         [ 0.0411,  0.2474, -0.1662, -0.0161, -0.1543, -0.0492,  0.4396,
           0.0109,  0.1336, -0.6234, -0.1889, -0.3460,  0.2228,  0.1669,
          -0.0702, -0.1018],
         [-0.0071,  0.1386, -0.1260, -0.0817, -0.0188, -0.0161,  0.2137,
          -0.0269,  0.3236, -0.4561, -0.0000, -0.5035,  0.2920,  0.1162,
           0.1887, -0.3010],
         [ 0.0302,  0.1596,  0.0136, -0.1083, -0.0831, -0.0985,  0.2624,
           0.0679,  0.3374, -0.5200, -0.0913, -0.3921,  0.5072,  0.1286,
           0.1298, -0.2699],
         [-0.0924,  0.2128, -0.1262,  0.0000, -0.1432,  0.0271,  0.3496,
           0.0000,  0.2732, -0.5393, -0.0162, -0.3434,  0.4345,  0.1667,
          -0.0000, -0.0000],
         [-0.0613,  0.0000, -0.0601,  0.1259, -0.1188, -0.0490,  0.0000,
           0.0394,  0.2111, -0.4838, -0.0226, -0.239