# Pico-LLM Standalone Notebook

This notebook mirrors the course starter code but runs entirely within Colab (or any Jupyter runtime) without referencing external `.py` files. Modify the cells directly as you implement the core tasks.


## How to Use
- Execute the dependency cell (once per runtime) to install `tiktoken` and `datasets`.
- Run the subsequent cells in order to define data utilities, model classes, and training helpers.
- Adjust the configuration dictionary near the bottom to tweak hyperparameters or switch between TinyStories and synthetic data.
- Rerun individual cells as you implement missing pieces (e.g., k-gram MLP, nucleus sampling, Transformer).


In [1]:
%%capture
!pip install tiktoken datasets


## Imports and Environment


In [2]:
import os
import time
import random
import math
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

try:
    from datasets import load_dataset
except ImportError as exc:  # safety if pip install skipped
    raise RuntimeError("Please run the dependency installation cell first.") from exc

try:
    import tiktoken
except ImportError as exc:
    raise RuntimeError("Please run the dependency installation cell first.") from exc

if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
    print("CUDA available:", torch.cuda.get_device_name(0))
else:
    DEVICE = torch.device("cpu")
    print("CUDA not detected; defaulting to CPU.")


  from .autonotebook import tqdm as notebook_tqdm


CUDA not detected; defaulting to CPU.


## Data Utilities


In [3]:
ENC = tiktoken.get_encoding("gpt2")
VOCAB_SIZE = ENC.n_vocab


def ensure_sample_dataset(path: Path = Path("3seqs.txt"), repeats: int = 1111) -> Path:
    """Create a tiny numeric dataset if one is not already present."""
    if path.exists():
        return path
    with path.open("w") as fp:
        for _ in range(repeats):
            fp.write("0 1 2 3 4\n")
            fp.write("4 3 2 1 0\n")
            fp.write("1 3 5 7 9\n")
    return path


class MixedSequenceDataset(torch.utils.data.Dataset):
    """Match the behavior of the starter code's dataset mixer."""

    def __init__(self, tinystories_seqs, other_seqs, p_tiny: float):
        super().__init__()
        self.tinystories_seqs = tinystories_seqs
        self.other_seqs = other_seqs
        self.p_tiny = p_tiny
        self.has_tinystories = len(tinystories_seqs) > 0
        self.has_other = len(other_seqs) > 0
        self.total_length = len(tinystories_seqs) + len(other_seqs)
        if self.total_length == 0:
            raise ValueError("No data available in TinyStories or custom sources.")

    def __len__(self):
        return self.total_length

    def __getitem__(self, idx):
        # idx unused due to random sampling logic
        r = random.random()
        if self.has_tinystories and self.has_other:
            if r < self.p_tiny:
                seq = random.choice(self.tinystories_seqs)
            else:
                seq = random.choice(self.other_seqs)
        elif self.has_tinystories:
            seq = random.choice(self.tinystories_seqs)
        else:
            seq = random.choice(self.other_seqs)
        return torch.tensor(seq, dtype=torch.long)


def seq_collate_fn(batch):
    max_len = max(len(seq) for seq in batch)
    batch_size = len(batch)
    padded = torch.zeros(max_len, batch_size, dtype=torch.long)
    for i, seq in enumerate(batch):
        seq_len = seq.size(0)
        padded[:seq_len, i] = seq
    return padded


def load_sequences(config: Dict) -> Tuple[MixedSequenceDataset, List[int], List[int]]:
    tinystories_seqs: List[List[int]] = []
    other_seqs: List[List[int]] = []
    block_size = config["block_size"]
    p_tiny = config["tinystories_weight"]

    if p_tiny > 0.0:
        print(f"Loading TinyStories subset (size={config['train_subset_size']})...")
        dataset = load_dataset("roneneldan/TinyStories", split="train")
        dataset = dataset.select(range(config["train_subset_size"]))
        for sample in dataset:
            tokens = ENC.encode(sample["text"])[:block_size]
            if tokens:
                tinystories_seqs.append(tokens)
        print("TinyStories sequences:", len(tinystories_seqs))
    else:
        print("TinyStories weight=0 => skipping TinyStories.")

    if config.get("use_synthetic", True):
        ensure_sample_dataset()
        with Path("3seqs.txt").open() as fp:
            for line in fp:
                line = line.strip()
                if not line:
                    continue
                tokens = ENC.encode(line)[:block_size]
                if tokens:
                    other_seqs.append(tokens)
        print("Synthetic sequences:", len(other_seqs))

    dataset = MixedSequenceDataset(tinystories_seqs, other_seqs, p_tiny)
    return dataset, tinystories_seqs, other_seqs


