### Implement Self-Attention mechanism in GPT

#### Multi Head Self-Attention 

In [73]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadSelfAttention(nn.Module):
    """ A Vanilla Multi-Head Self-Attention layer.
    """

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0

        # key, query, value projections for all heads
        self.key = nn.Linear(config.n_embd, config.n_embd)
        self.query = nn.Linear(config.n_embd, config.n_embd)
        self.value = nn.Linear(config.n_embd, config.n_embd)
        
        # output projection
        self.proj = nn.Linear(config.n_embd, config.n_embd)

        # regularization
        self.attn_dropout = nn.Dropout(config.attn_pdrop)
        self.resid_dropout = nn.Dropout(config.resid_pdrop)

        self.n_head = config.n_head

        # Casual mask to ensure that attention is only applied to the left in the input sequence
        self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size))
                                     .view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        """The forward pass for the multi-head masked self-attention layer.

        In this exercise, we include lots of print statements and checks to help you
        understand the code and the shapes of the tensors. When actually training
        such a model you would not log this information to the console.
        """

        # batch size, sequence length (in tokens), embedding dimensionality (n_embd per token)
        B, T, C = x.size()
        hs = C // self.n_head  # head size

        # print some debug information
        print(f"batch size: {B}")
        print(f"sequence length: {T}")
        print(f"embedding dimensionality: {C}")
        print(f"number of heads: {self.n_head}")
        print(f"head size: {hs}")

        # Calculate the query, key, and value matrices for all the heads in the batch and move head forward to be the batch dimension
        # The resulting dims for k, q, and v are (B, n_head, T, hs)
        k = self.key(x).view(B, T, self.n_head, hs).transpose(1, 2)
        q = self.query(x).view(B, T, self.n_head, hs).transpose(1, 2)
        v = self.value(x).view(B, T, self.n_head, hs).transpose(1, 2)

        print("=== Calculate MatrixMultiplication(Q, K_T) / sqrt(d_k) ===")
        k_t = k.transpose(-2, -1)
        print("Shape of K_T:", k_t.size())
        d_k = k.size(-1)
        print("d_k:", d_k)

        # Matrix multiplication to get the raw attention scores.
        att = (q @ k_t) / math.sqrt(d_k)  # (B, n_head, T, T)

        print(f"q.shape: {q.shape}")
        print(f"k_t.shape: {k_t.shape}")
        print(f"d_k: {d_k}")
        print(f"att.shape: {att.shape}")

        print("=== Apply the attention mask ===")
        masked_fill_value = float("-inf")
        
        att = att.masked_fill(self.mask[:, :, :T, :T] == 0, masked_fill_value)

        # Show the result of applying the mask
        print(f"att: {att}")

        print("=== Softmax ===")
        att = F.softmax(att, dim=-1)

        att = self.attn_dropout(att)

        # Show the result of applying the softmax and check that
        # the sum of the attention weights in each row is 1
        print(f"att.shape: {att.shape}")
        print(f"att: {att}")
        print(f"att.sum(dim=-1): {att.sum(dim=-1)}")
        att_rows_sum_to_one = all(
            ((att.sum(dim=-1) - 1.0) ** 2 < 1e-6).flatten().tolist()
        )
        print(f"att_rows_sum_to_one: {att_rows_sum_to_one}")
        if not att_rows_sum_to_one:
            raise ValueError(
                "Attention weight rows do not sum to 1. Perhaps the softmax dimension or masked_fill_value is not correct?"
            )
        
        print("=== Calculate final attention ===")
        y = att @ v

        print(f"y.shape: {y.shape}")

        # Re-assemble all head outputs side by side
        y = y.transpose(1, 2).contiguous().view(B, T, C)

        # output projection
        y = self.resid_dropout(self.proj(y))
        print(f"Final output shape: {y.shape}")
        return y
        

In [74]:
class GPTConfig:
    vocab_size = 11
    block_size = 5
    # n_layer = 1 # not used here
    n_head = 4
    n_embd = 12

    attn_pdrop = 0.0
    resid_pdrop = 0.0

    

