This notebook is based on the production-grade implementation, featuring the Chunkwise Parallel Scan for efficient training and Stateful Inference for $O(1)$ generation.

Reppo --> https://github.com/Sk16er/hope_nano

Made by [Shushank](https://shushank.site)

In [None]:
# @title Installation
!pip install -q tiktoken datasets matplotlib tqdm seaborn

In [None]:
# @title Imports & Setup
import os
import time
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import tiktoken
import matplotlib.pyplot as plt
import seaborn as sns # For memory visualization
from tqdm import tqdm
from datasets import load_dataset
from torch.utils.data import DataLoader, IterableDataset
from torch.cuda.amp import autocast, GradScaler
from dataclasses import dataclass, field
from typing import Optional, Tuple, List

# Setup device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {device}")
print(f"PyTorch version: {torch.__version__}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# @title Configuration
@dataclass
class HOPEConfig:
    vocab_size: int = 50257 # GPT-2 vocab size
    n_embd: int = 384  # Reduced for Colab demo
    n_head: int = 6    # Reduced for Colab demo
    n_layer: int = 6    # Reduced for Colab demo
    block_size: int = 256
    dropout: float = 0.1
    bias: bool = False

    # HOPE specific
    cms_update_periods: List[int] = field(default_factory=lambda: [1, 4, 16])
    learning_rate_memory: float = 1e-2

    def __post_init__(self):
        assert self.n_embd % self.n_head == 0

config = HOPEConfig()
print(f"Model size: ~{sum(config.n_layer * config.n_head * (config.n_embd//config.n_head)**2 for _ in range(2)) / 1e6:.1f}M parameters (approx)")

# CMS Architectural Clarification
The Continuum Memory System (CMS) Block is structurally implemented here as a standard Multi-Layer Perceptron (MLP). Conceptually, the CMS is intended to capture longer-term knowledge, often through multi-rate updates (e.g., updating parameters only every $\small{N}$ steps). While the current implementation does not enforce multi-rate logic, its purpose is to create the hierarchical memory structure of the HOPE architecture.

In [None]:
# @title HOPE Model Components (TitansL2, CMSBlock, HOPE)
# Content from model.py
class TitansL2(nn.Module):
    """Titans Memory Module with L2/Delta Rule Update and Chunkwise Parallel Scan."""
    def __init__(self, config: HOPEConfig):
        super().__init__()
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.head_dim = config.n_embd // config.n_head
        self.chunk_size = 128

        # Projections
        self.c_q = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        self.c_k = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        self.c_v = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)

        # Learnable parameters (bounded to prevent explosion)
        self.alpha_raw = nn.Parameter(torch.zeros(1, self.n_head, 1, 1))
        self.beta_raw = nn.Parameter(torch.zeros(1, self.n_head, 1, 1))

    @property
    def alpha(self):
        return torch.sigmoid(self.alpha_raw) * 0.5

    @property
    def beta(self):
        return torch.sigmoid(self.beta_raw) * 0.5

    def forward(self, x: torch.Tensor, state: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        B, T, C = x.size()
        q = self.c_q(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        k = self.c_k(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        v = self.c_v(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        k = F.normalize(k, dim=-1)

        if state is not None:
            # Inference Mode: single step update
            return self._forward_recurrent(q, k, v, state)
        else:
            # Training Mode: Chunkwise Parallel Scan (for efficiency)
            return self._forward_parallel(q, k, v)

    def _forward_recurrent(self, q, k, v, state):
        # q, k, v: (B, H, 1, D)
        # state: (B, H, D, D)

        # 1. Read: y = q @ M^T
        y = torch.matmul(q, state.transpose(-1, -2))

        # 2. Update
        k_t = k.transpose(-1, -2) # (B, H, D, 1)
        v_t = v.transpose(-1, -2) # (B, H, D, 1)

        # M_new = M - alpha * (M k) k^T + beta * v k^T
        Mk = torch.matmul(state, k_t)
        forget_term = torch.matmul(Mk, k)
        write_term = torch.matmul(v_t, k)

        new_state = state - self.alpha * forget_term + self.beta * write_term

        B, H, T, D = y.shape
        y = y.transpose(1, 2).contiguous().view(B, T, self.n_embd)

        return self.c_proj(y), new_state

    @torch.jit.ignore
    def _forward_parallel(self, q, k, v):
        """Sequential loop for demo simplicity (using chunking from model.py)"""
        B, H, T, D = q.shape
        # NOTE: Using the simple sequential loop from the original Colab for training clarity
        # The full Chunkwise Parallel Scan from model.py is computationally complex for a demo
        M = torch.zeros(B, H, D, D, device=q.device, dtype=q.dtype)
        ys = []

        # Fallback to simple sequential loop for demo clarity (less efficient than full parallel scan)
        for t in range(T):
            q_t = q[:, :, t:t+1, :]
            k_t = k[:, :, t:t+1, :]
            v_t = v[:, :, t:t+1, :]

            y_t = torch.matmul(q_t, M.transpose(-1, -2))
            ys.append(y_t)

            k_col = k_t.transpose(-1, -2)
            v_col = v_t.transpose(-1, -2)
            Mk = torch.matmul(M, k_col)
            M = M - self.alpha * torch.matmul(Mk, k_t) + self.beta * torch.matmul(v_col, k_t)

        y = torch.cat(ys, dim=2).transpose(1, 2).contiguous().view(B, T, self.n_embd)
        return self.c_proj(y), M

class CMSBlock(nn.Module):
    """Continuum Memory System Block (Standard MLP)."""
    def __init__(self, config: HOPEConfig):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias),
            nn.GELU(),
            nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias),
            nn.Dropout(config.dropout),
        )

    def forward(self, x):
        return self.net(x)

class HOPEBlock(nn.Module):
    def __init__(self, config: HOPEConfig, layer_idx: int):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.n_embd)
        self.titans = TitansL2(config)
        self.ln2 = nn.LayerNorm(config.n_embd)
        self.cms = CMSBlock(config)

    def forward(self, x, state: Optional[torch.Tensor] = None):
        res, new_state = self.titans(self.ln1(x), state)
        x = x + res
        x = x + self.cms(self.ln2(x))
        return x, new_state

class HOPE(nn.Module):
    def __init__(self, config: HOPEConfig):
        super().__init__()
        self.config = config
        self.wte = nn.Embedding(config.vocab_size, config.n_embd)
        self.wpe = nn.Embedding(config.block_size, config.n_embd)
        self.drop = nn.Dropout(config.dropout)
        self.blocks = nn.ModuleList([HOPEBlock(config, i) for i in range(config.n_layer)])
        self.ln_f = nn.LayerNorm(config.n_embd)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.wte.weight = self.lm_head.weight
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None, states=None, pos_offset=0):
        device = idx.device
        b, t = idx.size()

        # CRITICAL FIX: Use pos_offset for stateful generation
        pos = torch.arange(pos_offset, pos_offset + t, dtype=torch.long, device=device)

        tok_emb = self.wte(idx)
        pos_emb = self.wpe(pos % self.config.block_size) # Use % block_size for safety
        x = self.drop(tok_emb + pos_emb)

        new_states = []

        states = states if states is not None else [None] * self.config.n_layer

        for i, block in enumerate(self.blocks):
            x, new_block_state = block(x, state=states[i])
            new_states.append(new_block_state)

        x = self.ln_f(x)

        if targets is not None:
            logits = self.lm_head(x)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
        else:
            logits = self.lm_head(x[:, [-1], :])
            loss = None

        return logits, loss, new_states

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        """Stateful generation with correct positional encoding."""
        # 1. Prefill
        logits, _, states = self(idx, pos_offset=0)

        logits = logits[:, -1, :] / temperature
        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = -float('Inf')
        probs = F.softmax(logits, dim=-1)
        idx_next = torch.multinomial(probs, num_samples=1)

        out = torch.cat((idx, idx_next), dim=1)

        # 2. Generation Loop: O(1) per token
        current_pos = idx.size(1)

        for _ in range(max_new_tokens - 1):
            logits, _, states = self(idx_next, states=states, pos_offset=current_pos)

            logits = logits[:, -1, :] / temperature
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)

            out = torch.cat((out, idx_next), dim=1)
            current_pos += 1

        return out

