# TPU LoRA KD Training: WikiText-103 L=1024

**Run 1:** Long-context coherence + anti-repetition training

**Target:** TPU v6e-1 (single chip) on Colab

**Goals:**
- Reduce repetitions
- Improve long-context coherence at 1024 tokens
- Minimal drift / no new hallucinations
- Prepare for ANE evaluation

**Prerequisites:**
- WikiText-103 KD cache (L=1024, K=128) on Google Drive
- V2 QAT checkpoint (q4_r32) on Google Drive
- Colab TPU v6e-1 runtime

**Documentation:** See [docs/TPU.md](../docs/TPU.md) for TPU debugging and troubleshooting guide.

## 1. Setup TPU Environment

In [None]:
# Check TPU availability (don't initialize - let training script handle it)
import os
import sys

# Check if we're on TPU
tpu_env = os.environ.get('TPU_NAME', os.environ.get('COLAB_TPU_ADDR', None))
print(f"TPU environment: {tpu_env or 'not detected'}")

if tpu_env is None:
    print("\n[WARNING] No TPU detected!")
    print("Go to: Runtime > Change runtime type > TPU v6e-1")
else:
    print("TPU detected - will be initialized by training script")

In [None]:
# Install torch_xla for TPU support
try:
    import torch_xla
    print(f'torch_xla already installed: {torch_xla.__version__}')
except ImportError:
    print('Installing torch_xla...')
    !pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html -q
    import torch_xla
    print(f'Installed torch_xla: {torch_xla.__version__}')

## 2. Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Define paths
GDRIVE_BASE = '/content/drive/MyDrive'
GDRIVE_RUNS = f'{GDRIVE_BASE}/qwen3_runs'
GDRIVE_CACHES = f'{GDRIVE_BASE}/qwen3_caches'

print(f"Drive mounted. Checking paths...")
!ls -la {GDRIVE_RUNS} 2>/dev/null | head -10 || echo "qwen3_runs not found"
!ls -la {GDRIVE_CACHES} 2>/dev/null | head -10 || echo "qwen3_caches not found"

## 3. Setup W&B (Optional)

In [None]:
# Install and login to Weights & Biases
!pip install -q wandb

# Try to get API key from Colab secrets
# Setup: Colab menu -> Secrets (key icon) -> Add "WANDB_API_KEY"
try:
    from google.colab import userdata
    wandb_key = userdata.get('WANDB_API_KEY')
    if wandb_key:
        import wandb
        wandb.login(key=wandb_key)
        print("W&B: Logged in via Colab secret")
        USE_WANDB = True
    else:
        print("W&B: No API key in secrets (will skip wandb logging)")
        USE_WANDB = False
except Exception as e:
    print(f"W&B: Not configured - {e}")
    USE_WANDB = False

## 4. Clone Repository

In [None]:
# ============================================================
# CONFIGURATION - WikiText L1024 KD-LoRA Training
# ============================================================

# Model
MODEL_ID = "Qwen/Qwen3-0.6B"

# V2 checkpoint (on Google Drive)
# Update this path to your checkpoint
V2_CHECKPOINT_GDRIVE = f"{GDRIVE_RUNS}/SR-008a-revive-hermes/final_v2_q4_r32_fp32_20260107_015153.pt"

# KD Cache (on Google Drive) - WikiText-103 L=1024
# This was generated by Generate_KD_Cache_WikiText103_L1024.ipynb
KD_CACHE_GDRIVE = f"{GDRIVE_CACHES}/wikitext103_32B_L1024_K128_R64_N8000"

# LoRA config - MLP-only first for stability
RECOVERY_R = 8
MLP_ONLY = True  # Start with MLP-only, widen later if needed

# Training config for TPU v6e-1
SEQ_LEN = 1024
BATCH_SIZE = 4
ACCUMULATION_STEPS = 8  # Tokens/update: 4 * 1024 * 8 = 32K
MAX_STEPS = 1200
LOG_INTERVAL = 10

# KD settings - gentle, no-think
KD_TEMPERATURE = 1.05  # Safe with K=128 logits
KD_ALPHA = 0.7
HARD_TOP1 = 0.02       # Tiny "snap" early
HARD_TOP1_END = 0.00   # Ends purely soft KD

# Regularization
DROPOUT = 0.05
ANCHOR_KL_WEIGHT = 0.01  # Prevents drift into weird behaviors
ANCHOR_SAMPLES = 32

# Learning rate
LR = 8e-5
WARMUP_STEPS = 100

