In [None]:
import torch
import torch.nn as nn
import math

# --- 1. Synthetic Dataset --- #
# Let's create a simple sequence of embeddings (e.g., word embeddings for a short sentence)
# batch_size = 1, sequence_length = 5, embedding_dimension = 8
embedding_dim = 8
sequence_length = 5
batch_size = 1
x = torch.randn(batch_size, sequence_length, embedding_dim)
print(f"Input sequence (x):\n{x.shape}\n")

# --- 2. Single-Head Attention Implementation --- #
class SingleHeadAttention(nn.Module):
    def __init__(self, input_dim, output_dim, dropout_rate=0.1):
        super().__init__()
        self.query_proj = nn.Linear(input_dim, output_dim)
        self.key_proj = nn.Linear(input_dim, output_dim)
        self.value_proj = nn.Linear(input_dim, output_dim)
        self.scale = math.sqrt(output_dim) # Scaling factor
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, query, key, value, mask=None):
        # Project inputs to query, key, value spaces
        Q = self.query_proj(query) # (batch_size, seq_len, output_dim)
        K = self.key_proj(key)     # (batch_size, seq_len, output_dim)
        V = self.value_proj(value) # (batch_size, seq_len, output_dim)

        # Calculate attention scores (dot product attention)
        # Q @ K.transpose(-2, -1) -> (batch_size, seq_len, seq_len)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9) # Apply mask (e.g., for padding or causality)

        attention_weights = torch.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)

        # Multiply with Value matrix
        output = torch.matmul(attention_weights, V) # (batch_size, seq_len, output_dim)
        return output, attention_weights

print("--- Single-Head Attention Demonstration ---")
# For single-head attention, output_dim can be same as embedding_dim
single_head_attn = SingleHeadAttention(embedding_dim, embedding_dim)
single_output, single_weights = single_head_attn(x, x, x) # Q, K, V are all from x (self-attention)

print(f"Single-Head Attention Output Shape: {single_output.shape}")
print(f"Single-Head Attention Weights Shape: {single_weights.shape}\n")
# print(f"Sample Single-Head Attention Output:\n{single_output[0,0,:4].detach().numpy()}\n")
# print(f"Sample Single-Head Attention Weights for first token:\n{single_weights[0,0,:].detach().numpy()}\n")


# --- 3. Multi-Head Attention Implementation --- #
class MultiHeadAttention(nn.Module):
    def __init__(self, input_dim, head_dim, num_heads, dropout_rate=0.1):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.input_dim = input_dim

        # All projections are combined into one for efficiency
        self.query_proj = nn.Linear(input_dim, num_heads * head_dim)
        self.key_proj = nn.Linear(input_dim, num_heads * head_dim)
        self.value_proj = nn.Linear(input_dim, num_heads * head_dim)
        self.output_proj = nn.Linear(num_heads * head_dim, input_dim) # Final linear layer
        self.dropout = nn.Dropout(dropout_rate)
        self.scale = math.sqrt(head_dim)

    def forward(self, query, key, value, mask=None):
        batch_size, seq_len, _ = query.shape

        # Project inputs and split into heads
        Q = self.query_proj(query).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.key_proj(key).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.value_proj(value).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # Q, K, V now have shape (batch_size, num_heads, seq_len, head_dim)

        # Calculate attention scores for all heads in parallel
        # (batch_size, num_heads, seq_len, head_dim) @ (batch_size, num_heads, head_dim, seq_len)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale # (batch_size, num_heads, seq_len, seq_len)

        if mask is not None:
            # Ensure mask is broadcastable if it's 2D (seq_len, seq_len)
            scores = scores.masked_fill(mask == 0, -1e9)

        attention_weights = torch.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)

        # Apply attention to Value
        # (batch_size, num_heads, seq_len, seq_len) @ (batch_size, num_heads, seq_len, head_dim)
        x = torch.matmul(attention_weights, V) # (batch_size, num_heads, seq_len, head_dim)

        # Concatenate heads and apply final linear projection
        x = x.transpose(1, 2).contiguous().view(batch_size, seq_len, self.num_heads * self.head_dim)
        output = self.output_proj(x)

        return output, attention_weights

print("--- Multi-Head Attention Demonstration ---")
num_heads = 2
head_dim = embedding_dim // num_heads # Each head processes a part of the embedding dimension

multi_head_attn = MultiHeadAttention(embedding_dim, head_dim, num_heads)
multi_output, multi_weights = multi_head_attn(x, x, x) # Self-attention

print(f"Multi-Head Attention Output Shape: {multi_output.shape}")
print(f"Multi-Head Attention Weights Shape: {multi_weights.shape} (batch, heads, seq_len, seq_len)\n")
# print(f"Sample Multi-Head Attention Output:\n{multi_output[0,0,:4].detach().numpy()}\n")
# print(f"Sample Multi-Head Attention Weights for first token (head 0):\n{multi_weights[0,0,0,:].detach().numpy()}\n")

print("Key Differences:")
print("- Single-Head Attention: Computes one set of attention weights and output vectors for the entire input dimension.")
print("- Multi-Head Attention: Divides the input into 'num_heads' smaller parts (or projects into 'num_heads' different subspaces), computes attention independently for each head, and then concatenates and linearly transforms the results. This allows the model to jointly attend to information from different representation subspaces at different positions.")


Input sequence (x):
torch.Size([1, 5, 8])

--- Single-Head Attention Demonstration ---
Single-Head Attention Output Shape: torch.Size([1, 5, 8])
Single-Head Attention Weights Shape: torch.Size([1, 5, 5])

--- Multi-Head Attention Demonstration ---
Multi-Head Attention Output Shape: torch.Size([1, 5, 8])
Multi-Head Attention Weights Shape: torch.Size([1, 2, 5, 5]) (batch, heads, seq_len, seq_len)

Key Differences:
- Single-Head Attention: Computes one set of attention weights and output vectors for the entire input dimension.
- Multi-Head Attention: Divides the input into 'num_heads' smaller parts (or projects into 'num_heads' different subspaces), computes attention independently for each head, and then concatenates and linearly transforms the results. This allows the model to jointly attend to information from different representation subspaces at different positions.