In [75]:
attention = MultiHeadSelfAttention(GPTConfig())
x = torch.tensor(
    [
        [
            # 12-dimensional embeddings (4 heads @ 3 dim. each) for each of 5 tokens
            [1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0],
            [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
            [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
            [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
            [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
        ]
    ]
)

In [76]:
x.shape

torch.Size([1, 5, 12])

In [77]:
# Set all parameters to 0.1
for weight in attention.parameters():
    nn.init.constant_(weight, 0.1)

In [78]:
# Set the model to evaluation mode to disable dropout
# attention.eval()

In [79]:
# Perform a forward pass
y = attention(x)
assert y.shape == x.shape

batch size: 1
sequence length: 5
embedding dimensionality: 12
number of heads: 4
head size: 3
=== Calculate MatrixMultiplication(Q, K_T) / sqrt(d_k) ===
Shape of K_T: torch.Size([1, 4, 3, 5])
d_k: 3
q.shape: torch.Size([1, 4, 5, 3])
k_t.shape: torch.Size([1, 4, 3, 5])
d_k: 3
att.shape: torch.Size([1, 4, 5, 5])
=== Apply the attention mask ===
att: tensor([[[[16.6450,    -inf,    -inf,    -inf,    -inf],
          [ 6.9802,  2.9272,    -inf,    -inf,    -inf],
          [ 6.9802,  2.9272,  2.9272,    -inf,    -inf],
          [ 6.9802,  2.9272,  2.9272,  2.9272,    -inf],
          [ 6.9802,  2.9272,  2.9272,  2.9272,  2.9272]],

         [[16.6450,    -inf,    -inf,    -inf,    -inf],
          [ 6.9802,  2.9272,    -inf,    -inf,    -inf],
          [ 6.9802,  2.9272,  2.9272,    -inf,    -inf],
          [ 6.9802,  2.9272,  2.9272,  2.9272,    -inf],
          [ 6.9802,  2.9272,  2.9272,  2.9272,  2.9272]],

         [[16.6450,    -inf,    -inf,    -inf,    -inf],
          [ 6.9802,

In [80]:
print("=== Showing the input and output ===")
print(x)
print(y)

=== Showing the input and output ===
tensor([[[1., 1., 1., 2., 2., 2., 3., 3., 3., 4., 4., 4.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]])
tensor([[[3.8200, 3.8200, 3.8200, 3.8200, 3.8200, 3.8200, 3.8200, 3.8200,
          3.8200, 3.8200, 3.8200, 3.8200],
         [3.7831, 3.7831, 3.7831, 3.7831, 3.7831, 3.7831, 3.7831, 3.7831,
          3.7831, 3.7831, 3.7831, 3.7831],
         [3.7475, 3.7475, 3.7475, 3.7475, 3.7475, 3.7475, 3.7475, 3.7475,
          3.7475, 3.7475, 3.7475, 3.7475],
         [3.7130, 3.7130, 3.7130, 3.7130, 3.7130, 3.7130, 3.7130, 3.7130,
          3.7130, 3.7130, 3.7130, 3.7130],
         [3.6797, 3.6797, 3.6797, 3.6797, 3.6797, 3.6797, 3.6797, 3.6797,
          3.6797, 3.6797, 3.6797, 3.6797]]], grad_fn=<ViewBackward0>)


In [84]:
print("=== Checking gradients ===")
loss = y.sum()
loss.backward()

=== Checking gradients ===


In [85]:
loss

tensor(224.9195, grad_fn=<SumBackward0>)

In [86]:
# check if nan in y
if torch.isnan(y).any().item():
    raise ValueError(
        "It appears that the output contains NaNs. Perhaps the softmax dimension is incorrect?"
    )

In [87]:
gradients = [
    int((attention.query.weight.grad**2).sum().item()),
    int((attention.query.bias.grad**2).sum().item()),
    int((attention.key.weight.grad**2).sum().item()),
    int((attention.key.bias.grad**2).sum().item()),
    int((attention.value.weight.grad**2).sum().item()),
    int((attention.value.bias.grad**2).sum().item()),
]

In [88]:
print("Gradients:", gradients)

Gradients: [161, 13, 294, 0, 37187, 432]


In [89]:
if gradients == [161, 13, 294, 0, 37187, 432]:
    print("Success! ðŸš€ðŸš€ðŸš€")
elif gradients == [1, 0, 2, 0, 38787, 432]:
    raise RuntimeError(
        "There is an error in your implementation. Please check your code. Did you remember to divide by the square root of d_k?"
    )
elif gradients[-1] == 432:
    raise RuntimeError(
        "There is an error in your implementation. Please check your code. Did you use -inf as the masked_fill_value?"
    )
else:
    raise RuntimeError(
        "There is an error in your implementation. Please check your code."
    )

Success! ðŸš€ðŸš€ðŸš€
