# How to train a Generative Pretrained Transformer (GPT) model?

The goal of this notebook is to explore the main aspects around training a generative AI model for text.\
We will review the main concepts & steps for training and talk also about how the prediction of new content happens.

**Topics:**
- Tokenization
- Main model components
- How the model performance is evaluated
- How the model is trained
- How the model is used to create new text

In [1]:
import os
import subprocess
from pathlib import Path

repo_path = (
    subprocess.check_output(["git", "rev-parse", "--show-toplevel"])
    .strip()
    .decode("utf-8")
)
os.chdir(repo_path)

## The different phases of model training

- Pre-training: the model understand the language logics but does not specifically know how to answer an instruction/question
- Supervised fine-tuning: the model is trained to answer instruction
- Alignement with human preferences (RLHF): the model is trained to align preferences with human preferences

See slides: [pre-training](https://docs.google.com/presentation/d/1S8ao40-CdclRU0D2D9FdyN5x8fZL1Iv5/edit?_hsenc=p2ANqtz--zYcYKj9_o5fLbt_D3P4tzLanpAyfFm14Z2NXEvCZxbsjLtax9y5mYCzRg-opvXZhsYGEH#slide=id.g2b67c4de1ca_1_13), [supervised fine-tuning](https://docs.google.com/presentation/d/1S8ao40-CdclRU0D2D9FdyN5x8fZL1Iv5/edit?_hsenc=p2ANqtz--zYcYKj9_o5fLbt_D3P4tzLanpAyfFm14Z2NXEvCZxbsjLtax9y5mYCzRg-opvXZhsYGEH#slide=id.g2b67c4de1ca_1_59), [RLHF](https://docs.google.com/presentation/d/1S8ao40-CdclRU0D2D9FdyN5x8fZL1Iv5/edit?_hsenc=p2ANqtz--zYcYKj9_o5fLbt_D3P4tzLanpAyfFm14Z2NXEvCZxbsjLtax9y5mYCzRg-opvXZhsYGEH#slide=id.g2b67c4de1ca_1_77)

## Loading the data

In [2]:
# load wiki english
import json

with open("wikiGPT/data/shuffled_shards/shard_0.json", "r") as file:
    for line in file:
        data = json.loads(line)
        break

In [3]:
data["content"][:200]

'Nowa Kiszewa  is a village in the administrative district of Gmina Kościerzyna, within Kościerzyna County, Pomeranian Voivodeship, in northern Poland. It lies approximately  south-east of Kościerzyna '

## Creating the tokenizer

### A naive tokenizer

In [8]:
vocab = sorted(list(set(data["content"])))
vocab_size = len(vocab)


def create_naive_tokenizer(text):
    tokenizer = {vocab[i]: i + 3 for i in range(len(vocab))}
    tokenizer["<s>"] = 0
    tokenizer["</s>"] = 1
    tokenizer["<unk>"] = 2
    detokenizer = {v: k for k, v in tokenizer.items()}
    return tokenizer, detokenizer


naive_tokenizer, naive_detokenizer = create_naive_tokenizer(data["content"])


def tokenize(text, tokenizer):
    return [tokenizer.get(letter, 2) for letter in text]


def detokenize(tokens, detokenizer):
    return "".join([detokenizer.get(token, "<unk>") for token in tokens])


tokens = tokenize("Let's train a GPT model!", naive_tokenizer)
print(tokens)
print(detokenize(tokens, naive_detokenizer))

[2, 21, 34, 2, 33, 4, 34, 32, 18, 25, 29, 4, 18, 4, 10, 15, 2, 4, 28, 30, 20, 21, 27, 2]
<unk>et<unk>s train a GP<unk> model<unk>


### A tokenizer based on the Byte Pair Encoding (BPE) algorithm

In [5]:
# Byte pair encoding => use the data to tell how to tokenize

# From a corpus of text to train the tokenizer
# Choose a vocabulary size: i.e. the maximum number of tokens our vocabulary can have

# Start with a character-level tokenizer from the corpus
"I love to learn about AI"  # => [' ', 'A', 'I', 'a', 'b', 'e', 'l', 'n', 'o', 'r', 't', 'u', 'v']

# Count the frequency of vocabulary items of characters in the corpus
"I love to learn about AI!"  # => {"I ": 1, " l": 2, "lo": 1, "ov": 1, ...}

# Merge the most frequent pair of characters into a single token
"I love to learn about AI!"  # => [' ', 'A', 'I', 'a', 'b', 'e', 'l', 'n', 'o', 'r', 't', 'u', 'v', ' l']

# Repeat until the vocabulary size is reached

'I love to learn about AI!'

In [6]:
from wikiGPT.tokenize import Tokenizer

tokenizer = Tokenizer(Path("wikiGPT/tokenizers/tok32000.model"))

tokens = tokenizer.encode("I love to learn about AI!", bos=False, eos=False)
print(tokens)
print(tokenizer.decode(tokens))

[318, 3832, 298, 3178, 900, 14788, 19086]
I love to learn about AI!


In [7]:
### A batch of data from human to machine reading

# A batch is composed of two dimensions: the number of samples we are going to pass to the model and the context length

corpus = [
    "I love to learn about AI!",
    "Let's train a GPT model!",
    data["content"][:30],
]
corpus

['I love to learn about AI!',
 "Let's train a GPT model!",
 'Nowa Kiszewa  is a village in ']

In [8]:
batch_size = 1
context_length = 4

# let's tokenize our text
tokenized_corpus = [tokenizer.encode(text, False, False) for text in corpus]

tokenized_corpus

[[318, 3832, 298, 3178, 900, 14788, 19086],
 [6974, 19055, 19006, 3820, 262, 11867, 19028, 3005, 19086],
 [344, 4397, 385, 273, 15544, 19000, 351, 262, 1429, 280]]

In [9]:
# Creating a first batch for the next token prediction
import random

sample = random.choice(range(len(tokenized_corpus)))
print(f"Selected sample: {sample}")

batch = tokenized_corpus[sample][: context_length + 1]

Selected sample: 0


In [10]:
batch

[318, 3832, 298, 3178, 900]

In [11]:
for i in range(context_length):
    print(f"token(s) {batch[:i+1]} need to predict {batch[i+1]}")

token(s) [318] need to predict 3832
token(s) [318, 3832] need to predict 298
token(s) [318, 3832, 298] need to predict 3178
token(s) [318, 3832, 298, 3178] need to predict 900


In [12]:
for i in range(context_length):
    print(
        f"""subword(s) '{tokenizer.decode(batch[:i+1])}' need to predict '{tokenizer.decode(batch[i+1])}'"""
    )

subword(s) 'I' need to predict 'love'
subword(s) 'I love' need to predict 'to'
subword(s) 'I love to' need to predict 'learn'
subword(s) 'I love to learn' need to predict 'about'


### Model architecture: let's have a look at the main components of a text generation model

#### First, the initial embedding of tokens

In [13]:
import torch

### The embedding layer

#### This is the initial layer that will transform the token (int) into a vector
vocab = [0, 1, 2]
vocab_size = len(vocab)
embedding_size = 6

embedding_layer = torch.rand(vocab_size, embedding_size)
print(embedding_layer)

tensor([[0.5374, 0.7877, 0.4807, 0.3653, 0.6011, 0.3297],
        [0.4546, 0.0125, 0.9942, 0.6804, 0.5595, 0.0616],
        [0.3252, 0.3354, 0.2572, 0.9170, 0.7575, 0.3402]])


In [14]:
batch = [1, 0, 0, 2]
batch_size = 1
context_length = 4
embedding = torch.zeros(batch_size, context_length, embedding_size)
for i, token in enumerate(batch):
    embedding[0, i] = embedding_layer[token]

print(embedding.shape)
print(embedding)

torch.Size([1, 4, 6])
tensor([[[0.4546, 0.0125, 0.9942, 0.6804, 0.5595, 0.0616],
         [0.5374, 0.7877, 0.4807, 0.3653, 0.6011, 0.3297],
         [0.5374, 0.7877, 0.4807, 0.3653, 0.6011, 0.3297],
         [0.3252, 0.3354, 0.2572, 0.9170, 0.7575, 0.3402]]])


In [15]:
batch_one_hot_encoded = torch.Tensor([[0, 1, 0], [1, 0, 0], [1, 0, 0], [0, 0, 1]])

batch_one_hot_encoded @ embedding_layer

tensor([[0.4546, 0.0125, 0.9942, 0.6804, 0.5595, 0.0616],
        [0.5374, 0.7877, 0.4807, 0.3653, 0.6011, 0.3297],
        [0.5374, 0.7877, 0.4807, 0.3653, 0.6011, 0.3297],
        [0.3252, 0.3354, 0.2572, 0.9170, 0.7575, 0.3402]])

#### The self-attention mechanism, at the heart of the transformer

In [16]:
from torch import nn
from torch.nn import functional as F

torch.manual_seed(56)
### What is the attention mechanism?

#### Attention allows the model to learn the affinities between the tokens
#### Given the past (previous token), how can they interact to predict the most likely next token?

# [I love to learn] ---> about
print("I love to learn")
tokens = tokenizer.encode("I love to learn", False, False)
print(tokens)
vocab_size = 32000
embedding_size = 8

embedding_layer = nn.Embedding(vocab_size, embedding_size)
x = embedding_layer(torch.LongTensor(tokens))
x

I love to learn
[318, 3832, 298, 3178]


tensor([[-0.1219, -0.9217,  0.0804, -0.1049, -0.4088,  1.1438, -0.2155,  0.8280],
        [-0.5726,  2.1356,  0.1306,  1.0591,  0.0731,  0.1786, -0.8104, -0.1310],
        [ 0.0997,  0.6453,  2.4119,  0.5138,  0.2973,  0.0155, -0.8086,  0.3057],
        [ 1.0756,  0.4220, -1.1676,  2.2978,  0.5456,  0.6231,  2.5453, -0.4455]],
       grad_fn=<EmbeddingBackward0>)

In [17]:
# I     ---> 318  ---> [-0.1219, -0.9217,  0.0804, -0.1049, -0.4088,  1.1438, -0.2155,  0.8280]
# love  ---> 3832 ---> [-0.5726,  2.1356,  0.1306,  1.0591,  0.0731,  0.1786, -0.8104, -0.1310]
# to    ---> 298  ---> [ 0.0997,  0.6453,  2.4119,  0.5138,  0.2973,  0.0155, -0.8086,  0.3057]
# learn ---> 3178 ---> [ 1.0756,  0.4220, -1.1676,  2.2978,  0.5456,  0.6231,  2.5453, -0.4455]

In [18]:
context_length = 4
mask = torch.tril(torch.ones(context_length, context_length))
mask

tensor([[1., 0., 0., 0.],
        [1., 1., 0., 0.],
        [1., 1., 1., 0.],
        [1., 1., 1., 1.]])

In [19]:
weights = torch.zeros(context_length, context_length)
weights = weights.masked_fill(mask == 0, float("-Inf"))
weights = F.softmax(weights, dim=-1)
weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500]])

