# 🚀 Motivation
This notebook is based on [Andrej Karphatky](https://karpathy.ai/)'s transformer tutorial to provide a step by step demonstration of the different aspects that are necessary to cover in order to understand the transformer implementation.

Original tutorial can be found [here](https://www.youtube.com/watch?v=kCc8FmEb1nY). For anyone interested in Transformers and LLMs I highly recommend his full [series](https://www.youtube.com/playlist?list=PLAqhIrjkxbuWI23v9cThsA9GvCAUhRvKZ).

# 🧱 Building Blocks

The goal of this notebook is putting together a **character-based decoder-only transformer**, step by step, adjusted for text prediction. The following areas will be covered:
- How do I manipulate the data?
- Text prediction first iteration: bigram
- Implementing attention

After understanding these components, we'll proceed to a decoder like transformer based on GPT-2's architecture

## 📘 How do I manipulate the data?
- Choosing our corpus
- Tokenization (vocabulary definition)
- Dataset preparation
- Batch generation

### 🧰 Requirements

In [None]:
!pip install tiktoken



In [None]:
# retrieving the dataset we are going to use
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2025-08-08 21:12:42--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.108.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2025-08-08 21:12:42 (28.5 MB/s) - ‘input.txt’ saved [1115394/1115394]



### 📚 Choosing Our Corpus

In [None]:
# read it in to inspect it
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [None]:
# let's look at the first 1000 characters
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



### 🪙 Tokenization
There needs to be a mechanism that maps elements of the vocabulary to a numerical representation. This implementation can be as simple or complex as desired.

In [None]:
# all the unique characters that occur in this text
vocabulary = sorted(list(set(text))) # this could be letters, chunks, words, etc...
vocab_size = len(vocabulary)
print(vocab_size)
print(''.join(vocabulary))

65

 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


In [None]:
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(vocabulary) }
itos = { i:ch for i,ch in enumerate(vocabulary) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

print(encode("welcome to advanced DL topics!"))
print(decode(encode("welcome to advanced DL topics!")))

[61, 43, 50, 41, 53, 51, 43, 1, 58, 53, 1, 39, 42, 60, 39, 52, 41, 43, 42, 1, 16, 24, 1, 58, 53, 54, 47, 41, 57, 2]
welcome to advanced DL topics!


Is there any difference in having a longer or shorter vocabulary❓

In [None]:
import tiktoken

openai_encoder = tiktoken.get_encoding("gpt2")
print(f"vocab_size for OpenAI's GPT-2: {openai_encoder.n_vocab}")
print(openai_encoder.encode("welcome to advanced DL topics!"))

vocab_size for OpenAI's GPT-2: 50257
[86, 9571, 284, 6190, 23641, 10233, 0]


### 🍲 Dataset Preparation
The dataset, which is a sequence of words, needs to be partitioned and encoded. Also, there needs to be train-validation split.

In [None]:
import torch

In [None]:
# tokenization of the dataset
encoded_dataset = torch.tensor(encode(text), dtype=torch.long)

print(f'encoded dataset shape: {encoded_dataset.shape}')
print(f"let's look at the first 10 tokens:\n{encoded_dataset[:10]}\nwhich is:\n'{text[:10]}'")

encoded dataset shape: torch.Size([1115394])
let's look at the first 10 tokens:
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47])
which is:
'First Citi'


Why aren't we randomizing what we choose as train or validation❓

In [None]:
partition_idx = int(0.9*len(encoded_dataset)) # 90% train, 10% validation
train_data = encoded_dataset[:partition_idx]
val_data = encoded_dataset[partition_idx:]

### 🔄 Batch Generation
We create a single batch by leveraging the next tokens as the truth value for a current token. This is considered self-supervised learning.

In [None]:
def get_batch(data, context_length, batch_size):
    """
    This method considers context_length+1 tokens to prepare examples
    that will be later used to train a model.
    """
    # generate a small batch of data of inputs x and targets y
    ix = torch.randint(len(data) - context_length, (batch_size,)) # random offsets in the data
    x = torch.stack([data[i:i+context_length] for i in ix])
    y = torch.stack([data[i+1:i+context_length+1] for i in ix])
    return x, y

