In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
## Implemented using the Attention notebooks in the Attention directory

def selfAttention(input_embeddings, W_q, W_k, W_v, W_o):
    n = input_embeddings.shape[0]
    d_model = input_embeddings.shape[1]
    d_k = W_q.shape[1]

    Q = torch.matmul(input_embeddings, W_q)
    K = torch.matmul(input_embeddings, W_k)
    V = torch.matmul(input_embeddings, W_v)

    mask  = torch.tril(torch.ones(n, n))

    attention_scores = torch.matmul(Q, K.T)
    masked_attention_scores = attention_scores.masked_fill(mask == 0, float('-inf'))
    masked_attention_scores /= torch.sqrt(torch.tensor(d_k, dtype=torch.float32))

    attention_weights = F.softmax(masked_attention_scores, dim=-1)
    dropout = nn.Dropout(p=0.2)
    attention_weights = dropout(attention_weights)
    output = torch.matmul(attention_weights, V)
    final_output = torch.matmul(output, W_o)
    final_output = dropout(final_output)

    return final_output

In [3]:
sentence = "The quick brown fox jumps over the lazy dog"

## Simple tokenization by splitting on spaces, ideally more complex tokenization would be used like BPE or WordPiece
sentence = sentence.split()
n = len(sentence)

print(f"Tokenized sentence: {sentence}")
print(f"Number of tokens: {len(sentence)}")

Tokenized sentence: ['The', 'quick', 'brown', 'fox', 'jumps', 'over', 'the', 'lazy', 'dog']
Number of tokens: 9


In [16]:
## Sample word embeddings, ideally these would be learned in the language modelling process or loaded from a pre-trained model like GloVe or Word2Vec

# Shape of embeddings: (n, d) where n is number of tokens and d is embedding dimension
embeddings = torch.tensor([
        [1.0, 0.5, 0.2, 0.8], 
        [0.3, 1.0, 0.7, 0.1],  
        [0.6, 0.2, 1.0, 0.4],  
        [0.9, 0.8, 0.3, 1.0],  
        [0.4, 0.6, 0.8, 0.2],  
        [0.7, 0.3, 0.5, 0.9],  
        [1.0, 0.5, 0.2, 0.8],  
        [0.2, 0.9, 0.4, 0.6],  
        [0.8, 0.4, 0.9, 0.3]  
    ])

print("Word embeddings (4-dimensional):")
for i, word in enumerate(sentence):
    print(f"  {word:4}: {embeddings[i]}")
print()

Word embeddings (4-dimensional):
  The : tensor([1.0000, 0.5000, 0.2000, 0.8000])
  quick: tensor([0.3000, 1.0000, 0.7000, 0.1000])
  brown: tensor([0.6000, 0.2000, 1.0000, 0.4000])
  fox : tensor([0.9000, 0.8000, 0.3000, 1.0000])
  jumps: tensor([0.4000, 0.6000, 0.8000, 0.2000])
  over: tensor([0.7000, 0.3000, 0.5000, 0.9000])
  the : tensor([1.0000, 0.5000, 0.2000, 0.8000])
  lazy: tensor([0.2000, 0.9000, 0.4000, 0.6000])
  dog : tensor([0.8000, 0.4000, 0.9000, 0.3000])



In [5]:
## Sample positional encodings, typically these would be generated using math functions or learned during training or RoPE

positional_embeddings = torch.tensor([
    [0.0, 1.0, 0.0, 1.0],  
    [0.1, 0.9, 0.1, 0.9],  
    [0.2, 0.8, 0.2, 0.8],  
    [0.3, 0.7, 0.3, 0.7],  
    [0.4, 0.6, 0.4, 0.6],  
    [0.5, 0.5, 0.5, 0.5],  
    [0.6, 0.4, 0.6, 0.4],  
    [0.7, 0.3, 0.7, 0.3],  
    [0.8, 0.2, 0.8, 0.2]   
])

print("Positional embeddings (Same dimesnions as word embeddings):")
for i, word in enumerate(sentence):
    print(f"  Pos {i} ({word}): {positional_embeddings[i]}")
print()

