# Vanilla MultiHeadAttention

In [1]:
import math
from typing import Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F

In [2]:
n_head = 6
n_embed = 48
block_size = 8
batch_size = 4
bias = False
dropout = 0.2

x = torch.tensor(np.random.rand(batch_size, block_size, n_embed), dtype=torch.float32)

## Learn Level 3: No Flash Attention

In [12]:
class MultiHeadAttentionL3(nn.Module):
    def __init__(self, block_size, n_head, n_embed, bias, dropout):
        super().__init__()
        assert n_embed % n_head == 0

        self.block_size = block_size
        self.n_head = n_head
        self.n_embed = n_embed
        self.bias = bias
        self.dropout = dropout

        self.c_attn = nn.Linear(n_embed, 3 * n_embed, bias=bias)
        self.c_proj = nn.Linear(n_embed, n_embed, bias=bias)
        self.attn_dropout = nn.Dropout(dropout)
        self.residual_dropout = nn.Dropout(dropout)

        ltm = torch.tril(torch.ones(self.block_size, self.block_size)).view(
            1, 1, self.block_size, self.block_size
        )
        self.register_buffer("causal_mask", ltm)

    def forward(self, x):
        B, T, C = x.size()
        assert C == self.n_embed

        q, k, v = self.c_attn(x).split(
            self.n_embed, dim=2
        )  # B, T, C @ C, 3 x C ; O(BTC^2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        attn = (q @ k.transpose(-2, -1)) * (
            1 / math.sqrt(k.size(-1))
        )  # B, nh, T, hs @ B, nh, hs, T; O(BnhhsT^2)
        attn = attn.masked_fill(self.causal_mask == 0, float("-inf"))
        attn = F.softmax(attn, dim=-1)
        attn = self.attn_dropout(attn)

        y = attn @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.residual_dropout(self.c_proj(y))

        return y

In [15]:
yl3 = MultiHeadAttentionL3(
    block_size=block_size, n_embed=n_embed, n_head=n_head, bias=bias, dropout=dropout
)(x)
yl3.shape

torch.Size([4, 8, 48])

## Learn Level 3: Flash Attention

In [48]:
class MultiHeadAttentionL3Flash(nn.Module):
    def __init__(self, block_size, n_head, n_embed, dropout, bias):
        super().__init__()
        assert n_embed % n_head == 0

        self.block_size = block_size
        self.n_head = n_head
        self.n_embed = n_embed
        self.dropout = dropout
        self.bias = bias

        self.c_attn = nn.Linear(n_embed, 3 * n_embed, bias=bias)
        self.c_proj = nn.Linear(n_embed, n_embed, bias=bias)
        self.residual_dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.size()
        assert C == self.n_embed

        q, k, v = self.c_attn(x).split(self.n_embed, dim=2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        attn = F.scaled_dot_product_attention(
            q,
            k,
            v,
            attn_mask=None,
            dropout_p=self.dropout if self.training else 0,
            is_causal=True,
            scale=1.0 / math.sqrt(k.size(-1)),
        )
        y = attn @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.residual_dropout(self.c_proj(y))

        return y

In [49]:
yl3f = MultiHeadAttentionL3Flash(
    block_size=block_size, n_embed=n_embed, n_head=n_head, bias=bias, dropout=dropout
)(x)
yl3f.shape

torch.Size([4, 8, 48])

## Learn Lever 4: Multi Head Attention

In [8]:
class MultiHeadAttentionL4(nn.Module):
    def __init__(
        self,
        block_size: int,
        n_embed: int,
        n_heads: int,
        dropout: float,
        bias: bool
    ):
        super().__init__()
        assert n_embed % n_heads == 0

        self.n_embed = n_embed
        self.hs = n_embed // n_heads
        self.n_heads = n_heads

        self.c_attn = nn.Linear(n_embed, 3 * n_embed, bias=bias)
        self.c_proj = nn.Linear(n_embed, n_embed, bias=bias)
        self.dropout_attn = nn.Dropout(dropout)
        self.dropout_proj = nn.Dropout(dropout)

        ltm = torch.tril(torch.ones(block_size, block_size)).view(
            1, 1, block_size, block_size
        )
        self.register_buffer("causal_mask", ltm)

    def forward(self, x: torch.Tensor):
        B, T, C = x.size()
        assert C == self.n_embed

        qkv = self.c_attn(x)
        q, k, v = qkv.split(self.n_embed, dim=-1)

        q = q.view(B, T, self.n_heads, self.hs).transpose(1, 2)
        k = k.view(B, T, self.n_heads, self.hs).transpose(1, 2)
        v = v.view(B, T, self.n_heads, self.hs).transpose(1, 2)

        attn = (q @ k.transpose(-2, -1)) * (1 / math.sqrt(k.size(-1)))
        attn = attn.masked_fill(self.causal_mask == 0, float("-inf"))
        attn = F.softmax(attn, dim=-1)
        attn = self.dropout_attn(attn)

        y = attn @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.dropout_proj(self.c_proj(y))

        return y

mha_l4 =  MultiHeadAttentionL4(
    block_size=block_size, n_embed=n_embed, n_heads=n_head, bias=bias, dropout=dropout
)(x)
mha_l4.shape

torch.Size([4, 8, 48])