In [None]:
context_length = 8 # sequence size, this determines the window we sent to the model
batch_size = 1 # this distinguishes between online/mini-batch/batch training

xb, yb = get_batch(
    data=train_data,
    context_length=context_length,
    batch_size=batch_size)

print('inputs:')
print(xb)
print('targets:')
print(yb)

inputs:
tensor([[58, 63,  8,  0,  0, 24, 53, 56]])
targets:
tensor([[63,  8,  0,  0, 24, 53, 56, 42]])


## 💬 Text prediction first iteration: bigram
One of the simplest model families is called the N-gram family. These models can use N-1 tokens as context to predict the following token. A bigram model uses 1 token to predict the following. Further reading in this [medium article](https://medium.com/@roshmitadey/understanding-language-modeling-from-n-grams-to-transformer-based-neural-models-d2bdf1532c6d).

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F

class BigramLanguageModel(nn.Module):
    """

    This is a simple bigram where the next token is predicted
    based on a token_embedding_table. Let's imagine our vocabulary is
    [a, b, c, d], then token_embedding_table is

    ---- a ---- b ---- c ---- d ----
    a   1.4    9.2    -5.6    -2.1
    b   2.3    1.1     3.4    0.1
    c   2.9    1.2     0.5    0.3
    d   8.1    0.2     1.3    5.4

    Where each element represent a logit.
    logits are the ln of the odds and they can be any real number.
    They have a sigmoid relationship with respect to the probabilty.
    https://stats.stackexchange.com/questions/52825/what-does-the-logit-value-actually-mean
    """

    def __init__(self, vocab_size):
        super().__init__()

        # wrapper for a look-up table: https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html
        # each token directly reads off the logits for the next token from a lookup table
        self.lookup_table = nn.Embedding(vocab_size, vocab_size)

    def generate(self, idx, max_new_tokens):
        """
        This method generates new token based on previous tokens.

        idx:
          (batch_size, context_length) tensor with the indexes for the lookup table
        max_new_tokens:
          int, the number of new tokens to be generated
        """
        for _ in range(max_new_tokens):
            # get the logits
            logits, _ = self(idx) # this applies a forward pass
            # focus only on the last element of the context
            logits = logits[:, -1, :] # shape (batch_size, vocab_size)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # shape (batch_size, vocab_size)

            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # shape (batch_size, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)

        return idx

    def forward(self, idx, targets=None):
        """
        This method performs a forward pass that computes the loss depending
        on whether we are on training or not.

        idx:
          (batch_size, context_length) tensor with the indexes for the lookup table
        targets:
          (batch_size, context_length) tensor with expected value for each element in idx
        """

        # shape is (batch_size, context_length, vocab_size)
        logits = self.lookup_table(idx)

        if targets is None:
            loss = None
        else:
            batch_size, context_length, vocab_size = logits.shape
            # the loss function expects 2D array with the vocab_size to be the second dimension
            logits = logits.view(batch_size*context_length, vocab_size) # (batch*context) x vocab dimension
            # the loss function expects 1D array
            targets = targets.view(batch_size*context_length) # (batch*context) dimension

            # article explaining cross entropy applied to the following task
            # https://marinafuster.medium.com/cross-entropy-loss-for-next-token-prediction-83c684fa26d5
            loss = F.cross_entropy(logits, targets) # target as the "desired class"

        return logits, loss

### 🎲 Playing with the untrained model

In [None]:
torch.manual_seed(1337)

context_length = 1
batch_size = 1

xb, yb = get_batch(
    data=train_data,
    context_length=context_length,
    batch_size=batch_size)

bigram = BigramLanguageModel(vocab_size)
print(f'using inputs: {xb}, targets: {yb}')
logits, loss = bigram(xb, yb)
probabilities = F.softmax(logits[0], dim=-1)
print(f'\nprobability distribution for first token:\n{probabilities}\n')
print(f'probability for target token {yb[0]}: {probabilities[yb[0]]}\n')
print(f'loss {loss}')

using inputs: tensor([[41]]), targets: tensor([[46]])

probability distribution for first token:
tensor([0.0037, 0.0118, 0.0044, 0.0038, 0.0172, 0.0046, 0.0097, 0.0140, 0.0086,
        0.0018, 0.0554, 0.0024, 0.0073, 0.0118, 0.0009, 0.0061, 0.0085, 0.0073,
        0.0114, 0.0169, 0.0175, 0.0016, 0.0112, 0.0797, 0.0144, 0.0090, 0.0590,
        0.0542, 0.0063, 0.0098, 0.0318, 0.0035, 0.0325, 0.0046, 0.0100, 0.0023,
        0.0069, 0.0013, 0.0093, 0.0105, 0.0009, 0.0137, 0.0062, 0.0249, 0.0052,
        0.0023, 0.0076, 0.0067, 0.0272, 0.0207, 0.0373, 0.0136, 0.0203, 0.0089,
        0.0555, 0.0069, 0.0133, 0.0150, 0.0100, 0.0034, 0.0048, 0.0168, 0.0522,
        0.0184, 0.0252], grad_fn=<SoftmaxBackward0>)

probability for target token tensor([46]): tensor([0.0076], grad_fn=<IndexBackward0>)

loss 4.873363494873047


Is the model deterministic or stochastic❓

In [None]:
# generating new tokens with the untrained model

# batch of size 1, with context length 1, generate a sequence of tokens 24
encoded_context = torch.full((1, 1), 24)
print(f"encoded: {encoded_context[0]} and decoded: {decode(encoded_context[0].tolist())}.\n")

encoded: tensor([24]) and decoded: L.



In [None]:
max_new_tokens = 1000
generated_sequences = bigram.generate(encoded_context, max_new_tokens=max_new_tokens)
for sequence_index in range(1): # it is batch of size 1
    print(decode(generated_sequences[sequence_index].tolist()))

LL?DP-QW3rErLHjMDVIcLVy'!IHDTHdhs Yv&ix$,3sDYZwLEPS'pweuqOzsZuAT:F-QTkeMk xZAQ$FCLg!iW3.O!tDGIA Ysq3pdc-fKnneeJydQ,'rF!QSOk!:vwWSpwWxbjP
DmQgBKnjtfxr$Z-JI$WSaJq'yE-y-stSqyweZfY && qDmluBv&x$LptM!z:UKaMUM;v,pFZQs'Zalk'u!tKqgoSXujG'LJ $$lE-a cKsDVRHwpvOafSo'VNGINFMgy'hZCWhc.w.h!YP 
aGxID!plxw?pxe?xsUHHVxG&kotWujosI.JFWVzqygBNw iEQa UEavePOIY LHwJMH?zdRR;yhCVjwW.RRvMjt-fpckstn pa bKn e;?xq.PLJo!MkH?DNYJpIISF!TueKXG.RUhnsDmgRGFrSUIM!Sp!?BeMU!y
,MobQzOP'ybNWCbNSpL?ZxiODDLp!!yEFmY'ZOqkWURplxyRBSOVbGdMjWS-AtWvevUndOmmf!QaOosAWVYwaeohz-
ogVUxJ;I&nSq,
omxpYgQjH?xyj'aM'SoBhrBfQiNY.hFst!'xqWRR:BxJ;hXKcKL$FU&V.bESfHn b;:b
vosASkcxI'dI&gydEGWjfAbNJSTWBSUb-TKi-smYHeYcjlgi
D;kyBjQICE y
'LfkGa FvxIAThhzXTKDF evqohyXDLp-kywOaacKLOJkHG3el?leX
IUthrK-kNB;:BZflJNcytT&wwBNfMfwMNcNB,abFUFAyt,QFyMU:BGiAFD.hLT&UP'?$,asAjtC,RI$PO,htMJsACJgiafYkyZ?NeD'aHU?XG!pwJI-P
kjwOPlfwRRUuArHnohiviu.hFEhyM?tZW'PB?Nz;3ef&;&tCgI$yks'yF!.e;3rfYw3,oBjtfSUM;;EaGC$Fkybn?ZQVi-KtMXD.Q iOK;mZInohhYEXbfEOSUuLHC.QVNp'&rSOurkfciCJIfpy

### 🏋️‍♀️ Training the bigram model

In [None]:
optimizer = torch.optim.AdamW(bigram.parameters(), lr=1e-3)

def train(training_context_length, training_batch_size, iterations):
  for steps in range(iterations):
    input_batch, target_batch = get_batch(
        train_data,
        training_context_length,
        training_batch_size)

    # evaluate the loss
    logits, loss = bigram(input_batch, target_batch)
    loss.backward() # backpropagate the loss
    optimizer.step() # based on the gradients, take a step in the right direction

  return loss

In [None]:
loss = train(training_context_length=8, training_batch_size=32, iterations=1000)
print(f'loss is {loss}')

loss is 3.1414802074432373


In [None]:
# generate a new sequence with the 'L' from before
generated_sequences = bigram.generate(encoded_context, max_new_tokens=300)
print(decode(generated_sequences[0].tolist()))

LAgo heveld mof hou ing be ore, ENICAStho hinoutorin n
Ththathels ghy chis m theanir his trur.
Alateisavitave win hee pe:
O HO,

NI larowry windowou se pllme.
UKE urs ISe nt ld and ibeano y:
TUS:

Fot.

Ba VIOKIEOMENGonkity ashan acks d brselas t hea w My, r momeisiock's glindr my.
TEOLotheerthear pp


In [None]:
loss = train(training_context_length=8, training_batch_size=32, iterations=30000)
print(f'loss is {loss}')

loss is 2.8482561111450195


In [None]:
# generate a new sequence with the 'L' from before
generated_sequences = bigram.generate(encoded_context, max_new_tokens=300)
print(decode(generated_sequences[0].tolist()))

LI be hasty myor h

INENCHentl y, nto ckeseal towousson f Bon huses:

The Ant?
DUCooy he s.
ICltous ishe t the byockind, I ken br anginde oris so oullatuampe fowarielicad.
Toursicormeno mmendefor t a lanenofetenson ar toflalenkevou brdinst.

Cle llyo theleearal mie s o whofan m, t, s m Tho ie horavir


### 💡 Takeaways
- Scores are also referred to as "logits".
- Cross entropy is the loss function that allows the optimizer to adjust these scores so that probabilities are closer to the "real" structure of the corpus. This is done by comparing the "real" distribution of outputs vs. the bigram (could be other model) outputs.
- Text generation is stochastic.
- A bigram is one of the simplest models to "predict" next token. It could avoid training altoghether if we just compute the probabilities by hand counting occurences in the corpus. This approach is more rigid though.

## 🧐 Implementing Attention
In the bigram version there's no information about how interesting tokens are for the token we are predicting things from. There is no **affinity** between tokens. What strategies could be use to assign a certain weight to previous tokens, that will influence the next token prediction?

Remember that for token prediction we are basing the decision in PAST tokens.

In [None]:
torch.manual_seed(1337)

# 5 instead of 65 just so printing tensor is not so long
batch_size, context_length, n_embed = 1, 4, 4

fake_embeddings = torch.randn(batch_size, context_length, n_embed)

fake_tokens = [0, 3, 2, 1]
print(f"Let's imagine I have the tokens {fake_tokens}")
print(fake_embeddings) #token representations

Let's imagine I have the tokens [0, 3, 2, 1]
tensor([[[ 0.1808, -0.0700, -0.3596, -0.9152],
         [ 0.6258,  0.0255,  0.9545,  0.0643],
         [ 0.3612,  1.1679, -1.3499, -0.5102],
         [ 0.2360, -0.2398, -0.9211,  1.5433]]])


### 🧮 Averaged Weights Approach

In [None]:
averaged_logits = torch.zeros((batch_size, context_length, n_embed))

for batch_element in range(batch_size): # each sample from the batch
    for context_element in range(context_length):
        fake_previous_logits = fake_embeddings[batch_element,:context_element+1] # (t,C)
        averaged_logits[batch_element, context_element] = torch.mean(fake_previous_logits, 0)

print(f"A sample from the batch:")
for context_element in range(context_length):
    print(f"token {fake_tokens[context_element]}, logit: {fake_embeddings[0][context_element]}" \
        f", averaged logits: {averaged_logits[0][context_element]}")

A sample from the batch:
token 0, logit: tensor([ 0.1808, -0.0700, -0.3596, -0.9152]), averaged logits: tensor([ 0.1808, -0.0700, -0.3596, -0.9152])
token 3, logit: tensor([0.6258, 0.0255, 0.9545, 0.0643]), averaged logits: tensor([ 0.4033, -0.0222,  0.2974, -0.4254])
token 2, logit: tensor([ 0.3612,  1.1679, -1.3499, -0.5102]), averaged logits: tensor([ 0.3892,  0.3745, -0.2517, -0.4537])
token 1, logit: tensor([ 0.2360, -0.2398, -0.9211,  1.5433]), averaged logits: tensor([ 0.3509,  0.2209, -0.4190,  0.0456])


In [None]:
from torch.nn import functional as F

# alternative way of doing the same average computation
lower_triangular = torch.tril(torch.ones(context_length, context_length))
weights = torch.zeros((context_length, context_length))

weights = weights.masked_fill(lower_triangular == 0, float('-inf')) # the masked indicates we can't look into the future

# what softmax does is e^coefficient (e^0 is 1, e^-inf is 0) and then divide by the sum of elements in the row.
weights = F.softmax(weights, dim=-1)
print(f'weights are:\n{weights}')
print(f'averaged logits with new version:\n{weights @ fake_embeddings}')


weights are:
tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500]])
averaged logits with new version:
tensor([[[ 0.1808, -0.0700, -0.3596, -0.9152],
         [ 0.4033, -0.0222,  0.2974, -0.4254],
         [ 0.3892,  0.3745, -0.2517, -0.4537],
         [ 0.3509,  0.2209, -0.4190,  0.0456]]])


