In [4]:
# OPTIMIZING THE MULTIHEAD ATTENTION

# Inputs : 
# [batch_size, num_tokens = 3, d_in = 6]

# Decide d_out and num_head
"""
    d_out  = 6 {same as d_in} // Generally done in GPT Models
    d_out is the dimension of final context vectors i.e 6 and it is final
    num_heads = 2
    
    so head_dim = d_out(6) / num_heads(2) = 3
"""

# Initialize Wq, Wk, Wv
# dimensions : [d_in, d_out] == [6,6]
# this step is done in __init__ class

"""
Calculate the Keys, Values, Values Matrix : 
Keys, Values, Queries == Input * W_key,W_value, W_queries
[batch_size, 3,6] * [6,6] ==> [batch_size, 3, 6]
"""

# Reshaping(Splitting) the Keys, Values, Queries Matrix according to the 
# num_heads

"""
head_dim = d_out / num_heads = 3

Keys [batch_size,num_tokens,d_out(6)] ==> [batch_size, num_tokens,num_heads, head_dim_out]
Keys ==> [batch_size, 3, 2,3]
 
batch_size == num of inputs in each batch
num_heads = num of heads for each token
num_tokens = num of tokens in each input
head_dim_out or head_dim = for each head and token, the dim_out
"""

# The above code groups by num_tokens not by num_heads
## we will now group it by number of heads

# [batch_size,num_tokens,num_heads,head_dim] ==> [batch_size,num_heads, num_tokens, head_dim]
# Transpose(1,2)

## FIND ATTENTION SCORES
# Queries @ Keys.Transpose(2,3)
# [b,num_heads,num_tokens, head_dim] @ [b, num_heads, head_dim, num_tokens]
# Attention scores matrix  : [batch_size, num_heads, num_tokens, num_tokens]
### [batch_size, num_heads(2), 3, 3]


# FINDING ATTENTION WEIGHTS
import torch
num_tokens = 3
mask = torch.triu(torch.ones(num_tokens, num_tokens), diagonal = 1)
print(mask)

# attn_weights = attn_matrix.masked_fill_(mask.bool(), -torch.inf)
# WE have to mask the elements above the diagonal to -inf
# then we have to softmax the output
# before softmaxig we will divide the attn_scores by Root(dim(key) == keys.shape[-1])
# here dim(key) is head_dim i.e. 3

"""
We can have dropout after this
"""

## Finding the context_vectors : 
"""
context vectors = attn_weights @ values

[batch_size, num_heads, num_tokens,num_tokens] @ [batch_size, num_heads, num_tokens, head_dim_out]

context vectors ==> [batch_size, num_heads, num_tokens, head_dim_out]
BUT the resultant context vectors should have dimension of : 
[batch_size, num_tokens, d_out]

Now we will reshape it to : [batch_size, num_tokensm num_heads, head_dim_out]
Now join the last 2 dimenions : [batch_size, num_tokens, dim_out]
"""


tensor([[0., 1., 1.],
        [0., 0., 1.],
        [0., 0., 0.]])


In [27]:
from torch import nn

