# NanoPy Training Pipeline

This notebook trains a GPT-style model on Python code. It has been optimized for performance and observability.

**Improvements:**
- **Flash Attention**: Uses PyTorch 2.0 scaled dot product attention.
- **Efficient Data Loading**: Uses raw tokenization and `map` for memory efficiency.
- **Monitoring**: Live loss plots and text generation.
- **Checkpointing**: Saves the best model and periodic checkpoints.

## 1. Imports and Setup

In [None]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from datasets import load_dataset
import tiktoken
import math
import time
import os
import matplotlib.pyplot as plt
from IPython.display import clear_output

# Import model and config from local file
from model import NanoPy, GPTConfig

# Configuration
BATCH_SIZE = 8       # Micro-batch size
ACCUM_STEPS = 4      # Gradient accumulation steps
TOTAL_BATCH_SIZE = BATCH_SIZE * ACCUM_STEPS
LEARNING_RATE = 3e-4
MAX_STEPS = 5000     # Total optimization steps
WARMUP_STEPS = 100
VALIDATE_EVERY = 250
GENERATE_EVERY = 500
SAVE_EVERY = 1000
BLOCK_SIZE = 512
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
OUTPUT_DIR = "checkpoints"

os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"Device: {DEVICE}")
print(f"Total Batch Size: {TOTAL_BATCH_SIZE}")

## 2. Dataset Preparation (Optimized)

In [None]:
# Load Dataset
ds = load_dataset("jtatman/python-code-dataset-500k", split="train")

# Tokenizer
enc = tiktoken.get_encoding("gpt2")
eos_token = enc.eot_token

def process_and_tokenize(examples):
    # Batch processing for speed
    instructions = examples['instruction']
    outputs = examples['output']
    
    batch_input_ids = []
    batch_labels = []
    
    for inst, out in zip(instructions, outputs):
        # Format text
        text = f'"""\n{inst}\n"""\n{out}'
        
        # Tokenize
        tokens = enc.encode(text)
        tokens.append(eos_token)
        
        # Truncate/Pad
        if len(tokens) > BLOCK_SIZE + 1:
            tokens = tokens[:BLOCK_SIZE + 1]
        else:
            tokens = tokens + [eos_token] * (BLOCK_SIZE + 1 - len(tokens))
            
        # Create inputs (0:-1) and targets (1:)
        batch_input_ids.append(tokens[:-1])
        batch_labels.append(tokens[1:])
        
    return {"input_ids": batch_input_ids, "labels": batch_labels}

print("Processing dataset... This might take a moment but is memory efficient.")
# Use batched map for speed
tokenized_ds = ds.map(
    process_and_tokenize, 
    batched=True, 
    remove_columns=ds.column_names, 
    num_proc=4
)
tokenized_ds.set_format(type='torch', columns=['input_ids', 'labels'])

# Split Train/Val
split_ds = tokenized_ds.train_test_split(test_size=0.01, seed=42)
train_ds = split_ds['train']
val_ds = split_ds['test']

print(f"Train size: {len(train_ds)}")
print(f"Val size: {len(val_ds)}")

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

## 3. Model Initialization

In [None]:
# Initialize Model
config = GPTConfig(
    vocab_size=50257,
    block_size=BLOCK_SIZE,
    n_layer=12,
    n_head=8,
    n_embd=768,
    dropout=0.1,
    bias=True
)
model = NanoPy(config).to(DEVICE)

# Optimizer
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-1)
scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None

print(f"Model parameters: {model.get_num_params()/1e6:.2f}M")

## 4. Helper Functions

In [None]:
def get_lr(step):
    if step < WARMUP_STEPS:
        return LEARNING_RATE * step / WARMUP_STEPS
    if step > MAX_STEPS:
        return LEARNING_RATE * 0.1
    decay_ratio = (step - WARMUP_STEPS) / (MAX_STEPS - WARMUP_STEPS)
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
    return LEARNING_RATE * coeff

@torch.no_grad()
def estimate_loss(model, loader, eval_iters=50):
    out = {}
    model.eval()
    losses = []
    for i, batch in enumerate(loader):
        if i >= eval_iters: break
        X, Y = batch['input_ids'].to(DEVICE), batch['labels'].to(DEVICE)
        with torch.amp.autocast(device_type=DEVICE, dtype=torch.float16) if scaler else torch.no_grad():
            _, loss = model(X, Y)
        losses.append(loss.item())
    model.train()
    return sum(losses) / len(losses)

def generate_sample(model, prompt="def fibonacci("):
    model.eval()
    idx = enc.encode(prompt)
    idx = torch.tensor(idx, dtype=torch.long, device=DEVICE).unsqueeze(0)
    
    with torch.no_grad():
        generated = model.generate(idx, max_new_tokens=100, temperature=0.8, top_k=50)
        
    model.train()
    return enc.decode(generated[0].tolist())

## 5. Training Loop with Monitoring

In [None]:
train_losses = []
val_losses = []
best_val_loss = float('inf')
step = 0
t0 = time.time()

model.train()
optimizer.zero_grad()

print("Starting training...")

# Ensure loader is iterable once
iter_loader = iter(train_loader)

while step < MAX_STEPS:
    
    # Optimization step
    optimizer.zero_grad()
    accum_loss = 0.0
    
    for _ in range(ACCUM_STEPS):
        try:
            batch = next(iter_loader)
        except StopIteration:
            iter_loader = iter(train_loader)
            batch = next(iter_loader)
            
        X, Y = batch['input_ids'].to(DEVICE), batch['labels'].to(DEVICE)
        
        with torch.amp.autocast(device_type=DEVICE, dtype=torch.float16) if scaler else torch.no_grad():
            _, loss = model(X, Y)
            loss = loss / ACCUM_STEPS
        
        accum_loss += loss.item()
        if scaler:
            scaler.scale(loss).backward()
        else:
            loss.backward()

    # Gradient Clipping & Optimizer Step
    if scaler:
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
    else:
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

    # Adjust LR
    lr = get_lr(step)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    step += 1
    train_losses.append(accum_loss)

    # --- Monitoring ---
    if step % 10 == 0:
        # Print inline progress
        print(f"Step {step} | Loss: {accum_loss:.4f} | LR: {lr:.2e} | Time: {time.time()-t0:.2f}s", end='\r')
        t0 = time.time()

    if step % VALIDATE_EVERY == 0:
        val_loss = estimate_loss(model, val_loader)
        val_losses.append((step, val_loss))
        print(f"\nStep {step} Val Loss: {val_loss:.4f}")
        
        # Save Best Model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, "best_model.pt"))
            print("Saved new best model!")
            
        # Plotting
        clear_output(wait=True)
        plt.figure(figsize=(10, 5))
        plt.plot(train_losses, label='Train Loss', alpha=0.3)
        val_x, val_y = zip(*val_losses)
        plt.plot(val_x, val_y, label='Val Loss', marker='o', color='red')
        plt.xlabel('Steps')
        plt.ylabel('Loss')
        plt.legend()
        plt.show()

    if step % GENERATE_EVERY == 0:
        print("\n--- Sample Generation ---")
        print(generate_sample(model))
        print("-------------------------\n")
        
    if step % SAVE_EVERY == 0:
         torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, f"ckpt_step_{step}.pt"))