In [2]:
import torch
import torch.nn as nn

In [23]:
# Define the tensor with 3 rows and 6 columns
inputs = torch.tensor(
    [[0.43, 0.15, 0.89, 0.55, 0.87, 0.66],  # Row 1
     [0.57, 0.85, 0.64, 0.22, 0.58, 0.33],  # Row 2
     [0.77, 0.25, 0.10, 0.05, 0.80, 0.55]]  # Row 3
)

batch = torch.stack((inputs, inputs), dim = 0)
batch

tensor([[[0.4300, 0.1500, 0.8900, 0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400, 0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000, 0.0500, 0.8000, 0.5500]],

        [[0.4300, 0.1500, 0.8900, 0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400, 0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000, 0.0500, 0.8000, 0.5500]]])

In [26]:
class GPTConfig():
  def __init__(self) -> None:
    self.n_heads = 2
    self.d_in = 6
    self.d_out = 6
    self.context_length = 3
    self.batch_size = 2
    self.device = "CUDA" if torch.cuda else "cpu"

In [56]:
class MultiHeadAttention(nn.Module):
  def __init__(self, config: GPTConfig) -> None:
    super().__init__()
    self.config = config
    self.W_Q = nn.Linear(self.config.d_in, self.config.d_out, bias = False)
    self.W_K = nn.Linear(self.config.d_in, self.config.d_out, bias = False)
    self.W_V = nn.Linear(self.config.d_in, self.config.d_out, bias = False)
    self.head_dim = self.config.d_out // self.config.n_heads
    self.scale = self.head_dim ** -0.5
    self.n_heads = self.config.n_heads
    self.batch_size = self.config.batch_size
    self.register_buffer("mask", torch.tril(torch.ones(self.config.d_out, self.config.d_out)))

  def forward(self, x):
    context_vector = None
    b, n_tokens, d_in = x.shape

    print("Input shape: ", x.shape, end = "\n")

    self.Q = self.W_Q(x)
    self.K = self.W_K(x)
    self.V = self.W_V(x)

    print("Q shape: ", self.Q.shape)
    print("K shape: ", self.K.shape)
    print("V shape: ", self.V.shape, end = "\n\n")

    # Attention
    self.Q = self.Q.view(self.batch_size, n_tokens, self.n_heads, self.head_dim)
    self.K = self.K.view(self.batch_size, n_tokens, self.n_heads, self.head_dim)
    self.V = self.V.view(self.batch_size, n_tokens, self.n_heads, self.head_dim)

    print("QKV shape after view")
    print("Q shape: ", self.Q.shape)
    print("K shape: ", self.K.shape)
    print("V shape: ", self.V.shape, end= "\n\n")

    self.Q = self.Q.transpose(1, 2)
    self.K = self.K.transpose(1, 2)
    self.V = self.V.transpose(1, 2)

    print("QKV shape after transpose")
    print("Q shape: ", self.Q.shape)
    print("K shape: ", self.K.shape)
    print("V shape: ", self.V.shape, end= "\n\n")

    attention_scores = self.Q @ self.K.transpose(-1, -2)
    print("Attention scores: " ,attention_scores.shape)
    print(attention_scores, end="\n\n")

    #apply the mask
    mask_bool = self.mask.bool()[:n_tokens, :n_tokens]

    print("mask shape : ", mask_bool.shape)
    print(mask_bool, end="\n\n")

    attention_scores = attention_scores.masked_fill(mask_bool == False, -torch.inf)
    print("Attention scores after mask: " ,attention_scores.shape)
    print(attention_scores, end="\n\n")

    attention_weights = torch.softmax(attention_scores * self.config.d_out ** 0.5, dim = -1)
    print("Attention weights: ", attention_weights.shape)
    print(attention_weights, end="\n\n")

    context_vector = (attention_weights @ self.V).transpose(1, 2)
    print("Context vector: ", context_vector.shape)
    print(context_vector, end="\n\n")

    context_vector = context_vector.contiguous().view(self.batch_size, self.config.context_length, self.config.d_out)
    print("Context vector after view: ", context_vector.shape)
    print(context_vector, end="\n\n")

    return context_vector

In [57]:
torch.manual_seed(1332)

mha = MultiHeadAttention(GPTConfig())
mha(batch)

Input shape:  torch.Size([2, 3, 6])
Q shape:  torch.Size([2, 3, 6])
K shape:  torch.Size([2, 3, 6])
V shape:  torch.Size([2, 3, 6])

QKV shape after view
Q shape:  torch.Size([2, 3, 2, 3])
K shape:  torch.Size([2, 3, 2, 3])
V shape:  torch.Size([2, 3, 2, 3])

QKV shape after transpose
Q shape:  torch.Size([2, 2, 3, 3])
K shape:  torch.Size([2, 2, 3, 3])
V shape:  torch.Size([2, 2, 3, 3])

Attention scores:  torch.Size([2, 2, 3, 3])
tensor([[[[-0.0366, -0.0253,  0.0250],
          [ 0.0821, -0.0049, -0.0078],
          [-0.0173,  0.0074, -0.0315]],

         [[ 0.0909,  0.0849,  0.0057],
          [ 0.1518,  0.1120, -0.0596],
          [ 0.0210,  0.0333, -0.0547]]],


        [[[-0.0366, -0.0253,  0.0250],
          [ 0.0821, -0.0049, -0.0078],
          [-0.0173,  0.0074, -0.0315]],

         [[ 0.0909,  0.0849,  0.0057],
          [ 0.1518,  0.1120, -0.0596],
          [ 0.0210,  0.0333, -0.0547]]]], grad_fn=<UnsafeViewBackward0>)

mask shape :  torch.Size([3, 3])
tensor([[ True, Fals

tensor([[[ 0.3915, -0.3742, -0.0789,  0.1823,  0.0470,  0.4516],
         [ 0.2928, -0.2979, -0.2070,  0.0497,  0.0251,  0.4967],
         [ 0.3632, -0.2514, -0.1607,  0.0796, -0.0119,  0.5361]],

        [[ 0.3915, -0.3742, -0.0789,  0.1823,  0.0470,  0.4516],
         [ 0.2928, -0.2979, -0.2070,  0.0497,  0.0251,  0.4967],
         [ 0.3632, -0.2514, -0.1607,  0.0796, -0.0119,  0.5361]]],
       grad_fn=<ViewBackward0>)