In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
## Implemented using the Attention notebooks in the Attention directory

def selfAttention(input_embeddings, W_q, W_k, W_v, W_o):
    n = input_embeddings.shape[0]
    d_model = input_embeddings.shape[1]
    d_k = W_q.shape[1]

    Q = torch.matmul(input_embeddings, W_q)
    K = torch.matmul(input_embeddings, W_k)
    V = torch.matmul(input_embeddings, W_v)

    mask  = torch.tril(torch.ones(n, n))

    attention_scores = torch.matmul(Q, K.T)
    masked_attention_scores = attention_scores.masked_fill(mask == 0, float('-inf'))
    masked_attention_scores /= torch.sqrt(torch.tensor(d_k, dtype=torch.float32))

    attention_weights = F.softmax(masked_attention_scores, dim=-1)
    output = torch.matmul(attention_weights, V)
    final_output = torch.matmul(output, W_o)

    return final_output

In [3]:
## Implemented using the LayerNorm notebook in the LayerNorm directory
def residualPlusLayerNorm(attention_output, input_embeddings, gamma, beta, eps = 1e-5,):
    residual_output = attention_output + input_embeddings

    means = torch.mean(residual_output, dim=-1, keepdim=True) # Shape (n, 1)
    variances = torch.var(residual_output, dim=-1, keepdim=True, unbiased=False) # Shape (n, 1)
    normalized = (residual_output - means) / torch.sqrt(variances + eps) # Shape (n, d)

    ln_output = normalized * gamma + beta
    return ln_output

In [5]:
## Implemented using the FeedForwardNetwork notebook in the Transformer directory
def feedForwardNetwork(layernorm_output, W_ff1, b_ff1, W_ff2, b_ff2):
    ffn_layer1 = torch.matmul(layernorm_output, W_ff1) + b_ff1 
    ffn_layer1_activated = F.relu(ffn_layer1)

    ffn_layer2 = torch.matmul(ffn_layer1_activated, W_ff2) + b_ff2
    return ffn_layer2

In [6]:
sentence = "The quick brown fox jumps over the lazy dog"

## Simple tokenization by splitting on spaces, ideally more complex tokenization would be used like BPE or WordPiece
sentence = sentence.split()
n = len(sentence)

print(f"Tokenized sentence: {sentence}")
print(f"Number of tokens: {len(sentence)}")

Tokenized sentence: ['The', 'quick', 'brown', 'fox', 'jumps', 'over', 'the', 'lazy', 'dog']
Number of tokens: 9


In [7]:
## Sample vocabulary, ideally this would be a much larger vocabulary
vocab = ["<PAD>", "<START>", "<END>", "<UNK>"] + sentence + ["cat", "runs", "fast", "slowly", "and", "the", "a", "is", "are", "was", "were"]

vocab = list(dict.fromkeys(vocab)) # Remove duplicates
vocab_size = len(vocab)

print(f"Vocabulary size: {vocab_size}")
print(f"Vocabulary: {vocab}")

Vocabulary size: 23
Vocabulary: ['<PAD>', '<START>', '<END>', '<UNK>', 'The', 'quick', 'brown', 'fox', 'jumps', 'over', 'the', 'lazy', 'dog', 'cat', 'runs', 'fast', 'slowly', 'and', 'a', 'is', 'are', 'was', 'were']


In [9]:
## Creatinga a word to index and index to word mapping to represent text as integers
word_to_idx = {word: idx for idx, word in enumerate(vocab)}
idx_to_word = {idx: word for word, idx in word_to_idx.items()}

print(f"Word to index mapping: {word_to_idx}")
print(f"Index to word mapping: {idx_to_word}")

