Let's first put down the separate classes for reference and to check for correctness.

In [1]:
from fastcore.foundation import *

In [2]:
import torch
from torch import nn
import torch.nn.functional as F

In [3]:
# sample hyperparameters
n_embd = 32
blk_sz = 8
n_heads = 4
head_sz = n_embd // n_heads
head_sz

8

In [4]:
torch.tril(torch.ones((blk_sz, blk_sz)))[:blk_sz, :blk_sz]

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [5]:
torch.manual_seed(42);

In [6]:
class Head(nn.Module):
    def __init__(self, head_sz):
        super().__init__()
        self.key   = nn.Linear(n_embd, head_sz)
        self.query = nn.Linear(n_embd, head_sz)
        self.value = nn.Linear(n_embd, head_sz)
        self.register_buffer('tril', torch.tril(torch.ones((blk_sz, blk_sz))))

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)
        q = self.query(x)
        wei = q @ k.transpose(-2, -1) * C**-0.5 # normalize coz dot product grows with head_sz
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # B, T, T
        wei = wei.softmax(dim=-1)
        v = self.value(x)
        out = wei @ v
        return out

# h = Head(16)
# h(torch.randn((4, 8, 32))).shape

In [7]:
class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_sz=head_sz) for _ in range(n_heads)])
        self.proj = nn.Linear(n_heads * head_sz, n_embd)
    
    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.proj(out)
        return out

# mh = MultiHeadAttention(n_heads)
# mh(torch.randn((4, 8, 32))).shape

EX1: The n-dimensional tensor mastery challenge: Combine the `Head` and `MultiHeadAttention` into one class that processes all the heads in parallel, treating the heads as another batch dimension (answer is in nanoGPT).

In [8]:
torch.manual_seed(42)
attn = nn.Linear(n_embd, 3 * n_embd, bias=False)
x = torch.randn((4, blk_sz, n_embd))
x.shape, attn

(torch.Size([4, 8, 32]), Linear(in_features=32, out_features=96, bias=False))

In [11]:
k,q,v = attn(x).split(n_embd, dim=-1)
k.shape

torch.Size([4, 8, 32])

In [10]:
B,T,C = x.shape
B,T,C

(4, 8, 32)

view `k,q,v` to be of shape: `B, n_heads, T, head_sz` (4, 4, 8, 8)

In [11]:
k = k.view(B, T, n_heads, head_sz).transpose(1, 2) # "exchange" dims (1, 2)
k.shape

torch.Size([4, 4, 8, 8])

In [12]:
q = q.view(B, T, n_heads, head_sz).transpose(1, 2) # "exchange" dims (1, 2)
v = v.view(B, T, n_heads, head_sz).transpose(1, 2) # "exchange" dims (1, 2)
q.shape, v.shape

(torch.Size([4, 4, 8, 8]), torch.Size([4, 4, 8, 8]))

In [13]:
wei = q @ k.transpose(-2, -1) * head_sz**-0.5
wei.shape

torch.Size([4, 4, 8, 8])

In [14]:
tril = torch.tril(torch.ones(blk_sz, blk_sz))
wei = wei.masked_fill(tril[:T, :T]==0, float('-inf'))
wei = wei.softmax(dim=-1)
wei.shape

torch.Size([4, 4, 8, 8])

In [15]:
out = wei @ v
out.shape

torch.Size([4, 4, 8, 8])

In [19]:
# certain operations like view require the tensor to be contiguously stored in memory.
# operations like transposing alter the memory layout making the tensor non-contiguous.
# all the more reason to learn how tensors are represented under-the-hood >:(
out = out.transpose(1, 2).contiguous().view(B, T, n_embd)

In [20]:
out.shape

torch.Size([4, 8, 32])