In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [5]:
sentence = "The quick brown fox jumps over the lazy dog"

## Simple tokenization by splitting on spaces, ideally more complex tokenization would be used like BPE or WordPiece
sentence = sentence.split()
n = len(sentence)

print(f"Tokenized sentence: {sentence}")
print(f"Number of tokens: {len(sentence)}")

Tokenized sentence: ['The', 'quick', 'brown', 'fox', 'jumps', 'over', 'the', 'lazy', 'dog']
Number of tokens: 9


In [6]:
## Sample word embeddings, ideally these would be learned in the language modelling process or loaded from a pre-trained model like GloVe or Word2Vec

# Shape of embeddings: (n, d) where n is number of tokens and d is embedding dimension
embeddings = torch.tensor([
        [1.0, 0.5, 0.2, 0.8], 
        [0.3, 1.0, 0.7, 0.1],  
        [0.6, 0.2, 1.0, 0.4],  
        [0.9, 0.8, 0.3, 1.0],  
        [0.4, 0.6, 0.8, 0.2],  
        [0.7, 0.3, 0.5, 0.9],  
        [1.0, 0.5, 0.2, 0.8],  
        [0.2, 0.9, 0.4, 0.6],  
        [0.8, 0.4, 0.9, 0.3]  
    ])

print("Word embeddings (4-dimensional):")
for i, word in enumerate(sentence):
    print(f"  {word:4}: {embeddings[i]}")
print()

Word embeddings (4-dimensional):
  The : tensor([1.0000, 0.5000, 0.2000, 0.8000])
  quick: tensor([0.3000, 1.0000, 0.7000, 0.1000])
  brown: tensor([0.6000, 0.2000, 1.0000, 0.4000])
  fox : tensor([0.9000, 0.8000, 0.3000, 1.0000])
  jumps: tensor([0.4000, 0.6000, 0.8000, 0.2000])
  over: tensor([0.7000, 0.3000, 0.5000, 0.9000])
  the : tensor([1.0000, 0.5000, 0.2000, 0.8000])
  lazy: tensor([0.2000, 0.9000, 0.4000, 0.6000])
  dog : tensor([0.8000, 0.4000, 0.9000, 0.3000])



In [7]:
## Sample positional encodings, typically these would be generated using math functions or learned during training or RoPE

positional_embeddings = torch.tensor([
    [0.0, 1.0, 0.0, 1.0],  
    [0.1, 0.9, 0.1, 0.9],  
    [0.2, 0.8, 0.2, 0.8],  
    [0.3, 0.7, 0.3, 0.7],  
    [0.4, 0.6, 0.4, 0.6],  
    [0.5, 0.5, 0.5, 0.5],  
    [0.6, 0.4, 0.6, 0.4],  
    [0.7, 0.3, 0.7, 0.3],  
    [0.8, 0.2, 0.8, 0.2]   
])

print("Positional embeddings (Same dimesnions as word embeddings):")
for i, word in enumerate(sentence):
    print(f"  Pos {i} ({word}): {positional_embeddings[i]}")
print()

Positional embeddings (Same dimesnions as word embeddings):
  Pos 0 (The): tensor([0., 1., 0., 1.])
  Pos 1 (quick): tensor([0.1000, 0.9000, 0.1000, 0.9000])
  Pos 2 (brown): tensor([0.2000, 0.8000, 0.2000, 0.8000])
  Pos 3 (fox): tensor([0.3000, 0.7000, 0.3000, 0.7000])
  Pos 4 (jumps): tensor([0.4000, 0.6000, 0.4000, 0.6000])
  Pos 5 (over): tensor([0.5000, 0.5000, 0.5000, 0.5000])
  Pos 6 (the): tensor([0.6000, 0.4000, 0.6000, 0.4000])
  Pos 7 (lazy): tensor([0.7000, 0.3000, 0.7000, 0.3000])
  Pos 8 (dog): tensor([0.8000, 0.2000, 0.8000, 0.2000])



In [8]:
## The final input to the Attention block is the sum of the word embeddings and positional encodings

input_embeddings = embeddings + positional_embeddings

