In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [18]:
sentence = "The quick brown fox jumps over the lazy dog"

## Simple tokenization by splitting on spaces, ideally more complex tokenization would be used like BPE or WordPiece
sentence = sentence.split()
n = len(sentence)

print(f"Tokenized sentence: {sentence}")
print(f"Number of tokens: {len(sentence)}")

Tokenized sentence: ['The', 'quick', 'brown', 'fox', 'jumps', 'over', 'the', 'lazy', 'dog']
Number of tokens: 9


In [19]:
## Sample word embeddings, ideally these would be learned in the language modelling process or loaded from a pre-trained model like GloVe or Word2Vec

# Shape of embeddings: (n, d) where n is number of tokens and d is embedding dimension
embeddings = torch.tensor([
        [1.0, 0.5, 0.2, 0.8], 
        [0.3, 1.0, 0.7, 0.1],  
        [0.6, 0.2, 1.0, 0.4],  
        [0.9, 0.8, 0.3, 1.0],  
        [0.4, 0.6, 0.8, 0.2],  
        [0.7, 0.3, 0.5, 0.9],  
        [1.0, 0.5, 0.2, 0.8],  
        [0.2, 0.9, 0.4, 0.6],  
        [0.8, 0.4, 0.9, 0.3]  
    ])

print("Word embeddings (4-dimensional):")
for i, word in enumerate(sentence):
    print(f"  {word:4}: {embeddings[i]}")
print()

Word embeddings (4-dimensional):
  The : tensor([1.0000, 0.5000, 0.2000, 0.8000])
  quick: tensor([0.3000, 1.0000, 0.7000, 0.1000])
  brown: tensor([0.6000, 0.2000, 1.0000, 0.4000])
  fox : tensor([0.9000, 0.8000, 0.3000, 1.0000])
  jumps: tensor([0.4000, 0.6000, 0.8000, 0.2000])
  over: tensor([0.7000, 0.3000, 0.5000, 0.9000])
  the : tensor([1.0000, 0.5000, 0.2000, 0.8000])
  lazy: tensor([0.2000, 0.9000, 0.4000, 0.6000])
  dog : tensor([0.8000, 0.4000, 0.9000, 0.3000])



In [20]:
## Sample positional encodings, typically these would be generated using math functions or learned during training or RoPE

positional_embeddings = torch.tensor([
    [0.0, 1.0, 0.0, 1.0],  
    [0.1, 0.9, 0.1, 0.9],  
    [0.2, 0.8, 0.2, 0.8],  
    [0.3, 0.7, 0.3, 0.7],  
    [0.4, 0.6, 0.4, 0.6],  
    [0.5, 0.5, 0.5, 0.5],  
    [0.6, 0.4, 0.6, 0.4],  
    [0.7, 0.3, 0.7, 0.3],  
    [0.8, 0.2, 0.8, 0.2]   
])

print("Positional embeddings (Same dimesnions as word embeddings):")
for i, word in enumerate(sentence):
    print(f"  Pos {i} ({word}): {positional_embeddings[i]}")
print()

Positional embeddings (Same dimesnions as word embeddings):
  Pos 0 (The): tensor([0., 1., 0., 1.])
  Pos 1 (quick): tensor([0.1000, 0.9000, 0.1000, 0.9000])
  Pos 2 (brown): tensor([0.2000, 0.8000, 0.2000, 0.8000])
  Pos 3 (fox): tensor([0.3000, 0.7000, 0.3000, 0.7000])
  Pos 4 (jumps): tensor([0.4000, 0.6000, 0.4000, 0.6000])
  Pos 5 (over): tensor([0.5000, 0.5000, 0.5000, 0.5000])
  Pos 6 (the): tensor([0.6000, 0.4000, 0.6000, 0.4000])
  Pos 7 (lazy): tensor([0.7000, 0.3000, 0.7000, 0.3000])
  Pos 8 (dog): tensor([0.8000, 0.2000, 0.8000, 0.2000])



In [21]:
## The final input to the Attention block is the sum of the word embeddings and positional encodings

input_embeddings = embeddings + positional_embeddings

