# Multi-head Attention

In [37]:
import torch
import torch.nn as nn

## Causal Attention Class

In [38]:
class CausalSelfAttention(nn.Module):
    def __init__(self, d_in, d_out, dropout_prob, context_size):
        super(CausalSelfAttention, self).__init__()
        self.Wq = nn.Linear(d_in, d_out, bias=False)
        self.Wk = nn.Linear(d_in, d_out, bias=False)
        self.Wv = nn.Linear(d_in, d_out, bias=False)
        self.dropout = nn.Dropout(dropout_prob)
        self.register_buffer("mask", torch.triu(torch.ones(context_size, context_size), diagonal=1))

    def forward(self, input_embeddings):
        num_batch, num_tokens, d_in = input_embeddings.size()
        queries = self.Wq(input_embeddings)
        keys = self.Wk(input_embeddings)
        values = self.Wv(input_embeddings)

        dk = keys.shape[-1] # d_out
        attention_scores = queries @ keys.transpose(1, 2)

        masked_attention_scores = attention_scores.masked_fill(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)

        attention_weights = torch.softmax(masked_attention_scores / (dk ** 0.5), dim=-1)
        dropped_attention_weights = self.dropout(attention_weights)

        context_vectors = dropped_attention_weights @ values

        return context_vectors

## Multi-Head Attention Wrapper Class

In [39]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, num_heads, dropout_prob, context_size):
        super(MultiHeadAttentionWrapper, self).__init__()
        self.num_heads = num_heads
        self.attention_heads = nn.ModuleList([
            CausalSelfAttention(d_in, d_out, dropout_prob, context_size)
            for _ in range(num_heads)
        ])

    def forward(self, input_embeddings):
        head_outputs = [head(input_embeddings) for head in self.attention_heads]
        return torch.cat(head_outputs, dim=-1)

In [40]:
D_IN = 4
D_OUT = 3
CONTEXT_SIZE = 6
NUM_BATCH = 2

In [41]:
torch.manual_seed(0)
input_embeddings = torch.randn(NUM_BATCH, CONTEXT_SIZE, D_IN)

print(f"Input Embeddings:\n{input_embeddings}\nShape: {input_embeddings.shape}\n")

Input Embeddings:
tensor([[[-1.1258, -1.1524, -0.2506, -0.4339],
         [ 0.8487,  0.6920, -0.3160, -2.1152],
         [ 0.3223, -1.2633,  0.3500,  0.3081],
         [ 0.1198,  1.2377,  1.1168, -0.2473],
         [-1.3527, -1.6959,  0.5667,  0.7935],
         [ 0.5988, -1.5551, -0.3414,  1.8530]],

        [[ 0.7502, -0.5855, -0.1734,  0.1835],
         [ 1.3894,  1.5863,  0.9463, -0.8437],
         [-0.6136,  0.0316, -0.4927,  0.2484],
         [ 0.4397,  0.1124,  0.6408,  0.4412],
         [-0.1023,  0.7924, -0.2897,  0.0525],
         [ 0.5229,  2.3022, -1.4689, -1.5867]]])
Shape: torch.Size([2, 6, 4])



In [42]:
NUM_HEADS = 2

In [43]:
multi_head_attention = MultiHeadAttentionWrapper(D_IN, D_OUT, NUM_HEADS, dropout_prob=0.2, context_size=CONTEXT_SIZE)

In [44]:
context_vectors = multi_head_attention(input_embeddings)

print(f"Context Vectors:\n{context_vectors}\nShape: {context_vectors.shape}\n")

Context Vectors:
tensor([[[-1.2488e+00,  6.4150e-01,  1.6728e-01, -1.6614e-01, -1.0396e-01,
           5.5464e-02],
         [-5.8235e-01,  2.9914e-01,  7.8004e-02, -3.5638e-01,  3.3327e-01,
          -1.3804e-01],
         [-7.4325e-01,  5.9319e-01,  2.2963e-01, -3.6690e-01,  1.5348e-01,
          -1.9105e-01],
         [-3.1232e-01,  3.8020e-01,  1.5948e-01, -1.4189e-01,  6.2545e-02,
           3.3245e-02],
         [-2.8217e-01,  3.3133e-01,  3.8277e-01, -2.1264e-01,  4.9905e-01,
          -7.9651e-04],
         [ 3.5510e-03,  8.3250e-02, -1.4802e-01, -1.1702e-01,  1.0744e-01,
          -6.1177e-02]],

        [[-1.1205e-01,  3.2981e-01, -9.5889e-01,  3.1649e-02, -7.8833e-01,
          -1.2182e-01],
         [ 4.1930e-01,  7.5069e-02, -1.5388e-01, -1.8299e-01,  9.9638e-02,
          -1.6085e-01],
         [ 3.1157e-01, -1.7752e-01,  3.4384e-01, -4.7404e-02,  8.2824e-02,
          -7.6649e-02],
         [ 3.7046e-01, -7.2181e-02, -6.9792e-02, -8.3627e-02,  3.4039e-01,
          -4.36