Which problems can you think of with this approach❓

### 🎭 Masked Weights Approach

In [None]:
torch.manual_seed(1337)

# let's see a single head perform self-attention
head_size = 8 # this is d_k in transformers' paper
key = torch.nn.Linear(n_embed, head_size, bias=False)
query = torch.nn.Linear(n_embed, head_size, bias=False)
value = torch.nn.Linear(n_embed, head_size, bias=False)

# Access the weights of the layer
weights = key.weight
print(weights)

Parameter containing:
tensor([[-0.4217, -0.0044,  0.1231, -0.0776],
        [-0.2996, -0.4713,  0.0851,  0.1967],
        [-0.3239, -0.2405,  0.2086,  0.0809],
        [-0.4426,  0.2669,  0.3778, -0.2566],
        [ 0.1005,  0.2079,  0.0102, -0.0935],
        [ 0.3864, -0.1422,  0.3963,  0.4639],
        [-0.4852,  0.2358,  0.2884,  0.4469],
        [-0.0344,  0.3378, -0.3731, -0.2868]], requires_grad=True)


In [None]:
k = key(fake_embeddings)   # k is (batch_size, context_length, head_size)
q = query(fake_embeddings) # q is (batch_size, context_length, head_size)
# wei is (batch_size, context_length, context_length)
wei =  q @ k.transpose(-2, -1)