Positional embeddings (Same dimesnions as word embeddings):
  Pos 0 (The): tensor([0., 1., 0., 1.])
  Pos 1 (quick): tensor([0.1000, 0.9000, 0.1000, 0.9000])
  Pos 2 (brown): tensor([0.2000, 0.8000, 0.2000, 0.8000])
  Pos 3 (fox): tensor([0.3000, 0.7000, 0.3000, 0.7000])
  Pos 4 (jumps): tensor([0.4000, 0.6000, 0.4000, 0.6000])
  Pos 5 (over): tensor([0.5000, 0.5000, 0.5000, 0.5000])
  Pos 6 (the): tensor([0.6000, 0.4000, 0.6000, 0.4000])
  Pos 7 (lazy): tensor([0.7000, 0.3000, 0.7000, 0.3000])
  Pos 8 (dog): tensor([0.8000, 0.2000, 0.8000, 0.2000])



In [6]:
## The final input to the Attention block is the sum of the word embeddings and positional encodings

input_embeddings = embeddings + positional_embeddings
dropout = nn.Dropout(p=0.2)
input_embeddings = dropout(input_embeddings)

print("Input embeddings (word + positional):")
for i, word in enumerate(sentence):
    print(f"  {word:4}: {input_embeddings[i]}")
print()

Input embeddings (word + positional):
  The : tensor([0.0000, 1.8750, 0.2500, 2.2500])
  quick: tensor([0.5000, 2.3750, 1.0000, 1.2500])
  brown: tensor([1.0000, 1.2500, 1.5000, 1.5000])
  fox : tensor([1.5000, 1.8750, 0.7500, 2.1250])
  jumps: tensor([1.0000, 1.5000, 1.5000, 0.0000])
  over: tensor([1.5000, 0.0000, 1.2500, 1.7500])
  the : tensor([2.0000, 1.1250, 1.0000, 1.5000])
  lazy: tensor([1.1250, 1.5000, 1.3750, 1.1250])
  dog : tensor([2.0000, 0.7500, 2.1250, 0.6250])



In [7]:
d_model = input_embeddings.shape[1]  # Embedding dimension
d_k = 3 # Dimension of keys and queries (generally kept smaller to make Q, K and V matrices low rank for efficiency)

In [8]:
torch.manual_seed(42)  # For reproducible results

## Shape of the Q, K and V matrices is d x d_k and for the output projection matrix is d_k x d to project the attention output back to d dimensions
W_q = torch.randn(d_model, d_k, dtype=torch.float32) * 0.3  
W_k = torch.randn(d_model, d_k, dtype=torch.float32) * 0.3  
W_v = torch.randn(d_model, d_k, dtype=torch.float32) * 0.3  
W_o = torch.randn(d_k, d_model, dtype=torch.float32) * 0.3 

print(f"W_q (Query weights) shape: {W_q.shape}")
print(W_q)
print(f"\nW_k (Key weights) shape: {W_k.shape}")
print(W_k)
print(f"\nW_v (Value weights) shape: {W_v.shape}")
print(W_v)
print()

print(f"W_o (Output projection weights) shape: {W_o.shape}")
print(W_o)

W_q (Query weights) shape: torch.Size([4, 3])
tensor([[ 0.1010,  0.0386,  0.0703],
        [ 0.0691, -0.3369, -0.0559],
        [ 0.6625, -0.1914,  0.1385],
        [ 0.0802,  0.1605,  0.2428]])

W_k (Key weights) shape: torch.Size([4, 3])
tensor([[ 0.3331, -0.5069, -0.2967],
        [ 0.2874,  0.3966,  0.2452],
        [-0.2298, -0.2252,  0.4058],
        [ 0.2059, -0.0983,  0.2385]])

W_v (Value weights) shape: torch.Size([4, 3])
tensor([[ 0.0845,  0.0168,  0.1568],
        [-0.0715, -0.0150,  0.1579],
        [-0.0025,  0.2187,  0.0399],
        [ 0.2592, -0.3047, -0.2666]])

W_o (Output projection weights) shape: torch.Size([3, 4])
tensor([[ 0.0449, -0.0627, -0.1161,  0.2974],
        [ 0.1404, -0.0615, -0.2223,  0.1086],
        [ 0.5760, -0.0676, -0.1025,  0.0912]])


In [9]:
attention_output = selfAttention(input_embeddings, W_q, W_k, W_v, W_o)

print("\nAttention output shape:", attention_output.shape)
print(attention_output)


