<a href="https://colab.research.google.com/github/DiiGii/gpt2-scratch/blob/main/gpt2_bigram_models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Bigram Models

**Bigram models** predict the probability of a word based solely on the preceding word. They analyze text by counting the occurrences of word pairs (bigrams) and use these counts to estimate the likelihood of one word following another. This simple approach captures some local context but doesn't account for longer-range dependencies in language. You'll implement a bigram model in the notebook below, and see how it performs on a short piece of input text.

## Character Patterns: Understanding the inner mechanisms of transformer models

Before we move onto training large language models, let's talk about the history of NLP (Natural Language Processing).

Historically, natural language processing involved many of the steps we use in LLM training today:
1. Tokenization & parsing: breaking down sentences into tokens and building a parse tree (some may be familiar with ASTs).
2. Building models that predict next tokens based on explicit patterns in the text.

One such common pattern is the bigram. After tokenization, our training data may look something like this:
["Hello", "!", "How", "are", "you", "?"]. A simple pattern we can use is to look at every token (e.g "Hello") and learn the statistical distribution of the tokens that tend to come next ("!").
While bigram models are extremely simple and don't do any high-level reasoning, they have a few properties that are preserved in LLMs:
1. Bigram models are context aware (although their "context window" is only one token). Past techniques like bag of words were not.
2. Bigram models and other n-gram statistical models are actually **learned by transformers as circuits / circuit components**. That is, within a modern LLM, there tend to be a few copies of circuits that are responsible for modelling base token frequencies, which both directly influence the output distribution as well as provide information to circuits deeper in the network.

## Steps:
Recall our embedding process:
1. Tokenize the input text
2. Convert tokens to one-hot vectors
3. Project these vectors into a continuous embedding space

With bigrams, we add another step:

4. Use the embedding of the current token to predict the next token

### Components of a Bigram Model

1. **Embedding Layer**: As before, this transforms our discrete tokens into continuous vectors.
2. **Prediction Layer**: A new component that takes the current token's embedding and outputs probabilities for the next token.

### From last session (don't edit, just run):

In [None]:
# Run this block
import torch
from typing import List
from torch.nn import functional as F
from torch import nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import requests

vocab = """abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ., '\""()[]!?"""

def tokenize(text: str) -> List[str]:
    return [char for char in text if char in vocab]

char_to_index = {char: idx for idx, char in enumerate(vocab)}
index_to_char = {idx: char for char, idx in char_to_index.items()}

def vectorize(tokens: List[str]) -> torch.Tensor:
    indices = torch.tensor([char_to_index[char] for char in tokens])
    return F.one_hot(indices, num_classes=len(vocab)).float()

def detokenize(tensor: torch.Tensor):
    indices = tensor.argmax(dim=-1).tolist()
    return ''.join(index_to_char[idx] for idx in indices)

class EmbeddingProjection(nn.Module):
    def __init__(self, vocab_size: int, embedding_dim: int):
        super().__init__()
        self.projection = nn.Linear(vocab_size, embedding_dim, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.projection(x)

class TextDataset(Dataset):
    def __init__(self, text, seq_length):
        self.text = text
        self.seq_length = seq_length
        self.tokens = tokenize(text)

    def __len__(self):
        return len(self.tokens) - self.seq_length

    def __getitem__(self, idx):
        input_seq = self.tokens[idx:idx+self.seq_length]
        target_seq = self.tokens[idx+1:idx+self.seq_length+1]
        return vectorize(input_seq).squeeze(), vectorize(target_seq).squeeze()

# Download toy data (Shakespeare sonnets)
url = "https://www.gutenberg.org/files/1041/1041-0.txt"
response = requests.get(url)
text = response.text.split("THE SONNETS", 1)[1].split("End of the Project Gutenberg EBook", 1)[0]

# Prepare the dataset
seq_length = 1
dataset = TextDataset(text, seq_length)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)


To get a sense of what your input data looks like:

In [None]:
batch, target = next(iter(dataloader))
# batch is the input tensor to your model, shape (batch_size, vocab_size)
# It's the vector representation of the single token your bigram model has as context.
# target is the target tensor, shape (batch_size, vocab_size), representing the next token in the sequence (which your model is tasked with predicting).
print(batch.shape, target.shape)

