In [1]:
import torch
from torch.nn import functional as F
import torch.nn as nn

# Input shape: (batch_size, seq_len, hidden_dim) or B x T x C
# Output shape: (batch_size, seq_len, hidden_dim) or B x T x C

device = "cuda" if torch.cuda.is_available() else "cpu"


In [2]:
class Config:
    context_length=512
    embedding_dim=64
    num_heads=4

In [10]:
class MultiHeadAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.head_size = Config.embedding_dim // Config.num_heads
        C = Config.embedding_dim
        self.attn = nn.Linear(C, C*3, bias=False, device=device)
        self.softmax = nn.Softmax()
        self.register_buffer(
            "bias", torch.tril(torch.ones(Config.context_length, Config.context_length, device=device))
        )
        self.out_proj = nn.Linear(C, C)

    
    def forward(self, x,inbuilt_attn=True):
        B, T, C = x.shape
        qkv = self.attn(x)
        # k: BxTxC
        # q: BxTxC
        # v: BxTxC
        q, k, v = qkv.split(C, dim=-1)

        nh = Config.num_heads
        k = k.view(B, T, nh, C // nh).transpose(1, 2)  # B x nh, T, hs
        q = q.view(B, T, nh, C // nh).transpose(1, 2)  # B x nh, T, hs
        v = v.view(B, T, nh, C // nh).transpose(1, 2)  # B x nh, T, hs

        if inbuilt_attn:
            y = F.scaled_dot_product_attention(q, k, v, is_causal=True) # flash attention

        else:
            wei = ((q @ k.transpose(-2, -1))* k.size(-1)**-0.5) # B x nh x T x T
            wei = wei.masked_fill(self.bias[:T, :T]==0, float("-inf"))
            wei = F.softmax(wei, dim=-1)
            y = wei @ v
            y = y.transpose(1, 2)

        # https://stackoverflow.com/questions/48915810/what-does-contiguous-do-in-pytorch
        y  = y.contiguous().view(B, T, C)
        return self.out_proj(y)
    

In [11]:
mha = MultiHeadAttention()

B, T, C = 4, 10, Config.embedding_dim
x = torch.randn(B, T, C)
mha(x).shape

torch.Size([4, 10, 64])

In [12]:
with torch.no_grad():
    y1 = mha(x, inbuilt_attn=True)
    y2 = mha(x, inbuilt_attn=False)

In [6]:
y1[0][0][:5], y2[0][0][:5]

(tensor([-0.0521, -0.1870,  0.0028,  0.4239,  0.0067]),
 tensor([-0.4268, -0.0518, -0.0083,  0.0816, -0.5620]))