# PythonCoder v2 - Diff-Llama-MTP Training

**Architecture**: Differential Attention + GQA + Multi-Token Prediction

## Model (1B Parameters)
| Feature | Config | Benchmark Gain |
|---------|--------|----------------|
| **Differential Attention** | Head-pairing | +35-40% efficiency |
| **MTP (4 tokens)** | Enabled | +12-17% HumanEval |
| **HLP** | Enabled | +24% FIM accuracy |
| **YaRN** | 4x | Free context extension |
| **Context** | 8192 | 2026 minimum |

## Requirements
- **Colab Pro+** with **TPU v6e-1** runtime (32GB HBM)
- Google Drive with preprocessed data (~44GB)

## What Changed (v2)
- **Removed MoE**: Not worth it at <1B params
- **Removed MLA**: GQA sufficient at 1B scale
- **Removed Mamba**: Untested ROI for code
- **Code**: 3002 → 979 lines (-67%)

In [None]:
# Cell 1: Install Dependencies
# ============================================================================
# IMPORTANT: Run this cell FIRST, then restart runtime before continuing!
# ============================================================================

# Uninstall existing JAX to avoid conflicts
!pip uninstall -y jax jaxlib flax optax orbax-checkpoint 2>/dev/null

# Install JAX with TPU support
!pip install -q "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

# Install Flax NNX + Optax + Orbax
!pip install -q "flax>=0.12.0" "optax>=0.2.0" "orbax-checkpoint>=0.6.0"

# Data loading
!pip install -q "transformers>=4.40.0" "grain>=0.2.0"
!pip install -q "pyarrow>=14.0.0" "array-record>=0.6.0"

print("\n" + "="*60)
print("RESTART RUNTIME NOW!")
print("Runtime -> Restart runtime, then run Cell 2")
print("="*60)

In [None]:
# Cell 2: Mount Google Drive & Setup
# ============================================================================
from google.colab import drive
drive.mount('/content/drive')

import os

# Project paths
PROJECT_DIR = '/content/drive/MyDrive/python-coder-v6e'
DATA_DIR = f'{PROJECT_DIR}/preprocessed_data/train'
CHECKPOINT_DIR = f'{PROJECT_DIR}/checkpoints_v2'  # v2 checkpoints
JAX_CACHE_DIR = f'{PROJECT_DIR}/jax_cache'

# Create directories
for d in [CHECKPOINT_DIR, JAX_CACHE_DIR]:
    os.makedirs(d, exist_ok=True)

# Configure JAX cache
os.environ["JAX_COMPILATION_CACHE_DIR"] = JAX_CACHE_DIR

print(f"Project: {PROJECT_DIR}")
print(f"Data: {DATA_DIR}")
print(f"Checkpoints: {CHECKPOINT_DIR}")
print(f"JAX Cache: {JAX_CACHE_DIR}")

In [None]:
# Cell 3: Copy v2 Files from Drive
# ============================================================================
import shutil
import os

# Copy v2 files (new architecture)
v2_files = ['model_v2.py', 'train_v2.py']
for f in v2_files:
    src = f'{PROJECT_DIR}/{f}'
    if os.path.exists(src):
        shutil.copy(src, f'./{f}')
        print(f"✓ Copied {f}")
    else:
        print(f"✗ ERROR: {f} not found!")

# Copy other required files
other_files = ['inference.py', 'preprocess_data.py']
for f in other_files:
    src = f'{PROJECT_DIR}/{f}'
    if os.path.exists(src):
        shutil.copy(src, f'./{f}')
        print(f"✓ Copied {f}")

# Copy tokenizer
tokenizer_src = f'{PROJECT_DIR}/qwen_tokenizer'
if os.path.exists(tokenizer_src):
    if os.path.exists('./qwen_tokenizer'):
        shutil.rmtree('./qwen_tokenizer')
    shutil.copytree(tokenizer_src, './qwen_tokenizer')
    print("✓ Copied qwen_tokenizer/")
else:
    print("✗ ERROR: qwen_tokenizer not found!")

print("\nLocal files:")
!ls -la *.py 2>/dev/null | grep -E "model_v2|train_v2"

In [None]:
# Cell 4: Verify Setup
# ============================================================================
import os
import glob

print("=" * 60)
print("SETUP VERIFICATION (v2 Architecture)")
print("=" * 60)

# Check v2 files
print("\n[v2 Files]")
v2_ok = True
for f in ['model_v2.py', 'train_v2.py']:
    if os.path.exists(f):
        lines = len(open(f).readlines())
        print(f"  ✓ {f} ({lines} lines)")
    else:
        print(f"  ✗ {f} MISSING")
        v2_ok = False

# Check tokenizer
print("\n[Tokenizer]")
if os.path.exists('./qwen_tokenizer/tokenizer.json'):
    print("  ✓ qwen_tokenizer/ present")
else:
    print("  ✗ Tokenizer missing!")
    v2_ok = False

# Check training data
print("\n[Training Data]")
parquet_files = glob.glob(f'{DATA_DIR}/*.parquet')
if parquet_files:
    total_size = sum(os.path.getsize(f) for f in parquet_files) / (1024**3)
    print(f"  ✓ {len(parquet_files)} shards ({total_size:.1f} GB)")