In [20]:
weights @ x

tensor([[-0.1219, -0.9217,  0.0804, -0.1049, -0.4088,  1.1438, -0.2155,  0.8280],
        [-0.3473,  0.6069,  0.1055,  0.4771, -0.1679,  0.6612, -0.5130,  0.3485],
        [-0.1983,  0.6197,  0.8743,  0.4894, -0.0128,  0.4460, -0.6115,  0.3342],
        [ 0.1202,  0.5703,  0.3638,  0.9415,  0.1268,  0.4903,  0.1777,  0.1393]],
       grad_fn=<MmBackward0>)

In [21]:
### How can we get meaningfull weights for the attention?
# queries and key will interact to decide the level of affinity between tokens
torch.manual_seed(78)
key = nn.Linear(embedding_size, embedding_size, bias=False)
query = nn.Linear(embedding_size, embedding_size, bias=False)
k = key(
    x
)  # (context_length, embedding_size) @ (embedding_size, embedding_size) => (context_length, embedding_size)
q = query(x)
print(k.shape, q.shape)

torch.Size([4, 8]) torch.Size([4, 8])


In [22]:
print(k)
print(q)

tensor([[ 0.2632,  0.2153, -0.1248,  0.4342,  0.6620,  0.3543,  0.1199,  0.3901],
        [-0.4048,  0.7153, -0.1434,  0.3645, -0.2861, -0.3934,  0.7692, -0.6666],
        [-0.5860,  0.5796,  0.3083, -0.0288,  0.1974,  0.6948,  0.3016,  0.3322],
        [ 1.0913, -0.8812,  0.7187, -0.1261,  0.0445, -0.4471,  0.7331,  1.0798]],
       grad_fn=<MmBackward0>)
