In [2]:
import torch

  cpu = _conversion_method_template(device=torch.device("cpu"))


In [3]:
# Multi-head attention mechanism
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        '''
        Args: 
            d_model(int) : total dim
            num_heads (int): nums of attn heads
        '''
        super().__init__()
        assert d_model % num_heads == 0 # total dim 必须能被 nums_head 整除

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads # 每个头的维度

        # define Q, K, V Linear Network
        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        self.W_O = nn.Linear(d_model, d_model)

    '''
    causal mask
    mask = torch.tensor([[1, 0, 0],
                     [1, 1, 0], 
                     [1, 1, 1]])
    '''

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.size() # x: [batch_size, seq_len, d_model]

        # 1. Linear Projection
        Q = self.W_Q(x)
        K = self.W_K(x)
        V = self.W_V(x)

        # 2. split to multi-heads
        # (batch, seq_len, d_model) -> (batch, seq_len, num_heads, d_k) -> (batch, num_heads, seq_len, d_k)
        Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)

        # 3. 计算缩放点积注意力分数
        '''
        Q: (batch, num_heads, seq_len, d_k)
        K.transpose: (batch, num_heads, d_k, seq_len)
        scores: (batch, num_heads, seq_len, seq_len)
        '''
        scores = torch.matmul(Q, K.transpose(-1, -2)) / math.sqrt(self.d_k)

        # 4. apply mask
        if mask is not None:
            # mask: [seq_len, seq_len], forecasting to (batch, num_heads, seq_len, seq_len)
            scores = scores.masked_fill(mask == 0, float('-inf'))

        # 5. softmax
        attn_scores = F.softmax(scores, dim = -1)

        # 6. 加权求和
        output = torch.matmul(attn_scores, V)

        # 7. 拼接多头并且Linear Projections
        # (batch, num_heads, seq_len, d_k) -> (batch, seq_len, num_heads, d_k) -> (batch, seq_len, d_model)
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model) # contiguous: 让 transpose 后的 tensor 重新以连续方式组织内存
        output = self.W_O(output)

        return output


In [5]:
# test forward
d_model = 512
num_heads = 8
batch_size = 2
seq_len = 10
x = torch.randn(batch_size, seq_len, d_model)
mask = torch.tril(torch.ones(seq_len, seq_len))  # 下三角矩

output = MultiHeadAttention(d_model, num_heads)(x, mask)

print(output.shape)  # Expected: (2, 10, 512)

torch.Size([2, 10, 512])


In [6]:
# Multi-Query Attention Mechanism
'''
与 MHA 的区别：
有 num_heads 个 Q 头，只有一个 K 和 V 头
'''
class MultiQueryAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0

        self.num_heads = num_heads
        self.d_model = d_model
        self.d_k = d_model // num_heads

        self.W_Q = nn.Linear(self.d_model, self.d_model)

        # output dim of W_K and W_V is d_k
        self.W_K = nn.Linear(self.d_model, self.d_k)
        self.W_V = nn.Linear(self.d_model, self.d_k)

        self.W_O = nn.Linear(self.d_model, self.d_model)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.size()

        # 1. Linear Proj
        Q = self.W_Q(x)
        K = self.W_K(x)
        V = self.W_V(x)

        Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)

        # K V 只有一个头
        K = K.view(batch_size, seq_len, 1, self.d_k).transpose(1, 2)
        V = V.view(batch_size, seq_len, 1, self.d_k).transpose(1, 2)

        # forecast to (batch, num_heads, seq_len, d_k)
        scores = torch.matmul(Q, K.transpose(-1, -2)) / math.sqrt(self.d_k)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attn_scores = F.softmax(scores, dim=-1)
        
        output = torch.matmul(attn_scores, V)

        # output: [batch, num_heads, seq_len, d_k] -> [batch, seq_len, num_heads, d_k] -> [batch, seq_len, n_model]
        output = output.transpose(1, 2).contiguous.view(batch_size, seq_len, self.n_model)
        output = self.W_O(output)

        return output