In [2]:
import torch
# Multi head attention is just multiple single head attention
class MultiHeadAttention:
    def __init__(self, num_heads, head_size):
        self.num_heads = num_heads
        self.head_size = head_size

    def forward(self, queries, keys, values):
        # Split the inputs into multiple heads
        batch_size = queries.shape[0]
        queries = queries.view(batch_size, -1, self.num_heads, self.head_size).transpose(1, 2)
        keys = keys.view(batch_size, -1, self.num_heads, self.head_size).transpose(1, 2)
        values = values.view(batch_size, -1, self.num_heads, self.head_size).transpose(1, 2)

        # Perform single head attention for each head
        outputs = []
        for i in range(self.num_heads):
            output = self.single_head_attention(queries[:, i], keys[:, i], values[:, i])
            outputs.append(output)

        # Concatenate the outputs from all heads
        return torch.cat(outputs, dim=-1)

    def single_head_attention(self, query, key, value):
        # Implement single head attention logic here
        attention_scores = torch.matmul(query, key.transpose(-2, -1)) / (self.head_size ** 0.5)
        attention_weights = torch.nn.functional.softmax(attention_scores, dim=-1)
        output = torch.matmul(attention_weights, value)
        return output
    

In [4]:
input = torch.randn(2, 10, 64)  # Example input with batch size 2, sequence length 10, and embedding size 64
mha = MultiHeadAttention(num_heads=8, head_size=8)  # 8 heads, each of size 8
context_vec = mha.forward(input, input, input)  # Using the same input for queries, keys, and values
print(context_vec.shape)  # Should be (2, 10, 64) since we concatenate the outputs from all heads
print(context_vec)

torch.Size([2, 10, 64])
tensor([[[ 0.5545,  0.3625,  0.2356,  ..., -0.1696,  0.2542,  0.9349],
         [-0.2540,  0.2386, -0.3160,  ..., -0.4753, -0.0434,  0.5628],
         [ 0.9996,  0.8392,  0.3020,  ..., -0.1834,  0.8125,  0.3112],
         ...,
         [-1.0645,  0.4282,  0.1393,  ..., -0.3102,  0.1344,  0.4881],
         [-0.7822,  0.5139,  0.0188,  ..., -0.4683, -0.0082, -0.5034],
         [ 0.4240, -0.5624, -0.1818,  ..., -0.1887,  0.5831,  0.3837]],

        [[-1.1466,  0.3671,  0.0050,  ..., -1.0981, -0.1420, -0.3580],
         [ 1.0708,  0.0838,  0.4759,  ..., -1.1877,  0.0113,  0.2362],
         [-0.0442,  0.0201, -1.1163,  ...,  0.0133,  1.1793, -0.1848],
         ...,
         [-0.1741,  1.7254, -1.7822,  ...,  1.4998,  0.0876, -0.6872],
         [ 0.5054,  0.4028,  0.2271,  ..., -0.7465, -0.4857, -0.0780],
         [-0.2361,  0.7216, -0.4721,  ..., -2.4264,  0.2613, -0.2558]]])


In [6]:
# Another implementation 
import torch.nn as nn
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape

        keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
        queries = self.W_query(x)
        values = self.W_value(x)

        # We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) 
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Use the mask to fill attention scores
        attn_scores.masked_fill_(mask_bool, -torch.inf)
        
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2) 
        
        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec) # optional projection

        return context_vec

In [7]:
torch.manual_seed(123)

# Define the tensor with 3 rows and 6 columns
inputs = torch.tensor( # (1,3,6)
    [[0.43, 0.15, 0.89, 0.55, 0.87, 0.66],  # Row 1
     [0.57, 0.85, 0.64, 0.22, 0.58, 0.33],  # Row 2
     [0.77, 0.25, 0.10, 0.05, 0.80, 0.55]]  # Row 3
)
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape) # Should be (2, 3, 6) since we stacked two identical inputs

torch.Size([2, 3, 6])


In [8]:
batch_size, context_length, d_in = batch.shape
d_out = 6
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[ 0.1569, -0.0873,  0.0210,  0.0215, -0.3243, -0.2518],
         [ 0.1117, -0.0547,  0.0406, -0.0213, -0.3251, -0.2993],
         [ 0.1196, -0.0491,  0.0318, -0.0635, -0.2788, -0.2578]],

        [[ 0.1569, -0.0873,  0.0210,  0.0215, -0.3243, -0.2518],
         [ 0.1117, -0.0547,  0.0406, -0.0213, -0.3251, -0.2993],
         [ 0.1196, -0.0491,  0.0318, -0.0635, -0.2788, -0.2578]]],
       grad_fn=<ViewBackward0>)
context_vecs.shape: torch.Size([2, 3, 6])