tensor([[-0.3480, -0.0343,  0.1525, -0.1166,  0.3183,  0.0807,  0.3076, -0.0275],
        [ 0.3468,  0.8050, -0.4144,  0.2355,  0.0146, -0.3874, -0.5229,  0.4610],
        [-0.0678,  1.0615,  0.4411,  1.0523,  0.1471,  0.2051,  0.2539,  0.6423],
        [ 1.4488, -1.9597, -0.4880,  0.0270,  0.0287, -0.8733, -1.3105, -0.9306]],
       grad_fn=<MmBackward0>)


In [23]:
weights = q @ k.T

In [24]:
weights
weights = weights.masked_fill(mask == 0, float("-Inf"))
weights = F.softmax(weights, dim=-1)
weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.5960, 0.4040, 0.0000, 0.0000],
        [0.3444, 0.2520, 0.4036, 0.0000],
        [0.0996, 0.0306, 0.0069, 0.8629]], grad_fn=<SoftmaxBackward0>)

In [25]:
#          I      love     to     learn
# I     [1.0000, 0.0000, 0.0000, 0.0000]
# love  [0.5960, 0.4040, 0.0000, 0.0000]
# to    [0.3444, 0.2520, 0.4036, 0.0000]
# learn [0.0996, 0.0306, 0.0069, 0.8629]

