Instead of maintaining two separate classes, Multi Head Attention Wrapper and Causal Attention, we combine both of these
concepts into single Multi-HeadAttention class"""

In the MultiHeadAttention Wrapper, multiple heads are implemented by creating a list of Causal attention Object (self.heads), 
each representing a separate attention head 

The CausalAttention class independently performs the attention mechansim and the results from each head are concatenated.

In constrast, the following MultiHeadAttention class integrates the multi-head functionality within single class

In [1]:

import torch.nn as nn
import torch 

In [2]:

class MultiHeadAttention(nn.Module) :
    def __init__(self, d_in, d_out, context_len, dropout, num_heads, qkv_bias=False) :
        super().__init__()
        # assert(d_out % num_heads == 0, "d_out must be divisible by num heads")

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        self.W_queries = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_keys = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_values = nn.Linear(d_in, d_out, bias=qkv_bias)

        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer("mask", torch.triu(torch.ones(context_len, context_len), diagonal=1))


    def forward(self, x):
        batch_size, num_tokens, d_in = x.shape
        Keys = self.W_keys(x)
        Queries = self.W_queries(x)
        Values = self.W_values(x)

        # We implicitly split the matrix by adding a 'num_heads' dimensions
        # Unroll last dim : (batch_size, num_tokens, d_out) -> (batch_size, num_tokens, num_heads, head_dim)
        Keys = Keys.view(batch_size, num_tokens, self.num_heads, self.head_dim)
        Values = Values.view(batch_size, num_tokens, self.num_heads, self.head_dim)
        Queries = Queries.view(batch_size, num_tokens, self.num_heads, self.head_dim)
        
        # Tranpose : ( batch_size, num_tokens, num_heads, head_dim) -> (batch_size, num_heads, num_token, head_dim)
        Keys = Keys.transpose(1, 2)
        Values = Values.transpose(1, 2)
        Queries = Queries.transpose(1, 2)

        # Compute scale dot product attention (aka self attention) with causal mask
        atten_scores = Queries @ Keys.transpose(2, 3)

        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Use the mask to fill attention scores
        atten_scores.masked_fill_(mask_bool, -torch.inf)

        # Normalize causal attention scores to obtain attention weights
        atten_weights = torch.softmax(atten_scores / (Keys.shape[-1] ** 0.5), dim=-1)
        atten_weights = self.dropout(atten_weights)

        # Shape : ( batch_Size, num_tokens, num_heads, head_dim )
        context_vec = (atten_weights @ Values).transpose(1, 2)

        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(batch_size, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)

        return context_vec

In [3]:

torch.manual_seed(123)

# Define the tensor with 3 rows and 6 columns
inputs = torch.tensor(
        [[0.43, 0.15, 0.89, 0.55, 0.87, 0.66],
         [0.57, 0.85, 0.64, 0.22, 0.58, 0.45],
         [0.77, 0.25, 0.10, 0.05, 0.80, 0.55]]
    )

batch = torch.stack((inputs, inputs), dim=0)
print("batch shape : ", batch.shape)

batch_size, num_tokens, embed_dims = batch.shape 
d_out = 6
context_len = batch.shape[1]
mha = MultiHeadAttention(6, d_out, context_len=context_len, dropout=0.0, num_heads=2 )
context_vec = mha(batch)
print("context vector : \n", context_vec)
print("context vector shape : \n", context_vec.shape)

batch shape :  torch.Size([2, 3, 6])
context vector : 
 tensor([[[ 0.1569, -0.0873,  0.0210,  0.0215, -0.3243, -0.2518],
         [ 0.1038, -0.0506,  0.0340, -0.0179, -0.3291, -0.2941],
         [ 0.1145, -0.0466,  0.0280, -0.0614, -0.2814, -0.2546]],

        [[ 0.1569, -0.0873,  0.0210,  0.0215, -0.3243, -0.2518],
         [ 0.1038, -0.0506,  0.0340, -0.0179, -0.3291, -0.2941],
         [ 0.1145, -0.0466,  0.0280, -0.0614, -0.2814, -0.2546]]],
       grad_fn=<ViewBackward0>)
context vector shape : 
 torch.Size([2, 3, 6])
