# Playground for Multihead Attention

## Prototype Attention Head
Implementation of a single attention head

In [27]:
import torch
import torch.nn as nn

torch.manual_seed(42)

## This is just for demonstration purposes. Not used later
class CausalSelfAttention(nn.Module):
    
    def __init__(self, input_dim, output_dim, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.output_dim = output_dim
        
        # init weight matrices for input -> query, key, weight projection
        self.W_query = nn.Linear(input_dim, output_dim, bias=qkv_bias)
        self.W_key = nn.Linear(input_dim, output_dim, bias=qkv_bias)
        self.W_value = nn.Linear(input_dim, output_dim, bias=qkv_bias)
        
        # init dropout (prevent overfitting)
        self.dropout = nn.Dropout(dropout) # dropout is a probability
        
        # registers static causal masking matrix (diagonal) in same device as model (is self.causal_mask)
        self.register_buffer('causal_mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) # Causal Masking on diagonal triangle

    def forward(self, x):
        
        # stores dimensions of input (here only context length needed)
        batch_size, con_len, in_dim = x.shape # set them to the shape of x ( 8, 4, 256 )
        
        # using the weight matrices for projection of the input to query, key and value
        # broadcasting (for badges)
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        # calculates all attention scores
        # Compares query for every input token to all keys
        attn_scores = queries @ keys.transpose(1, 2)

        # creates bool mask (diagonal) fills all causal invisible elements with -inf
        attn_scores.masked_fill( self.causal_mask.bool()[:con_len, :con_len], -torch.inf)

        # calculations of soft max scaled with sqrt of keys length dimension (keys.shape)
        # note: dim=-1 takes most internal dimension in an array 
        # example: [ [ [ 1, 2, 3] ],  [ [ 4, 5, 6 ], [ 7, 8, 8 ] ] ] -> 3
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)

        # applying dropout to weights
        # randomly sets weights to 0 based on probability
        # !! scales remaining values up to compensate dropped weights
        attn_weights = self.dropout(attn_weights) # Additional random dropout

        # sum of attention weighted values for every input (including causal masked knowledge)
        context_vec = attn_weights @ values
        return context_vec


## Short insight in masking and causal masking

In [28]:
# Triangular Matrix filtering example for causal masking
def mask_demo():

    # matrix to mask. Example attention scores
    matrix = torch.rand(4,4)

    # buffered static causal_mask
    mask = torch.triu(torch.ones(4, 4), diagonal=1)

    print("\nmatrix (ex. attention_scores):\n", matrix) # some value matrix
    print("\nmask (buffered causal_mask of size context_length):\n", mask) # triangle mask matrix consisting of 0 and 1
    print("\nmask.bool():\n", mask.bool()) # makes matrix with 0 and 1 to true and false
    print("\nmask.bool()[interval] (clipped to match input)\n", mask.bool()[:3, :3]) # sub matrix
    print("\nmatrix.mask_fill(mask.bool()[interval], value)\n", matrix.masked_fill(mask.bool()[:4, :4], -torch.inf))

# mask_demo()

## Multi Head Attention
An implementation of Multi Head Attention
Multiple heads train themself on parts of the input vector. They represent a part of the query, key, value (short qkv) matrices and are combined in the end by linear layer to context vector

In [29]:
class MultiHeadAttention(nn.Module):

    def __init__(self, input_dim, output_dim, context_length, dropout, num_heads, qkv_bias=False, verbose=False):
        super().__init__()
        
        # Check if output dimension is dividable by attention head number without rest
        # Input is split to num_head chunks of length head_dim
        assert output_dim % num_heads == 0, "Output dimension must be dividable by head_num"

        self.TransformerBlock = output_dim
        self.num_heads = num_heads
        self.head_dim = output_dim // num_heads

        # Init weight matrices for input to qkv projection (full not yet separated into head parts)
        self.W_query = nn.Linear(input_dim, output_dim, qkv_bias)
        self.W_key = nn.Linear(input_dim, output_dim, qkv_bias)
        self.W_value = nn.Linear(input_dim, output_dim, qkv_bias)

        self.register_buffer("causal_mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))
        
        self.dropout = nn.Dropout(dropout)
        
        self.out_proj = nn.Linear(output_dim, output_dim) # Linear layer for combination of head outputs
        
        if verbose:
            print(f"\n=== MultiHeadAttention Initialization ===")
            print(f"    input_dim =", input_dim)
            print(f"    output_dim =", output_dim)
            print(f"    num_heads =", self.num_heads)
            print(f"    head_dim =", self.head_dim)
            print(f"    Generating nn.Linear({input_dim}, {output_dim}) weights for query, key and value")
            print(f"    Generating causal diagonal mask torch.triu(torch.ones({context_length}, {context_length}), diagonal=1) for causal masking of attn_scores")        
            print(f"    Generating dropout nn.Dropout({dropout}) for random dropout of attn_weights")        
            print(f"    Generating optional nn.Linear({output_dim}, {output_dim}) weights for final context_vector projection")
            print(f"=== End MultiHeadAttention Initialization ===\n")
        
    def forward(self, x, verbose=False):

        # local variables for input shape
        batch_size, context_length, input_dim = x.shape

        if verbose:
            print(f"\n=== MultiHeadAttention Forward Pass ===")
            print(f"Input shape: {x.shape} (batch_size={batch_size}, context_length={context_length}, input_dim={input_dim})")
            print(f"Config: num_heads={self.num_heads}, head_dim={self.head_dim}, output_dim={self.num_heads * self.head_dim}")
            print(f"\nInput tensor (batch 0 with shape {x[0].shape}):")
            print(f" {x[0]}")

        # using weight matrices for projection of the input to qkv (not yet splitted for attention heads)
        # broadcasting for badges
        # -> shape: batch_size, context_length, output_dim
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)

        if verbose:
            print(f"\n1. QKV Projection:")
            print(f"   QKV shapes: {queries.shape}")
 
        # Implicitly splitting matrix by adding head_num dimension
        # Unrol last dim: (batch_size, context_length, output_dim) -> (batch_size, context_length, head_num, head_dim)
        # Example ( 8, 4, 256 ) -> ( 8, 4, 8, 32 ) for num_heads = 8 and head_dim = 32
        queries = queries.view(batch_size, context_length, self.num_heads, self.head_dim)
        keys = keys.view(batch_size, context_length, self.num_heads, self.head_dim)
        values = values.view(batch_size, context_length, self.num_heads, self.head_dim)

        if verbose:
            print(f"\n2. Split into heads:")
            print(f"(batch_size, context_length, output_dim) -> (batch_size, context_length, head_num, head_dim)")
            print(f"   QKV shapes after view: {queries.shape}")
 
        # Transpose to use for query comparison - move num_head to front
        # (batch_size, context_length, num_heads, head_dim) -> (batch_size, num_head, context_length, head_dim)
        queries = queries.transpose(1,2)
        keys = keys.transpose(1,2)
        values = values.transpose(1,2)

        if verbose:
            print(f"\n3. Transpose for attention computation:")
            print(f"(batch_size, context_length, num_heads, head_dim) -> (batch_size, num_head, context_length, head_dim)")
            print(f"   QKV shapes after transpose: {queries.shape}")

        # Compute scaled dot_production attention
        attn_scores = queries @ keys.transpose(2,3) # Dot product for each head

        if verbose:
            print(f"\n4. Attention scores computation:")
            print(f"   attn_scores shape: {attn_scores.shape}")
            print(f"   Scale factor (1/sqrt(head_dim)): {1/keys.shape[-1]**0.5:.4f}")
            print(f"   Raw attention scores for head 0, batch 0:\n{attn_scores[0, 0]}")

        # Causal Masking
        mask_bool = self.causal_mask.bool()[:context_length, :context_length]
        attn_scores = attn_scores.masked_fill(mask_bool, -torch.inf)

        if verbose:
            print(f"\n5. Causal masking:")
            print(f"   Causal mask:\n{mask_bool}\n")
            print(f"   Then masked_fill True -> -torch.inf\n")
            print(f"   Masked attention scores for head 0, batch 0:\n{attn_scores[0, 0]}")

        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim =-1)

        if verbose:
            print(f"\n6. Softmax attention weights:")
            print(f"   attn_weights shape: {attn_weights.shape}")
            print(f"   Attention weights for head 0, batch 0:\n{attn_weights[0, 0]}\n")
            print(f"   Sum of weights (should be ~1.0): {attn_weights[0, 0].sum(dim=-1)}")

        attn_weights = self.dropout(attn_weights)

        if verbose:
            print(f"\n7. After dropout:")
            print(f"   Attention weights after dropout for head 0, batch 0:\n{attn_weights[0, 0]}")

        # build context vector switch back num_heads and context_length, combining heads
        # self.output_dim = self.num_heads * self.head_dim
        context_vec = (attn_weights @ values).transpose(1,2)

        if verbose:
            print(f"\n8. Compute context vectors:")
            print(f"   context_vec shape after attention: {context_vec.shape}")
            print(f"   First context vector (batch 0, token 0, head 0): {context_vec[0, 0, 0]}...")

        context_vec = context_vec.contiguous().view(batch_size, context_length, self.num_heads * self.head_dim)

        if verbose:
            print(f"\n9. Concatenate heads:")
            print(f"   context_vec shape after view: {context_vec.shape}")
            print(f"   First concatenated context vector (batch 0): {context_vec[0]}...")

        context_vec = self.out_proj(context_vec) # optional projection by Linear layer

        if verbose:
            print(f"\n10. Final output projection:")
            print(f"   Final context_vec shape: {context_vec.shape}")
            print(f"   Final context vector (batch 0): {context_vec[0]}...")
            print(f"=== End MultiHeadAttention Forward Pass ===\n")

        return context_vec

## Test Run

In [30]:
def multi_head_attention_test_run(verbose=False):
    if verbose: print("\n\n------- initializing Multi Head Attention ----------------\n")
    mha = MultiHeadAttention(input_dim=768, output_dim=128, context_length=4, dropout=0.2, num_heads=8, verbose=verbose)

    if verbose: print("\n\n----- Using real data input (see DataPreparation.ipynb) --\n")
    %run "./01. DataPreparation.ipynb"
    batch = get_test_input_embedding(verbose=verbose) # defined in DataPreparation.ipynb

    if verbose: print("\n\n------- performing multi head attention ------------------\n")
    context_vector = mha(batch, verbose=verbose)

# _test_run = multi_head_attention_test_run(True)