# Day 15: WaveNet — Hierarchical Language Models with Dilated Convolutions

**Building LLMs from Scratch** — Following Andrej Karpathy's makemore lectures.

---

## 1. Introduction

Our MLP language model from Day 9 concatenated a fixed context window of embeddings and fed them into a single linear layer. WaveNet (DeepMind, 2016) introduced a smarter architecture: **hierarchically fuse** adjacent pairs of tokens through multiple layers, building up a richer representation at each level.

Instead of flattening the full context window at once, we progressively merge:
- Layer 1: fuse pairs (token $i$, token $i+1$) → half as many features
- Layer 2: fuse pairs of the fused features → quarter as many
- ...
- Final layer: single rich representation → logits

This is a **dilated causal convolution** pattern — each layer has access to an exponentially growing receptive field.

## 2. Setup — Names Dataset

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import random

torch.manual_seed(42)
random.seed(42)

# Build vocabulary from names dataset (or fallback to synthetic)
try:
    with open('../names.txt') as f:
        words = f.read().splitlines()
except FileNotFoundError:
    # Synthetic fallback: 1000 random short 'names'
    import string
    words = [''.join(random.choices(string.ascii_lowercase, k=random.randint(3,8))) for _ in range(1000)]

chars = sorted(set(''.join(words)))
vocab_size = len(chars) + 1  # +1 for '.' pad token
stoi = {c: i+1 for i, c in enumerate(chars)}
stoi['.'] = 0
itos = {i: c for c, i in stoi.items()}

print(f"Words: {len(words)}, Vocab: {vocab_size}")
print(f"Sample: {words[:5]}")

## 3. Build Dataset — Context Window of 8

In [None]:
block_size = 8  # WaveNet works better with larger context

def build_dataset(words):
    X, Y = [], []
    for w in words:
        context = [0] * block_size
        for ch in w + '.':
            ix = stoi[ch]
            X.append(context)
            Y.append(ix)
            context = context[1:] + [ix]
    return torch.tensor(X), torch.tensor(Y)

random.shuffle(words)
n1, n2 = int(0.8 * len(words)), int(0.9 * len(words))
Xtr, Ytr = build_dataset(words[:n1])
Xdev, Ydev = build_dataset(words[n1:n2])
Xte, Yte = build_dataset(words[n2:])

print(f"Train: {Xtr.shape}, Dev: {Xdev.shape}, Test: {Xte.shape}")
print(f"Input shape: (batch, block_size={block_size})")

## 4. WaveNet Architecture

We define a hierarchy of **FlattenConsecutive + Linear** layers.

Each `FlattenConsecutive(n)` takes `(B, T, C)` and reshapes to `(B, T//n, n*C)` — merging `n` adjacent tokens into one. After a linear layer, each position now captures `n` tokens worth of context. Stack this and you get a tree-like fusion.

In [None]:
class FlattenConsecutive(nn.Module):
    """Reshape (B, T, C) -> (B, T//n, n*C) by merging n adjacent timesteps."""
    def __init__(self, n):
        super().__init__()
        self.n = n

    def forward(self, x):
        B, T, C = x.shape
        x = x.view(B, T // self.n, C * self.n)
        if x.shape[1] == 1:
            x = x.squeeze(1)  # collapse time dim when T=1
        return x


class WaveNet(nn.Module):
    def __init__(self, vocab_size, embed_dim=24, hidden=128, block_size=8):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)

        # 3 levels: 8 tokens -> 4 -> 2 -> 1
        # Each level: flatten 2 adjacent -> linear -> batchnorm -> tanh
        self.layers = nn.Sequential(
            FlattenConsecutive(2), nn.Linear(embed_dim * 2, hidden, bias=False),
            nn.BatchNorm1d(hidden), nn.Tanh(),

            FlattenConsecutive(2), nn.Linear(hidden * 2, hidden, bias=False),
            nn.BatchNorm1d(hidden), nn.Tanh(),

            FlattenConsecutive(2), nn.Linear(hidden * 2, hidden, bias=False),
            nn.BatchNorm1d(hidden), nn.Tanh(),

            nn.Linear(hidden, vocab_size)
        )

    def forward(self, x):
        # x: (B, block_size) int
        x = self.embed(x)  # (B, 8, embed_dim)
        x = self.layers(x) # (B, vocab_size)
        return x


