In [None]:
import torch
from torch import nn

The multi query attention component went under several iterations, starting with an additional head class which was discarded for a simpler "one class" implementation with increased performance

In [None]:
class MultiQueryAttention(nn.Module):
    '''
    Multi Query Attention Layer as described in the paper "Fast Transformer Decoding: One Write-Head is All You Need" (Shazeer, 2019)

    In Multi Query Attention, the keys and values are shared across attention heads, reducing the memory required at inference time at the cost of a small decrease in performance as compared to multi-head attention.
    The aim being to reduce the memory bandwidth requirements.

    https://arxiv.org/abs/1911.02150

    '''
    def __init__(self, model_dimension: int, n_heads: int, dropout: float, mask: bool = True):
        super().__init__()
        self.head_dimension = model_dimension // n_heads
        self.mask = mask
        self.keys = nn.Linear(model_dimension, self.head_dimension, bias=False)
        self.values = nn.Linear(model_dimension, self.head_dimension, bias=False)
        self.queries = nn.ModuleList([nn.Linear(model_dimension, self.head_dimension, bias=False) for _ in range(n_heads)])
        self.linear = nn.Linear(model_dimension, model_dimension, bias=False)
        self.dropout_p = dropout
        self.r_dropout = nn.Dropout(dropout)

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        K = self.keys(X)
        V = self.values(X)
        heads = [F.scaled_dot_product_attention(query(X), K, V, dropout_p=self.dropout_p, is_causal=self.mask) for query in self.queries]
        concat = torch.cat(heads, dim=-1)
        linear = self.linear(concat)
        return self.r_dropout(linear)

I was unhappy with the performance so I went back to the books and discovered batching and other techniques which improved performance significantly at the cost of readability.

In [None]:
class MultiQueryAttention(nn.Module):
    '''
    Multi Query Attention Layer as described in the paper "Fast Transformer Decoding: One Write-Head is All You Need" (Shazeer, 2019)

    In Multi Query Attention, the keys and values are shared across attention heads, reducing the memory required at inference time at the cost of a small decrease in performance as compared to multi-head attention.
    The aim being to reduce the memory bandwidth requirements.

    https://arxiv.org/abs/1911.02150

    '''
    def __init__(self, model_dimension: int, n_heads: int, dropout: float, mask: bool = True):
        super().__init__()
        self.head_dimension = model_dimension // n_heads
        self.n_heads = n_heads
        self.mask = mask
        self.queries = nn.Linear(model_dimension, model_dimension, bias=False)
        self.kv_projection = nn.Linear(model_dimension, self.head_dimension * 2, bias=False)
        self.linear = nn.Linear(model_dimension, model_dimension, bias=False)
        self.dropout_p = dropout
        self.r_dropout = nn.Dropout(dropout)

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        batch_len, seq_len, embd_dim = X.shape
        Q = self.queries(X).view(batch_len, seq_len, self.n_heads, self.head_dimension).transpose(1, 2) # (batch, n_heads, seq_len, head_dimension)

        # Notes on KV projection:
        # - The key and value projections were initially seperated but here we combine them into a single projection to reduce the number of linear layers.
        # - As a result we need to split the output of the projection into two tensors, one for keys and one for values.
        # - We also need to match the shape of the queries tensor by adding a head dimension.
        # - Expanding a tensor does not allocate new memory, but only creates a new view on the existing tensor where a dimension of size
        K, V = self.kv_projection(X).unsqueeze(1).expand(-1, self.n_heads, -1, -1).split(self.head_dimension, dim=-1)

        heads = F.scaled_dot_product_attention(Q, K, V, dropout_p=self.dropout_p, is_causal=self.mask)

        # Notes on concat:
        # - With this method we also need a new way of concating the heads back together
        # - instead of concating the heads as the previous version we instead reshape the heads tensor to match the shape of the input tensor
        # - view requires the new shape to be contiguous, so we call contiguous before calling view
        concat = heads.transpose(1, 2).contiguous().view(batch_len, seq_len, embd_dim)
        
        linear = self.linear(concat)
        return self.r_dropout(linear)