# Saving
SAVE_STEPS = 200
KEEP_CHECKPOINTS = 5

# Output
OUTPUT_DIR = "runs/lora_wikitext1024_r8_mlp_only"
WANDB_RUN = "lora_wikitext_L1024_r8_mlp"
WANDB_PROJECT = "qwen3-recovery-lora"

# Compute token counts
tokens_per_step = BATCH_SIZE * SEQ_LEN
tokens_per_update = tokens_per_step * ACCUMULATION_STEPS

print("="*60)
print("Configuration: WikiText L1024 KD-LoRA Training")
print("="*60)
print(f"Model:          {MODEL_ID}")
print(f"LoRA rank:      {RECOVERY_R} ({'MLP-only' if MLP_ONLY else 'MLP + attention'})")
print(f"Seq length:     {SEQ_LEN}")
print(f"Batch size:     {BATCH_SIZE}, Accumulation: {ACCUMULATION_STEPS}")
print(f"Tokens/step:    {tokens_per_step:,}")
print(f"Tokens/update:  {tokens_per_update:,}")
print(f"Max steps:      {MAX_STEPS}")
print(f"KD temp:        {KD_TEMPERATURE}, alpha: {KD_ALPHA}")
print(f"Hard top1:      {HARD_TOP1} -> {HARD_TOP1_END}")
print(f"Anchor KL:      {ANCHOR_KL_WEIGHT}")
print(f"LR:             {LR}, warmup: {WARMUP_STEPS}")
print("="*60)

In [None]:
# Install dependencies
!pip install -q transformers accelerate datasets sentencepiece protobuf jinja2>=3.1.0
print('Dependencies installed')

## 5. Configuration

In [None]:
# ============================================================
# CONFIGURATION - WikiText L1024 KD-LoRA Training
# ============================================================

# Model
MODEL_ID = "Qwen/Qwen3-0.6B"

# V2 checkpoint (on Google Drive)
# Update this path to your checkpoint
V2_CHECKPOINT_GDRIVE = f"{GDRIVE_RUNS}/SR-008a-revive-hermes/final_v2_q4_r32_fp32_20260107_015153.pt"

# KD Cache (on Google Drive) - WikiText-103 L=1024
# This was generated by Generate_KD_Cache_WikiText103_L1024.ipynb
KD_CACHE_GDRIVE = f"{GDRIVE_CACHES}/wikitext103_32B_L1024_K128_R64_N8000"

# LoRA config - MLP-only first for stability
RECOVERY_R = 8
MLP_ONLY = True  # Start with MLP-only, widen later if needed

# Training config for TPU v6e-1
SEQ_LEN = 1024
BATCH_SIZE = 4
ACCUMULATION_STEPS = 8  # Effective batch: 4 * 1024 * 8 = 32K tokens/update
MAX_STEPS = 1200
LOG_INTERVAL = 10

# KD settings - gentle, no-think
KD_TEMPERATURE = 1.05  # Safe with K=128 logits
KD_ALPHA = 0.7
HARD_TOP1 = 0.02       # Tiny "snap" early
HARD_TOP1_END = 0.00   # Ends purely soft KD

# Regularization
DROPOUT = 0.05
ANCHOR_KL_WEIGHT = 0.01  # Prevents drift into weird behaviors
ANCHOR_SAMPLES = 32

# Learning rate
LR = 8e-5
WARMUP_STEPS = 100

# Saving
SAVE_STEPS = 200
KEEP_CHECKPOINTS = 5

# Output
OUTPUT_DIR = "runs/lora_wikitext1024_r8_mlp_only"
WANDB_RUN = "lora_wikitext_L1024_r8_mlp"
WANDB_PROJECT = "qwen3-recovery-lora"

print("="*60)
print("Configuration: WikiText L1024 KD-LoRA Training")
print("="*60)
print(f"Model:          {MODEL_ID}")
print(f"LoRA rank:      {RECOVERY_R} ({'MLP-only' if MLP_ONLY else 'MLP + attention'})")
print(f"Seq length:     {SEQ_LEN}")
print(f"Batch size:     {BATCH_SIZE} x {ACCUMULATION_STEPS} accum = {BATCH_SIZE * ACCUMULATION_STEPS} effective")
print(f"Tokens/update:  {BATCH_SIZE * SEQ_LEN * ACCUMULATION_STEPS:,}")
print(f"Max steps:      {MAX_STEPS}")
print(f"KD temp:        {KD_TEMPERATURE}, alpha: {KD_ALPHA}")
print(f"Hard top1:      {HARD_TOP1} -> {HARD_TOP1_END}")
print(f"Anchor KL:      {ANCHOR_KL_WEIGHT}")
print(f"LR:             {LR}, warmup: {WARMUP_STEPS}")
print("="*60)