model = WaveNet(vocab_size=vocab_size, embed_dim=24, hidden=128, block_size=8)
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}")

# Test forward
x_test = Xtr[:4]
logits = model(x_test)
print(f"Input shape: {x_test.shape}")
print(f"Output shape: {logits.shape}")

## 5. Training

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5000, gamma=0.3)

losses = []
for step in range(15000):
    # Minibatch
    ix = torch.randint(0, Xtr.shape[0], (64,))
    Xb, Yb = Xtr[ix], Ytr[ix]

    logits = model(Xb)
    loss = F.cross_entropy(logits, Yb)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    scheduler.step()

    if step % 1000 == 0:
        losses.append(loss.item())
        print(f"step {step:5d} | loss {loss.item():.4f}")

plt.plot(losses)
plt.xlabel('Step (x1000)')
plt.ylabel('Loss')
plt.title('WaveNet Training Loss')
plt.grid(True)
plt.show()

## 6. Evaluate & Sample

In [None]:
@torch.no_grad()
def evaluate(model, X, Y):
    model.eval()
    logits = model(X)
    loss = F.cross_entropy(logits, Y)
    model.train()
    return loss.item()

print(f"Train loss: {evaluate(model, Xtr, Ytr):.4f}")
print(f"Dev loss:   {evaluate(model, Xdev, Ydev):.4f}")

@torch.no_grad()
def sample(model, n=10):
    model.eval()
    out = []
    for _ in range(n):
        context = [0] * block_size
        name = []
        while True:
            x = torch.tensor([context])
            logits = model(x)
            probs = F.softmax(logits, dim=-1)
            ix = torch.multinomial(probs, 1).item()
            if ix == 0:
                break
            name.append(itos[ix])
            context = context[1:] + [ix]
        out.append(''.join(name))
    model.train()
    return out

print("\nSampled names:")
for name in sample(model, 15):
    print(f"  {name}")

## 7. Why Hierarchical Fusion Works

Compare the receptive field of our flat MLP vs WaveNet:

| Architecture | Layer | Tokens in context |
|---|---|---|
| Flat MLP | 1 layer | 8 (all at once, flat) |
| WaveNet | Layer 1 | 2 |
| WaveNet | Layer 2 | 4 |
| WaveNet | Layer 3 | 8 |

WaveNet learns **local structure first** (bigrams), then **medium-range** (4-grams), then **full context** (8-grams). This inductive bias — short-range before long-range — matches how language actually works: adjacent characters form syllables, syllables form words, words form phrases.

In [None]:
# Visualize the hierarchical fusion
fig, ax = plt.subplots(figsize=(10, 5))

# Positions of 8 input tokens
n_tokens = 8
levels = 4  # input + 3 layers
colors = ['#4C72B0', '#DD8452', '#55A868', '#C44E52']

for level in range(levels):
    n = n_tokens // (2 ** level)
    stride = 2 ** level
    for i in range(n):
        x = i * stride + stride / 2 - 0.5
        ax.scatter(x, level, s=200, color=colors[level], zorder=3)
        if level > 0:
            # Draw lines to children
            child_l = (i * 2) * (stride // 2) + (stride // 2) / 2 - 0.5
            child_r = (i * 2 + 1) * (stride // 2) + (stride // 2) / 2 - 0.5
            ax.plot([x, child_l], [level, level - 1], 'k-', alpha=0.4)
            ax.plot([x, child_r], [level, level - 1], 'k-', alpha=0.4)

ax.set_yticks(range(levels))
ax.set_yticklabels(['Input (8 tokens)', 'Layer 1 (4 nodes)', 'Layer 2 (2 nodes)', 'Layer 3 (1 node)'])
ax.set_xticks([])
ax.set_title('WaveNet Hierarchical Fusion — Each Node Fuses 2 Children')
ax.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.show()

---

**Building LLMs from Scratch** — [Day 15: WaveNet](https://omkarray.com/llm-day15.html) | [← Prev](llm_day14_cross_entropy.ipynb) | [Next →](llm_day16_self_attention.ipynb)