# üß† Titans Neural Memory ‚Äî GPU Training (Colab)

This notebook trains a GPT model with Titans MAL (Memory as Layer) neural memory
on a real text dataset using a T4 GPU (15GB VRAM).

**What this validates:**
- Loss decreases steadily on real text (not random)
- Memory VRAM usage stays stable (no leaks)
- Mixed precision (bfloat16) is numerically stable
- Inner-loop memory updates work correctly on GPU

**Runtime:** ~10-15 minutes for 500 steps

## 1. Setup & Install Dependencies

In [None]:
# Check GPU
!nvidia-smi
import torch
print(f"\nPyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1024**3:.1f} GB")

In [None]:
# Install dependencies
!pip install -q datasets tiktoken wandb

## 2. Clone Repository

In [None]:
import os
if not os.path.exists('nanochat'):
    !git clone https://github.com/Pandurangmopgar/nanochat.git
os.chdir('nanochat')
!git log --oneline -5

## 3. Import Model Components

We import directly from the source ‚Äî no Rust/maturin build needed for training.

In [None]:
import sys
sys.path.insert(0, '.')

import time
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass

# Import our Titans modules directly
from nanochat.titans_memory import TitansMemory, TitansLayer
from nanochat.gpt import GPT, GPTConfig

print("‚úÖ Imports successful")

## 4. Quick Memory Validation on GPU

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# Quick sanity check: memory learns on GPU
mem = TitansMemory(model_dim=128, memory_dim=64, memory_depth=2).to(device).train()
# Boost learning rate for visible memorization
nn.init.constant_(mem.proj_theta.bias, 0.0)  # softplus(0) ‚âà 0.69

pattern = torch.randn(1, 8, 128, device=device)

# Measure surprise before
with torch.no_grad():
    x_conv = mem.conv(pattern.transpose(1, 2)).transpose(1, 2)
    k, v = mem.W_K(x_conv), mem.W_V(x_conv)
    loss_before = F.mse_loss(mem.memory_net(k).float(), v.float()).item()

# Feed pattern 50 times
for _ in range(50):
    mem(pattern, update_memory=True)

# Measure surprise after
with torch.no_grad():
    x_conv = mem.conv(pattern.transpose(1, 2)).transpose(1, 2)
    k, v = mem.W_K(x_conv), mem.W_V(x_conv)
    loss_after = F.mse_loss(mem.memory_net(k).float(), v.float()).item()

print(f"Surprise before: {loss_before:.6f}")
print(f"Surprise after:  {loss_after:.6f}")
print(f"Reduction: {(1 - loss_after/loss_before)*100:.1f}%")
print(f"‚úÖ Memory is learning on {device.upper()}!")

del mem, pattern  # free VRAM
torch.cuda.empty_cache()

## 5. Load Real Text Data

We use WikiText-103 ‚Äî a standard language modeling benchmark.
Using tiktoken (GPT-2 tokenizer) since the Rust BPE tokenizer requires maturin build.

In [None]:
from datasets import load_dataset
import tiktoken

# Load WikiText-103
print("Loading WikiText-103...")
dataset = load_dataset("wikitext", "wikitext-103-raw-v1", split="train")
val_dataset = load_dataset("wikitext", "wikitext-103-raw-v1", split="validation")

# Tokenize with tiktoken (GPT-2 compatible)
enc = tiktoken.get_encoding("gpt2")
VOCAB_SIZE = enc.n_vocab  # 50257, padded to 50304 for efficiency
VOCAB_SIZE_PADDED = 50304

print(f"Tokenizer vocab size: {VOCAB_SIZE} (padded to {VOCAB_SIZE_PADDED})")
print(f"Train documents: {len(dataset):,}")
print(f"Val documents: {len(val_dataset):,}")

# Pre-tokenize everything into a flat tensor for fast batching
print("Tokenizing train split...")
all_train_tokens = []
for i, doc in enumerate(dataset):
    text = doc['text'].strip()
    if text:
        all_train_tokens.extend(enc.encode(text))
    if i % 100000 == 0 and i > 0:
        print(f"  Processed {i:,} docs, {len(all_train_tokens):,} tokens so far...")

print(f"\nTokenizing val split...")
all_val_tokens = []
for doc in val_dataset:
    text = doc['text'].strip()
    if text:
        all_val_tokens.extend(enc.encode(text))

train_tokens = torch.tensor(all_train_tokens, dtype=torch.long)
val_tokens = torch.tensor(all_val_tokens, dtype=torch.long)

