# Bigram Character Language Model (PyTorch)

This notebook teaches you how a **bigram character language model** works: it predicts the next character using only the current character. It is a tiny but complete language model that helps you understand the core ideas behind LLMs?tokenization, next-token prediction, cross-entropy loss, and sampling?without the complexity of attention or deep networks.

## 1) Setup + imports
We will use PyTorch for modeling and the Hugging Face `datasets` library for loading text. We'll also set global configuration values to keep training reproducible and easy to tweak.

In [None]:
import math
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset

# ---- Global config ----
SEED = 42  # random seed
BATCH_SIZE = 32  # batch size per step
BLOCK_SIZE = 64  # context length
EMBED_DIM = None  # set to e.g. 64 to use embedding->linear, or None for pure bigram
LR = 1e-2  # learning rate
MAX_STEPS = 800  # max training steps
EVAL_EVERY = 200  # steps between evals
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'  # device string for torch

torch.manual_seed(SEED)
random.seed(SEED)

print('Device:', DEVICE)

## 2) Load dataset
We load `wikitext-2-raw-v1` from Hugging Face. This notebook assumes that dataset is available.


In [None]:
ds = load_dataset('wikitext', 'wikitext-2-raw-v1')  # dataset object
dataset_name = 'wikitext-2-raw-v1'  # dataset name
print('Loaded dataset:', dataset_name)

# Optional cleanup: keep ASCII only so the character vocab stays simple

# Most text datasets provide a single train split with a 'text' column.
text = ''  # raw text
if 'train' in ds:
    text = '\n'.join(ds['train']['text'])
else:
    # fallback if dataset has only one split
    first_split = list(ds.keys())[0]
    text = '\n'.join(ds[first_split]['text'])

# Filter to ASCII to avoid a huge or noisy character set
text = ''.join(ch for ch in text if ch.isascii())  # raw text
print('Text length:', len(text))

## 3) Build character vocabulary
We build a vocabulary of all unique characters. Each character gets an integer ID. This is the simplest possible tokenization.

In [None]:
chars = sorted(list(set(text)))  # unique characters
vocab_size = len(chars)  # vocabulary size
stoi = {ch: i for i, ch in enumerate(chars)}  # string-to-id map
itos = {i: ch for ch, i in stoi.items()}  # id-to-string map

print('Vocab size:', vocab_size)
print('First 10 chars:', chars[:10])

## 4) Encode/decode helpers
`encode` turns a string into a list of integers. `decode` reverses that mapping.

In [None]:
def encode(s):
    return [stoi[c] for c in s]

def decode(ids):
    return ''.join(itos[i] for i in ids)

# Quick sanity check
sample = text[:100]  # sample text
print('Sample text:', repr(sample[:80]))
encoded = encode(sample)  # encoded token ids
print('Encoded length:', len(encoded))
print('Round-trip:', decode(encoded[:80]))

## 5) Train/val split
We create a 90/10 split for quick evaluation.

In [None]:
data = torch.tensor(encode(text), dtype=torch.long)  # raw text data
n = int(0.9 * len(data))  # length or count
train_data = data[:n]  # training split data
val_data = data[n:]  # validation split data

print('Train tokens:', train_data.numel())
print('Val tokens:', val_data.numel())

## 6) Batch sampler
We sample random chunks of length `BLOCK_SIZE`. For each chunk, the input is characters `x`, and the target is the next character `y`.

In [None]:
def get_batch(split):
    data_split = train_data if split == 'train' else val_data
    # Random starting indices
    ix = torch.randint(len(data_split) - BLOCK_SIZE - 1, (BATCH_SIZE,))
    x = torch.stack([data_split[i:i+BLOCK_SIZE] for i in ix])
    y = torch.stack([data_split[i+1:i+BLOCK_SIZE+1] for i in ix])
    return x.to(DEVICE), y.to(DEVICE)

xb, yb = get_batch('train')
print('xb shape:', xb.shape)
print('yb shape:', yb.shape)

## 7) Bigram model definition
A **bigram model** predicts the next character based solely on the current character. The simplest version is just a lookup table of logits of size `(vocab_size, vocab_size)`.

