    --- Multi Head Attention Mechanism -> GPT 2 ---
    1. Input and output dimension ~ 768
    2. Context Length ~ 1024
    3. Attention heads ~ 12

In [7]:
# pytorch
import torch
import torch.nn as nn

In [9]:
# dummy inputs
inputs = torch.tensor([
[0.43, 0.15, 0.89],
[0.55, 0.87, 0.66],
[0.57, 0.85, 0.64],
[0.22, 0.58, 0.33],
[0.77, 0.25, 0.10],
[0.05, 0.80, 0.55]
])
print(inputs)

tensor([[0.4300, 0.1500, 0.8900],
        [0.5500, 0.8700, 0.6600],
        [0.5700, 0.8500, 0.6400],
        [0.2200, 0.5800, 0.3300],
        [0.7700, 0.2500, 0.1000],
        [0.0500, 0.8000, 0.5500]])


In [29]:
# multi head attention class -> GPT 2
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias = False):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"

        # define variables
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        self.W_query = torch.nn.Linear(d_in, d_out, bias = qkv_bias)
        self.W_key = torch.nn.Linear(d_in, d_out, bias = qkv_bias)
        self.W_value = torch.nn.Linear(d_in, d_out, bias = qkv_bias)
        self.out_proj = torch.nn.Linear(d_out, d_out)
        self.dropout = torch.nn.Dropout(p = 0.5)
        self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal = 1))

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        # reshaping
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)

        # transpose 2 matrix
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # scores and weights
        attention_scores = queries @ keys.transpose(2, 3)
        mask = self.mask.bool()[:num_tokens, :num_tokens]
        attention_scores.masked_fill_(mask, -torch.inf)
        attention_weights = torch.softmax(attention_scores / keys.shape[-1] ** 0.5, dim = -1)
        attention_weights = self.dropout(attention_weights)

        # context vector
        context_vector = (attention_weights @ values).transpose(1, 2)
        context_vector = context_vector.contiguous().view(b, num_tokens, self.d_out)
        context_vector = self.out_proj(context_vector)
        return context_vector

In [30]:
batch_size = 2
context_length = 1024
embedding_dim = 768
num_heads = 12

In [31]:
# 2 inputs stacked
batch = torch.rand(batch_size, context_length, embedding_dim)
print(batch)
print(batch.shape)

tensor([[[0.9376, 0.0789, 0.0154,  ..., 0.7146, 0.6156, 0.8013],
         [0.0518, 0.5691, 0.5170,  ..., 0.2004, 0.8347, 0.8188],
         [0.7874, 0.9510, 0.7134,  ..., 0.3541, 0.6651, 0.5039],
         ...,
         [0.3708, 0.4403, 0.4082,  ..., 0.8869, 0.1839, 0.9136],
         [0.3791, 0.2145, 0.2959,  ..., 0.9943, 0.1894, 0.0997],
         [0.2855, 0.2986, 0.4613,  ..., 0.7675, 0.1691, 0.0940]],

        [[0.0322, 0.2788, 0.6866,  ..., 0.5323, 0.5390, 0.4320],
         [0.6601, 0.2336, 0.2429,  ..., 0.9476, 0.7902, 0.2701],
         [0.1568, 0.3891, 0.7309,  ..., 0.4209, 0.9941, 0.6957],
         ...,
         [0.5627, 0.8837, 0.4123,  ..., 0.9091, 0.8926, 0.7092],
         [0.7204, 0.2424, 0.8280,  ..., 0.9554, 0.0900, 0.9750],
         [0.3342, 0.8437, 0.0028,  ..., 0.6076, 0.9679, 0.4777]]])
torch.Size([2, 1024, 768])


In [35]:
ma = MultiHeadAttention(input_dim, output_dim, context_length, 0.0, num_heads = 12)
context_vector = ma(batch)
print(context_vector)
print("Batch, Context_length, Output_dimension")
print(context_vector.shape)

tensor([[[ 0.1262, -0.3485, -0.2013,  ...,  0.4537,  0.0804, -0.0666],
         [ 0.1460,  0.1017, -0.0718,  ...,  0.0064,  0.1986, -0.0366],
         [ 0.2423, -0.0534,  0.0023,  ..., -0.3047,  0.0735,  0.0331],
         ...,
         [ 0.1853, -0.0299, -0.0615,  ..., -0.0128,  0.0646,  0.0571],
         [ 0.1941, -0.0295, -0.0729,  ..., -0.0260,  0.0623,  0.0570],
         [ 0.1965, -0.0280, -0.0609,  ..., -0.0146,  0.0651,  0.0717]],

        [[-0.0181,  0.1903, -0.2322,  ..., -0.3012,  0.2181, -0.2026],
         [ 0.0825, -0.0283,  0.1364,  ..., -0.1499, -0.0179,  0.1686],
         [ 0.2580, -0.0461, -0.1338,  ..., -0.0244,  0.0315, -0.1044],
         ...,
         [ 0.2036, -0.0195, -0.0606,  ..., -0.0172,  0.0588,  0.0525],
         [ 0.1918, -0.0334, -0.0627,  ..., -0.0291,  0.0618,  0.0553],
         [ 0.1925, -0.0305, -0.0624,  ..., -0.0129,  0.0581,  0.0563]]],
       grad_fn=<ViewBackward0>)
Batch, Context_length, Output_dimension
torch.Size([2, 1024, 768])