model = HOPE(config).to(device)

if hasattr(torch, 'compile'):
    print("Compiling model with torch.compile (6-10x speedup)...")
    model = torch.compile(model)
    print("✓ Compilation enabled")
else:
    print("⚠ torch.compile not available (update PyTorch to 2.0+)")

print(f"Parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")

# Traning
This section sets up the data pipeline and runs the stateful training loop, including loss tracking and periodic state resets.

In [None]:
# @title Streaming Dataset (Memory Efficient)
class StreamingTextDataset(IterableDataset):
    """Memory-efficient streaming dataset"""
    def __init__(self, split="train", block_size=config.block_size):
        self.dataset = load_dataset("roneneldan/TinyStories", split=split, streaming=True)
        self.tokenizer = tiktoken.get_encoding("gpt2")
        self.block_size = block_size

    def __iter__(self):
        buffer = []
        for item in self.dataset:
            tokens = self.tokenizer.encode(item['text'])
            buffer.extend(tokens)
            while len(buffer) >= self.block_size + 1:
                chunk = buffer[:self.block_size + 1]
                buffer = buffer[self.block_size:]
                x = torch.tensor(chunk[:-1], dtype=torch.long)
                y = torch.tensor(chunk[1:], dtype=torch.long)
                yield x, y

train_dataset = StreamingTextDataset(split="train")
train_loader = DataLoader(train_dataset, batch_size=8)
print("✓ Streaming dataset ready")

