In [2]:
from typing import Tuple
import torch

In [18]:
class MultiHeadAttention(torch.nn.Module):
    def __init__(
            self,
            d_in: int,
            d_out: int,
            max_num_tokens: int,
            num_heads: int,
            dropout_rate: float,
            with_bias: bool = False,
            with_mask: bool = False,
    ):
        super().__init__()
        self.d_in = d_in
        self.d_out = d_out
        self.max_num_tokens = max_num_tokens
        self.num_heads = num_heads
        self.dropout_rate = dropout_rate
        self.with_bias = with_bias
        self.with_mask = with_mask

        if d_out % num_heads != 0:
            raise ValueError(f"d_out必须可以被num_heads整除")
        self.d_head = d_out // num_heads

        self.wq = None
        self.wk = None
        self.wv = None
        # self.mask = None
        self.dropout = None
        self.out_proj = None
        self._init_parameters()

    def _init_parameters(self):
        d_in, d_out, with_bias = self.d_in, self.d_out, self.with_bias
        self.wq = torch.nn.Linear(in_features=d_in, out_features=d_out, bias=with_bias)
        self.wk = torch.nn.Linear(in_features=d_in, out_features=d_out, bias=with_bias)
        self.wv = torch.nn.Linear(in_features=d_in, out_features=d_out, bias=with_bias)

        block_size = self.max_num_tokens
        if self.with_mask:
            mask = torch.triu(torch.ones(block_size, block_size), diagonal=1).bool()
            self.register_buffer(name="mask", tensor=mask)  # 保存模型时，也会同时保存掩码

        self.dropout = torch.nn.Dropout(p=self.dropout_rate)

        self.out_proj = torch.nn.Linear(d_out, d_out)

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        batch_size, num_tokens, d_in = X.shape  # 这里需要获取实际输入数据集的上下文长度
        assert num_tokens <= self.max_num_tokens, f"输入序列长度 {num_tokens} 超过最大允许长度 {self.max_num_tokens}"
        assert d_in == self.d_in, "输入维度不正确"

        Q, K, V = self._compute_qkv(X)
        Q, K, V = self._reshape_qkv(Q, K, V, num_tokens)
        Q, K, V = self._transpose_qkv(Q, K, V)

        attention_scores = self._compute_attention_scores(Q, K)
        if self.with_mask:
            attention_scores = self._mask_attention_scores(attention_scores, num_tokens)

        attention_weights = self._compute_attention_weights(attention_scores)
        attention_weights = self.dropout(attention_weights)

        contexts = self._compute_contexts(attention_weights, V)
        contexts = self._transpose_contexts(contexts)
        contexts = self._reshape_contexts(contexts, num_tokens)
        
        Y = self.out_proj(contexts)
        return Y


    def _compute_qkv(self, X: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """ 
        X.shape:        (batch_size, num_tokens, d_in)
        Q, K, V.shape:  (batch_size, num_tokens, d_out)
        """
        Q, K, V = self.wq(X), self.wk(X), self.wv(X)
        return Q, K, V
    
    def _reshape_qkv(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, num_tokens: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """ 
        这里的num_tokens与self.num_tokens是不相同的，
        这里的num_tokens指的是输入数据的token数量，
        而self.num_tokens是初始化时设定的最大token数。
        """
        batch_size = Q.shape[0]
        """
        Q, K, V.shape:  (batch_size, num_tokens, d_out) -> (batch_size, num_tokens, num_heads, d_head)
        """
        Q = Q.reshape(batch_size, num_tokens, self.num_heads, self.d_head)
        K = K.reshape(batch_size, num_tokens, self.num_heads, self.d_head)
        V = V.reshape(batch_size, num_tokens, self.num_heads, self.d_head)
        return Q, K, V

    def _transpose_qkv(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """ 
        Q, K, V.shape:  (batch_size, num_tokens, num_heads, d_head) -> (batch_size, num_heads, num_tokens, d_head)
        """
        Q = Q.transpose(1, 2).contiguous()
        K = K.transpose(1, 2).contiguous()
        V = V.transpose(1, 2).contiguous()
        return Q, K, V
    
    def _compute_attention_scores(self, Q: torch.Tensor, K: torch.Tensor) -> torch.Tensor:
        """  
        attention_scores.shape: (batch_size, num_heads, num_tokens, num_tokens)
        """
        attention_scores = Q @ K.transpose(2, 3)
        return attention_scores
    
    def _mask_attention_scores(self, attention_scores: torch.Tensor, num_tokens: int) -> torch.Tensor:
        mask = self.mask[:num_tokens, :num_tokens]
        mask = self.mask.reshape(1, 1, num_tokens, num_tokens)  # 实际输入数据的token数量未必就是初始化时的token数
        attention_scores.masked_fill_(mask, -torch.inf)
        return attention_scores

    def _compute_attention_weights(self, attention_scores: torch.Tensor) -> torch.Tensor:
        """  
        attention_weights.shape: (batch_size, num_heads, num_tokens, num_tokens)
        """
        attention_weights = torch.softmax(attention_scores / self.d_head ** 0.5, dim=-1)
        return attention_weights
    
    def _compute_contexts(self, attention_weights: torch.Tensor, V: torch.Tensor) -> torch.Tensor:
        """  
        contexts.shape: (batch_size, num_heads, num_tokens, d_head)
        """
        contexts = attention_weights @ V
        return contexts
    
    def _transpose_contexts(self, contexts: torch.Tensor) -> torch.Tensor:
        """
        contexts.shape: (batch_size, num_heads, num_tokens, d_head) -> (batch_size, num_tokens, num_heads, d_head)
        """
        contexts = contexts.transpose(1, 2).contiguous()
        return contexts
    
    def _reshape_contexts(self, contexts: torch.Tensor, num_tokens: int) -> torch.Tensor:
        """
        contexts.shape: (batch_size, num_tokens, num_heads, d_head) -> (batch_size, num_tokens, d_out)
        """
        batch_size = contexts.shape[0]
        contexts = contexts.reshape(batch_size, num_tokens, self.d_out)
        return contexts


In [19]:
row = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)
X = torch.stack((row, row), dim=0)
print(f"X' shape is: {X.shape}")
print(f"and C is: \n{X}")

X' shape is: torch.Size([2, 6, 3])
and C is: 
tensor([[[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]],

        [[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]]])


In [20]:
torch.manual_seed(123)
mha = MultiHeadAttention(
    d_in=3,
    d_out=2,
    max_num_tokens=6,
    num_heads=2,
    with_bias=False,
    with_mask=True,
    dropout_rate=0.0
)
Y = mha(X)
print(f"Y's shape is: {Y.shape}")
print(f"Y is: \n{Y}")

Y's shape is: torch.Size([2, 6, 2])
Y is: 
tensor([[[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]],

        [[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]]], grad_fn=<ViewBackward0>)


In [21]:
expected = torch.tensor(
  [[[0.3190, 0.4858],
    [0.2943, 0.3897],
    [0.2856, 0.3593],
    [0.2693, 0.3873],
    [0.2639, 0.3928],
    [0.2575, 0.4028]],

   [[0.3190, 0.4858],
    [0.2943, 0.3897],
    [0.2856, 0.3593],
    [0.2693, 0.3873],
    [0.2639, 0.3928],
    [0.2575, 0.4028]]]
)
assert torch.allclose(Y, expected, atol=1e-4)