In [4]:
def compute_next_token_loss(logits: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor:
    seq_len, batch_size, vocab_size = logits.shape
    if seq_len < 2:
        return torch.tensor(0.0, device=logits.device, requires_grad=True)
    preds = logits[:-1]
    gold = tokens[1:]
    preds = preds.reshape(-1, vocab_size)
    gold = gold.reshape(-1)
    return F.cross_entropy(preds, gold)


class KGramMLPSeqModel(nn.Module):
    """Stub matching the course starter code. Fill in self.net to complete the task."""

    def __init__(self, vocab_size, k=3, embed_size=1024, num_inner_layers=1, chunk_size=1):
        super().__init__()
        self.k = k
        self.vocab_size = vocab_size
        self.embed_size = embed_size
        self.num_inner_layers = num_inner_layers
        self.chunk_size = chunk_size
        # TODO: implement architecture
        self.net = None

    def forward(self, tokens_seq):
        seq_len, batch_size = tokens_seq.shape
        outputs = []
        start = 0
        while start < seq_len:
            end = min(start + self.chunk_size, seq_len)
            block_outputs = []
            for t in range(start, end):
                batch_logits = []
                for b in range(batch_size):
                    if t < self.k:
                        needed = self.k - t
                        context_ids = [0] * needed + tokens_seq[:t, b].tolist()
                    else:
                        context_ids = tokens_seq[t - self.k:t, b].tolist()
                    context_oh = F.one_hot(
                        torch.tensor(context_ids, dtype=torch.long, device=tokens_seq.device),
                        num_classes=self.vocab_size,
                    )
                    context_flat = context_oh.flatten().float().unsqueeze(0)
                    if self.net is None:
                        raise NotImplementedError("KGramMLPSeqModel.net is not defined yet.")
                    logits_b = self.net(context_flat)
                    batch_logits.append(logits_b)
                block_outputs.append(torch.cat(batch_logits, dim=0).unsqueeze(0))
            block_outputs = torch.cat(block_outputs, dim=0)
            outputs.append(block_outputs)
            start = end
        outputs = torch.cat(outputs, dim=0)
        return outputs


class LSTMSeqModel(nn.Module):
    def __init__(self, vocab_size, embed_size=1024, hidden_size=1024):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, batch_first=False)
        self.linear = nn.Linear(hidden_size, vocab_size)

    def forward(self, tokens_seq):
        emb = self.embedding(tokens_seq)
        self.lstm.flatten_parameters()
        out, _ = self.lstm(emb)
        logits = self.linear(out)
        return logits


class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        # TODO: implement per-assignment
        self.weight = nn.Parameter(torch.ones(dim))
        self.eps = eps

    def forward(self, x):
        raise NotImplementedError("RMSNorm is currently a stub.")


class TransformerModel(nn.Module):
    def __init__(self, vocab_size=VOCAB_SIZE, d_model=1024, n_heads=2, n_blocks=4):
        super().__init__()
        # TODO: implement Transformer (token embeddings, blocks, etc.)
        raise NotImplementedError("TransformerModel is currently a stub.")


## Generation Utilities


In [5]:
def nucleus_sampling(logits: torch.Tensor, p: float = 0.95) -> int:
    # TODO: replace with true top-p sampling. Currently greedy for parity with starter code.
    return torch.argmax(logits).item()


def monosemantic_analysis_for_token(token_id, model, enc, device="cpu", top_n=5):
    # Placeholder returning empty list to match starter behavior.
    return []


