In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F

# Hyperparameters

In [2]:
# data hyperparameters
seq_len = 8 # aka context window

# model hyperparameters
embed_dim = 128
n_heads = 4 # Embed dimension needs to be a multiple of number of attention heads

# training hyperparameters
batch_size = 5

# Create Class for multihead attention

In [3]:
class MultiHeadAttention(nn.Module):
  def __init__(self, num_heads, embed_dim):
    super().__init__()

    # head-dimensionality is embed_dim split across the heads
    self.num_heads = num_heads
    self.head_dim = embed_dim // num_heads

    # num_heads Q, K, and V matrices, initialized as one "super-head"
    #    note: in model 5, these three matrices are combined into one
    self.query = nn.Linear(embed_dim, embed_dim, bias=False)
    self.key   = nn.Linear(embed_dim, embed_dim, bias=False)
    self.value = nn.Linear(embed_dim, embed_dim, bias=False)

    # final linear projection merges the heads' outputs
    self.W0 = nn.Linear(embed_dim, embed_dim, bias=False)

  def forward(self,x,track_sizes=False):

    # extract the dimension sizes of the inputs (token embeddings)
    B, T, E = x.shape # [batch, tokens (sequence length), embed_dim]
    if track_sizes: print(f"1){' Input data shape:':>28} {x.shape}")

    # push data through Q, K, and V (actually multiple heads still in the same matrix)
    q = self.query(x) # [batch, seq_len, embed_dim]
    k = self.key(x)
    v = self.value(x)
    if track_sizes: print(f"2){'q/k/v pre-split shape:':>28} {q.shape}")

    # reshape to split up the heads (note: head-splitting is done after XW_Q)
    q = q.view(B, T, self.num_heads, self.head_dim)
    k = k.view(B, T, self.num_heads, self.head_dim)
    v = v.view(B, T, self.num_heads, self.head_dim)
    if track_sizes: print(f"3){'q/k/v post-split shape:':>28} {q.shape}")

    # but pytorch's SDPA function needs the shape to be [B, num_heads, T, head_dim]
    q = q.transpose(1,2)
    k = k.transpose(1,2)
    v = v.transpose(1,2)
    if track_sizes: print(f"4){'q/k/v trnasposed shape:':>28} {q.shape}")

    # now we can call SDPA
    out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
    if track_sizes: print(f"5){'Data post-attention shape:':>28} {out.shape}")

    # but our code still needs [B, T, num_heads, head_dim]
    out = out.transpose(1,2)
    if track_sizes: print(f"6){'Post-attention data reshape:':>28} {out.shape}")

    # merge heads back into embed_dim
    out = out.reshape(B, T, E)
    if track_sizes: print(f"7){'Data merged to size:':>28} {out.shape}")

    # finally, apply linear mixing matrix
    out = self.W0(out)
    if track_sizes: print(f"8){'Post-MHA H0 linear mixing:':>28} {out.shape}")

    return out

In [4]:
mha = MultiHeadAttention(n_heads,embed_dim)
mha

MultiHeadAttention(
  (query): Linear(in_features=128, out_features=128, bias=False)
  (key): Linear(in_features=128, out_features=128, bias=False)
  (value): Linear(in_features=128, out_features=128, bias=False)
  (W0): Linear(in_features=128, out_features=128, bias=False)
)

In [5]:
# run some fake data through
data = torch.randn(size=(batch_size,seq_len,embed_dim))
out = mha(data, track_sizes=True)
print(f'Input size:  {data.shape}')
print(f'Output size: {out.shape}')

1)           Input data shape: torch.Size([5, 8, 128])
2)      q/k/v pre-split shape: torch.Size([5, 8, 128])
3)     q/k/v post-split shape: torch.Size([5, 8, 4, 32])
4)     q/k/v trnasposed shape: torch.Size([5, 4, 8, 32])
5)  Data post-attention shape: torch.Size([5, 4, 8, 32])
6)Post-attention data reshape: torch.Size([5, 8, 4, 32])
7)        Data merged to size: torch.Size([5, 8, 128])
8)  Post-MHA H0 linear mixing: torch.Size([5, 8, 128])
Input size:  torch.Size([5, 8, 128])
Output size: torch.Size([5, 8, 128])


In [6]:
print(f'    Sequence length: {seq_len:2d}')
print(f'Embedding dimension: {embed_dim}')
print(f'    Number of heads: {n_heads:2d}')
print(f'Head dimensionality: {embed_dim // n_heads}')

print('\nDimensions of the data as it passes through the attention sublayer of one Transformer block:')
out = mha(data,track_sizes=True)

    Sequence length:  8
Embedding dimension: 128
    Number of heads:  4
Head dimensionality: 32

Dimensions of the data as it passes through the attention sublayer of one Transformer block:
1)           Input data shape: torch.Size([5, 8, 128])
2)      q/k/v pre-split shape: torch.Size([5, 8, 128])
3)     q/k/v post-split shape: torch.Size([5, 8, 4, 32])
4)     q/k/v trnasposed shape: torch.Size([5, 4, 8, 32])
5)  Data post-attention shape: torch.Size([5, 4, 8, 32])
6)Post-attention data reshape: torch.Size([5, 8, 4, 32])
7)        Data merged to size: torch.Size([5, 8, 128])
8)  Post-MHA H0 linear mixing: torch.Size([5, 8, 128])
