# Building a Bigram model

I am going to be building a simple bigram model (as introduction to then build a transformer) following `Andrej Karpathy` tutorial `Let's build GPT: from scratch, in code, spelled out`.

The idea is to improve our knowledge of transformers and connect from the theory to the actual code.

To train our transformer we will use a Shakespeare dataset and in theory the model after training will be able to generate sequences of text like Shapespeare.

### Import packages

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F

### Device agnostic code

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cpu'

### Download the Shakespeare Dataset

1. We are going to download the `.txt` from the karpathy repo.
2. We are going to read it and store it in our code variable.

In [None]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2025-12-03 12:16:05--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.108.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2025-12-03 12:16:06 (22.9 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [None]:
with open(file="/content/input.txt", mode="r", encoding="utf-8") as f:
  text = f.read()

print(f"Aprox length of dataset: {len(text)}\n")
print(f"First 100 characters of the dataset:\n{text[:100]}")

Aprox length of dataset: 1115394

First 100 characters of the dataset:
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


### Build vocabulary based on characters (Character-level)

This will be the characters that the model can see or 'generate'.

In [None]:
# Characters
characters = sorted(list(set(text)))
# Vocabulary size
vocab_size = len(characters)

print(f"Characters: {''.join(characters)}")
print(f"Vocab size: {vocab_size}")

Characters: 
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
Vocab size: 65


### Character level lookup tables

This is a simple tokenizer for our characters/vocab. We will have a loop up table to transform characters into tokens/ints and another to do the reverse process (transform tokens/ints to characters/text).

In [None]:
# We create the lookup tables (Character-base)
string_to_id = { char:idx for idx, char in enumerate(characters)}
id_to_string = { idx:char for idx, char in enumerate(characters)}

# Create the functions to look words in the lookup tables
encode = lambda sentence: [string_to_id[char] for char in sentence]
decode = lambda id_list: "".join(id_to_string[id] for id in id_list)

In [None]:
# Testing our our tokenizers
print(encode("hii there"))
print(decode(encode("hii there")))

[46, 47, 47, 1, 58, 46, 43, 56, 43]
hii there


### Tokenize the entire text

We are going to encode the whole dataset and transform it into a tensor

In [None]:
data = torch.tensor(encode(text), dtype=torch.long)

print(f"Data tensor shape: {data.shape, data.dtype}")
print(f"First 100 encodes: {data[:100]}")

Data tensor shape: (torch.Size([1115394]), torch.int64)
First 100 encodes: tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59])


### Split data between training & validation set

This will allow us to train our model and later test it with data it hasn't seen yet. We are going to do it manually (90% for training).

In [None]:
# Get the index where we need to "cut"
n = int(0.9 * len(data))
# Training set
train_data = data[:n]
# Validation set
val_data = data[n:]

print(f"Training length: {len(train_data)}")
print(f"Validation length: {len(val_data)}")

Training length: 1003854
Validation length: 111540


### Context size (Block size)

We want to sample parts/chunks of the text. The idea is that the model can see want should come next from one character up to block size.