In [None]:
# @title Training Loop (STATEFUL + AMP)
# Hyperparameters
max_iters = 5000
learning_rate = 3e-4
min_lr = 3e-5
warmup_iters = 200
grad_clip = 1.0

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.1)
scaler = GradScaler()

def get_lr(it):
    """Cosine learning rate schedule with warmup"""
    if it < warmup_iters:
        return learning_rate * it / warmup_iters
    if it > max_iters:
        return min_lr
    decay_ratio = (it - warmup_iters) / (max_iters - warmup_iters)
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
    return min_lr + coeff * (learning_rate - min_lr)

# CRITICAL FIX: Persistent states across batches
persistent_states = None
loss_history = []
state_reset_interval = 500

model.train()
train_iter = iter(train_loader)
pbar = tqdm(range(max_iters), desc="Training")

for step in pbar:
    lr = get_lr(step)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    try:
        X, Y = next(train_iter)
    except StopIteration:
        train_iter = iter(train_loader)
        X, Y = next(train_iter)

    X, Y = X.to(device), Y.to(device)

    # Forward pass with state persistence
    with autocast():
        logits, loss, new_states = model(X, Y, states=persistent_states)

    # Detach states to prevent backprop through time explosion
    persistent_states = [s.detach() if s is not None else None for s in new_states]

    optimizer.zero_grad(set_to_none=True)
    scaler.scale(loss).backward()
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    scaler.step(optimizer)
    scaler.update()

    loss_history.append(loss.item())
    pbar.set_postfix({'loss': f'{loss.item():.4f}', 'lr': f'{lr:.2e}'})

    # Reset states periodically to prevent drift
    if step % state_reset_interval == 0 and step > 0:
        persistent_states = None

print("\n✓ Training complete!")

# Plotting the Training Loss (New/Improved Visualization)
plt.figure(figsize=(10, 4))
plt.plot(loss_history)
plt.title("Training Loss (Stateful)")
plt.xlabel("Step")
plt.ylabel("Loss")
plt.grid(True)
plt.show()

# Stateful Generation and Memory Visualization

This section demonstrates the $\small{O(1)}$ stateful generation process and includes the new visualization step to see the Titans memory matrix $\small{M}$ change in real-time.

In [None]:
# @title Stateful Generation and Memory Visualization
@torch.no_grad()
def visualize_and_generate(model, prompt, max_tokens=20):
    model.eval()
    tokenizer = tiktoken.get_encoding("gpt2")
    tokens = torch.tensor(tokenizer.encode(prompt), dtype=torch.long, device=device).unsqueeze(0)

    # 1. Prefill and get initial state
    logits, _, states = model(tokens, pos_offset=0)
    next_token = torch.multinomial(F.softmax(logits[0, -1] / 0.8, dim=-1), 1)

    print("="*40)
    print("PROMPT:", prompt)
    print("Generated Text (Step-by-step):")

    current_pos = tokens.size(1)
    out = tokenizer.decode(tokens[0].tolist())

    for i in range(max_tokens):
        # 2. Generation Step: O(1) per token
        logits, _, states = model(next_token.unsqueeze(0), states=states, pos_offset=current_pos)

        # --- Memory Visualization (NEW) ---
        if i == 0 or i == max_tokens - 1:
            # Visualize the memory state M for the first head of the first layer
            M_state = states[0][0, 0].cpu().numpy() # [Layer 0, Head 0, D, D]

            plt.figure(figsize=(6, 5))
            sns.heatmap(M_state, cmap='viridis', square=True,
                        cbar_kws={'label': 'Memory Value'},
                        vmax=0.1, vmin=-0.1) # Bounding for clear color contrast
            plt.title(f"Layer 0, Head 0 Memory State (Step {i+1}/{max_tokens})")
            plt.ylabel("Value Dim")
            plt.xlabel("Key Dim")
            plt.show()

        # Sample next token
        next_token = torch.multinomial(F.softmax(logits[0, -1] / 0.8, dim=-1), 1)

        out += tokenizer.decode(next_token.squeeze().tolist())
        print(f"[{i+1}] {out} | M Updated.")
        current_pos += 1

    print("="*40)
    print("\nFINAL OUTPUT:\n", out)
    return out

# Run the demo
prompt = "Once upon a time, a small mouse named Timmy "
visualize_and_generate(model, prompt, max_tokens=5)