# masking attention because it's decoder mechanism
tril = torch.tril(torch.ones(context_length, context_length))
wei = wei.masked_fill(tril == 0, float('-inf'))

# applying to logits to obtain probabilities
softmax_weights = F.softmax(wei, dim=-1)
print(f'weights are:\n{softmax_weights}')

weights are:
tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
         [0.3351, 0.6649, 0.0000, 0.0000],
         [0.3062, 0.3892, 0.3046, 0.0000],
         [0.2787, 0.2557, 0.2636, 0.2020]]], grad_fn=<SoftmaxBackward0>)


In [None]:
scaled_wei = wei / head_size**0.5
v = value(fake_embeddings)
out = F.softmax(scaled_wei, dim=-1) @ v
print(out)

tensor([[[ 0.5184,  0.2507,  0.2316,  0.1183,  0.2188,  0.3971, -0.2551,
          -0.2845],
         [ 0.1033, -0.0743, -0.0812,  0.0715,  0.2333,  0.2587, -0.4965,
           0.0815],
         [ 0.3060,  0.1583,  0.2150, -0.0083,  0.2695,  0.3020, -0.2340,
          -0.1180],
         [ 0.2844,  0.2292,  0.1089,  0.0274,  0.1873,  0.0954,  0.0206,
          -0.1145]]], grad_fn=<UnsafeViewBackward0>)