print(f"\n‚úÖ Train tokens: {len(train_tokens):,}")
print(f"‚úÖ Val tokens: {len(val_tokens):,}")

In [None]:
# Data loader: yields random batches from the token buffer
def get_batch(split, batch_size, seq_len, device):
    data = train_tokens if split == 'train' else val_tokens
    ix = torch.randint(len(data) - seq_len - 1, (batch_size,))
    x = torch.stack([data[i:i+seq_len] for i in ix]).to(device, dtype=torch.int32)
    y = torch.stack([data[i+1:i+seq_len+1] for i in ix]).to(device, dtype=torch.long)
    return x, y

# Quick test
x, y = get_batch('train', 2, 64, device)
print(f"Batch shape: x={x.shape}, y={y.shape}")
print(f"Sample decoded: {enc.decode(x[0, :20].tolist())}")

## 6. Initialize Model

**Small config for T4 (15GB VRAM):**
- 4 layers, 384 dim, 6 heads ‚Üí ~15M params (fits easily)
- seq_len=512 to keep VRAM manageable
- Titans memory: dim=256, depth=2

In [None]:
# Model configuration (fits in T4 15GB)
SEQ_LEN = 512
BATCH_SIZE = 8  # Adjust down if OOM
GRAD_ACCUM = 4  # Effective batch = 8 * 4 = 32

config = GPTConfig(
    sequence_len=SEQ_LEN,
    vocab_size=VOCAB_SIZE_PADDED,
    n_layer=4,
    n_head=6,
    n_kv_head=6,
    n_embd=384,
    # Titans memory (MAL variant)
    use_titans=True,
    titans_memory_dim=256,
    titans_memory_depth=2,
)

# Initialize model on GPU
model = GPT(config).to(device)
model.init_weights()

# Parameter counts
total_params = sum(p.numel() for p in model.parameters())
titans_params = sum(p.numel() for p in model.titans_layer.parameters()) if model.titans_layer else 0
other_params = total_params - titans_params

print(f"Total parameters: {total_params:,}")
print(f"  Transformer:    {other_params:,}")
print(f"  Titans memory:  {titans_params:,}")
print(f"  Memory fraction: {titans_params/total_params*100:.1f}%")
print(f"\nVRAM used after init: {torch.cuda.memory_allocated()/1024**2:.0f} MB")

## 7. Training Loop

500 steps with AdamW, mixed precision, gradient clipping.
Evaluates validation loss every 50 steps.

In [None]:
# Training hyperparameters
NUM_STEPS = 500
LR = 3e-4
WARMUP_STEPS = 50
GRAD_CLIP = 1.0
EVAL_EVERY = 50
EVAL_STEPS = 10  # batches for val loss estimate

# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, betas=(0.9, 0.95), weight_decay=0.1)

# LR scheduler: linear warmup then cosine decay
def get_lr(step):
    if step < WARMUP_STEPS:
        return LR * (step + 1) / WARMUP_STEPS
    progress = (step - WARMUP_STEPS) / max(1, NUM_STEPS - WARMUP_STEPS)
    return LR * 0.5 * (1 + math.cos(math.pi * progress))

# Mixed precision (use float16 for T4, bfloat16 for A100+)
use_amp = device == 'cuda'
amp_dtype = torch.float16 if torch.cuda.get_device_capability()[0] < 8 else torch.bfloat16
scaler = torch.amp.GradScaler(enabled=(amp_dtype == torch.float16))
print(f"AMP dtype: {amp_dtype}")
print(f"Grad scaler: {'enabled' if amp_dtype == torch.float16 else 'disabled'}")

# Evaluation function
@torch.no_grad()
def estimate_val_loss():
    model.eval()
    losses = []
    for _ in range(EVAL_STEPS):
        x, y = get_batch('val', BATCH_SIZE, SEQ_LEN, device)
        with torch.amp.autocast(device_type='cuda', dtype=amp_dtype, enabled=use_amp):
            loss = model(x, y)
        losses.append(loss.item())
    model.train()
    return sum(losses) / len(losses)

print(f"\nConfig: {NUM_STEPS} steps, batch={BATCH_SIZE}x{GRAD_ACCUM}={BATCH_SIZE*GRAD_ACCUM}, seq_len={SEQ_LEN}")
print(f"Tokens/step: {BATCH_SIZE * GRAD_ACCUM * SEQ_LEN:,}")
print(f"Total tokens: {BATCH_SIZE * GRAD_ACCUM * SEQ_LEN * NUM_STEPS:,}")