Word to index mapping: {'<PAD>': 0, '<START>': 1, '<END>': 2, '<UNK>': 3, 'The': 4, 'quick': 5, 'brown': 6, 'fox': 7, 'jumps': 8, 'over': 9, 'the': 10, 'lazy': 11, 'dog': 12, 'cat': 13, 'runs': 14, 'fast': 15, 'slowly': 16, 'and': 17, 'a': 18, 'is': 19, 'are': 20, 'was': 21, 'were': 22}
Index to word mapping: {0: '<PAD>', 1: '<START>', 2: '<END>', 3: '<UNK>', 4: 'The', 5: 'quick', 6: 'brown', 7: 'fox', 8: 'jumps', 9: 'over', 10: 'the', 11: 'lazy', 12: 'dog', 13: 'cat', 14: 'runs', 15: 'fast', 16: 'slowly', 17: 'and', 18: 'a', 19: 'is', 20: 'are', 21: 'was', 22: 'were'}


In [10]:
## Convert the sentence to a list of indices
input_indices = [word_to_idx[word] for word in sentence]

print(f"Input indices: {input_indices}")

Input indices: [4, 5, 6, 7, 8, 9, 10, 11, 12]


In [11]:
d_model = 8      # embedding dimension
d_k = 6          # attention dimension
hidden_dim = 16  # feed forward network hidden dimension

In [None]:
## Sample embedding matrix, ideally this would be learned during training or loaded from a pre-trained model
torch.manual_seed(42)
embedding_matrix = torch.randn(vocab_size, d_model) * 0.3 # Shape (vocab_size, d_model) i.e. for every token in the vocab, we have a d_model dimensional embedding

print(f"Embedding matrix shape: {embedding_matrix.shape}")
print(f"Embedding matrix: {embedding_matrix}")