Why scaling the weights in the paper❓

In [None]:
print("softmax on small variance weights")
print(torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5]), dim=-1))

print("softmax on large variance weights")
print(torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5])*100, dim=-1))

softmax on small variance weights
tensor([0.1925, 0.1426, 0.2351, 0.1426, 0.2872])
softmax on large variance weights
tensor([4.2484e-18, 3.9754e-31, 2.0612e-09, 3.9754e-31, 1.0000e+00])


# 🤖 Transformer Implementation

## 🏛️ Architecture

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
class Head(nn.Module):
    """ one head of self-attention """
    def __init__(
        self,
        head_size,
        n_embd, # this is d_model in the paper
        context_length,
        dropout):

        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(context_length, context_length)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # output (batch_size, context_length, vocab_size)
        batch_size, context_length, vocab_size = x.shape

        k = self.key(x)   # (batch_size, context_length, head_size)
        q = self.query(x) # (batch_size, context_length, head_size)

        # (BS, CL, HS) @ (BS, HS, CL) -> (BS, CL, CL)
        wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5
        wei = wei.masked_fill(self.tril[:context_length, :context_length] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)

        # perform the weighted aggregation of the values
        v = self.value(x)
        out = wei @ v # (BS, CL, CL) @ (BS, CL, HS) -> (BS, CL, HS)

        # output (batch_size, context_length, head_size)
        return out