Optional: we can use an embedding table followed by a linear layer. This has the same effect if the embedding dimension equals the vocab size, but also lets us use smaller embeddings for a tiny bit of generalization.

In [None]:
class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size, embed_dim=None):
        super().__init__()
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        if embed_dim is None:
            # Direct bigram logits: for each token id, store logits over next token
            self.logit_table = nn.Embedding(vocab_size, vocab_size)
        else:
            # Optional: smaller embedding -> linear to vocab logits
            self.embed = nn.Embedding(vocab_size, embed_dim)
            self.proj = nn.Linear(embed_dim, vocab_size)

    def forward(self, idx, targets=None):
        # idx: (B, T)
        if self.embed_dim is None:
            logits = self.logit_table(idx)  # (B, T, vocab_size)
        else:
            emb = self.embed(idx)
            logits = self.proj(emb)

        loss = None
        if targets is not None:
            # Flatten the batch/time for cross-entropy
            B, T, C = logits.shape
            logits_flat = logits.view(B*T, C)
            targets_flat = targets.view(B*T)
            loss = F.cross_entropy(logits_flat, targets_flat)
        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0):
        # idx: (B, T)
        for _ in range(max_new_tokens):
            logits, _ = self(idx)
            # Focus on last time step
            logits_last = logits[:, -1, :] / temperature
            probs = F.softmax(logits_last, dim=-1)
            next_id = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, next_id], dim=1)
        return idx

model = BigramLanguageModel(vocab_size, EMBED_DIM).to(DEVICE)  # model instance
print(model)

# Sanity check: forward pass shapes and loss
xb, yb = get_batch('train')
logits, loss = model(xb, yb)
print('logits shape:', logits.shape)
print('loss:', loss.item())
print('logits[0,0,:5]:', logits[0, 0, :5].detach().cpu())

## 8) Loss function (cross entropy)
Cross-entropy compares the model's predicted logits against the true next character. It is the standard loss for classification and next-token prediction.

## 9) Training loop with periodic evaluation
We'll train for a small number of steps and print periodic updates. We'll also generate short samples to see qualitative progress.

In [None]:
@torch.no_grad()
def estimate_loss():
    model.eval()
    losses = {}
    for split in ['train', 'val']:
        loss_vals = []
        for _ in range(10):
            xb, yb = get_batch(split)
            _, loss = model(xb, yb)
            loss_vals.append(loss.item())
        losses[split] = sum(loss_vals) / len(loss_vals)
    model.train()
    return losses

optimizer = torch.optim.AdamW(model.parameters(), lr=LR)  # optimizer instance

# optimization steps
for step in range(1, MAX_STEPS + 1):
    xb, yb = get_batch('train')
    logits, loss = model(xb, yb)

    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    if step % EVAL_EVERY == 0 or step == 1:
        losses = estimate_loss()
        print('Step {:4d} | train loss {:.4f} | val loss {:.4f}'.format(step, losses['train'], losses['val']))
        # Generate a tiny sample to see progress
        context = torch.zeros((1, 1), dtype=torch.long, device=DEVICE)
        sample_ids = model.generate(context, max_new_tokens=200, temperature=1.0)[0].tolist()
        print('Sample:', decode(sample_ids))
        print('-' * 60)

## 10) Inference: generate text with sampling
We can control creativity with a **temperature**. Lower values make the model more confident; higher values make it more random.

In [None]:
context = torch.zeros((1, 1), dtype=torch.long, device=DEVICE)  # current context window
for temp in [0.8, 1.0, 1.2]:
    out_ids = model.generate(context, max_new_tokens=400, temperature=temp)[0].tolist()
    print(f'\nTemperature={temp}')
    print(decode(out_ids))

## Scaling Notes
Real LLMs use **many layers**, **attention heads**, and **large embedding dimensions**, trained on huge datasets (billions of tokens). Typical scale knobs include:
- Number of layers
- Number of attention heads
- Embedding dimension
- Context length
- Dataset size and diversity

Here we deliberately keep the model tiny to make the mechanics easy to understand.