In [1]:
import torch


In [None]:
class MultiHeadAttention: 
  def __init__ (self, batch_size:int, sequence_length:int, model_dimension, n_heads):
    assert model_dimension % n_heads == 0 
    self.B = batch_size
    self.L = sequence_length
    self.D = model_dimension
    self.H = n_heads
    self.dk = self.D // self.H 
    
  def forward(self, X, W_Q, W_K, W_V): 
    # 1. Fuse weights (D, 3D)
    W_Q_K_V = torch.cat([W_Q, W_K, W_V], dim=1)
    
    # 2. Project Input: (B, L, D) @ (D, 3D) -> (B, L, 3D)
    projs = torch.matmul(X, W_Q_K_V)
    
    # 3. Reshape to isolate 3 (Q,K,V) and Heads
    # Shape becomes: (B, L, 3, H, Dk)
    T = projs.reshape(self.B, self.L, 3, self.H, self.dk)
    
    # 4. Permute to get Heads *before* Sequence Length
    # We want: (3, B, H, L, Dk)
    # 2 -> 0 (The QKV dimension moves to front)
    # 0 -> 1 (Batch stays)
    # 3 -> 2 (Heads move before Length)
    # 1 -> 3 (Length moves after Heads)
    # 4 -> 4 (Head Dim stays last)
    T = T.permute(2, 0, 3, 1, 4)
    
    # 5. Split into Q, K, V
    Q, K, V = T[0], T[1], T[2]

In [8]:
batch_size = 16 
sequence_length = 8 
model_dimension = 256
n_heads = 32

X = torch.randn(batch_size, sequence_length, model_dimension)
W_Q, W_K, W_V = torch.randn(model_dimension, model_dimension), torch.randn(model_dimension, model_dimension), torch.randn(model_dimension, model_dimension)

In [9]:
H = MultiHeadAttention(batch_size, sequence_length, model_dimension, n_heads)

In [10]:
proj = H.forward(X,W_Q, W_K, W_V)

RuntimeError: permute(sparse_coo): number of dimensions in the tensor input does not match the length of the desired ordering of dimensions i.e. input.dim() = 5 is not equal to len(dims) = 2

In [6]:
proj.shape

NameError: name 'proj' is not defined

In [46]:
print(32*8)

256