In [None]:
block_size = 8
train_data[:block_size + 1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

In [None]:
x = train_data[:block_size]
y = train_data[1:block_size + 1]

for t in range(block_size):
  context = x[:t+1]
  target = y[t]

  print(f"When input is {context} the target: {target}")

When input is tensor([18]) the target: 47
When input is tensor([18, 47]) the target: 56
When input is tensor([18, 47, 56]) the target: 57
When input is tensor([18, 47, 56, 57]) the target: 58
When input is tensor([18, 47, 56, 57, 58]) the target: 1
When input is tensor([18, 47, 56, 57, 58,  1]) the target: 15
When input is tensor([18, 47, 56, 57, 58,  1, 15]) the target: 47
When input is tensor([18, 47, 56, 57, 58,  1, 15, 47]) the target: 58


Let's generalize it to also work with batches

In [None]:
from typing import Literal
torch.manual_seed(1337)

batch_size = 4 # Amount of sequences processed in parallel
block_size = 8 # Max context

def get_batch(split: Literal["train", "validation"]) -> tuple[torch.Tensor, torch.Tensor]:
  """
  Generate a small batch of data of inputs x and targets y.
  It first generates 4 random locations in the dataset, and then extract the data for each of those indexes.

  Args:
    split ("train" | "validation"): What data split to use
  Returns:
    x -> Tensor
    y -> Tensor
  """
  data = train_data if split == "train" else val_data
  # Sample random parts
  ix = torch.randint(len(data) - block_size, (batch_size,))
  # Get the actual data for that random sample
  x = torch.stack([data[i:i+block_size] for i in ix])
  y = torch.stack([data[i+1:i+block_size+1] for i in ix])
  x, y = x.to(device), y.to(device)

  return x, y

In [None]:
xb, yb = get_batch('train')
print('inputs:')
print(xb.shape)
print(xb)
print('targets:')
print(yb.shape)
print(yb)

inputs:
torch.Size([4, 8])
tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])
targets:
torch.Size([4, 8])
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])


# Simple model (Bigram)

We are going to be implementing a simpler model first and then move to the more complex one. We will start with a `Bigram Language Model`. Bigrams would only take the last idx when generating but as we are later going to build the transformer, which takes the whole sequence, it help us understand how the data flows and everything works together.

In [None]:
torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):
  """
  Predicts the next token only using the current token
  """
  def __init__(self, vocab_size: int):
    """
    Initializes the Bigram class
    """
    super().__init__()
    # Each token directly reads off the logits for the next token from a lookup table
    self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

  def forward(self, idx, targets=None):
    """
    Forward method for the BigramLanguageModel
    """
    # (B,T,C) -> (Batch_size, Sequence length, vocab_size)
    logits = self.token_embedding_table(idx)

    if targets is None:
      loss = None
    else:
      # Transform into (B,C,T)
      B, T, C = logits.shape
      # Strech the array to make it 2 dimentional
      logits = logits.view(B*T, C)
      # We have to do the same for the target
      targets = targets.view(B*T)
      # Calculate loss
      loss = F.cross_entropy(logits, targets)

    return logits, loss

  def generate(self, idx, max_new_tokens):
    """
    Generate predictions

    Args:
      idx: Is (B, T) array of indices in the crr. context
      max_new_tokens (int): Max amount of tokens to generate
    """
    for _ in range(max_new_tokens):
      # Get the prediction
      logits, loss = self(idx)
      # We only want to focus on the last time step
      logits = logits[:, -1, :] # (B, C)
      # Apply soft max to logits -> Get probabilities
      probs = F.softmax(logits, dim=-1) # (B, C)
      # Sample from the distribution
      idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
      # Append sampled index to current sequence
      idx = torch.cat([idx, idx_next], dim=1) # (B, T+1)

    return idx

In [None]:
m = BigramLanguageModel(vocab_size).to(device)
logits, loss = m(xb, yb)

print(f"Example: {logits.shape}")
print(f"Loss: {loss}")

Example: torch.Size([32, 65])
Loss: 4.878634929656982


What it generated it totally random letters because the model is totally random (no training, nothing). So we would need to train it to start getting some "real" text.

In [None]:
# Example
idx = torch.zeros((1, 1), dtype=torch.long, device=device) # Where we kick off the generation
generated_list = m.generate(idx, max_new_tokens=100)[0].tolist()

print(f"Example: {decode(generated_list)}")

Example: 
SKIcLT;AcELMoTbvZv C?nq-QE33:CJqkOKH-q;:la!oiywkHjgChzbQ?u!3bLIgwevmyFJGUGp
wnYWmnxKWWev-tDqXErVKLgJ


### Let's train our Bigram Model

In [None]:
# AdamW optimizer
optimizer = torch.optim.AdamW(params=m.parameters(),
                             lr=1e-3)

### Training loop

We will train our model for 10.000 epochs.

In [None]:
batch_size = 32
epochs = 10_000