In [26]:
# aggregating information from the values
torch.manual_seed(78)
values = nn.Linear(embedding_size, embedding_size, bias=False)
v = values(x)
output = weights @ v
output

tensor([[ 0.2632,  0.2153, -0.1248,  0.4342,  0.6620,  0.3543,  0.1199,  0.3901],
        [-0.0067,  0.4173, -0.1323,  0.4060,  0.2789,  0.0522,  0.3822, -0.0368],
        [-0.2478,  0.4883,  0.0453,  0.2298,  0.2355,  0.3033,  0.3568,  0.1004],
        [ 0.9514, -0.7130,  0.6055, -0.0546,  0.0970, -0.3577,  0.6701,  0.9525]],
       grad_fn=<MmBackward0>)

In [27]:
### Final layer to produce logits and probabilities

final_projection = nn.Linear(embedding_size, vocab_size)
logits = final_projection(output)

logits.shape

torch.Size([4, 32000])

In [28]:
probs = F.softmax(logits, dim=-1)
probs[-1]

tensor([5.8534e-05, 1.7943e-05, 2.7108e-05,  ..., 2.0701e-05, 3.7982e-05,
        5.1995e-05], grad_fn=<SelectBackward0>)

In [29]:
### How can we measure the performance of a model? The cross-entropy loss

##### For each target token, we can retrieve the probability of the model to predict it
##### In a batch, we have batch_size * context_length predictions tokens
##### One way to measure the quality is to multiply all the model probability to predict the next token

# token(s) [6974] need to predict 19055
# token(s) [6974, 19055] need to predict 19006
# token(s) [6974, 19055, 19006] need to predict 3820
# token(s) [6974, 19055, 19006, 3820] need to predict 262

In [30]:
print(probs[0, 19055])
print(probs[1, 19006])
print(probs[2, 3820])
print(probs[3, 262])

tensor(5.0201e-05, grad_fn=<SelectBackward0>)
tensor(2.8422e-05, grad_fn=<SelectBackward0>)
tensor(3.3656e-05, grad_fn=<SelectBackward0>)
tensor(1.0244e-05, grad_fn=<SelectBackward0>)


In [34]:
# To avoid numeric instability, we can use the log of the probabilities
#  log(a*b*c) = log(a) + log(b) + log(c)

(
    torch.log(probs[0, 19055])
    + torch.log(probs[1, 19006])
    + torch.log(probs[2, 3820])
    + torch.log(probs[3, 262])
)

tensor(-41.1317, grad_fn=<AddBackward0>)

In [35]:
# Because, we like to minimize the loss, we can use the negative log likelihood

-(
    torch.log(probs[0, 19055])
    + torch.log(probs[1, 19006])
    + torch.log(probs[2, 3820])
    + torch.log(probs[3, 262])
)

tensor(41.1317, grad_fn=<NegBackward0>)

In [36]:
### How can we actually train the model?
#### What is gradient descent and backpropagation?

from wikiGPT.model import Transformer, ModelArgs

model_config = ModelArgs(
    dim=64,
    n_layers=1,
    n_heads=2,
    vocab_size=32000,
    hidden_dim=64,
    max_context_length=10,
)

model = Transformer(model_config)

In [37]:
print(model)