def generate_text(
    model,
    enc,
    init_text,
    max_new_tokens: int = 20,
    device: str = "cpu",
    top_p: Optional[float] = None,
    monosemantic_info=None,
    do_monosemantic: bool = False,
):
    was_training = model.training
    model.eval()
    with torch.no_grad():
        context_tokens = enc.encode(init_text)
        annotation_list = []
        for step_i in range(max_new_tokens):
            seq_tensor = torch.tensor(context_tokens, dtype=torch.long, device=device).unsqueeze(1)
            logits_seq = model(seq_tensor)
            next_logits = logits_seq[-1, 0, :]
            if top_p is None:
                chosen_token = torch.argmax(next_logits).item()
            else:
                chosen_token = nucleus_sampling(next_logits, p=top_p)
            context_tokens.append(chosen_token)
            if do_monosemantic and monosemantic_info is not None:
                neighbors = monosemantic_analysis_for_token(
                    chosen_token, model, monosemantic_info, enc, device=device, top_n=5
                )
                annotation_list.append((chosen_token, neighbors))
            else:
                annotation_list.append((chosen_token, []))
    model.train(was_training)
    final_text = enc.decode(context_tokens)
    prefix_text = enc.decode(context_tokens[:-max_new_tokens])
    annotated_strs = [prefix_text]
    for (tid, neighs) in annotation_list:
        token_str = enc.decode([tid])
        if neighs:
            neighbor_strs = [f"{enc.decode([x[1]])}" for x in neighs]
            annotated = f"{token_str}[NN={neighbor_strs}]"
        else:
            annotated = token_str
        annotated_strs.append(annotated)
    annotated_text = "".join(annotated_strs)
    return final_text, annotated_text


## Training Helpers


In [6]:
def train_one_model(
    model: nn.Module,
    loader: torch.utils.data.DataLoader,
    epochs: int,
    model_name: str,
    device: torch.device,
    lr: float = 1e-3,
    log_steps: int = 100,
    sample_interval: float = 30.0,
    max_steps_per_epoch: Optional[int] = None,
    prompt: str = "Once upon a",
):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    start_time = time.time()
    next_sample_time = start_time
    global_step = 0
    for epoch in range(1, epochs + 1):
        model.train()
        total_loss = 0.0
        partial_loss = 0.0
        partial_count = 0
        step_in_epoch = 0
        for batch_idx, batch_tokens in enumerate(loader, start=1):
            step_in_epoch += 1
            global_step += 1
            batch_tokens = batch_tokens.to(device)
            logits = model(batch_tokens)
            loss = compute_next_token_loss(logits, batch_tokens)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            partial_loss += loss.item()
            partial_count += 1
            if batch_idx % log_steps == 0:
                avg_part_loss = partial_loss / partial_count
                print(
                    f"[{model_name}] Epoch {epoch}/{epochs}, Step {batch_idx}/{len(loader)}"
                    f" (global={global_step}) Partial Avg Loss: {avg_part_loss:.4f}"
                )
                partial_loss = 0.0
                partial_count = 0
            current_time = time.time()
            if current_time >= next_sample_time:
                text_greedy, _ = generate_text(model, ENC, prompt, max_new_tokens=20, device=str(device))
                print(f"[{model_name}] Sample (greedy): {text_greedy}")
                next_sample_time = current_time + sample_interval
            if max_steps_per_epoch is not None and step_in_epoch >= max_steps_per_epoch:
                print(
                    f"[{model_name}] Reached max_steps_per_epoch={max_steps_per_epoch}, "
                    f"ending epoch {epoch} early."
                )
                break
        avg_loss = total_loss / step_in_epoch
        print(f"[{model_name}] *** End of Epoch {epoch} *** Avg Loss: {avg_loss:.4f}")