Embedding matrix shape: torch.Size([23, 8])
Embedding matrix: tensor([[ 5.7807e-01,  4.4619e-01,  2.7022e-01, -6.3166e-01,  2.0353e-01,
         -3.7036e-01, -1.2920e-02, -4.8140e-01],
        [-2.2564e-01,  4.9462e-01, -1.1774e-01, -4.2108e-01, -2.1836e-01,
         -1.6783e-01, -2.3065e-01,  2.2873e-01],
        [ 4.9270e-01, -4.7879e-02, -1.4922e-01,  1.3188e-01, -2.2744e-01,
          3.2350e-01,  2.4024e-01,  5.0419e-01],
        [ 3.8374e-01,  3.8893e-01,  1.8314e-01,  4.0042e-01, -6.9487e-02,
          1.2528e-02, -7.5473e-02,  2.5796e-01],
        [-4.1540e-01, -2.6137e-01, -6.7010e-02,  5.1521e-01,  9.5664e-02,
         -1.2736e-01,  9.1716e-02, -2.3238e-01],
        [-4.6727e-01,  2.9869e-01, -2.6394e-01, -1.8034e-01, -3.8225e-01,
          6.3684e-01, -3.7040e-01, -1.4637e-01],
        [-2.7415e-01, -1.9744e-01,  2.3407e-02,  1.5774e-01, -1.4640e-01,
          3.5741e-01, -2.4420e-01, -2.2080e-01],
        [-4.2097e-01,  1.0801e-02, -1.9043e-02,  2.0268e-01, -2.9342e-02,
   

In [13]:
## Using the embedding matrix to convert input indices to embeddings
input_embeddings = embedding_matrix[input_indices]

print(f"Input embeddings shape: {input_embeddings.shape}")
print(f"Input embeddings: {input_embeddings}")

Input embeddings shape: torch.Size([9, 8])
Input embeddings: tensor([[-0.4154, -0.2614, -0.0670,  0.5152,  0.0957, -0.1274,  0.0917, -0.2324],
        [-0.4673,  0.2987, -0.2639, -0.1803, -0.3822,  0.6368, -0.3704, -0.1464],
        [-0.2741, -0.1974,  0.0234,  0.1577, -0.1464,  0.3574, -0.2442, -0.2208],
        [-0.4210,  0.0108, -0.0190,  0.2027, -0.0293,  0.5534, -0.3554,  0.4151],
        [ 0.4335,  0.2569,  0.6654,  0.1569,  0.1040, -0.0592, -0.3164,  0.3834],
        [-0.0517,  0.1571,  0.0170,  0.1279,  0.1725, -0.1925, -0.6619, -0.2252],
        [ 0.0033, -0.1016, -0.4022, -0.1756,  0.1609,  0.1574,  0.3424,  0.0155],
        [ 0.2232, -0.1445, -0.3148,  0.1812, -0.5167, -0.2483,  0.4004,  0.1451],
        [-0.7529,  0.1464,  0.2354,  0.0086,  0.1922,  0.1750,  0.3201, -0.1350]])


In [14]:
## Sample positional embeddings, ideally these would be learned during training
positional_embeddings = torch.randn(n, d_model) * 0.2
input_embeddings = input_embeddings + positional_embeddings

print(f"Input embeddings with positional encoding shape: {input_embeddings.shape}")
print(f"Input embeddings with positional encoding: {input_embeddings}")

Input embeddings with positional encoding shape: torch.Size([9, 8])
Input embeddings with positional encoding: tensor([[-0.4473, -0.3464,  0.1218,  0.4782,  0.3078, -0.0857, -0.0239, -0.1673],
        [-0.4149,  0.1467, -0.6732, -0.4862, -0.3013,  0.7632, -0.3079, -0.1531],
        [-0.0135, -0.0999,  0.2502,  0.0866, -0.0740,  0.7573, -0.1116, -0.0799],
        [-0.4167, -0.1551, -0.2352,  0.0459,  0.0721,  0.5698, -0.2666,  0.2703],
        [ 0.3413,  0.2441,  0.3921,  0.2229, -0.0925,  0.0012, -0.2806,  0.3575],
        [-0.3667,  0.6073,  0.2172,  0.4007,  0.2992, -0.1115, -0.5936, -0.2695],
        [ 0.0378,  0.1086, -0.4007, -0.1911,  0.2894,  0.2722,  0.4597,  0.0117],
        [-0.0527, -0.0193, -0.8318,  0.1764, -0.5411, -0.3977,  0.7423,  0.1566],
        [-0.5143,  0.5339,  0.3811,  0.2048,  0.2751,  0.4063,  0.3739, -0.1424]])


In [15]:
## Self Attention weights for the K, Q and V matrices
W_q = torch.randn(d_model, d_k) * 0.3 # Shape (d_model, d_k)
W_k = torch.randn(d_model, d_k) * 0.3 # Shape (d_model, d_k)
W_v = torch.randn(d_model, d_k) * 0.3 # Shape (d_model, d_k)
W_o = torch.randn(d_k, d_model) * 0.3 # Shape (d_k, d_model)

In [16]:
## Layer Norm weights for both before and after the feed forward network
gamma1 = torch.ones(d_model) # Shape (d_model,)
beta1 = torch.zeros(d_model) # Shape (d_model,)
gamma2 = torch.ones(d_model) # Shape (d_model,)
beta2 = torch.zeros(d_model) # Shape (d_model,)

In [17]:
## Feed Forward Network weights
W_ff1 = torch.randn(d_model, hidden_dim) * 0.3 # Shape (d_model, hidden_dim)
b_ff1 = torch.randn(hidden_dim) * 0.3 # Shape (hidden_dim,)
W_ff2 = torch.randn(hidden_dim, d_model) * 0.3 # Shape (hidden_dim, d_model)
b_ff2 = torch.randn(d_model) * 0.3 # Shape (d_model,)

In [18]:
attention_output = selfAttention(input_embeddings, W_q, W_k, W_v, W_o)

print(f"Attention output shape: {attention_output.shape}")
print(f"Attention output: {attention_output}")

Attention output shape: torch.Size([9, 8])
Attention output: tensor([[ 0.0420,  0.1177, -0.1520,  0.1020, -0.0923, -0.3580, -0.2704,  0.3188],
        [-0.1931, -0.0459,  0.0348, -0.2076,  0.1689,  0.1606,  0.2770, -0.1112],
        [-0.1394, -0.0246,  0.0010, -0.1322,  0.0920,  0.0782,  0.1727, -0.0423],
        [-0.1905, -0.0089, -0.0098, -0.1450,  0.1244,  0.0854,  0.2088, -0.0361],
        [-0.1461, -0.0080, -0.0061, -0.1069,  0.1072,  0.0867,  0.1692, -0.0402],
        [-0.1287,  0.0028,  0.0051, -0.0920,  0.1210,  0.0837,  0.1603, -0.0400],
        [-0.1551,  0.0135,  0.0064, -0.0830,  0.1035,  0.0726,  0.1603, -0.0402],
        [-0.1783, -0.0072,  0.0260, -0.1135,  0.0913,  0.1007,  0.1822, -0.0851],
        [-0.2061,  0.0017,  0.0406, -0.1097,  0.0788,  0.0983,  0.2022, -0.0940]])


In [19]:
layernorm_output1 = residualPlusLayerNorm(attention_output, input_embeddings, gamma1, beta1)

print(f"LayerNorm output shape: {layernorm_output1.shape}")
print(f"LayerNorm output: {layernorm_output1}")

LayerNorm output shape: torch.Size([9, 8])
LayerNorm output: tensor([[-1.0499, -0.5177,  0.0803,  1.9196,  0.8206, -1.1656, -0.7152,  0.6278],
        [-0.8836,  0.5394, -0.9445, -1.0558,  0.0713,  2.1916,  0.2750, -0.1935],
        [-0.7902, -0.6979,  0.5241, -0.4412, -0.2346,  2.4244, -0.0943, -0.6903],
        [-1.6947, -0.4349, -0.6654, -0.2505,  0.5893,  1.8928, -0.1331,  0.6964],
        [ 0.2632,  0.5326,  1.5188, -0.2581, -0.9253, -0.4430, -1.7550,  1.0667],
        [-1.3674,  1.4724,  0.4763,  0.6982,  0.9844, -0.1663, -1.2078, -0.8898],
        [-0.6106,  0.1186, -1.4543, -1.0884,  0.9432,  0.7969,  1.6348, -0.3402],
        [-0.2899,  0.1425, -1.5062,  0.3317, -0.7530, -0.4298,  2.1546,  0.3500],
        [-2.1284,  0.8038,  0.5379, -0.2246,  0.3798,  0.7315,  0.8985, -0.9985]])


In [20]:
feedforward_output = feedForwardNetwork(layernorm_output1, W_ff1, b_ff1, W_ff2, b_ff2)

print(f"Feed Forward Network output shape: {feedforward_output.shape}")
print(f"Feed Forward Network output: {feedforward_output}")

Feed Forward Network output shape: torch.Size([9, 8])
Feed Forward Network output: tensor([[ 0.2579,  0.8045,  0.5444,  1.7155,  0.7242,  0.9203, -0.6189, -0.5029],
        [-0.8781, -0.4182,  1.3479,  1.4541,  1.2107,  2.2080,  0.2971, -0.2043],
        [-0.1201, -0.5148,  1.7449,  1.2045,  0.1516,  3.6753,  0.6008, -0.0437],
        [ 0.4315, -0.3111,  1.3706,  2.3913,  0.7638,  2.6776,  0.1653, -0.3813],
        [-0.2345,  0.2486,  0.2702,  0.6519,  1.0916,  1.1160, -0.4263, -0.7526],
        [-1.2025,  0.6422, -0.0513,  0.6233,  1.1179,  0.4635, -0.4650, -0.2125],
        [-2.1081, -0.5484,  1.0487, -0.7822,  1.2553,  0.9405,  0.3068,  0.3993],
        [-0.8725,  0.0732,  1.6564,  0.5992,  1.4034,  0.4075, -0.3222,  0.1004],
        [-1.6970, -0.1806,  0.9394, -0.3444,  0.9858,  0.9045,  0.1255,  0.2424]])


In [21]:
layernorm_output2 = residualPlusLayerNorm(feedforward_output, layernorm_output1, gamma2, beta2)

print(f"Final LayerNorm output shape: {layernorm_output2.shape}")
print(f"Final LayerNorm output: {layernorm_output2}")

Final LayerNorm output shape: torch.Size([9, 8])
Final LayerNorm output: tensor([[-0.8806, -0.1341,  0.0998,  2.1828,  0.7364, -0.5023, -1.2557, -0.2462],
        [-1.4478, -0.3066, -0.1356, -0.1387,  0.3969,  2.2864, -0.0333, -0.6212],
        [-0.7774, -0.9119,  0.6368, -0.0329, -0.4094,  2.3409, -0.1472, -0.6989],
        [-1.2491, -0.9488, -0.1064,  0.7270,  0.2698,  2.1375, -0.4971, -0.3328],
        [-0.2065,  0.5101,  1.4698,  0.1411, -0.0755,  0.4070, -2.3111,  0.0652],
        [-1.6405,  1.2223,  0.1898,  0.7377,  1.2149,  0.1117, -1.0923, -0.7436],
        [-1.6472, -0.2923, -0.2780, -1.1451,  1.2635,  0.9905,  1.1114, -0.0029],
        [-1.9304, -0.2063, -0.2883,  0.6883,  0.3374, -0.5041,  1.8161,  0.0873],
        [-2.2997,  0.2920,  0.7896, -0.4025,  0.7245,  0.8821,  0.5255, -0.5115]])


In [None]:
## Weights for the final linear layer to project to vocab size i.e. the language model head
W_lm_head = torch.randn(d_model, vocab_size) * 0.3 # Shape (d_model, vocab_size)
b_lm_head = torch.randn(vocab_size) * 0.1 # Shape (vocab_size,)

print(f"LM head weights shape: {W_lm_head.shape}")
print(f"LM head bias shape: {b_lm_head.shape}")

LM head weights shape: torch.Size([8, 23])
LM head bias shape: torch.Size([23])


In [23]:
## Passing the final output through the LM head to get logits for each token in the vocabulary
logits = torch.matmul(layernorm_output2, W_lm_head) + b_lm_head # Shape (n, vocab_size)

print(f"Logits shape: {logits.shape}") 

Logits shape: torch.Size([9, 23])


In [24]:
## Applying softmax to get probabilities for each token in the vocabulary, where each row represents the probabilities for each word in the vocab for the next token given the previous tokens
probabilities = F.softmax(logits, dim=-1)

print(f"Probabilities shape: {probabilities.shape}")
print(f"Probabilities: {probabilities}")

Probabilities shape: torch.Size([9, 23])
Probabilities: tensor([[0.0232, 0.0844, 0.0464, 0.0309, 0.0216, 0.0316, 0.0227, 0.0698, 0.0434,
         0.0279, 0.0440, 0.0525, 0.0262, 0.0781, 0.0685, 0.0717, 0.0209, 0.0803,
         0.0246, 0.0060, 0.0074, 0.0921, 0.0256],
        [0.0092, 0.0709, 0.0117, 0.2183, 0.0219, 0.0365, 0.0161, 0.0270, 0.0756,
         0.0544, 0.0258, 0.0056, 0.0184, 0.0131, 0.0419, 0.0680, 0.0134, 0.0044,
         0.0241, 0.0583, 0.0941, 0.0789, 0.0125],
        [0.0059, 0.0482, 0.0077, 0.0985, 0.0242, 0.0675, 0.0174, 0.0238, 0.1012,
         0.0410, 0.0290, 0.0123, 0.0230, 0.0370, 0.0214, 0.0758, 0.0206, 0.0055,
         0.0296, 0.0636, 0.0942, 0.1316, 0.0208],
        [0.0067, 0.0507, 0.0107, 0.2075, 0.0190, 0.0444, 0.0128, 0.0410, 0.0861,
         0.0528, 0.0252, 0.0107, 0.0270, 0.0242, 0.0446, 0.0858, 0.0193, 0.0084,
         0.0224, 0.0365, 0.0564, 0.0897, 0.0180],
        [0.0173, 0.0506, 0.0128, 0.0049, 0.0117, 0.0671, 0.0111, 0.0625, 0.0215,
         0.0567