In [None]:
# ==================== TRAINING ====================
model.train()
train_losses = []
val_losses = []
step_times = []

print("Starting training...")
print(f"{'Step':>6} | {'Train Loss':>10} | {'Val Loss':>10} | {'LR':>10} | {'ms/step':>8} | {'VRAM MB':>8}")
print("-" * 75)

# Initial val loss
val_loss = estimate_val_loss()
val_losses.append((0, val_loss))
print(f"{'0':>6} | {'---':>10} | {val_loss:>10.4f} | {'---':>10} | {'---':>8} | {torch.cuda.memory_allocated()/1024**2:>8.0f}")

for step in range(1, NUM_STEPS + 1):
    t0 = time.time()

    # Set learning rate
    lr = get_lr(step)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    # Gradient accumulation loop
    optimizer.zero_grad(set_to_none=True)
    total_loss = 0.0

    for micro_step in range(GRAD_ACCUM):
        x, y = get_batch('train', BATCH_SIZE, SEQ_LEN, device)
        with torch.amp.autocast(device_type='cuda', dtype=amp_dtype, enabled=use_amp):
            loss = model(x, y)
        loss_scaled = loss / GRAD_ACCUM
        scaler.scale(loss_scaled).backward()
        total_loss += loss.item()

    # Gradient clipping
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)

    # Optimizer step
    scaler.step(optimizer)
    scaler.update()

    torch.cuda.synchronize()
    t1 = time.time()
    dt_ms = (t1 - t0) * 1000
    step_times.append(dt_ms)

    avg_loss = total_loss / GRAD_ACCUM
    train_losses.append((step, avg_loss))

    # Evaluate periodically
    if step % EVAL_EVERY == 0 or step == NUM_STEPS:
        val_loss = estimate_val_loss()
        val_losses.append((step, val_loss))
        vram_mb = torch.cuda.memory_allocated() / 1024**2
        peak_mb = torch.cuda.max_memory_allocated() / 1024**2
        print(f"{step:>6} | {avg_loss:>10.4f} | {val_loss:>10.4f} | {lr:>10.6f} | {dt_ms:>8.1f} | {vram_mb:>8.0f}")

    # Progress printout every 10 steps
    elif step % 10 == 0:
        print(f"{step:>6} | {avg_loss:>10.4f} | {'---':>10} | {lr:>10.6f} | {dt_ms:>8.1f} | {'':>8}", end='\r')

# ==================== DONE ====================
print(f"\n{'='*75}")
print(f"Training complete!")
print(f"Peak VRAM: {torch.cuda.max_memory_allocated()/1024**2:.0f} MB")
print(f"Avg step time: {sum(step_times)/len(step_times):.1f} ms")
print(f"Best val loss: {min(v for _, v in val_losses):.4f}")
print(f"Final val loss: {val_losses[-1][1]:.4f}")

## 8. Training Curves

In [None]:
import matplotlib.pyplot as plt

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Training loss (smoothed)
steps_t = [s for s, _ in train_losses]
losses_t = [l for _, l in train_losses]
# EMA smoothing
smooth = [losses_t[0]]
for l in losses_t[1:]:
    smooth.append(0.95 * smooth[-1] + 0.05 * l)