print("Input embeddings (word + positional):")
for i, word in enumerate(sentence):
    print(f"  {word:4}: {input_embeddings[i]}")
print()

Input embeddings (word + positional):
  The : tensor([1.0000, 1.5000, 0.2000, 1.8000])
  quick: tensor([0.4000, 1.9000, 0.8000, 1.0000])
  brown: tensor([0.8000, 1.0000, 1.2000, 1.2000])
  fox : tensor([1.2000, 1.5000, 0.6000, 1.7000])
  jumps: tensor([0.8000, 1.2000, 1.2000, 0.8000])
  over: tensor([1.2000, 0.8000, 1.0000, 1.4000])
  the : tensor([1.6000, 0.9000, 0.8000, 1.2000])
  lazy: tensor([0.9000, 1.2000, 1.1000, 0.9000])
  dog : tensor([1.6000, 0.6000, 1.7000, 0.5000])



In [None]:
## Demo of dropout that will be added intermittently in the model
demo = torch.ones(6, 6)
print(f"Demo tensor before dropout:{demo}")

dropout = nn.Dropout(p=0.5) ## Randomly zeroes 50% of the elements in the input tensor during training and doubles the remaining elements to maintain the expected value
demo_dropped = dropout(demo)
print(f"Demo tensor after dropout:{demo_dropped}")

Demo tensor before dropout:tensor([[1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.]])
Demo tensor after dropout:tensor([[0., 2., 0., 0., 2., 2.],
        [0., 0., 0., 2., 2., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 2., 0., 2., 2., 2.],
        [2., 2., 0., 0., 0., 2.],
        [0., 2., 0., 2., 0., 2.]])


In [9]:
dropout = nn.Dropout(p=0.2)

input_embeddings = dropout(input_embeddings)
print(f"Input embeddings shape after dropout: {input_embeddings.shape}")
print(f"Input embeddings after dropout:{input_embeddings}")

Input embeddings shape after dropout: torch.Size([9, 4])
Input embeddings after dropout:tensor([[1.2500, 1.8750, 0.2500, 0.0000],
        [0.5000, 0.0000, 1.0000, 1.2500],
        [0.0000, 1.2500, 1.5000, 1.5000],
        [0.0000, 1.8750, 0.7500, 2.1250],
        [1.0000, 1.5000, 0.0000, 1.0000],
        [1.5000, 1.0000, 0.0000, 0.0000],
        [2.0000, 1.1250, 1.0000, 1.5000],
        [0.0000, 0.0000, 1.3750, 0.0000],
        [2.0000, 0.0000, 0.0000, 0.6250]])


In [10]:
d_model = input_embeddings.shape[1]  # Embedding dimension
d_k = 3 # Dimension of keys and queries (generally kept smaller to make Q, K and V matrices low rank for efficiency)

In [11]:
torch.manual_seed(42)  # For reproducible results

## Shape of the Q, K and V matrices is d x d_k and for the output projection matrix is d_k x d to project the attention output back to d dimensions
W_q = torch.randn(d_model, d_k, dtype=torch.float32) * 0.3  
W_k = torch.randn(d_model, d_k, dtype=torch.float32) * 0.3  
W_v = torch.randn(d_model, d_k, dtype=torch.float32) * 0.3  
W_o = torch.randn(d_k, d_model, dtype=torch.float32) * 0.3 

print(f"W_q (Query weights) shape: {W_q.shape}")
print(W_q)
print(f"\nW_k (Key weights) shape: {W_k.shape}")
print(W_k)
print(f"\nW_v (Value weights) shape: {W_v.shape}")
print(W_v)
print()

print(f"W_o (Output projection weights) shape: {W_o.shape}")
print(W_o)

W_q (Query weights) shape: torch.Size([4, 3])
tensor([[ 0.1010,  0.0386,  0.0703],
        [ 0.0691, -0.3369, -0.0559],
        [ 0.6625, -0.1914,  0.1385],
        [ 0.0802,  0.1605,  0.2428]])