else:
    print(f"  ✗ No data in {DATA_DIR}")
    v2_ok = False

# Check existing checkpoints
print("\n[v2 Checkpoints]")
ckpts = glob.glob(f'{CHECKPOINT_DIR}/epoch_*')
if ckpts:
    latest = max(ckpts, key=lambda p: int(os.path.basename(p).split('_')[1]))
    print(f"  Found {len(ckpts)} checkpoints")
    print(f"  Latest: {os.path.basename(latest)}")
    print("  → Training will auto-resume")
else:
    print("  No checkpoints (fresh training)")

print("\n" + "=" * 60)
if v2_ok:
    print("✓ READY FOR v2 TRAINING")
    print("\nArchitecture: Diff-Llama-MTP (1B params)")
    print("Features: DiffAttn + MTP + HLP + YaRN")
else:
    print("✗ FIX ERRORS ABOVE")
print("=" * 60)

In [None]:
# Cell 5: Start Training (v2)
# ============================================================================
# Trains 1B Diff-Llama-MTP model
#
# Architecture (v2):
#   - Differential Attention + GQA (35-40% efficiency)
#   - Multi-Token Prediction (+12-17% HumanEval)
#   - Horizon Length Prediction (+24% FIM)
#   - YaRN (4x context extension)
#   - Dense FFN (no MoE - not worth it at 1B)
#
# Config:
#   - 1B parameters
#   - 8192 context length
#   - Batch: 2 x 32 = 64 effective
#   - LR: 3e-4 with warmup-cosine
#   - Auto-resume from checkpoint
# ============================================================================

!python train_v2.py

In [None]:
# Cell 6: Check Training Progress
# ============================================================================
import os
import glob

print("=" * 60)
print("TRAINING PROGRESS (v2)")
print("=" * 60)

ckpts = sorted(glob.glob(f'{CHECKPOINT_DIR}/epoch_*'),
               key=lambda p: int(os.path.basename(p).split('_')[1]))

if ckpts:
    print(f"\nCheckpoints: {len(ckpts)}")
    print("\nEpoch | Size")
    print("-" * 20)
    for ckpt in ckpts:
        epoch = os.path.basename(ckpt).split('_')[1]
        size_mb = sum(os.path.getsize(os.path.join(ckpt, f))
                     for f in os.listdir(ckpt) if os.path.isfile(os.path.join(ckpt, f))) / (1024**2)
        print(f"  {epoch:3s}  | {size_mb:.0f} MB")
else:
    print("\nNo checkpoints yet.")

# JAX cache
cache_files = glob.glob(f'{JAX_CACHE_DIR}/*')
if cache_files:
    cache_size = sum(os.path.getsize(f) for f in cache_files if os.path.isfile(f)) / (1024**2)
    print(f"\nJAX cache: {len(cache_files)} files ({cache_size:.0f} MB)")

In [None]:
# Cell 7: Quick Inference Test
# ============================================================================
# Test the trained model
# WARNING: This initializes TPU - may conflict with training
# ============================================================================

RUN_TEST = False  # Set to True to test

if RUN_TEST:
    test_code = '''
import sys
sys.path.insert(0, ".")

import jax.numpy as jnp
from flax import nnx
from model_v2 import create_model, CONFIG_1B

# Create model
print("Loading model...")
model = create_model(CONFIG_1B)

# Simple test
print("Testing forward pass...")
input_ids = jnp.ones((1, 64), dtype=jnp.int32)
out = model(input_ids)
print(f"Output shape: {out['logits'].shape}")
print("✓ Model works!")
'''
    with open('/tmp/test_v2.py', 'w') as f:
        f.write(test_code)
    !python /tmp/test_v2.py
else:
    print("Test disabled. Set RUN_TEST = True to run.")

## Troubleshooting

### TPU Not Found
1. Runtime → Change runtime type → TPU
2. Restart runtime, re-run from Cell 2

### Out of Memory
- v2 model uses ~4GB for 1B params
- TPU v6e-1 has 32GB - plenty of room
- If OOM, reduce `micro_batch_size` in train_v2.py

### Session Timeout
- Training auto-resumes from latest checkpoint
- Re-run Cells 2-5

### Slow First Step
- JAX compiles the graph (cached for next time)
- First epoch may be 2-3x slower

## Architecture Notes

### v2 vs v1
| Component | v1 | v2 | Why |
|-----------|----|----|-----|
| FFN | MoE 32×4 | Dense | MoE needs 5T+ tokens |
| Attention | MLA | GQA | MLA overkill for 1B |
| Hybrid | Mamba | None | Untested ROI |
| Code | 3002 lines | 979 lines | -67% |

### Benchmark Justifications
- **Differential Attention**: 35-40% efficiency ([Microsoft](https://arxiv.org/abs/2410.05258))
- **MTP**: +12-17% HumanEval ([Meta](https://arxiv.org/pdf/2404.19737))
- **HLP**: +24% FIM accuracy
- **YaRN**: Free context extension