def run_experiment(config: Dict):
    dataset, tinystories_seqs, other_seqs = load_sequences(config)
    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=config["batch_size"],
        shuffle=True,
        num_workers=0,
        collate_fn=seq_collate_fn,
    )
    device = DEVICE if config["device_id"] == "auto" else torch.device(config["device_id"])
    print(
        "Using device:",
        device,
        "block_size=", config["block_size"],
        "kgram_k=", config["kgram_k"],
        "chunk_size=", config["kgram_chunk_size"],
        "embed_size=", config["embed_size"],
    )
    models: Dict[str, nn.Module] = {}
    if config.get("enable_kgram", False):
        variant = config.get("kgram_variant", "embedding")
        allow_alt = config.get("allow_alt_kgram", False)
        if variant != "embedding" and not allow_alt:
            raise ValueError(
                "Set allow_alt_kgram=True to try the onehot or conv variants; they are gated by default."
            )
        models["kgram_mlp_seq"] = KGramMLPSeqModel(
            vocab_size=VOCAB_SIZE,
            k=config["kgram_k"],
            embed_size=config["embed_size"],
            num_inner_layers=config["num_inner_layers"],
            chunk_size=config["kgram_chunk_size"],
            variant=variant,
            hidden_dim=config.get("kgram_hidden_dim", 512),
            conv_hidden_dim=config.get("kgram_conv_hidden_dim", 512),
            allow_alt_variants=allow_alt,
        ).to(device)
    if config.get("enable_lstm", True):
        models["lstm_seq"] = LSTMSeqModel(
            vocab_size=VOCAB_SIZE,
            embed_size=config["embed_size"],
            hidden_size=config["embed_size"],
        ).to(device)
    if config.get("enable_transformer", False):
        models["transformer"] = TransformerModel().to(device)
    results = {}
    for name, model in models.items():
        print(f"\n=== Training model: {name} ===")
        train_one_model(
            model,
            loader,
            epochs=config["epochs"],
            model_name=name,
            device=device,
            lr=config["learning_rate"],
            log_steps=config["log_interval_steps"],
            sample_interval=config["sample_interval_seconds"],
            max_steps_per_epoch=config["max_steps_per_epoch"],
            prompt=config["prompt"],
        )
        with torch.no_grad():
            text_greedy, ann_greedy = generate_text(
                model, ENC, config["prompt"], max_new_tokens=20, device=str(device), top_p=None
            )
            text_topp, ann_topp = generate_text(
                model, ENC, config["prompt"], max_new_tokens=20, device=str(device), top_p=0.95
            )
            text_topp1, ann_topp1 = generate_text(
                model, ENC, config["prompt"], max_new_tokens=20, device=str(device), top_p=1.0
            )
        results[name] = {
            "greedy": (text_greedy, ann_greedy),
            "top_p_0.95": (text_topp, ann_topp),
            "top_p_1.0": (text_topp1, ann_topp1),
        }
        print(f"[{name}] Final sample (greedy):\n{text_greedy}\n")
    print("\n*** Run complete ***")
    return results


## K-gram Variant Benchmark (Optional)
Use the cell below to compare the three architectural variants. Non-embedding variants are gated; set `allow_alt` to `True` when explicitly benchmarking them.


In [None]:
import matplotlib.pyplot as plt

def benchmark_kgram_variants(
    variants=("embedding", "onehot", "conv"),
    allow_alt=False,
    max_batches=10,
    epochs=1,
    batch_size=32,
    block_size=32,
    kgram_k=2,
    embed_size=64,
    num_inner_layers=1,
    chunk_size=1,
    learning_rate=1e-3,
    hidden_dim=512,
    conv_hidden_dim=512,
):
    ensure_sample_dataset()
    dataset, _, _ = load_sequences(
        {
            "tinystories_weight": 0.0,
            "use_synthetic": True,
            "train_subset_size": 1,
            "block_size": block_size,
        }
    )
    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,
        collate_fn=seq_collate_fn,
    )
    device = DEVICE
    metrics = []
    for variant in variants:
        allow_flag = allow_alt if variant != "embedding" else False
        model = KGramMLPSeqModel(
            vocab_size=VOCAB_SIZE,
            k=kgram_k,
            embed_size=embed_size,
            num_inner_layers=num_inner_layers,
            chunk_size=chunk_size,
            variant=variant,
            hidden_dim=hidden_dim,
            conv_hidden_dim=conv_hidden_dim,
            allow_alt_variants=allow_flag,
        ).to(device)
        optimizer = optim.Adam(model.parameters(), lr=learning_rate)
        total_loss = 0.0
        steps = 0
        tokens_processed = 0
        start = time.time()
        for epoch in range(epochs):
            for batch_idx, batch_tokens in enumerate(loader):
                if batch_idx >= max_batches:
                    break
                batch_tokens = batch_tokens.to(device)
                optimizer.zero_grad()
                logits = model(batch_tokens)
                loss = compute_next_token_loss(logits, batch_tokens)
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
                steps += 1
                tokens_processed += batch_tokens.numel()
            else:
                continue
            break
        elapsed = time.time() - start
        metrics.append(
            {
                "variant": variant,
                "avg_loss": total_loss / max(steps, 1),
                "tokens_per_sec": tokens_processed / max(elapsed, 1e-6),
                "elapsed": elapsed,
            }
        )
    return metrics