W_k (Key weights) shape: torch.Size([4, 3])
tensor([[ 0.3331, -0.5069, -0.2967],
        [ 0.2874,  0.3966,  0.2452],
        [-0.2298, -0.2252,  0.4058],
        [ 0.2059, -0.0983,  0.2385]])

W_v (Value weights) shape: torch.Size([4, 3])
tensor([[ 0.0845,  0.0168,  0.1568],
        [-0.0715, -0.0150,  0.1579],
        [-0.0025,  0.2187,  0.0399],
        [ 0.2592, -0.3047, -0.2666]])

W_o (Output projection weights) shape: torch.Size([3, 4])
tensor([[ 0.0449, -0.0627, -0.1161,  0.2974],
        [ 0.1404, -0.0615, -0.2223,  0.1086],
        [ 0.5760, -0.0676, -0.1025,  0.0912]])


In [12]:
Q = torch.matmul(input_embeddings, W_q) # Shape: (n, d_k)  
K = torch.matmul(input_embeddings, W_k) # Shape: (n, d_k)
V = torch.matmul(input_embeddings, W_v) # Shape: (n, d_k)

print(f"\nQ (Queries) shape: {Q.shape}")
print(Q)
print(f"\nK (Keys) shape: {K.shape}")
print(K)
print(f"\nV (Values) shape: {V.shape}")
print(V)
print()


Q (Queries) shape: torch.Size([9, 3])
tensor([[ 0.4214, -0.6312,  0.0177],
        [ 0.8132,  0.0285,  0.4772],
        [ 1.2004, -0.4675,  0.5021],
        [ 0.7968, -0.4342,  0.5150],
        [ 0.2849, -0.3062,  0.2293],
        [ 0.2206, -0.2789,  0.0496],
        [ 1.0625, -0.2524,  0.5805],
        [ 0.9109, -0.2632,  0.1904],
        [ 0.2521,  0.1776,  0.2924]])

K (Keys) shape: torch.Size([9, 3])
tensor([[ 0.8978,  0.0537,  0.1903],
        [ 0.1942, -0.6016,  0.5555],
        [ 0.3235,  0.0105,  1.2728],
        [ 0.8041,  0.3659,  1.2708],
        [ 0.9701, -0.0103,  0.3095],
        [ 0.7870, -0.3638, -0.1999],
        [ 1.0686, -0.9403,  0.4459],
        [-0.3159, -0.3096,  0.5579],
        [ 0.7949, -1.0753, -0.4443]])

V (Values) shape: torch.Size([9, 3])
tensor([[-0.0291,  0.0477,  0.5021],
        [ 0.3637, -0.1537, -0.2149],
        [ 0.2956, -0.1477, -0.1426],
        [ 0.4148, -0.5115, -0.2406],
        [ 0.2364, -0.3103,  0.1270],
        [ 0.0552,  0.0103,  0.3931

In [13]:
## Creating the mask for causal attention where each token can only attend to itself and previous tokens, where there is a 1 in the mask matrix, the attention score is kept, where there is a 0 the attention score is masked out

mask = torch.tril(torch.ones(n, n))  # Shape: (n, n)
print(f"Mask shape: {mask.shape}")
print(mask)

Mask shape: torch.Size([9, 9])
tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1.]])


In [14]:
attention_scores = torch.matmul(Q, K.T)  # Shape: (n, n)

print(f"Attention scores shape: {attention_scores.shape}")
print(attention_scores)

