In [1]:
import torch


In [2]:
class MultiHeadAttention: 
  def __init__ (self, batch_size:int, sequence_length:int, model_dimension, n_heads):
    assert model_dimension % n_heads == 0 
    self.B = batch_size
    self.L = sequence_length
    self.D = model_dimension
    self.H = n_heads
    self.dk = self.D // self.H 
    
  def forward(self, X, W_Q, W_K, W_V): 
    # 1. Fuse weights (D, 3D)
    W_Q_K_V = torch.cat([W_Q, W_K, W_V], dim=1)
    
    # 2. Project Input: (B, L, D) @ (D, 3D) -> (B, L, 3D)
    projs = torch.matmul(X, W_Q_K_V)
    
    # 3. Reshape to isolate 3 (Q,K,V) and Heads
    # Shape becomes: (B, L, 3, H, Dk)
    T = projs.reshape(self.B, self.L, 3, self.H, self.dk)
    
    # 4. Permute to get Heads *before* Sequence Length
    # We want: (3, B, H, L, Dk)
    # 2 -> 0 (The QKV dimension moves to front)
    # 0 -> 1 (Batch stays)
    # 3 -> 2 (Heads move before Length)
    # 1 -> 3 (Length moves after Heads)
    # 4 -> 4 (Head Dim stays last)
    T = T.permute(2, 0, 3, 1, 4)
    # 5. Split into Q, K, V
    Q, KT, V = T[0], T[1].permute(0,1,3,2), T[2]
    all_head_out = torch.matmul(torch.softmax((torch.rsqrt(torch.tensor(self.dk)))*(torch.matmul(Q, KT)), dim =3), V)
    return (all_head_out).permute(0,2,1,3).reshape(self.B, self.L, self.D)
    

In [3]:
batch_size = 16 
sequence_length = 4
model_dimension = 256
n_heads = 8

X = torch.randn(batch_size, sequence_length, model_dimension)
W_Q, W_K, W_V = torch.randn(model_dimension, model_dimension), torch.randn(model_dimension, model_dimension), torch.randn(model_dimension, model_dimension)

In [4]:
H = MultiHeadAttention(batch_size, sequence_length, model_dimension, n_heads)

In [5]:
Q = H.forward(X,W_Q, W_K, W_V)

In [6]:
print(Q.shape)


torch.Size([16, 4, 256])


In [7]:
print(X.shape)

torch.Size([16, 4, 256])
