In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [2]:
## Using for loop to loop through the heads and compute the attention scores is very slow, we need to optimize with tensor operations

In [4]:
## Also we use batches of data here to improve performance instead of looping through the dataset individually.

In [None]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout):
        super(MultiHeadedAttention, self).__init__()

        assert d_model % num_heads == 0 ## d_model must be divisible by num_heads, so that we can break the Q, K and V matrices into num_heads x n x d_k

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        ## Randomly initialize the W_q, W_k and W_v matrices to d x d shape (we will break the Q, K and V to num_heads x d_k). Instead of num_heads x d x d_k, we initialize as d x d 
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

        ## Initialize the dropout
        self.dropout = nn.Dropout(dropout)

        ## Initialize the W_q, W_k and W_v matrices as a normal distribution with mean 0 and std 0.02
        for module in [self.W_q, self.W_k, self.W_v, self.W_o]:
            nn.init.normal_(module.weight, mean=0.0, std=0.02)

    
    def forward(self, input_embeddings):
        ## input_embeddings have shape of n x d where n is number of tokens (window size) and d is embedding dimension
        n, d_model = input_embeddings.shape

        mask = torch.tril(torch.ones(n, n))

        ## Compute the Q, K and V matrices together before splitting them into heads. Instead of running a loop and creating individual Q, K and V matrices of shape n x d_k for each head, we store the Q, K and V matrices of all heads in 3 matrices for each of shape n x d, which will later be split into num_heads x n x d_k
        Q_all = self.W_q(input_embeddings) # Shape: n x d
        K_all = self.W_k(input_embeddings) # Shape: n x d
        V_all = self.W_v(input_embeddings) # Shape: n x d

        ## Change from shape n x d to n x num_heads x d_k by spliiting them into separate heads
        Q = Q_all.view(n, self.num_heads, self.d_k)
        K = K_all.view(n, self.num_heads, self.d_k)
        V = V_all.view(n, self.num_heads, self.d_k)

        Q = Q.transpose(0, 1) # Shape: num_heads x n x d_k
        K = K.transpose(0, 1) # Shape: num_heads x n x d_k
        V = V.transpose(0, 1) # Shae: num_heads x n x d_k

        # (num_heads, n, d_k) @ (num_heads, d_k, n) -> (num_heads, n, n)
        ## So for every head we have gotten the attention scores together in 1 matrix
        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        ## Apply masking to mask out scores of future tokens
        attention_scores = attention_scores.masked_fill(mask == 0, float('-inf'))

        attention_weights = F.softmax(attention_scores, dim=-1) # Shape: num_heads x n x n
        attention_weights = self.dropout(attention_weights)

        # (num_heads, n, n) @ (num_heads, n, d_k) -> (num_heads,n, d_k)
        ## For every head, we get the attention outputs in 1 matrix
        attention_output = torch.matmul(attention_weights, V)

        ## Now we want to do the reverse operations, converting from num_heads x n x d_k to n x d
        # (num_heads, n, d_k) -> (n, num_heads, d_k)
        attention_output = attention_output.transpose(0, 1).contiguous()
        # (n, num_heads, d_k) -> (n, d)
        attention_output = attention_output.view(n, self.d_model)

        ## So head outputs can interact with each other
        # n x d @ d x d -> n x d
        output = self.W_o(attention_output)

        return output, attention_weights


In [7]:
sentence = "The quick brown fox jumps over the lazy dog"

## Simple tokenization by splitting on spaces, ideally more complex tokenization would be used like BPE or WordPiece
sentence = sentence.split()
n = len(sentence)

print(f"Tokenized sentence: {sentence}")
print(f"Number of tokens: {len(sentence)}")

Tokenized sentence: ['The', 'quick', 'brown', 'fox', 'jumps', 'over', 'the', 'lazy', 'dog']
Number of tokens: 9


In [8]:
## Sample word embeddings, ideally these would be learned in the language modelling process or loaded from a pre-trained model like GloVe or Word2Vec

# Shape of embeddings: (n, d) where n is number of tokens and d is embedding dimension
embeddings = torch.tensor([
        [1.0, 0.5, 0.2, 0.8], 
        [0.3, 1.0, 0.7, 0.1],  
        [0.6, 0.2, 1.0, 0.4],  
        [0.9, 0.8, 0.3, 1.0],  
        [0.4, 0.6, 0.8, 0.2],  
        [0.7, 0.3, 0.5, 0.9],  
        [1.0, 0.5, 0.2, 0.8],  
        [0.2, 0.9, 0.4, 0.6],  
        [0.8, 0.4, 0.9, 0.3]  
    ])

print("Word embeddings (4-dimensional):")
for i, word in enumerate(sentence):
    print(f"  {word:4}: {embeddings[i]}")
print()

Word embeddings (4-dimensional):
  The : tensor([1.0000, 0.5000, 0.2000, 0.8000])
  quick: tensor([0.3000, 1.0000, 0.7000, 0.1000])
  brown: tensor([0.6000, 0.2000, 1.0000, 0.4000])
  fox : tensor([0.9000, 0.8000, 0.3000, 1.0000])
  jumps: tensor([0.4000, 0.6000, 0.8000, 0.2000])
  over: tensor([0.7000, 0.3000, 0.5000, 0.9000])
  the : tensor([1.0000, 0.5000, 0.2000, 0.8000])
  lazy: tensor([0.2000, 0.9000, 0.4000, 0.6000])
  dog : tensor([0.8000, 0.4000, 0.9000, 0.3000])