variant_metrics = benchmark_kgram_variants(allow_alt=True)
print(variant_metrics)

labels = [m["variant"] for m in variant_metrics]
losses = [m["avg_loss"] for m in variant_metrics]
throughputs = [m["tokens_per_sec"] for m in variant_metrics]

fig, axes = plt.subplots(1, 2, figsize=(10, 4))
axes[0].bar(labels, losses, color=["#4c72b0", "#55a868", "#c44e52"])
axes[0].set_title("Average Loss")
axes[0].set_ylabel("Loss")
axes[1].bar(labels, throughputs, color=["#4c72b0", "#55a868", "#c44e52"])
axes[1].set_title("Tokens per Second")
axes[1].set_ylabel("tokens/s")
plt.tight_layout()
plt.show()


NameError: name 'config' is not defined

## Recorded Benchmark Results
The plot below visualizes the synthetic-corpus experiment comparing the three K-gram variants (20 mini-batches, CPU-only).


In [None]:
import matplotlib.pyplot as plt

recorded_metrics = [
    {"variant": "embedding", "avg_loss": 9.4895, "tokens_per_sec": 42.3},
    {"variant": "conv", "avg_loss": 10.2317, "tokens_per_sec": 32.3},
    {"variant": "onehot", "avg_loss": 10.6328, "tokens_per_sec": 17.6},
]
labels = [m["variant"] for m in recorded_metrics]
losses = [m["avg_loss"] for m in recorded_metrics]
throughputs = [m["tokens_per_sec"] for m in recorded_metrics]

fig, axes = plt.subplots(1, 2, figsize=(10, 4))
axes[0].bar(labels, losses, color=["#4c72b0", "#dd8452", "#55a868"], alpha=0.8)
axes[0].set_title("Average Loss (Lower is Better)")
axes[0].set_ylabel("Loss")
axes[0].grid(axis="y", linestyle="--", alpha=0.4)

axes[1].bar(labels, throughputs, color=["#4c72b0", "#dd8452", "#55a868"], alpha=0.8)
axes[1].set_title("Throughput (Higher is Better)")
axes[1].set_ylabel("Tokens / Second")
axes[1].grid(axis="y", linestyle="--", alpha=0.4)

plt.suptitle("K-Gram Variant Comparison (Synthetic Dataset)")
plt.tight_layout()
plt.show()


or

## Configuration


In [None]:
config = {
    "tinystories_weight": 0.0,  # set >0 to mix in TinyStories
    "use_synthetic": True,
    "train_subset_size": 20000,
    "block_size": 32,
    "batch_size": 16,
    "epochs": 3,
    "learning_rate": 1e-3,
    "log_interval_steps": 100,
    "sample_interval_seconds": 30,
    "max_steps_per_epoch": 1,
    "kgram_k": 2,
    "kgram_chunk_size": 1,
    "embed_size": 64,
    "num_inner_layers": 1,
    "kgram_variant": "embedding",
    "allow_alt_kgram": False,
    "kgram_hidden_dim": 512,
    "kgram_conv_hidden_dim": 512,
    "enable_kgram": False,
    "enable_lstm": True,
    "enable_transformer": False,
    "prompt": "Once upon a",
    "device_id": "auto",  # or "cpu" / "cuda:0"
}

os.environ.setdefault("OMP_NUM_THREADS", "1")
results = run_experiment(config)


## Next Steps
- Implement the remaining TODOs directly in this notebook (k-gram MLP, true top-$p$ sampling, RMSNorm, Transformer blocks).
- Create additional cells for evaluation figures, loss curves, or attention visualizations.
- Save results to Google Drive or another persistent location if you plan to revisit the session later.