ax1.plot(steps_t, losses_t, alpha=0.2, color='blue', label='Raw')
ax1.plot(steps_t, smooth, color='blue', linewidth=2, label='Smoothed')
ax1.set_xlabel('Step')
ax1.set_ylabel('Train Loss')
ax1.set_title('Training Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Validation loss
steps_v = [s for s, _ in val_losses]
losses_v = [l for _, l in val_losses]
ax2.plot(steps_v, losses_v, 'ro-', linewidth=2, markersize=6)
ax2.set_xlabel('Step')
ax2.set_ylabel('Val Loss')
ax2.set_title('Validation Loss')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('titans_training_curves.png', dpi=150, bbox_inches='tight')
plt.show()
print("\nüìä Saved: titans_training_curves.png")

## 9. Memory Module Inspection

In [None]:
# Inspect the trained memory module
mem = model.titans_layer.memory

print("=== Titans Memory State ===")
print(f"\nüìä Data-dependent hyperparameters (on a random batch):")
x_test = torch.randn(1, 64, 384, device=device)
alpha, eta, theta = mem._compute_hyperparams(x_test)
print(f"  alpha (forgetting rate):  {alpha.item():.6f}")
print(f"  eta (momentum decay):     {eta.item():.6f}")
print(f"  theta (learning rate):    {theta.item():.6f}")

print(f"\nüîë K/V Projection norms:")
print(f"  W_K norm: {mem.W_K.weight.norm().item():.4f}")
print(f"  W_V norm: {mem.W_V.weight.norm().item():.4f}")

print(f"\nüß† Memory MLP weight norms per layer:")
for i, (name, p) in enumerate(mem.memory_net.named_parameters()):
    print(f"  {name}: norm={p.norm().item():.4f}, mean={p.mean().item():.6f}, std={p.std().item():.6f}")

print(f"\nüö™ Gate statistics:")
gate = model.titans_layer.gate
print(f"  Weight norm: {gate.weight.norm().item():.4f}")
print(f"  Bias mean: {gate.bias.mean().item():.4f} (sigmoid = {torch.sigmoid(gate.bias).mean().item():.4f})")

print(f"\nüíæ Momentum state:")
if mem._momentum_state is not None:
    ms = mem._momentum_state
    print(f"  Shape: {ms.shape}")
    print(f"  Norm: {ms.norm().item():.6f}")
    print(f"  Non-zero: {(ms.abs() > 1e-8).sum().item()} / {ms.numel()}")
else:
    print(f"  Not initialized (eval mode or no forward pass)")

## 10. Generate Text

In [None]:
# Simple greedy generation
model.eval()

prompts = [
    "The capital of France is",
    "In the year 2025,",
    "The largest planet in our solar system",
    "Once upon a time",
]

for prompt_text in prompts:
    tokens = enc.encode(prompt_text)
    x = torch.tensor([tokens], dtype=torch.long, device=device)

    with torch.no_grad(), torch.amp.autocast(device_type='cuda', dtype=amp_dtype, enabled=use_amp):
        for _ in range(50):
            # Only use last seq_len tokens
            x_cond = x[:, -SEQ_LEN:].to(torch.int32)
            logits = model(x_cond)
            next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True).long()
            x = torch.cat([x, next_token], dim=1)

    generated = enc.decode(x[0].tolist())
    print(f"\nüí¨ {prompt_text}")
    print(f"   {generated}")

model.train()
print("\n‚ö†Ô∏è Note: Generation quality requires much longer training (100K+ steps).")
print("The goal here is to verify the model runs without errors, not to produce good text.")

## 11. VRAM Analysis

In [None]:
print("=== VRAM Summary ===")
print(f"Current allocated: {torch.cuda.memory_allocated()/1024**2:.0f} MB")
print(f"Peak allocated:    {torch.cuda.max_memory_allocated()/1024**2:.0f} MB")
print(f"Current reserved:  {torch.cuda.memory_reserved()/1024**2:.0f} MB")
print(f"Peak reserved:     {torch.cuda.max_memory_reserved()/1024**2:.0f} MB")

total_vram = torch.cuda.get_device_properties(0).total_mem / 1024**2
peak = torch.cuda.max_memory_allocated() / 1024**2
print(f"\nUsage: {peak:.0f} / {total_vram:.0f} MB ({peak/total_vram*100:.1f}% of VRAM)")

if peak / total_vram < 0.8:
    print("\nüí° Tip: You have headroom. Try increasing BATCH_SIZE or n_layer for better quality.")
elif peak / total_vram > 0.95:
    print("\n‚ö†Ô∏è Close to VRAM limit. Reduce BATCH_SIZE if you see OOM errors.")
else:
    print("\n‚úÖ VRAM usage is healthy.")

## 12. Save Checkpoint (Optional)

In [None]:
# Save the trained model
checkpoint = {
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'config': {
        'sequence_len': SEQ_LEN,
        'vocab_size': VOCAB_SIZE_PADDED,
        'n_layer': 4,
        'n_head': 6,
        'n_kv_head': 6,
        'n_embd': 384,
        'use_titans': True,
        'titans_memory_dim': 256,
        'titans_memory_depth': 2,
    },
    'train_losses': train_losses,
    'val_losses': val_losses,
    'step': NUM_STEPS,
}
torch.save(checkpoint, 'titans_checkpoint.pt')
print(f"‚úÖ Checkpoint saved: titans_checkpoint.pt ({os.path.getsize('titans_checkpoint.pt')/1024/1024:.1f} MB)")

# Download from Colab
try:
    from google.colab import files
    files.download('titans_checkpoint.pt')
    files.download('titans_training_curves.png')
except ImportError:
    print("Not in Colab ‚Äî files saved locally.")