In [9]:
## Sample positional encodings, typically these would be generated using math functions or learned during training or RoPE

positional_embeddings = torch.tensor([
    [0.0, 1.0, 0.0, 1.0],  
    [0.1, 0.9, 0.1, 0.9],  
    [0.2, 0.8, 0.2, 0.8],  
    [0.3, 0.7, 0.3, 0.7],  
    [0.4, 0.6, 0.4, 0.6],  
    [0.5, 0.5, 0.5, 0.5],  
    [0.6, 0.4, 0.6, 0.4],  
    [0.7, 0.3, 0.7, 0.3],  
    [0.8, 0.2, 0.8, 0.2]   
])

print("Positional embeddings (Same dimesnions as word embeddings):")
for i, word in enumerate(sentence):
    print(f"  Pos {i} ({word}): {positional_embeddings[i]}")
print()

Positional embeddings (Same dimesnions as word embeddings):
  Pos 0 (The): tensor([0., 1., 0., 1.])
  Pos 1 (quick): tensor([0.1000, 0.9000, 0.1000, 0.9000])
  Pos 2 (brown): tensor([0.2000, 0.8000, 0.2000, 0.8000])
  Pos 3 (fox): tensor([0.3000, 0.7000, 0.3000, 0.7000])
  Pos 4 (jumps): tensor([0.4000, 0.6000, 0.4000, 0.6000])
  Pos 5 (over): tensor([0.5000, 0.5000, 0.5000, 0.5000])
  Pos 6 (the): tensor([0.6000, 0.4000, 0.6000, 0.4000])
  Pos 7 (lazy): tensor([0.7000, 0.3000, 0.7000, 0.3000])
  Pos 8 (dog): tensor([0.8000, 0.2000, 0.8000, 0.2000])



In [10]:
## The final input to the Attention block is the sum of the word embeddings and positional encodings

input_embeddings = embeddings + positional_embeddings

print("Input embeddings (word + positional):")
for i, word in enumerate(sentence):
    print(f"  {word:4}: {input_embeddings[i]}")
print()

Input embeddings (word + positional):
  The : tensor([1.0000, 1.5000, 0.2000, 1.8000])
  quick: tensor([0.4000, 1.9000, 0.8000, 1.0000])
  brown: tensor([0.8000, 1.0000, 1.2000, 1.2000])
  fox : tensor([1.2000, 1.5000, 0.6000, 1.7000])
  jumps: tensor([0.8000, 1.2000, 1.2000, 0.8000])
  over: tensor([1.2000, 0.8000, 1.0000, 1.4000])
  the : tensor([1.6000, 0.9000, 0.8000, 1.2000])
  lazy: tensor([0.9000, 1.2000, 1.1000, 0.9000])
  dog : tensor([1.6000, 0.6000, 1.7000, 0.5000])



In [11]:
d_model = 4      # embedding dimension of the tokens
num_heads = 2    # number of attention heads
d_k = d_model // num_heads # Dimension of the Q, K and V matrices for each head

In [20]:
attentionBlock = MultiHeadedAttention(d_model, num_heads, dropout=0.2)
print(attentionBlock)

MultiHeadedAttention(
  (W_q): Linear(in_features=4, out_features=4, bias=False)
  (W_k): Linear(in_features=4, out_features=4, bias=False)
  (W_v): Linear(in_features=4, out_features=4, bias=False)
  (W_o): Linear(in_features=4, out_features=4, bias=False)
  (dropout): Dropout(p=0.2, inplace=False)
)


In [14]:
with torch.no_grad():
    output, attention_weights = attentionBlock(input_embeddings)

In [15]:
print(output.shape)
print(output)

torch.Size([9, 4])
tensor([[ 0.0007, -0.0025,  0.0002, -0.0016],
        [ 0.0018, -0.0030,  0.0018, -0.0007],
        [ 0.0012, -0.0024,  0.0002, -0.0019],
        [ 0.0015, -0.0031,  0.0005, -0.0019],
        [ 0.0014, -0.0027,  0.0008, -0.0012],
        [ 0.0015, -0.0031,  0.0005, -0.0017],
        [ 0.0013, -0.0032,  0.0008, -0.0013],
        [ 0.0014, -0.0031,  0.0004, -0.0016],
        [ 0.0009, -0.0023, -0.0002, -0.0015]])


In [16]:
print(attention_weights.shape)
print(attention_weights)

torch.Size([2, 9, 9])
tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000],
         [0.6249, 0.6251, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000],
         [0.0000, 0.4167, 0.4168, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000],
         [0.0000, 0.3125, 0.3126, 0.3125, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000],
         [0.2499, 0.2500, 0.2501, 0.2499, 0.2501, 0.0000, 0.0000, 0.0000,
          0.0000],
         [0.2082, 0.2083, 0.2084, 0.2082, 0.2084, 0.2084, 0.0000, 0.0000,
          0.0000],
         [0.1785, 0.1785, 0.1786, 0.1785, 0.0000, 0.1786, 0.0000, 0.0000,
          0.0000],
         [0.1562, 0.1562, 0.1563, 0.1562, 0.0000, 0.1563, 0.1562, 0.1563,
          0.0000],
         [0.0000, 0.1388, 0.1390, 0.0000, 0.1389, 0.1389, 0.1389, 0.1389,
          0.1390]],

        [[1.2500, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000],
         [0.6239, 0.0000, 0.0000, 0.0000, 0.00