Attention scores shape: torch.Size([9, 9])
tensor([[ 3.4782e-01,  4.7136e-01,  1.5225e-01,  1.3049e-01,  4.2082e-01,
          5.5773e-01,  1.0517e+00,  7.2186e-02,  1.0058e+00],
        [ 8.2241e-01,  4.0583e-01,  8.7070e-01,  1.2707e+00,  9.3629e-01,
          5.3428e-01,  1.0550e+00,  4.9726e-04,  4.0372e-01],
        [ 1.1481e+00,  7.9320e-01,  1.0224e+00,  1.4322e+00,  1.3247e+00,
          1.0144e+00,  1.9462e+00,  4.5658e-02,  1.2337e+00],
        [ 7.9005e-01,  7.0201e-01,  9.0872e-01,  1.1364e+00,  9.3689e-01,
          6.8212e-01,  1.4894e+00,  1.7005e-01,  8.7140e-01],
        [ 2.8292e-01,  3.6687e-01,  3.8077e-01,  4.0842e-01,  3.5047e-01,
          2.8974e-01,  6.9455e-01,  1.3274e-01,  4.5378e-01],
        [ 1.9251e-01,  2.3817e-01,  1.3157e-01,  1.3839e-01,  2.3224e-01,
          2.6516e-01,  5.2012e-01,  4.4340e-02,  4.5321e-01],
        [ 1.0508e+00,  6.8061e-01,  1.0799e+00,  1.4997e+00,  1.2130e+00,
          8.1200e-01,  1.6316e+00,  6.6359e-02,  8.5801e-01],
     

In [15]:
masked_attention_scores = attention_scores.clone()

## Wherever the mask is 0, set the attention score to -inf so that after softmax, the attention weights becomes 0 for the future tokens
masked_attention_scores = masked_attention_scores.masked_fill(mask == 0, float('-inf'))

print(f"\nMasked attention scores shape: {masked_attention_scores.shape}")
print(masked_attention_scores)


Masked attention scores shape: torch.Size([9, 9])
tensor([[ 0.3478,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
            -inf],
        [ 0.8224,  0.4058,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
            -inf],
        [ 1.1481,  0.7932,  1.0224,    -inf,    -inf,    -inf,    -inf,    -inf,
            -inf],
        [ 0.7900,  0.7020,  0.9087,  1.1364,    -inf,    -inf,    -inf,    -inf,
            -inf],
        [ 0.2829,  0.3669,  0.3808,  0.4084,  0.3505,    -inf,    -inf,    -inf,
            -inf],
        [ 0.1925,  0.2382,  0.1316,  0.1384,  0.2322,  0.2652,    -inf,    -inf,
            -inf],
        [ 1.0508,  0.6806,  1.0799,  1.4997,  1.2130,  0.8120,  1.6316,    -inf,
            -inf],
        [ 0.8399,  0.4410,  0.5342,  0.8781,  0.9453,  0.7746,  1.3057, -0.1000,
            -inf],
        [ 0.2915,  0.1046,  0.4556,  0.6393,  0.3333,  0.0754,  0.2329,  0.0285,
         -0.1205]])


In [16]:
masked_attention_scores /= torch.sqrt(torch.tensor(d_k, dtype=torch.float32)) # Scale by sqrt(d_k)

print(f"\nScaled masked attention scores shape: {masked_attention_scores.shape}")
print(masked_attention_scores)


Scaled masked attention scores shape: torch.Size([9, 9])
tensor([[ 0.2008,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
            -inf],
        [ 0.4748,  0.2343,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
            -inf],
        [ 0.6628,  0.4580,  0.5903,    -inf,    -inf,    -inf,    -inf,    -inf,
            -inf],
        [ 0.4561,  0.4053,  0.5246,  0.6561,    -inf,    -inf,    -inf,    -inf,
            -inf],
        [ 0.1633,  0.2118,  0.2198,  0.2358,  0.2023,    -inf,    -inf,    -inf,
            -inf],
        [ 0.1111,  0.1375,  0.0760,  0.0799,  0.1341,  0.1531,    -inf,    -inf,
            -inf],
        [ 0.6067,  0.3929,  0.6235,  0.8659,  0.7003,  0.4688,  0.9420,    -inf,
            -inf],
        [ 0.4849,  0.2546,  0.3084,  0.5070,  0.5458,  0.4472,  0.7539, -0.0577,
            -inf],
        [ 0.1683,  0.0604,  0.2631,  0.3691,  0.1924,  0.0435,  0.1344,  0.0165,
         -0.0696]])


In [17]:
attention_weights = F.softmax(masked_attention_scores, dim=-1)

print(f"\nAttention weights shape: {attention_weights.shape}")
print(attention_weights)


Attention weights shape: torch.Size([9, 9])
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5598, 0.4402, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3643, 0.2968, 0.3388, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2357, 0.2240, 0.2524, 0.2879, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1915, 0.2010, 0.2026, 0.2059, 0.1991, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1659, 0.1703, 0.1602, 0.1608, 0.1698, 0.1730, 0.0000, 0.0000, 0.0000],
        [0.1335, 0.1078, 0.1358, 0.1731, 0.1467, 0.1163, 0.1867, 0.0000, 0.0000],
        [0.1321, 0.1049, 0.1107, 0.1350, 0.1404, 0.1272, 0.1729, 0.0768, 0.0000],
        [0.1144, 0.1027, 0.1258, 0.1399, 0.1172, 0.1010, 0.1106, 0.0983, 0.0902]])


