# TPU Training Test - XLA Warmup Verification

**Rev: 1.5 (2025-01-05 16:45) - Fix TPU lock (don't init in notebook)**

Quick test notebook to verify TPU training works without hanging.

**Requirements:**
- Colab with TPU runtime (Runtime > Change runtime type > TPU)
- ~4 min total (cache: ~1 min, training: ~3 min)

**What this tests:**
- XLA warmup precompilation (forward + backward + optimizer)
- Training loop stability (20 steps)
- Compilation count (should be 1-3, not 20+)

**Expected output:**
```
Replacing with V2 layers... (parallel SVD init with 23 workers)
...done (10-15s)

[TPU] Warmup: compiling XLA graph... forward... backward... optimizer... done (90s)
Training: ....................
[20/20] loss=X.XXXX ...
[TPU] XLA compilations: 1
```

## Step 1: Setup TPU

In [None]:
# Check torch_xla is installed (but DON'T initialize TPU - that locks the device)
import os
import sys

try:
    import torch_xla
    print(f'torch_xla installed: {torch_xla.__version__}')
    # Just verify TPU exists without initializing
    tpu_env = os.environ.get('TPU_NAME', os.environ.get('COLAB_TPU_ADDR', 'not set'))
    print(f'TPU environment: {tpu_env}')
    print('TPU will be initialized by training script')
except ImportError:
    print('Installing torch_xla...')
    !pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html -q
    import torch_xla
    print(f'Installed torch_xla: {torch_xla.__version__}')

## Step 2: Clone Repository

In [None]:
os.chdir('/content')

if os.path.exists('qwen3_apple_style_2bit_qat_lora'):
    print('Repo exists, pulling latest...')
    !cd qwen3_apple_style_2bit_qat_lora && git fetch && git reset --hard origin/main
else:
    print('Cloning repo...')
    !git clone https://github.com/anemll/qwen3_apple_style_2bit_qat_lora.git

os.chdir('/content/qwen3_apple_style_2bit_qat_lora')
print(f'Working directory: {os.getcwd()}')

# Show latest commit
!git log --oneline -3

## Step 3: Install Dependencies

In [None]:
!pip install transformers datasets accelerate jinja2>=3.1.0 -q
print('Dependencies installed')

## Step 4: Generate Minimal Test Cache

Small cache for quick testing (1K sequences, 64 tokens each)

In [None]:
%%time

CACHE_DIR = 'caches/test_L64_K64_N1K'

if os.path.exists(f'{CACHE_DIR}/meta.json'):
    print(f'Cache exists: {CACHE_DIR}')
    !cat {CACHE_DIR}/meta.json
else:
    print('Generating test cache (1K sequences from Alpaca - fast download)...')
    # Use Alpaca (52K examples, ~25MB) instead of OpenHermes (1M examples, 2GB)
    !python scripts/precompute_teacher_topk.py \
        --output_dir {CACHE_DIR} \
        --teacher_model_name_or_path Qwen/Qwen3-0.6B \
        --dataset_name tatsu-lab/alpaca \
        --dataset_format alpaca \
        --max_length 64 \
        --topk 64 \
        --num_sequences 1000 \
        --batch_size 32 \
        --shard_size 500 \
        --dtype bf16
    print('\nCache generated!')

## Step 5: Run Training Test (20 steps)

This tests:
1. XLA warmup phase (forward + backward + optimizer compilation)
2. First few training steps
3. No hang at optimizer step

**Expected output:**
- Warmup phase: ~60-120s (one-time compilation)
- Step 1-20: Fast (< 1s per step after warmup)

In [None]:
%%time

# Note: Layer init now uses parallel SVD (auto-detects CPU cores)
# Remove --fast-init to get better initial loss with proper SVD initialization
!python scripts/train_v2_simple.py \
    --from-scratch \
    --cache-dir caches/test_L64_K64_N1K \
    --output-dir runs/tpu_test \
    --config q4_r32 \
    --max-steps 20 \
    --batch-size 4 \
    --accumulation-steps 2 \
    --lr 3e-5 \
    --hard-top1 0.2 \
    --temperature 2.0 \
    --warmup-steps 5 \
    --tpu

## Success Criteria

The test passes if:
1. Warmup completes with "optimizer... done"
2. Training runs 20 steps without hanging
3. Loss values are printed every 5 steps

If it hangs after "[TPU] Warmup: compiling XLA graph...", the fix didn't work.

## Optional: Test Multi-Step Training

If the quick test passes, try a longer run:

In [None]:
# Uncomment to run longer test (100 steps)
# %%time
# !python scripts/train_v2_simple.py \
#     --from-scratch \
#     --cache-dir caches/test_L64_K64_N1K \
#     --output-dir runs/tpu_test_100 \
#     --config q4_r32 \
#     --max-steps 100 \
#     --batch-size 8 \
#     --accumulation-steps 4 \
#     --lr 3e-5 \
#     --warmup-steps 10 \
#     --tpu