print("Input embeddings (word + positional):")
for i, word in enumerate(sentence):
    print(f"  {word:4}: {input_embeddings[i]}")
print()

Input embeddings (word + positional):
  The : tensor([1.0000, 1.5000, 0.2000, 1.8000])
  quick: tensor([0.4000, 1.9000, 0.8000, 1.0000])
  brown: tensor([0.8000, 1.0000, 1.2000, 1.2000])
  fox : tensor([1.2000, 1.5000, 0.6000, 1.7000])
  jumps: tensor([0.8000, 1.2000, 1.2000, 0.8000])
  over: tensor([1.2000, 0.8000, 1.0000, 1.4000])
  the : tensor([1.6000, 0.9000, 0.8000, 1.2000])
  lazy: tensor([0.9000, 1.2000, 1.1000, 0.9000])
  dog : tensor([1.6000, 0.6000, 1.7000, 0.5000])



In [22]:
d_model = 4      # embedding dimension of the tokens
num_heads = 2    # number of attention heads
d_k = d_model // num_heads # Dimension of the Q, K and V matrices for each head

In [31]:
torch.manual_seed(42)

## Shape of the weights is (num_heads, d_model, d_k) for Q, K, V and (d_model, d_model) for output projection
W_q = torch.randn(num_heads, d_model, d_k, dtype=torch.float32) * 0.3  # Query wieghts
W_k = torch.randn(num_heads, d_model, d_k, dtype=torch.float32) * 0.3  # Key weights
W_v = torch.randn(num_heads, d_model, d_k, dtype=torch.float32) * 0.3  # Value weights
W_o = torch.randn(d_model, d_model) * 0.3 # Output projection weights. This is applied after concatenating the heads so that the head outputs can interact information with each other.

print(f"W_q (Query weights) shape: {W_q.shape}")
print(W_q)
print(f"\nW_k (Key weights) shape: {W_k.shape}")
print(W_k)
print(f"\nW_v (Value weights) shape: {W_v.shape}")
print(W_v)
print()

print(f"W_o (Output projection weights) shape: {W_o.shape}")
print(W_o)

W_q (Query weights) shape: torch.Size([2, 4, 2])
tensor([[[ 0.5781,  0.4462],
         [ 0.2702, -0.6317],
         [ 0.2035, -0.3704],
         [-0.0129, -0.4814]],

        [[-0.2256,  0.4946],
         [-0.1177, -0.4211],
         [-0.2184, -0.1678],
         [-0.2307,  0.2287]]])

W_k (Key weights) shape: torch.Size([2, 4, 2])
tensor([[[ 0.4927, -0.0479],
         [-0.1492,  0.1319],
         [-0.2274,  0.3235],
         [ 0.2402,  0.5042]],

        [[ 0.3837,  0.3889],
         [ 0.1831,  0.4004],
         [-0.0695,  0.0125],
         [-0.0755,  0.2580]]])

W_v (Value weights) shape: torch.Size([2, 4, 2])
tensor([[[-0.4154, -0.2614],
         [-0.0670,  0.5152],
         [ 0.0957, -0.1274],
         [ 0.0917, -0.2324]],

        [[-0.4673,  0.2987],
         [-0.2639, -0.1803],
         [-0.3822,  0.6368],
         [-0.3704, -0.1464]]])