for epoch in range(epochs):
  # Get a batch of data
  xb, yb = get_batch(split="train")
  # 1. Forward pass + Loss calculation
  logits, loss = m(xb, yb)
  # 2. Optimize zero_grad
  optimizer.zero_grad(set_to_none=True)
  # 3. Back propagation
  loss.backward()
  # 4. Optimizer step
  optimizer.step()

  if epoch % 1000 == 0:
    print(f"Epoch {epoch}: {loss.item()} loss")

Epoch 0: 4.692410945892334 loss
Epoch 1000: 3.7637593746185303 loss
Epoch 2000: 3.2342257499694824 loss
Epoch 3000: 2.892245292663574 loss
Epoch 4000: 2.703908681869507 loss
Epoch 5000: 2.515348196029663 loss
Epoch 6000: 2.4889943599700928 loss
Epoch 7000: 2.514069080352783 loss
Epoch 8000: 2.444497585296631 loss
Epoch 9000: 2.3975775241851807 loss


We could say we had an improvemnt but not much... But at least the loss is going down. Now we need to have better context and let the tokens see what is in the context.

Remember that this bigram is an extremely simple model, with the transformer model we will see way better results.

In [None]:
# Example with trained model
idx = torch.zeros((1, 1), dtype=torch.long, device=device) # Where we kick off the generation
generated_list = m.generate(idx, max_new_tokens=100)[0].tolist()

print(f"Example: {decode(generated_list)}")

Example: 
lso br. ave aviasurf my, yxMPZI ivee iuedrd whar ksth y h bora s be hese, woweee; the! KI 'de, ulsee


*Note*: We have some insigth in the `matrix_example_attention.ipynb` colab.

# Bigram Model V2

Building forward to our Transformer model

In [None]:
torch.manual_seed(1337)
NUM_EMBEDDINGS = 32

class BigramLanguageModel(nn.Module):
  """
  Predicts the next token only using the current token
  """
  def __init__(self, vocab_size: int):
    """
    Initializes the Bigram class
    """
    super().__init__()
    # Each token directly reads off the logits for the next token from a lookup table
    self.token_embedding_table = nn.Embedding(vocab_size, NUM_EMBEDDINGS)
    self.position_embedding_table = nn.Embedding(block_size, NUM_EMBEDDINGS)
    self.lm_head = nn.Linear(NUM_EMBEDDINGS, vocab_size)

  def forward(self, idx, targets=None):
    """
    Forward method for the BigramLanguageModel
    """
    # (B,T,C) -> (Batch_size, Sequence length, vocab_size)
    token_embedding = self.token_embedding_table(idx)
    # Positional embedding
    pos_embedding = self.position_embedding_table(torch.arange(T, device=device)) # (T, C)

    # Hold the token identity + the position where it occures
    x = token_embedding + pos_embedding

    logits = self.lm_head(x) # (B,T,vocab_size)

    if targets is None:
      loss = None
    else:
      # Transform into (B,C,T)
      B, T, C = logits.shape
      # Strech the array to make it 2 dimentional
      logits = logits.view(B*T, C)
      # We have to do the same for the target
      targets = targets.view(B*T)
      # Calculate loss
      loss = F.cross_entropy(logits, targets)

    return logits, loss

  def generate(self, idx, max_new_tokens):
    """
    Generate predictions

    Args:
      idx: Is (B, T) array of indices in the crr. context
      max_new_tokens (int): Max amount of tokens to generate
    """
    for _ in range(max_new_tokens):
      # crop idx to the last block_size tokens
      idx_cond = idx[:, -block_size:]
      # Get the prediction
      logits, loss = self(idx_cond)
      # We only want to focus on the last time step
      logits = logits[:, -1, :] # (B, C)
      # Apply soft max to logits -> Get probabilities
      probs = F.softmax(logits, dim=-1) # (B, C)
      # Sample from the distribution
      idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
      # Append sampled index to current sequence
      idx = torch.cat([idx, idx_next], dim=1) # (B, T+1)

    return idx