In [15]:
import torch
import torch.nn as nn

inputs = torch.tensor(
    [[0.43, 0.15, 0.89],  # Your     (x^1)
     [0.55, 0.87, 0.66],  # journey  (x^2)
     [0.57, 0.85, 0.64],  # starts   (x^3)
     [0.22, 0.58, 0.33],  # with     (x^4)
     [0.77, 0.25, 0.10],  # one      (x^5)
     [0.05, 0.80, 0.55]]  # step     (x^6)
)
d_in = inputs.shape[1]      #2
d_out = 2

batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)
batch

torch.Size([2, 6, 3])


tensor([[[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]],

        [[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]]])

In [25]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out,
                 context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads    #1
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)
        print(f"Weights: {self.out_proj.weight}")
        print(f"bias: {self.out_proj.bias}")
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)
        )

    def forward(self, x):
        print(f"x.shape: {x.shape}")
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)         #3
        queries = self.W_query(x)    #3
        values = self.W_value(x)     #3
        print(f"keys: {keys}")

        # b 2
        # num_tokens 6
        # num_heads 2
        # head_dim 1 (d_out = 2/num_heads = 2)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)       #4
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(
            b, num_tokens, self.num_heads, self.head_dim
        )
        print(f"keys new view: {keys}")

        keys = keys.transpose(1, 2)          #5
        queries = queries.transpose(1, 2)    #5
        values = values.transpose(1, 2)      #5
        print(f"keys transposed: {keys}")

        attn_scores = queries @ keys.transpose(2, 3)   #6
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]    #7

        attn_scores.masked_fill_(mask_bool, -torch.inf)     #8

        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vec = (attn_weights @ values).transpose(1, 2)   #9
        print(f"context_vec: {context_vec}")
 #10
        context_vec = context_vec.contiguous().view(
            b, num_tokens, self.d_out
        )
        print(f"context_vec reshaped: {context_vec}")
        context_vec = self.out_proj(context_vec)    #11
        print(f"context_vec out_proj: {context_vec}")
        return context_vec

In [27]:
torch.manual_seed(123)
batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

Weights: Parameter containing:
tensor([[-0.1668,  0.2270],
        [ 0.5000,  0.1317]], requires_grad=True)
bias: Parameter containing:
tensor([0.1934, 0.6825], requires_grad=True)
x.shape: torch.Size([2, 6, 3])
keys: tensor([[[-0.5740,  0.2727],
         [-0.8709,  0.1008],
         [-0.8628,  0.1060],
         [-0.4789,  0.0051],
         [-0.4744,  0.1696],
         [-0.5888, -0.0388]],

        [[-0.5740,  0.2727],
         [-0.8709,  0.1008],
         [-0.8628,  0.1060],
         [-0.4789,  0.0051],
         [-0.4744,  0.1696],
         [-0.5888, -0.0388]]], grad_fn=<UnsafeViewBackward0>)
keys new view: tensor([[[[-0.5740],
          [ 0.2727]],

         [[-0.8709],
          [ 0.1008]],

         [[-0.8628],
          [ 0.1060]],

         [[-0.4789],
          [ 0.0051]],

         [[-0.4744],
          [ 0.1696]],

         [[-0.5888],
          [-0.0388]]],


        [[[-0.5740],
          [ 0.2727]],

         [[-0.8709],
          [ 0.1008]],

         [[-0.8628],
         

In [45]:
# Excercise

torch.manual_seed(123)

# 6 tokens, embedding size 768
# Simulate 12 batches
batch_gpt2 = torch.stack((torch.randn(6, 768), torch.randn(6, 768), torch.randn(6, 768), torch.randn(6, 768), torch.randn(6, 768), torch.randn(6, 768), torch.randn(6, 768), torch.randn(6, 768), torch.randn(6, 768), torch.randn(6, 768), torch.randn(6, 768), torch.randn(6, 768)), dim=0)

batch_size, context_length, d_in = batch_gpt2.shape
d_out = 768
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=12)
context_vecs = mha(batch_gpt2)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

Weights: Parameter containing:
tensor([[ 0.0288,  0.0091, -0.0089,  ..., -0.0188, -0.0202, -0.0159],
        [ 0.0112, -0.0108, -0.0133,  ..., -0.0217, -0.0196,  0.0315],
        [ 0.0046,  0.0129, -0.0118,  ..., -0.0281, -0.0161,  0.0091],
        ...,
        [-0.0064,  0.0348, -0.0162,  ...,  0.0339, -0.0086,  0.0251],
        [ 0.0350,  0.0139, -0.0014,  ..., -0.0224,  0.0312,  0.0273],
        [ 0.0014, -0.0176,  0.0133,  ..., -0.0298, -0.0010, -0.0219]],
       requires_grad=True)
bias: Parameter containing:
tensor([-2.8296e-02, -1.6673e-02,  2.9840e-02,  1.3796e-02,  7.6705e-03,
         3.2462e-02, -3.3612e-02,  2.5259e-02,  1.6343e-02, -3.2406e-02,
        -1.0553e-02,  1.5028e-02, -3.2378e-02,  2.7463e-02,  2.4137e-02,
         6.7063e-03, -3.0391e-02, -2.0159e-03,  1.1820e-02,  1.3553e-02,
         2.3600e-02, -2.7559e-02,  1.8782e-02,  9.4929e-03, -1.1878e-02,
        -1.7413e-02, -3.5899e-02, -1.7677e-02,  2.3865e-02, -3.5720e-02,
        -1.6256e-02, -4.1255e-03,  2.5740e