In [18]:
## Adding dropout to the attention weights
dropout = nn.Dropout(p=0.2)

attention_weights = dropout(attention_weights)

print(f"\nAttention weights after dropout shape: {attention_weights.shape}")
print(attention_weights)


Attention weights after dropout shape: torch.Size([9, 9])
tensor([[1.2500, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.6998, 0.5502, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4554, 0.0000, 0.4235, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2946, 0.2800, 0.3155, 0.3598, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.2533, 0.2573, 0.2489, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2074, 0.0000, 0.2002, 0.2010, 0.2122, 0.2163, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.1348, 0.1698, 0.2163, 0.1833, 0.1454, 0.2334, 0.0000, 0.0000],
        [0.1651, 0.1312, 0.1384, 0.1688, 0.1755, 0.0000, 0.2161, 0.0000, 0.0000],
        [0.1430, 0.1284, 0.1572, 0.1748, 0.0000, 0.1262, 0.1382, 0.1229, 0.1127]])


In [19]:
output = torch.matmul(attention_weights, V)  # Shape: (n, d_k)

print(f"\nAttention output shape: {output.shape}")
print(output)


Attention output shape: torch.Size([9, 3])
tensor([[-0.0364,  0.0596,  0.6276],
        [ 0.1797, -0.0512,  0.2331],
        [ 0.1119, -0.0408,  0.1682],
        [ 0.3358, -0.2597, -0.0438],
        [ 0.2404, -0.2463, -0.0664],
        [ 0.1986, -0.1861,  0.1392],
        [ 0.3511, -0.2635,  0.0059],
        [ 0.2979, -0.2214,  0.0450],
        [ 0.2710, -0.1356,  0.0708]])


In [20]:
final_output = torch.matmul(output, W_o) # Shape: (n, d)

## Projecting the output back to d dimensions
print(f"\nFinal output shape: {final_output.shape}")
print(final_output)


Final output shape: torch.Size([9, 4])
tensor([[ 0.3682, -0.0438, -0.0733,  0.0529],
        [ 0.1351, -0.0239, -0.0334,  0.0691],
        [ 0.0962, -0.0159, -0.0212,  0.0442],
        [-0.0466, -0.0021,  0.0232,  0.0677],
        [-0.0620,  0.0046,  0.0336,  0.0387],
        [ 0.0630, -0.0104,  0.0040,  0.0516],
        [-0.0178, -0.0062,  0.0172,  0.0763],
        [ 0.0082, -0.0081,  0.0100,  0.0687],
        [ 0.0339, -0.0134, -0.0086,  0.0723]])


In [21]:
## Adding dropout to the final attention output
dropout = nn.Dropout(p=0.2)

final_output = dropout(final_output)

print(f"\nFinal output after dropout shape: {final_output.shape}")
print(final_output)


Final output after dropout shape: torch.Size([9, 4])
tensor([[ 0.4602, -0.0000, -0.0000,  0.0661],
        [ 0.1689, -0.0298, -0.0000,  0.0864],
        [ 0.1202, -0.0198, -0.0265,  0.0552],
        [-0.0583, -0.0026,  0.0290,  0.0846],
        [-0.0775,  0.0057,  0.0420,  0.0484],
        [ 0.0787, -0.0130,  0.0051,  0.0644],
        [-0.0223, -0.0077,  0.0215,  0.0954],
        [ 0.0103, -0.0101,  0.0125,  0.0858],
        [ 0.0000, -0.0168, -0.0107,  0.0000]])