Attention output shape: torch.Size([9, 4])
tensor([[-3.7750e-01,  0.0000e+00,  0.0000e+00,  5.4733e-02],
        [-1.2751e-01,  1.5926e-02,  0.0000e+00,  6.7920e-02],
        [-1.2334e-01,  8.3253e-03,  5.6893e-02,  7.0013e-02],
        [-8.0065e-02, -3.1087e-04,  6.1870e-02,  1.1917e-01],
        [ 5.2497e-02, -1.8745e-02, -1.0964e-02,  9.9666e-02],
        [-5.4889e-02, -5.5895e-03,  3.0832e-02,  1.0706e-01],
        [ 4.8455e-02, -2.3002e-02, -1.1746e-02,  1.3226e-01],
        [ 1.6850e-02, -1.3240e-02,  1.1290e-02,  1.0419e-01],
        [ 0.0000e+00, -2.0309e-02,  1.1206e-03,  0.0000e+00]])


In [10]:
## Adding residual connection where the input embeddings before the attention block are added to the attention output
residual_output = attention_output + input_embeddings

print("\nResidual output shape:", residual_output.shape)
print(residual_output)


Residual output shape: torch.Size([9, 4])
tensor([[-0.3775,  1.8750,  0.2500,  2.3047],
        [ 0.3725,  2.3909,  1.0000,  1.3179],
        [ 0.8767,  1.2583,  1.5569,  1.5700],
        [ 1.4199,  1.8747,  0.8119,  2.2442],
        [ 1.0525,  1.4813,  1.4890,  0.0997],
        [ 1.4451, -0.0056,  1.2808,  1.8571],
        [ 2.0485,  1.1020,  0.9883,  1.6323],
        [ 1.1418,  1.4868,  1.3863,  1.2292],
        [ 2.0000,  0.7297,  2.1261,  0.6250]])


In [11]:
## Learnable parameters for layer normalization, gamma is the scaling factor and beta is the shifting factor
gamma = torch.ones(d_model) # Shape: (d,) for every embedding dimension
beta = torch.zeros(d_model) # Shape: (d,) for every embedding dimension

print("\nGamma (scaling factor) shape:", gamma.shape)
print(gamma)
print("\nBeta (shifting factor) shape:", beta.shape)
print(beta)


Gamma (scaling factor) shape: torch.Size([4])
tensor([1., 1., 1., 1.])

Beta (shifting factor) shape: torch.Size([4])
tensor([0., 0., 0., 0.])


In [12]:
## Now applying layer normalization to the residual output (Post-LN). Recent approaches use Pre-LN for better training stability
def layerNorm(residual_output, gamma, beta, eps = 1e-5,):
    ## For evert token (row), we calculate the mean and variance across the embedding dimension
    means = torch.mean(residual_output, dim=-1, keepdim=True) # Shape (n, 1)
    print(f"Means shape: {means.shape}")
    variances = torch.var(residual_output, dim=-1, keepdim=True, unbiased=False) # Shape (n, 1)
    print(f"Variances shape: {variances.shape}")

    ## Normalizing the residual output i.e. making it zero mean and unit variance
    normalized = (residual_output - means) / torch.sqrt(variances + eps) # Shape (n, d)
    print(f"Normalized shape: {normalized.shape}")

    ln_output = normalized * gamma + beta # Shape (n, d) after broadcasting of gamma and beta. For each token (row), we scale and shift the normalized values by element-wise multiplication and addition of the corresponding gamma and beta values for that embedding dimension (column)
    return ln_output

In [13]:
final_output = layerNorm(residual_output, gamma, beta)

print("\nFinal output shape:", final_output.shape)
print(final_output)

Means shape: torch.Size([9, 1])
Variances shape: torch.Size([9, 1])
Normalized shape: torch.Size([9, 4])

Final output shape: torch.Size([9, 4])
tensor([[-1.2529,  0.7766, -0.6875,  1.1638],
        [-1.2283,  1.5330, -0.3698,  0.0651],
        [-1.5540, -0.2024,  0.8550,  0.9014],
        [-0.3137,  0.5368, -1.4510,  1.2279],
        [ 0.0387,  0.7965,  0.8103, -1.6454],
        [ 0.4319, -1.6515,  0.1960,  1.0235],
        [ 1.4223, -0.8001, -1.0672,  0.4450],
        [-1.2618,  1.3107,  0.5614, -0.6103],
        [ 0.9058, -0.9212,  1.0872, -1.0718]])