W_o (Output projection weights) shape: torch.Size([4, 4])
tensor([[-0.2741, -0.1974,  0.0234,  0.1577],
        [-0.1464,  0.3574, -0.2442, -0.2

In [32]:
## To concatenate the outputs of each of the self attention heads, we will store them in this list. This will be of shape (num_heads, n, d_k)
head_outputs = []

In [33]:
for head in range(num_heads):
    print(f"=== HEAD {head + 1} ===")
    
    # Compute Q, K, V for this specific head only
    Q_h = torch.matmul(input_embeddings, W_q[head])  # (n, d_k)
    K_h = torch.matmul(input_embeddings, W_k[head])  # (n, d_k)
    V_h = torch.matmul(input_embeddings, W_v[head])  # (n, d_k)
    
    print(f"Head {head + 1} - Computing Q, K, V using weight matrices:")
    print(f"  Q_h = input_embeddings @ W_q[{head}]  ->  shape: {Q_h.shape}")
    print(f"  K_h = input_embeddings @ W_k[{head}]  ->  shape: {K_h.shape}")
    print(f"  V_h = input_embeddings @ W_v[{head}]  ->  shape: {V_h.shape}")
    print()

=== HEAD 1 ===
Head 1 - Computing Q, K, V using weight matrices:
  Q_h = input_embeddings @ W_q[0]  ->  shape: torch.Size([9, 2])
  K_h = input_embeddings @ W_k[0]  ->  shape: torch.Size([9, 2])
  V_h = input_embeddings @ W_v[0]  ->  shape: torch.Size([9, 2])

=== HEAD 2 ===
Head 2 - Computing Q, K, V using weight matrices:
  Q_h = input_embeddings @ W_q[1]  ->  shape: torch.Size([9, 2])
  K_h = input_embeddings @ W_k[1]  ->  shape: torch.Size([9, 2])
  V_h = input_embeddings @ W_v[1]  ->  shape: torch.Size([9, 2])



In [None]:
for head in range(num_heads):
    print(f"=== HEAD {head + 1} ===")

    Q_h = torch.matmul(input_embeddings, W_q[head])  
    K_h = torch.matmul(input_embeddings, W_k[head]) 
    V_h = torch.matmul(input_embeddings, W_v[head]) 

    attention_scores_h = torch.matmul(Q_h, K_h.T) # (n, n)
    attention_scores_h /= torch.sqrt(torch.tensor(d_k, dtype=torch.float32)) # Scale the scores
    print(f"Head {head + 1} - Attention scores (Q_h @ K_h.T) scaled by sqrt(d_k): {attention_scores_h.shape}")

    attention_weights_h = F.softmax(attention_scores_h, dim = -1) # (n, n)
    print(f"Head {head + 1} - Attention weights after softmax: {attention_weights_h.shape}")

    output_h = torch.matmul(attention_weights_h, V_h) # (n, d_k)
    print(f"Head {head + 1} - Output (attention_weights_h @ V_h): {output_h.shape}")

    ## Will have shape (n_heads, n, d_k) after appending all head outputs
    head_outputs.append(output_h)
    print()   


=== HEAD 1 ===
Head 1 - Attention scores (Q_h @ K_h.T) scaled by sqrt(d_k): torch.Size([9, 9])
Head 1 - Attention weights after softmax: torch.Size([9, 9])
Head 1 - Output (attention_weights_h @ V_h): torch.Size([9, 2])

=== HEAD 2 ===
Head 2 - Attention scores (Q_h @ K_h.T) scaled by sqrt(d_k): torch.Size([9, 9])
Head 2 - Attention weights after softmax: torch.Size([9, 9])
Head 2 - Output (attention_weights_h @ V_h): torch.Size([9, 2])



In [35]:
## Concatenate the output of each head along the last dimension i.e. over the d_k dimension (rows). num_heads * d_k = d_model

concatenated_output = torch.cat(head_outputs, dim=1) # (n, d_model)
print(f"Concatenated output from all heads: {concatenated_output.shape}")

Concatenated output from all heads: torch.Size([9, 4])


In [38]:
final_output = torch.matmul(concatenated_output, W_o) # (n, d_model)
print(f"Final output after output projection (concatenated_output @ W_o): {final_output.shape}")

Final output after output projection (concatenated_output @ W_o): torch.Size([9, 4])


In [39]:
print(final_output)

tensor([[ 0.7657,  0.3064, -0.1404, -0.1354],
        [ 0.7626,  0.3212, -0.1500, -0.1233],
        [ 0.7642,  0.3108, -0.1435, -0.1325],
        [ 0.7672,  0.3047, -0.1389, -0.1348],
        [ 0.7640,  0.3141, -0.1455, -0.1295],
        [ 0.7662,  0.3032, -0.1384, -0.1383],
        [ 0.7682,  0.2995, -0.1357, -0.1403],
        [ 0.7646,  0.3121, -0.1442, -0.1309],
        [ 0.7678,  0.3030, -0.1378, -0.1371]])