# Build training command
os.chdir(REPO_DIR)

# Use local paths for faster I/O
ckpt_path = LOCAL_CHECKPOINT if 'LOCAL_CHECKPOINT' in dir() else V2_CHECKPOINT_GDRIVE
cache_path = LOCAL_CACHE if 'LOCAL_CACHE' in dir() else KD_CACHE_GDRIVE

# Build command as list then join (avoids escaping issues)
cmd_parts = [
    "python scripts/train_recovery_lora.py",
    "--tpu",
    f"--model {MODEL_ID}",
    f'--v2-checkpoint "{ckpt_path}"',
    f'--kd-cache-dir "{cache_path}"',
    "--lora-mode kd",
    f"--recovery-r {RECOVERY_R}",
]

if MLP_ONLY:
    cmd_parts.append("--mlp-only")

cmd_parts.extend([
    f"--seq-len {SEQ_LEN}",
    f"--batch-size {BATCH_SIZE}",
    f"--accumulation-steps {ACCUMULATION_STEPS}",
    f"--max-steps {MAX_STEPS}",
    f"--log-interval {LOG_INTERVAL}",
    f"--kd-temperature {KD_TEMPERATURE}",
    f"--kd-alpha {KD_ALPHA}",
    f"--hard-top1 {HARD_TOP1}",
    f"--hard-top1-end {HARD_TOP1_END}",
    f"--dropout {DROPOUT}",
    f"--lr {LR}",
    f"--warmup-steps {WARMUP_STEPS}",
    f"--anchor-kl-weight {ANCHOR_KL_WEIGHT}",
    f"--anchor-samples {ANCHOR_SAMPLES}",
    f"--save-steps {SAVE_STEPS}",
    f"--keep-checkpoints {KEEP_CHECKPOINTS}",
    f'--output "{OUTPUT_DIR}"',
])

if USE_WANDB:
    cmd_parts.extend([
        "--wandb",
        f"--wandb-project {WANDB_PROJECT}",
        f"--wandb-run {WANDB_RUN}",
    ])

cmd = " \\\n  ".join(cmd_parts)

print("Training command:")
print(cmd)

In [None]:
%%time
# Run training with XLA debug to show compilation events
# 
# Expected timing on TPU v6e-1:
# - Step 1: 5-15 min (XLA compilation)
# - Step 2-3: 1-3 min (possible recompilation)
# - Step 4+: ~1-3 sec/step (stable)
#
# Debug output will show:
# - [Step N] Starting/Forward/Backward/Optimizer timing
# - If stuck >15min, likely dynamic shape issue
#
# XLA env vars:
# - PT_XLA_DEBUG=1: PyTorch/XLA debug info
# - XLA_HLO_DEBUG=1: HLO compilation debug

import os
os.environ['PJRT_DEVICE'] = 'TPU'

!PT_XLA_DEBUG=1 {cmd}

In [None]:
# --- Sync KD Cache ---
print("[2/2] Syncing KD cache...")
if not os.path.exists(KD_CACHE_GDRIVE):
    print(f"  ERROR: Cache not found: {KD_CACHE_GDRIVE}")
    print("\n  Available caches:")
    !ls -la {GDRIVE_CACHES} 2>/dev/null || echo "No caches directory"
else:
    cache_name = Path(KD_CACHE_GDRIVE).name
    LOCAL_CACHE = f"{LOCAL_CACHE_DIR}/{cache_name}"
    
    if not os.path.exists(LOCAL_CACHE):
        # Count files and size
        num_files = len(list(Path(KD_CACHE_GDRIVE).glob("*")))
        total_size = sum(f.stat().st_size for f in Path(KD_CACHE_GDRIVE).glob("*")) / 1e9
        print(f"  Copying {cache_name} ({num_files} files, {total_size:.2f} GB)...")
        print(f"  This may take a few minutes...")
        !rsync -ah --info=progress2 {KD_CACHE_GDRIVE}/ {LOCAL_CACHE}/
        print(f"  Done: {LOCAL_CACHE}")
    else:
        print(f"  Already exists: {LOCAL_CACHE}")
    
    # Verify cache
    print(f"\n  Cache contents:")
    !ls -la {LOCAL_CACHE} | head -5
    if os.path.exists(f"{LOCAL_CACHE}/meta.json"):
        print(f"\n  Meta:")
        !cat {LOCAL_CACHE}/meta.json

