## Multi Head Attention
![multi_head_attention](./images/multi_head_attention.png)

In [1]:
import torch

In [2]:
inputs = torch.tensor(
    [[0.72, 0.45, 0.310], #Dream
    [0.75, 0.20,0.55], #big
    [0.30,0.80,0.40], #and
    [0.85,0.35,0.60], #work
    [0.55,0.15,0.75], #for
    [0.25,0.20,0.85] #it
    ]
)

#corresponding wordss
words = ["Dream", "big", "and", "work", "for", "it"]

In [9]:
import torch.nn as nn
class CausalAttention(nn.Module):

    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.d_in = d_in
        self.d_out = d_out
        self.dropout = nn.Dropout(dropout)

        # Define the query, key, value linear layers
        self.w_queryquery = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.w_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.w_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        b, num_tokens, d_in = x.shape  # batch size, number of tokens, input dimension
        keys = self.w_key(x)    # (b, num_tokens, d_out)
        queries = self.w_queryquery(x)  # (b, num_tokens, d_out)   
        values = self.w_value(x)  # (b, num_tokens, d_out)

        attn_scores = queries @ keys.transpose(1,2) # (b, num_tokens, num_tokens) (1,2 is the index to be transposed)
        attn_scores.masked_fill_(self.mask == 1, float('-inf'))
        attn_weights = torch.softmax(attn_scores / torch.sqrt(torch.tensor(self.d_out)), dim=-1)  # (b, num_tokens, num_tokens)
        attn_weights = self.dropout(attn_weights)

        context_vec = attn_weights @ values  # (b, num_tokens, d_out)
        return context_vec

In [10]:
d_in = inputs.shape[-1] 
d_out = 2

In [31]:
batch = torch.stack((inputs,), dim=0)   
batch.shape  # (1, 6, 3)

torch.Size([1, 6, 3])

In [32]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.num_heads = num_heads
        self.heads = nn.ModuleList([
            CausalAttention(d_in, d_out, context_length, dropout, qkv_bias=qkv_bias)
            for _ in range(num_heads)
        ])

    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

In [35]:
torch.manual_seed(0)
context_length = batch.shape[1]
d_in, d_out = 3, 2
num_heads = 4

mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, dropout=0.0, num_heads=num_heads)

In [36]:
context_vecs = mha(batch)  # (1, 6, 4)
context_vecs

tensor([[[-0.6430,  0.2255, -0.0536, -0.5405, -0.0793,  0.5405,  0.1871,
           0.1333],
         [-0.6319,  0.2390, -0.1641, -0.5589, -0.0362,  0.5299,  0.2138,
           0.1721],
         [-0.6097,  0.2691, -0.0483, -0.5788, -0.0075,  0.4847,  0.2282,
           0.0269],
         [-0.6439,  0.2779, -0.0935, -0.6058, -0.0053,  0.5151,  0.2411,
           0.0674],
         [-0.6230,  0.2833, -0.1427, -0.6075,  0.0220,  0.4946,  0.2532,
           0.0693],
         [-0.5880,  0.2937, -0.1667, -0.6075,  0.0578,  0.4567,  0.2675,
           0.0352]]], grad_fn=<CatBackward0>)