In [None]:
import torch
from torch import nn

torch.manual_seed(123)

# Define the tensor of shape 3 x 6
inputs = torch.tensor(
    [[0.43, 0.15, 0.89, 0.55, 0.87, 0.66],  # Row 1
     [0.57, 0.85, 0.64, 0.22, 0.58, 0.33],  # Row 2
     [0.77, 0.25, 0.10, 0.05, 0.80, 0.55]]  # Row 3
)

batch = torch.stack((inputs, inputs), dim=0) ## Batch with 2 inputs, each of shape 3 x 6
print(batch.shape) 

torch.Size([2, 3, 6])


In [7]:
class MultiHeadedAttentionOptimized(nn.Module):
    def __init__(self, d_in, d_out, context_length, num_heads, dropout, qkv_bias = False):
        super().__init__()

        assert d_out % num_heads == 0, "d_out must be divisible by num_heads" ## We keep the head dimension as d_out // num_heads, so that we can split the d_out dimension into num_heads parts.

        self.d_in = d_in ## Input dimension of the embedding
        self.d_out = d_out ## Output dimension of the embedding
        self.num_heads = num_heads
        self.context_length = context_length ## number of tokens in the input sequence
        self.head_dim = d_out // num_heads ## d_k
        
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias) # Shape d_in x d_out
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length), ## Causal mask
                       diagonal=1)
        )
    
    def forward(self, x):
        b, num_tokens, d_in = x.shape # Shape b x num_tokens x d_in

        keys = self.W_key(x) # Shape b x num_tokens x d_out
        queries = self.W_query(x)
        values = self.W_value(x)
        print(f"Keys shape before splitting d_out: {keys.shape}")

        # We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) 
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
        print(f"Keys shape after splitting d_out: {keys.shape}")

        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim) to group by heads
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)
        print(f"Keys shape after grouping heads together: {keys.shape}")

        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head Shape: (b, num_heads, num_tokens, num_tokens)
        print(f"Attention scores shape: {attn_scores.shape}")
        print(f"Attention Scores: {attn_scores}")

        # Original mask truncated to the number of tokens and converted to boolean. :num_tokens is used to handle cases where the number of tokens in the batch is smaller than the supported context_length.
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Use the mask to fill attention scores
        attn_scores.masked_fill_(mask_bool, -torch.inf)
        
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights) # Shape (b, num_heads, num_tokens, num_tokens)
        print(f"Attention weights shape: {attn_weights.shape}")
        print(f"Attention Weights: {attn_weights}")

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values) # Shape (b, num_heads, num_tokens, head_dim) -> (b, num_tokens, num_heads, head_dim) after transpose to unroll 
        print(f"Context vector shape before combining heads: {context_vec.shape}")
        print(f"Context Vector: {context_vec}")
        context_vec = context_vec.transpose(1, 2)
        print(f"Context vector shape after transposing heads: {context_vec.shape}")
        
        # Combine heads, where self.d_out = self.num_heads * self.head_dim. contigous() is used to ensure that the tensor is contiguous in memory, which is often a requirement for further operations.
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec) # optional projection so head outputs can interact
        print(f"Context vector shape after combining heads: {context_vec.shape}")
        return context_vec

In [8]:
batch_size, context_length, d_in = batch.shape
d_out = 6
mha = MultiHeadedAttentionOptimized(d_in, d_out, context_length, num_heads=2, dropout=0.0)
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

Keys shape before splitting d_out: torch.Size([2, 3, 6])
Keys shape after splitting d_out: torch.Size([2, 3, 2, 3])
Keys shape after grouping heads together: torch.Size([2, 2, 3, 3])
Attention scores shape: torch.Size([2, 2, 3, 3])
Attention Scores: tensor([[[[-0.1389, -0.3177, -0.2816],
          [ 0.1951,  0.1256, -0.0087],
          [-0.2005, -0.2311, -0.1883]],

         [[ 0.0304,  0.0561,  0.0441],
          [-0.1166, -0.1927, -0.1691],
          [ 0.0022,  0.0300, -0.0119]]],


        [[[-0.1389, -0.3177, -0.2816],
          [ 0.1951,  0.1256, -0.0087],
          [-0.2005, -0.2311, -0.1883]],

         [[ 0.0304,  0.0561,  0.0441],
          [-0.1166, -0.1927, -0.1691],
          [ 0.0022,  0.0300, -0.0119]]]], grad_fn=<UnsafeViewBackward0>)
Attention weights shape: torch.Size([2, 2, 3, 3])
Attention Weights: tensor([[[[1.0000, 0.0000, 0.0000],
          [0.5100, 0.4900, 0.0000],
          [0.3345, 0.3286, 0.3369]],

         [[1.0000, 0.0000, 0.0000],
          [0.5110, 0.4890