## 7. Run Training

In [None]:
# Build training command
os.chdir(REPO_DIR)

# Use local paths for faster I/O
ckpt_path = LOCAL_CHECKPOINT if 'LOCAL_CHECKPOINT' in dir() else V2_CHECKPOINT_GDRIVE
cache_path = LOCAL_CACHE if 'LOCAL_CACHE' in dir() else KD_CACHE_GDRIVE

cmd = f"""
PJRT_DEVICE=TPU python scripts/train_recovery_lora.py \\
  --model {MODEL_ID} \\
  --v2-checkpoint "{ckpt_path}" \\
  --kd-cache-dir "{cache_path}" \\
  --lora-mode kd \\
  --recovery-r {RECOVERY_R} \\
  {'--mlp-only' if MLP_ONLY else ''} \\
  --seq-len {SEQ_LEN} \\
  --batch-size {BATCH_SIZE} \\
  --accumulation-steps {ACCUMULATION_STEPS} \\
  --max-steps {MAX_STEPS} \\
  --log-interval {LOG_INTERVAL} \\
  --kd-temperature {KD_TEMPERATURE} --kd-alpha {KD_ALPHA} \\
  --hard-top1 {HARD_TOP1} --hard-top1-end {HARD_TOP1_END} \\
  --dropout {DROPOUT} \\
  --lr {LR} --warmup-steps {WARMUP_STEPS} \\
  --anchor-kl-weight {ANCHOR_KL_WEIGHT} --anchor-samples {ANCHOR_SAMPLES} \\
  --save-steps {SAVE_STEPS} --keep-checkpoints {KEEP_CHECKPOINTS} \\
  --output "{OUTPUT_DIR}"
"""

# Add wandb flags if configured
if USE_WANDB:
    cmd = cmd.rstrip() + f" \\ \n  --wandb --wandb-project {WANDB_PROJECT} --wandb-run {WANDB_RUN}"

print("Training command:")
print(cmd)

In [None]:
%%time
# Run training
# Expected: ~30-60 min for 1200 steps on TPU v6e-1
!{cmd}

## 8. Test Inference

In [None]:
# Find latest checkpoint
from pathlib import Path

output_path = Path(OUTPUT_DIR)
checkpoints = sorted(output_path.glob("recovery_*.pt"))

if checkpoints:
    LATEST_CKPT = str(checkpoints[-1])
    print(f"Found {len(checkpoints)} checkpoints")
    print(f"Latest: {LATEST_CKPT}")
else:
    print("No checkpoints found!")
    LATEST_CKPT = None

In [None]:
# Test prompts for repetition + coherence evaluation
TEST_PROMPTS = [
    # Long-context coherence test
    "The ancient library contained thousands of scrolls. Among them was a rare manuscript describing a forgotten city. The city was said to have towers made of crystal. What material were the towers made of?",
    
    # Repetition trigger test
    "List the first 10 prime numbers:",
    
    # General knowledge
    "What is the capital of France?",
    
    # Reasoning
    "If a train travels at 60 mph for 2 hours, how far does it go?",
]

In [None]:
# Test with LoRA checkpoint
if LATEST_CKPT:
    print("=" * 60)
    print("TESTING WITH RECOVERY LoRA")
    print("=" * 60)
    
    for i, prompt in enumerate(TEST_PROMPTS, 1):
        print(f"\n[{i}/{len(TEST_PROMPTS)}] Prompt: {prompt[:80]}..." if len(prompt) > 80 else f"\n[{i}/{len(TEST_PROMPTS)}] Prompt: {prompt}")
        print("-" * 40)
        !python scripts/test_inference.py "{LATEST_CKPT}" --prompt "{prompt}" --max-tokens 256 2>/dev/null | tail -10

In [None]:
# Compare: Base (no LoRA) vs With LoRA
base_ckpt = ckpt_path  # The original V2 checkpoint

print("=" * 60)
print("COMPARISON: BASE vs LoRA")
print("=" * 60)

comparison_prompt = TEST_PROMPTS[0]  # Long-context test

print(f"\nPrompt: {comparison_prompt}")

print("\n--- BASE (no LoRA) ---")
!python scripts/test_inference.py "{base_ckpt}" --prompt "{comparison_prompt}" --max-tokens 256 2>/dev/null | tail -8