In [None]:
class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size, n_embd, context_length, dropout):
        super().__init__()
        self.heads = nn.ModuleList(
            [Head(head_size, n_embd, context_length, dropout) for _ in range(num_heads)]
        )
        self.proj = nn.Linear(head_size * num_heads, n_embd) # implementing the skip connection
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))

        return out

In [None]:
class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd, dropout):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            # this linear is the same as the projection in the multi-head and I could have it
            # outside of this sequential, but in order to simplify, we'll use it
            nn.Linear(4 * n_embd, n_embd),
            # this is something you can add right before the residual connection goes back into the pathway.
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

In [None]:
class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head, context_length, dropout):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()

        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size, n_embd, context_length, dropout) # outputs heads * n_embed

        self.ffwd = FeedFoward(n_embd, dropout) # outputs n_embed

        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x)) # implementing the skip connection
        x = x + self.ffwd(self.ln2(x)) # implementing the skip connection
        return x

In [None]:
class TransformerLanguageModel(nn.Module):

    def __init__(self, vocab_size, n_embd, context_length, n_head, n_layer, dropout):
        super().__init__()

        self.context_length = context_length

        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(context_length, n_embd)

        # in here, we are combining "communication" with "computation/pattern representation" many times
        self.blocks = nn.Sequential(*[Block(n_embd, n_head, context_length, dropout) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd) # final layer norm

        self.lm_head = nn.Linear(n_embd, vocab_size)

        # out of the scope of this presentation
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        batch_size, context_length = idx.shape
        tok_emb = self.token_embedding_table(idx) # (batch_size, context_length, vocab_size)
        pos_emb = self.position_embedding_table(torch.arange(context_length, device=device)) # (batch_size, context_length, vocab_size)
        x = tok_emb + pos_emb # (batch_size, context_length, vocab_size)
        x = self.blocks(x) # (batch_size, context_length, vocab_size)
        x = self.ln_f(x) # (batch_size, context_length, vocab_size)
        logits = self.lm_head(x) # (batch_size, context_length, vocab_size)

        if targets is None:
            loss = None
        else:
            _, _, vocab_size = logits.shape
            logits = logits.view(batch_size*context_length, vocab_size)
            targets = targets.view(batch_size*context_length)
            # article explaining cross entropy applied to the following task
            # https://marinafuster.medium.com/cross-entropy-loss-for-next-token-prediction-83c684fa26d5
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            # crop idx to the last context_length tokens
            idx_cond = idx[:, -self.context_length:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :]
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

## 🏋️‍♀️ Training the transformer model

In [None]:
# hyperparameters
batch_size = 16 # how many independent sequences will we process in parallel?
context_length = 32 # what is the maximum context length for predictions?
max_iters = 5000
eval_interval = 100
learning_rate = 1e-3
eval_iters = 200
n_embd = 64
n_head = 4
n_layer = 4
dropout = 0.0
vocab_size = 65

In [None]:
model = TransformerLanguageModel(
    vocab_size=vocab_size,
    n_embd=n_embd,
    context_length=context_length,
    n_head=n_head,
    n_layer=n_layer,
    dropout=dropout
)
m = model.to(device)

# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

0.209729 M parameters


In [None]:
from collections import defaultdict

param_counts = defaultdict(int)

for name, param in m.named_parameters():
    if param.requires_grad:
        module_name = name.split('.')[0]
        param_counts[module_name] += param.numel()

for module, count in param_counts.items():
    print(f"{module}: {count} parameters")

total_params = sum(p.numel() for p in m.parameters() if p.requires_grad)
print(f"Total: {total_params / 1e6:.6f}M parameters")

token_embedding_table: 4160 parameters
position_embedding_table: 2048 parameters
blocks: 199168 parameters
ln_f: 128 parameters
lm_head: 4225 parameters
Total: 0.209729M parameters


In [None]:
# data loading
def get_batch_for_transformer(split, context_length):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - context_length, (batch_size,))
    x = torch.stack([data[i:i+context_length] for i in ix])
    y = torch.stack([data[i+1:i+context_length+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y


@torch.no_grad()
def estimate_loss(context_length):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch_for_transformer(split, context_length)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [None]:
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

def train_transformer(iterations):
  for iter in range(iterations):

      # every once in a while evaluate the loss on train and val sets
      if iter % eval_interval == 0 or iter == max_iters - 1:
          losses = estimate_loss(context_length)
          print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

      # sample a batch of data
      xb, yb = get_batch_for_transformer('train', context_length)

      # evaluate the loss
      logits, loss = model(xb, yb)
      optimizer.zero_grad(set_to_none=True)
      loss.backward()
      optimizer.step()

  return loss

In [None]:
loss = train_transformer(1)
print(f'loss is {loss}')

step 0: train loss 4.2124, val loss 4.2100
loss is 4.198349475860596


In [None]:
# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=300)[0].tolist()))


& .nf ioQDBg'PIWo!KtwbuZNYP&?$'eioyGv'lwW;&T 'zuS'XfjZNXJViDtVRq$XazZ;$Sk:PG-:!.QeGuFyzoPHy!pmMzV,!qN&y!?wLnseTt,
JHa.hdtuV!$EDofld
,&JqrUPR patViNE:mIrqio sdDSn'Idf&o!AHdmNFFGVm3TZhMKYGhvnLJQ!OM
'UPH,;?eWlzOb$Nl3:rrqkhZq?AWGlRSPBTvt$
sUiclJgrOnZEGry3fFngGKm fgslXLy 'deXMAvX KwhteFFSbdE'!rWcUR:&Nd;.


In [None]:
loss = train_transformer(2000)
print(f'loss is {loss}')

step 0: train loss 3.9855, val loss 3.9919
step 100: train loss 2.6256, val loss 2.6245
step 200: train loss 2.4633, val loss 2.4563
step 300: train loss 2.3899, val loss 2.3902
step 400: train loss 2.3209, val loss 2.3319
step 500: train loss 2.2603, val loss 2.2824
step 600: train loss 2.2129, val loss 2.2190
step 700: train loss 2.1469, val loss 2.1830
step 800: train loss 2.1001, val loss 2.1368
step 900: train loss 2.0641, val loss 2.1099
step 1000: train loss 2.0461, val loss 2.0807
step 1100: train loss 1.9984, val loss 2.0674
step 1200: train loss 1.9821, val loss 2.0450
step 1300: train loss 1.9684, val loss 2.0493
step 1400: train loss 1.9444, val loss 2.0169
step 1500: train loss 1.9102, val loss 2.0167
step 1600: train loss 1.8905, val loss 1.9873
step 1700: train loss 1.8679, val loss 1.9916
step 1800: train loss 1.8501, val loss 1.9606
step 1900: train loss 1.8460, val loss 1.9566
loss is 1.8301169872283936


In [None]:
# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=300)[0].tolist()))


Fhich, onather my harl Second poine!

CORIXENES:
Sorrooly, you sonsmile your thie heave fast.

LUCIO:
I you:
You weard hold hisend: couged that you polly dukence!

Then, lord, une whow:
Aseass, pirice diley.

VILCHETH:
Mayar, upoinion.

DUKE OF OF RIV:
We honst da my sir, the see he do,
If grame par