torch.Size([32, 64]) torch.Size([32, 64])


In [None]:
detokenized_targets = detokenize(target)
for index, item in enumerate(detokenize(batch[:6])):
    print(f"Context: {item}, Target: {detokenized_targets[index]}")

# Seems like a tough task, eh?

Context: u, Target: l
Context: I, Target:  
Context: e, Target: p
Context: t, Target: h
Context: t, Target: h
Context: s, Target: e


In [None]:
# Exercise 1:
# Implement a multilayer linear model. Feel free to use nn.Linear and nn.ReLU.

# Your projection layer is a linear projection from vocab size -> model size. Make sure your intermediate linear layers are projections from model size -> model size,
# and your final layer is a projection from model size -> vocab size.
class BigramModel(nn.Module):
    def __init__(self, model_dim = 128, vocab_size = len(vocab)):
        super().__init__()
        self.projection = EmbeddingProjection(vocab_size, model_dim)
        self.linear1 = nn.Linear(model_dim, model_dim)
        self.linear2 = nn.Linear(model_dim, vocab_size)
        self.relu = nn.ReLU()
        pass
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.projection(x)
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        return x

def test_bigram_model():
    model = BigramModel()
    out = model(batch)
    assert out.shape == target.shape, f"Expected output shape {target.shape} but got {out.shape}"
    print("Success!")

test_bigram_model()

Success!


### Write your own training loop:
Remember the elements of a training loop:
1. Send your training data to the device (both data and targets).
2. Use your model to predict an output based on the data.
3. Call your loss criterion on the output and target to get your loss.
4. Backpropagate on the loss (`loss.backward()` and `optimizer.step()`).
5. Zero your gradients.

In [None]:
print(torch.__version__)

2.5.0+cu121


In [None]:
from tqdm import tqdm, trange

# Training loop
num_epochs = 2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize model and transfer it to the device
model = BigramModel().to(device)
# Initialize optimizer (from torch.optim). We recommend using AdamW with the default parameters.
optimizer = optim.AdamW(model.parameters())
# Initialize the loss criterion (from torch.nn). Since this is basically a classification task (we decide which character comes next), we recommend using nn.CrossEntropyLoss.
criterion = nn.CrossEntropyLoss()

model.train()
loss_ema = None
for epoch in range(num_epochs):
    with tqdm(dataloader) as pbar:
        for batch, target in pbar:
            # pass
            # Training loop
            # ------------------
            batch, target = batch.to(device), target.to(device)
            output = model(batch)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            # ------------------
            if loss_ema is None:
                loss_ema = loss.item()
            else:
                loss_ema = 0.95 * loss_ema + 0.05 * loss.item()
            pbar.set_description(f"Loss: {round(loss.item(), 3)}")

Loss: 2.368: 100%|██████████| 2888/2888 [00:32<00:00, 88.58it/s] 
Loss: 2.422: 100%|██████████| 2888/2888 [00:34<00:00, 84.65it/s]


In [None]:
# Generate some text
model.eval()
start_text = "Shall I compare thee to a summer's day?"
input_seq = torch.tensor(vectorize(tokenize(start_text))).unsqueeze(0).to(device)[:, -1, :]
generated_text = start_text

with torch.no_grad():
    for _ in range(100):
        output = model(input_seq)
        next_char = output.argmax(dim=-1)
        generated_text += index_to_char[next_char.item()]
        input_seq = F.one_hot(next_char, num_classes=len(vocab)).float()

print("Generated text:")
print(generated_text)

Generated text:
Shall I compare thee to a summer's day? the the the the the the the the the the the the the the the the the the the the the the the the the


  input_seq = torch.tensor(vectorize(tokenize(start_text))).unsqueeze(0).to(device)[:, -1, :]


Yikes. Well, we see why NLP performed so poorly in the early days. If your model was anything like mine, it probably went through some failure mode like mode collapse (where it just learns to generate the most frequent word). This is reminiscient of the failure modes of the early GPT models, with this being ameliorated as our models increase in scale and complexity. See you all next week!