Transformer(
  (tok_embeddings): Embedding(32000, 64)
  (dropout): Dropout(p=0.0, inplace=False)
  (layers): ModuleList(
    (0): TransformerBlock(
      (attn): Attention(
        (wq): Linear(in_features=64, out_features=64, bias=False)
        (wk): Linear(in_features=64, out_features=64, bias=False)
        (wv): Linear(in_features=64, out_features=64, bias=False)
        (wo): Linear(in_features=64, out_features=64, bias=False)
        (attn_dropout): Dropout(p=0.0, inplace=False)
        (resid_dropout): Dropout(p=0.0, inplace=False)
      )
      (feed_forward): FeedForward(
        (w1): Linear(in_features=64, out_features=96, bias=False)
        (w2): Linear(in_features=96, out_features=64, bias=False)
        (w3): Linear(in_features=64, out_features=96, bias=False)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (attn_norm): RMSNorm()
      (ff_norm): RMSNorm()
    )
  )
  (norm): RMSNorm()
  (output): Linear(in_features=64, out_features=32000, bias=False)
)


In [38]:
params = {pn: p for pn, p in model.named_parameters() if p.requires_grad}
params = [p for _, p in params.items()]
print(f"Number of model params: {sum(p.numel() for p in params):,}")

Number of model params: 2,083,008


In [49]:
from wikiGPT.iterate import TokenIterator
from functools import partial

# training loop
iter_params = {
    "pretokenized_source": Path(f"wikiGPT/data/tok{model_config.vocab_size}"),
    "context_length": model_config.max_context_length,
    # "verbose": True,
}
iter_batches = partial(
    TokenIterator.iter_batches,
    batch_size=2,
    device="cpu",
    num_workers=0,
    **iter_params,
)
train_batch_iter = iter_batches(split="train")
X, Y = next(train_batch_iter)

# define an optimizer
optimizer = torch.optim.AdamW(params=model.parameters(), lr=5e-4)

In [50]:
X.shape

torch.Size([2, 10])

#### Good introduction videos gradient descent and on backpropagation from 3blue1brown youtube channel
[What is a neural network](https://www.youtube.com/watch?v=aircAruvnKk&list=PLZHQObOWTQDNU6R1_67000Dx_ZCJB-3pi&index=1)\
[What is gradient descent](https://www.youtube.com/watch?v=IHZwWFHWa-w&list=PLZHQObOWTQDNU6R1_67000Dx_ZCJB-3pi&index=2)\
[Backpropagation intuitive explanation](https://www.youtube.com/watch?v=Ilg3gGewQ5U&list=PLZHQObOWTQDNU6R1_67000Dx_ZCJB-3pi&index=3)\
[Backpropagation computation](https://www.youtube.com/watch?v=tIeHLnjs5U8&list=PLZHQObOWTQDNU6R1_67000Dx_ZCJB-3pi&index=4)

### The training loop

In [56]:
for i in range(1000):
    # if i > 0:
    #     break

    # project the input through the model
    logits = model(X, Y)

    # compute the loss
    logits = logits.view(-1, logits.size(-1))
    targets = Y.view(-1)
    loss = F.cross_entropy(logits, targets)

    # update the weights
    loss.backward()
    optimizer.step()
    optimizer.zero_grad(set_to_none=True)

    # get next batch of data
    X, Y = next(train_batch_iter)

In [58]:
loss

tensor(6.7052, grad_fn=<NllLossBackward0>)

### How do we generate new text with the model?

In [67]:
temperature = 1.0
top_k = None

start_ids = tokenizer.encode("", bos=True, eos=False)
x = torch.tensor(start_ids, dtype=torch.long, device="cpu")[None, ...]

with torch.inference_mode():
    for _ in range(10):
        logits = model(x)
        logits = logits[:, -1, :]  # crop to just the final time step

        if temperature > 0:
            # pluck the logits at the final step and scale by desired temperature
            logits = logits / temperature

            # optionally crop the logits to only the top k options
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float("Inf")

            # apply softmax to convert logits to (normalized) probabilities
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)

        else:
            _, idx_next = torch.topk(logits, k=1, dim=-1)
        # append sampled index to the running sequence and continue
        x = torch.cat((x, idx_next), dim=1)
        print(x)

    print(tokenizer.decode(x.tolist()))

tensor([[  1, 952]])
tensor([[  1, 952, 370]])
tensor([[  1, 952, 370, 262]])
tensor([[   1,  952,  370,  262, 9294]])
tensor([[   1,  952,  370,  262, 9294,  396]])
tensor([[   1,  952,  370,  262, 9294,  396, 3408]])
tensor([[    1,   952,   370,   262,  9294,   396,  3408, 31157]])
tensor([[    1,   952,   370,   262,  9294,   396,  3408, 31157, 19034]])
tensor([[    1,   952,   370,   262,  9294,   396,  3408, 31157, 19034,   317]])
tensor([[    1,   952,   370,   262,  9294,   396,  3408, 31157, 19034,   317,
           667]])
['oss with a CBS he below딪I l all']