if LATEST_CKPT:
    print("\n--- WITH LoRA ---")
    !python scripts/test_inference.py "{LATEST_CKPT}" --prompt "{comparison_prompt}" --max-tokens 256 2>/dev/null | tail -8

## 9. Save to Google Drive

In [None]:
# Save results to Google Drive
import shutil
from datetime import datetime

# Output directory on Drive
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
GDRIVE_OUTPUT = f"{GDRIVE_RUNS}/SR-011-lora-wikitext-L1024-{timestamp}"

os.makedirs(GDRIVE_OUTPUT, exist_ok=True)

# Copy checkpoints
output_path = Path(OUTPUT_DIR)
if output_path.exists():
    print(f"Saving to: {GDRIVE_OUTPUT}")
    
    for f in output_path.glob("*.pt"):
        dest = Path(GDRIVE_OUTPUT) / f.name
        print(f"  Copying {f.name}...")
        shutil.copy(f, dest)
    
    # Copy logs
    for f in output_path.glob("*.csv"):
        shutil.copy(f, Path(GDRIVE_OUTPUT) / f.name)
        print(f"  Copied {f.name}")
    
    for f in output_path.glob("*.json"):
        shutil.copy(f, Path(GDRIVE_OUTPUT) / f.name)
        print(f"  Copied {f.name}")
    
    print(f"\nSaved to: {GDRIVE_OUTPUT}")
else:
    print(f"Output directory not found: {OUTPUT_DIR}")

## 10. Plot Training Loss

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

log_file = Path(OUTPUT_DIR) / "training_log.csv"
if log_file.exists():
    df = pd.read_csv(log_file)
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 4))
    
    # Loss plot
    axes[0].plot(df['step'], df['train_loss'], label='Train Loss', alpha=0.8)
    if 'eval_loss' in df.columns:
        axes[0].plot(df['step'], df['eval_loss'], label='Eval Loss', linestyle='--')
    axes[0].set_xlabel('Step')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Training Loss')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # LR plot
    if 'lr' in df.columns:
        axes[1].plot(df['step'], df['lr'], color='orange')
        axes[1].set_xlabel('Step')
        axes[1].set_ylabel('Learning Rate')
        axes[1].set_title('Learning Rate Schedule')
        axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print(f"\nFinal loss: {df['train_loss'].iloc[-1]:.4f}")
    print(f"Best loss:  {df['train_loss'].min():.4f}")
    print(f"Steps:      {len(df)}")
else:
    print("No training log found")

---

## Run 2 (Optional): Widen to Attention LoRA

Only if Run 1 shows improvement but needs more capacity.

- Remove `--mlp-only`
- Lower LR to 5e-5
- Shorter run (300-800 steps)

In [None]:
# Run 2: Attention + MLP LoRA (uncomment to run)

# RUN2_CMD = f"""
# PJRT_DEVICE=TPU python scripts/train_recovery_lora.py \\
#   --model {MODEL_ID} \\
#   --v2-checkpoint "{ckpt_path}" \\
#   --kd-cache-dir "{cache_path}" \\
#   --lora-mode kd \\
#   --recovery-r 8 \\
#   --seq-len 1024 \\
#   --batch-size 4 \\
#   --accumulation-steps 8 \\
#   --max-steps 600 \\
#   --log-interval 10 \\
#   --kd-temperature 1.05 --kd-alpha 0.7 \\
#   --hard-top1 0.01 --hard-top1-end 0.00 \\
#   --dropout 0.05 \\
#   --lr 5e-5 --warmup-steps 50 \\
#   --anchor-kl-weight 0.01 --anchor-samples 32 \\
#   --save-steps 100 --keep-checkpoints 5 \\
#   --output "runs/lora_wikitext1024_r8_full" \\
#   --wandb --wandb-project {WANDB_PROJECT} --wandb-run "lora_wikitext_L1024_r8_full"
# """
# print("Run 2 command (attention + MLP):")
# print(RUN2_CMD)
# # !{RUN2_CMD}

---

## Notes

**TPU v6e-1 Memory:**
- Single chip: ~16GB HBM
- Batch 4 x L1024 x accum 8 = 32K tokens/update should fit

**What to evaluate after training:**
1. Long prompt with detail early, question late
2. Known repetition triggers
3. ANE reproduction prompts (if available)

**If repetition improves on TPU but ANE still loops:**
- Run Phase 2 targeting attention LoRA
- Check if ANE has different numeric behavior