class MultiHeadAttention(nn.Module) :
    def __init__(self,d_in,d_out, context_length ,dropout,num_heads,qkv_bias=False) :
        # d_in == vector embedding dimension
        # d_out == dimension of context vector
        # context_legnth = num_tokens that the LLM can process at max
        super().__init__()
        assert (d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"
        
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        

        self.W_query = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.out_proj = nn.Linear(d_out, d_out) # To combine the head outputs

        self.register_buffer(
        "mask",
         torch.triu(torch.ones(context_length, context_length), diagonal = 1)
        )
    
    def forward(self,x) : 
        # We want the final context_vectors dimension :
        # [batch_size, num_tokens,d_out]
        
        b, num_tokens, d_in = x.shape
        
        # num_tokens <= context_length
        
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        print("Dimension of keys ", keys.shape)
        
        # [b,num_tokens, d_in] @ [d_in, d_out]
        # Keys, Queries, Values dim ==> [batch_size,num_tokens, d_out]
        
        # Splitting the keys, queries and values
        keys = keys.view(b, num_tokens,self.num_heads ,self.head_dim)
        values = values.view(b, num_tokens,self.num_heads ,self.head_dim)
        queries = queries.view(b, num_tokens,self.num_heads ,self.head_dim)
        # dim == [b,num_tokens, num_heads, head_dim]
        
        print("Dimension of keys ", keys.shape)
        
        # Reshaping the keys,values, quries
        ## ==> [b, num_heads, num_tokens, head_dim]
        
        keys = keys.transpose(1,2)
        queries = queries.transpose(1,2)
        values = values.transpose(1,2)
        # dim == [batch_size, num_heads, num_tokens, head_dim]
        
        print("New dimension of keys ", keys.shape)
        
        attn_scores = queries @ keys.transpose(2,3)
        # dim == [batch_size, num_heads, num_tokens, num_tokens]
        
        # Masking
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        attn_scores = attn_scores.masked_fill_(mask_bool, -torch.inf)
        print("Attention scores masked ", attn_scores)
        
        attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim = -1)
        
        # apply dropout
        attn_weights = self.dropout(attn_weights)
        
        context_vec = (attn_weights @ values).transpose(1,2)
        
        # attn_weights @ values == [batch_size, num_heads,num_tokens, head_dim]
        # we will take a transpose to join the num_heads * head_dim
        # dim after transpose = [b,num_tokens, num_heads, head_dim]
        
        # Now join :
        context_vec = context_vec.contiguous().view(b,num_tokens,self.d_out)
        # final dimensino of context vec : [b,num_tokens,d_out]
        print(context_vec.shape)
        
        context_vec = self.out_proj(context_vec)
        # [b, num_tokens, d_out] @ [d_out, d_out] == [b,num_tokens, d_out]
        
        return context_vec

In [28]:
torch.manual_seed(123)

# Define the tensor with 3 rows and 6 columns
inputs = torch.tensor(
    [[0.43, 0.15, 0.89, 0.55, 0.87, 0.66],  # Row 1
     [0.57, 0.85, 0.64, 0.22, 0.58, 0.33],  # Row 2
     [0.77, 0.25, 0.10, 0.05, 0.80, 0.55]]  # Row 3
)

batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape) 

batch_size, context_size, d_in = batch.shape
# context_size = num_tokens
d_out = 6

mha = MultiHeadAttention(d_in, d_out, context_size, 0.0, num_heads = 2)
context_vecs = mha(batch)
print(context_vecs)
print(context_vecs.shape)

torch.Size([2, 3, 6])
Dimension of keys  torch.Size([2, 3, 6])
Dimension of keys  torch.Size([2, 3, 2, 3])
New dimension of keys  torch.Size([2, 2, 3, 3])
Attention scores masked  tensor([[[[-0.0917,    -inf,    -inf],
          [-0.1737, -0.3922,    -inf],
          [-0.0993, -0.2388, -0.1278]],

         [[-0.0584,    -inf,    -inf],
          [-0.0252, -0.2527,    -inf],
          [ 0.0454, -0.1627,  0.0923]]],


        [[[-0.0917,    -inf,    -inf],
          [-0.1737, -0.3922,    -inf],
          [-0.0993, -0.2388, -0.1278]],

         [[-0.0584,    -inf,    -inf],
          [-0.0252, -0.2527,    -inf],
          [ 0.0454, -0.1627,  0.0923]]]], grad_fn=<MaskedFillBackward0>)
torch.Size([2, 3, 6])
tensor([[[ 0.1569, -0.0873,  0.0210,  0.0215, -0.3243, -0.2518],
         [ 0.1117, -0.0547,  0.0406, -0.0213, -0.3251, -0.2993],
         [ 0.1196, -0.0491,  0.0318, -0.0635, -0.2788, -0.2578]],

        [[ 0.1569, -0.0873,  0.0210,  0.0215, -0.3243, -0.2518],
         [ 0.1117, -0.0547