# SR-011: Q4_A4_r32 From Scratch

**Rev: 1.4 (2026-01-09 16:00 PST)** - Deferred cache download until after model creation

**Configuration: 4-bit MLP, 4-bit Attention, rank=32**

| Component | LUT Size | Bits | Scale Rank |
|-----------|----------|------|------------|
| MLP | 16 | 4 | 32 |
| Attention | 16 | 4 | 32 |

## Workflow (IMPORTANT ORDER)

### Step 0: Create Initial V2 Checkpoint (MUST BE FP32)
1. Load Qwen3-0.6B base model (FP32)
2. Replace linears with V2 (group init -> SVD decomposition)
3. Freeze Q (compute quantization indices once)
4. **Save initial checkpoint** (this is your "truth" checkpoint)
5. Download KD cache (deferred to avoid wasting time if model creation fails)
6. Run sanity checks (mags, LUT FP16 representability)

### Step 1: MLP-Only Training + AutoSnap Mags (Terminal)
- Train MLP scales only (`--mlp-only`)
- Enable AutoSnap for rank_magnitude (`--auto-snap-mags --auto-snap-target mlp`)
- Use **FP32** dtype (`--dtype fp32`) to avoid BF16 rounding during mag training
- Mags automatically frozen when stable
- Use `scripts/train_v2_simple.py` (terminal command in CELL 14)

### Step 2: Attention Training (Python API)
- Train attention scales (Python API, no CLI --attn-only flag)
- MLP layers frozen (scales and mags)
- Uses `train_e2e()` with explicit freeze/enable logic (CELL 18)

### Step 3: Global Fine-tune + Export
- Low LR polish (optional)
- Final snap + export for ANE

## CLI Flag Reference

| Flag | Purpose |
|------|---------|
| `--v2-checkpoint` | Load V2 checkpoint (NOT --checkpoint) |
| `--dtype fp32` | FP32 training (recommended for scale/mag training) |
| `--dtype bf16` | BF16 training (faster, use after mags frozen) |
| `--mlp-only` | Train MLP layers only, freeze attention |
| `--freeze-mags-mlp` | Freeze MLP rank_magnitude (after AutoSnap) |
| `--auto-snap-mags` | Enable auto-snap of rank_magnitude when stable |

## Cell Index

| Cell | Purpose |
|------|---------|
| 1-4 | Setup (paths, mount, clone, deps) |
| 5 | Cache config (paths only, no download) |
| 6 | Configuration (Q4_A4_r32, AutoSnap params) |
| 7-9 | Load model, replace linears, freeze Q |
| 9b | Save to /tmp for inspection |
| **9c** | **Download KD cache (deferred)** |
| 10 | Verify gradients, compute initial loss |
| 11 | Save initial checkpoint (Step 0 complete) |
| 12 | **Sanity checks** (Q frozen, LUT, mags) |
| 13 | Upload to Google Drive |
| **14** | **Phase A: MLP-only + AutoSnap (terminal cmd)** |
| 15-17 | Optional: weight tuning, MLP refinement |
| **18** | **Phase B: Attention training (Python API)** |
| 19-20 | Joint fine-tune, save final |
| 21-25 | Inference, export for ANE |

## V2 Key Features:
- Rank-by-rank forward pass (ANE-friendly)
- Per-rank normalization: `scale_A` (unit columns), `scale_B` (unit rows), `rank_magnitude`
- **Frozen Q**: indices computed once at init via group init + SVD
- **STE-FP16**: Built into V2 forward (no --use-ste flag needed)

## From-Scratch Initialization (Step 0):
1. Group init: Compute per-group max-abs scales
2. Expand: Repeat to per-weight scales
3. SVD: `u, s, vh = svd(scales)` -> `scale_A=u[:,:r]`, `scale_B=vh[:r,:]`, `rank_magnitude=s[:r]`

In [None]:
# [CELL 1: Google Drive paths]
# ============================================================
# GOOGLE DRIVE PATHS (STANDARD)
# ============================================================

# Checkpoints/runs go here
GD_RUNS = '/content/drive/MyDrive/qwen3_runs'

# KD caches go here
GD_CACHES = '/content/drive/MyDrive/qwen3_caches'

# Local directories (on Colab VM)
LOCAL_RUNS = 'runs'
LOCAL_CACHES = 'caches'

In [None]:
# [CELL 2: Mount Google Drive]
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

### GITUB

In [None]:
# [CELL 3: Clone/update repo]
# Clone repo if needed
!git clone https://github.com/anemll/qwen3_apple_style_2bit_qat_lora.git || (cd qwen3_apple_style_2bit_qat_lora && git pull)
%cd qwen3_apple_style_2bit_qat_lora
# to allow updates
!git fetch
!git pull
!git reset --hard HEAD
import sys
[sys.modules.pop(k) for k in list(sys.modules) if k.startswith('qat_lora')]

# Import V2 modules
from qat_lora import (
    # V2 ANE-friendly classes
    AnemllQATLinearV2,
    AnemllQuantConfigV2,
    replace_linear_with_anemll_v2,
    freeze_Q_all,
    freeze_model_for_inference_v2,
    unfreeze_model_for_training_v2,
    # Training utilities
    evaluate_kd_loss,
    train_e2e,
    save_checkpoint,
    load_checkpoint,
)

In [None]:
# [CELL 4: Install dependencies]
# Install dependencies
!pip install -q transformers accelerate safetensors

In [None]:
# [CELL 5: Cache paths (download deferred)]
# ============================================================
# KD CACHE CONFIGURATION (download happens later)
# ============================================================
# Define cache name and paths here. Actual download is deferred
# until after the V2 model is created (Cell 9c).

import os

#CACHE_NAME = 'alpaca_chat_think_both_L128_K32_R256'
#CACHE_NAME = 'alpaca_chat_think_both_L128_K64_R512'
CACHE_NAME = 'alpaca_chat_think_both_L128_K128_R1024'

CACHE_TGZ = f'{CACHE_NAME}.tgz'
cache_local_path = f'{LOCAL_CACHES}/{CACHE_NAME}'

print(f"Cache configured: {CACHE_NAME}")
print(f"  Local path: {cache_local_path}")
print(f"  Download will happen after model creation (Cell 9c)")

In [None]:
# [CELL 6: Configuration - Q4_A4_r32]
# ============================================================
# CONFIGURATION - Q4_A4_r32 (4-bit MLP, 4-bit Attention, rank=32)
# ============================================================
# IMPORTANT: Step 0 MUST use FP32 for SVD init accuracy!

import torch

# Model
MODEL_ID = 'Qwen/Qwen3-0.6B'

# === MLP Quantization (4-bit) ===
LUT_BITS = 4
LUT_SIZE = 2**LUT_BITS  # 16 values
SCALE_RANK = 32  # High rank for quality

# === Attention Quantization (4-bit) ===
ATTN_LUT_BITS = 4
ATTN_LUT_SIZE = 2**ATTN_LUT_BITS  # 16 values
ATTN_SCALE_RANK = 32  # Same rank as MLP

# Training
BATCH_SIZE = 8
GRAD_ACCUM = 4

if torch.cuda.is_available():
    BATCH_SIZE = 16
    GRAD_ACCUM = 2

LR = 5e-5
MIN_LR_RATIO = 0.1

# KD / Distillation params
DISTILL_TEMP = 2.0
HARD_TOP1_WEIGHT = 0.0  # Pure KD for scale training
HARD_FULL_WEIGHT = 0.0

# Device
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# CRITICAL: Step 0 must use FP32 for accurate SVD init
# Training can later use BF16 for speed
DTYPE = torch.float32  # FP32 for Step 0 checkpoint creation

# ============================================================
# AUTO-SNAP CONFIGURATION (for MLP-only training in Step 1)
# ============================================================
AUTO_SNAP_ENABLE = False   # OFF for Step 0, enable in Step 1
AUTO_SNAP_TARGET = "mlp"   # "mlp" for MLP-only phase
AUTO_SNAP_THRESHOLD = 0.05  # Max abs delta to consider stable
AUTO_SNAP_PATIENCE = 2      # Consecutive stable saves before freeze
AUTO_SNAP_START_STEP = 100  # Don't audit before this step
AUTO_SNAP_MIN_SAVES = 2     # Minimum saves before eligible
AUTO_SNAP_DRY_RUN = False   # Set True to audit without freezing

# Quality string for filenames
QUAL = f'q{LUT_BITS}_a{ATTN_LUT_BITS}_r{SCALE_RANK}'

print(f'=== SR-011: Q4_A4_r32 From Scratch ===')
print(f'Quality: {QUAL} (4-bit MLP, 4-bit Attention)')
print(f'Device: {DEVICE}')
print(f'')
print(f'*** STEP 0: FP32 dtype for SVD init accuracy ***')
print(f'dtype: {DTYPE}')
print(f'')
print(f'MLP Config:')
print(f'  LUT size: {LUT_SIZE} ({LUT_BITS}-bit)')
print(f'  Scale rank: {SCALE_RANK}')
print(f'')
print(f'Attention Config:')
print(f'  LUT size: {ATTN_LUT_SIZE} ({ATTN_LUT_BITS}-bit)')
print(f'  Scale rank: {ATTN_SCALE_RANK}')
print(f'')
print(f'AutoSnap Config (for Step 1):')
print(f'  Enabled: {AUTO_SNAP_ENABLE}')
print(f'  Target: {AUTO_SNAP_TARGET}')
print(f'  Threshold: {AUTO_SNAP_THRESHOLD}')
print(f'  Patience: {AUTO_SNAP_PATIENCE}')


In [None]:
# [CELL 8: Load model]
# ============================================================
# LOAD MODEL
# ============================================================

from transformers import AutoModelForCausalLM, AutoTokenizer

print(f'Loading {MODEL_ID}...')
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=DTYPE,
    trust_remote_code=True,
)
model.to(DEVICE)
model.eval()
print(f'Loaded. Parameters: {sum(p.numel() for p in model.parameters()):,}')

In [None]:
# [CELL 9: Replace linears + SVD init + freeze Q]
# ============================================================
# REPLACE LINEARS WITH AnemllQATLinearV2
# ============================================================

import sys
sys.path.insert(0, '.')

# Force reimport to get latest code
import importlib
import qat_lora
importlib.reload(qat_lora)
import qat_lora.ane_qat_linear_v2 as ane_module_v2
importlib.reload(ane_module_v2)
import qat_lora.layer_qat as layer_module
importlib.reload(layer_module)

from qat_lora import AnemllQuantConfigV2, replace_linear_with_anemll_v2, freeze_Q_all

# Debug: Check what modules exist in the model
print("Checking model structure...")
import torch.nn as nn
linear_count = 0
for name, m in model.named_modules():
    if isinstance(m, nn.Linear):
        linear_count += 1
        if linear_count <= 5:
            print(f"  Found Linear: {name}")
print(f"Total Linear modules: {linear_count}")

# Create V2 configs with POSITIVE SCALE CONSTRAINTS
# This prevents scale sign flips when Q is frozen (key fix!)
mlp_config = AnemllQuantConfigV2(
    lut_size=LUT_SIZE,
    scale_rank=SCALE_RANK,
    learnable_lut=False,
    # Positive scale constraints (prevents sign flips with frozen Q)
    force_positive_scales=True,
    positive_scale_method="abs",      # "abs" or "softplus"
    magnitude_activation="softplus",  # keeps g >= 0
    magnitude_eps=1e-6,
)

attn_config = AnemllQuantConfigV2(
    lut_size=ATTN_LUT_SIZE,
    scale_rank=ATTN_SCALE_RANK,
    learnable_lut=False,
    # Positive scale constraints
    force_positive_scales=True,
    positive_scale_method="abs",
    magnitude_activation="softplus",
    magnitude_eps=1e-6,
)

print(f'\nV2 Config: force_positive_scales={mlp_config.force_positive_scales}')
print(f'           positive_scale_method={mlp_config.positive_scale_method}')
print(f'           magnitude_activation={mlp_config.magnitude_activation}')

print('\nReplacing linear layers with V2...')
count = replace_linear_with_anemll_v2(
    model,
    mlp_config=mlp_config,
    attn_config=attn_config,
    quantize_attn=True,
    quantize_lm_head=False,
)

# Verify replacement worked (use type name to avoid reload issues)
qat_count = sum(1 for _, m in model.named_modules() if type(m).__name__ == 'AnemllQATLinearV2')
print(f"\nVerification: {qat_count} AnemllQATLinearV2 modules in model")

# ============================================================
# FREEZE Q - CRITICAL STEP FOR V2
# ============================================================
# Compute Q = lut[indices] once. After this, only scales are trained.
print('\nFreezing Q (computing indices once)...')
freeze_Q_all(model, verbose=False)
print('Q frozen for all layers. Training will only update scale_A, scale_B, rank_magnitude.')

In [None]:
# [CELL 9b: Save V2 model + snap + compare]
# ============================================================
# SAVE V2 MODEL TO /tmp + SNAP WITH SCRIPT + COMPARE
# ============================================================
# 1. Save initial V2 model (FP32)
# 2. Save config.json for snap script
# 3. Run snap_and_test_v2.py (proper snap with all params)
# 4. Run debug_snap_difference.py to compare
#
# NOTE: FP16 snap is always done on CPU internally (.cpu().half().float())

import os
import json
import subprocess

# Paths
ORIG_PATH = '/tmp/v2_initial_frozen_Q.pt'
SNAP_PATH = '/tmp/v2_initial_frozen_Q_snapped.pt'
CONFIG_PATH = '/tmp/v2_config.json'

# ============================================================
# STEP 1: Save original model
# ============================================================
print("=" * 60)
print("STEP 1: Save original V2 model")
print("=" * 60)
torch.save(model.state_dict(), ORIG_PATH)
print(f"Saved: {ORIG_PATH}")
print(f"  Size: {os.path.getsize(ORIG_PATH) / 1e6:.1f} MB")

# ============================================================
# STEP 2: Save config.json for reference
# ============================================================
print("\n" + "=" * 60)
print("STEP 2: Save config.json")
print("=" * 60)
snap_config = {
    'model_id': MODEL_ID,
    'version': 'v2',
    'mlp_lut_bits': LUT_BITS,
    'mlp_scale_rank': SCALE_RANK,
    'attn_lut_bits': ATTN_LUT_BITS,
    'attn_scale_rank': ATTN_SCALE_RANK,
}
with open(CONFIG_PATH, 'w') as f:
    json.dump(snap_config, f, indent=2)
print(f"Saved: {CONFIG_PATH}")
print(f"  MLP:  {LUT_BITS}-bit LUT, rank={SCALE_RANK}")
print(f"  Attn: {ATTN_LUT_BITS}-bit LUT, rank={ATTN_SCALE_RANK}")

# ============================================================
# STEP 3: Run snap_and_test_v2.py
# ============================================================
# CLI args (from --help):
#   --lut-bits        : MLP LUT bits
#   --attn-lut-bits   : Attention LUT bits
#   --scale-rank      : MLP scale rank
#   --attn-scale-rank : Attention scale rank
#   --fp16            : Snap to FP16 for ANE (always uses CPU internally)
print("\n" + "=" * 60)
print("STEP 3: Run snap_and_test_v2.py")
print("=" * 60)

snap_cmd = [
    'python', 'scripts/snap_and_test_v2.py',
    '--checkpoint', ORIG_PATH,
    '--output', SNAP_PATH,
    '--lut-bits', str(LUT_BITS),
    '--scale-rank', str(SCALE_RANK),
    '--attn-lut-bits', str(ATTN_LUT_BITS),
    '--attn-scale-rank', str(ATTN_SCALE_RANK),
    '--fp16',  # Snap to FP16 for ANE (internally uses .cpu().half().float())
    '--no-test',  # Skip inference test (no tokenizer loaded yet)
]
print(f"Running: {' '.join(snap_cmd)}")
print()

result = subprocess.run(snap_cmd, capture_output=True, text=True)
print(result.stdout)
if result.stderr:
    print("STDERR:", result.stderr)

if result.returncode != 0:
    print(f"ERROR: snap_and_test_v2.py failed with code {result.returncode}")
else:
    print(f"Snapped checkpoint saved: {SNAP_PATH}")

# ============================================================
# STEP 4: Run debug_snap_difference.py to compare
# ============================================================
# Note: This script always loads on CPU (map_location='cpu')
print("\n" + "=" * 60)
print("STEP 4: Compare with debug_snap_difference.py")
print("=" * 60)

diff_cmd = [
    'python', 'scripts/debug_snap_difference.py',
    ORIG_PATH,
    SNAP_PATH,
]
print(f"Running: {' '.join(diff_cmd)}")
print()

result = subprocess.run(diff_cmd, capture_output=True, text=True)
print(result.stdout)
if result.stderr:
    print("STDERR:", result.stderr)

# ============================================================
# External commands for reference
# ============================================================
print("\n" + "=" * 60)
print("External commands (for terminal):")
print("=" * 60)
print(f"  python scripts/check_mags_fp16.py {ORIG_PATH}")
print(f"  python scripts/debug_snap_difference.py {ORIG_PATH} {SNAP_PATH}")

In [None]:
# [CELL 9c: Download KD cache (after model created)]
# ============================================================
# SYNC KD CACHE FROM GOOGLE DRIVE
# ============================================================
# Priority order:
#   1. Folder (rsync) - fastest, no extraction needed
#   2. .tar.lz4 - fast decompression (copy locally first!)
#   3. .tgz - slowest but most compatible (copy locally first!)
#
# IMPORTANT: Copy archive to local disk BEFORE extracting!
# Extracting directly from GDrive is extremely slow.

import os
import subprocess
from pathlib import Path

print(f"Syncing KD cache: {CACHE_NAME}")

# Verify drive is mounted
if not os.path.exists('/content/drive/MyDrive'):
    print('Google Drive not mounted! Mounting now...')
    from google.colab import drive
    drive.mount('/content/drive')

# Create local cache directory
os.makedirs(LOCAL_CACHES, exist_ok=True)

# Define possible source paths
gdrive_folder = f'{GD_CACHES}/{CACHE_NAME}'
gdrive_lz4 = f'{GD_CACHES}/{CACHE_NAME}.tar.lz4'
gdrive_tgz = f'{GD_CACHES}/{CACHE_NAME}.tgz'

# Check if cache exists locally already
if os.path.exists(cache_local_path) and list(Path(cache_local_path).glob('*.pt')):
    print(f'Cache already exists at {cache_local_path}')
    cache_files = list(Path(cache_local_path).glob('*.pt'))
    print(f'  Found {len(cache_files)} .pt files')

# Priority 1: Rsync from folder (fastest - no extraction)
elif os.path.isdir(gdrive_folder):
    print(f'Found folder: {gdrive_folder}')
    print('Using rsync (fastest - no extraction needed)...')
    os.makedirs(cache_local_path, exist_ok=True)
    result = subprocess.run(
        ['rsync', '-ah', '--progress', f'{gdrive_folder}/', f'{cache_local_path}/'],
        capture_output=False
    )
    if result.returncode != 0:
        raise RuntimeError(f'rsync failed with code {result.returncode}')

# Priority 2: Copy .tar.lz4 locally, then extract (fast)
elif os.path.exists(gdrive_lz4):
    print(f'Found: {gdrive_lz4}')
    size_gb = os.path.getsize(gdrive_lz4) / (1024**3)

    # Install lz4 if needed
    print('  Installing lz4 if needed...')
    subprocess.run(['apt-get', 'install', '-y', '-qq', 'lz4'], capture_output=True)

    # Copy to local first with rsync --progress (MUCH faster than extracting from GDrive)
    local_archive = f'{LOCAL_CACHES}/{CACHE_NAME}.tar.lz4'
    print(f'  Copying to local disk ({size_gb:.2f} GB)...')
    result = subprocess.run(
        ['rsync', '-ah', '--progress', gdrive_lz4, local_archive],
        capture_output=False  # Show progress
    )
    if result.returncode != 0:
        raise RuntimeError(f'rsync copy failed with code {result.returncode}')

    # Extract from local disk
    print(f'  Extracting from local disk...')
    result = subprocess.run(
        f'lz4 -d "{local_archive}" -c | tar -xf - -C "{LOCAL_CACHES}"',
        shell=True,
        capture_output=True,
        text=True
    )
    if result.returncode != 0:
        print(f'  STDERR: {result.stderr}')
        raise RuntimeError(f'lz4 extraction failed with code {result.returncode}')

    # Delete local archive to save space
    print(f'  Cleaning up local archive...')
    os.remove(local_archive)

# Priority 3: Copy .tgz locally, then extract
elif os.path.exists(gdrive_tgz):
    print(f'Found: {gdrive_tgz}')
    size_gb = os.path.getsize(gdrive_tgz) / (1024**3)

    # Copy to local first with rsync --progress
    local_archive = f'{LOCAL_CACHES}/{CACHE_NAME}.tgz'
    print(f'  Copying to local disk ({size_gb:.2f} GB)...')
    result = subprocess.run(
        ['rsync', '-ah', '--progress', gdrive_tgz, local_archive],
        capture_output=False  # Show progress
    )
    if result.returncode != 0:
        raise RuntimeError(f'rsync copy failed with code {result.returncode}')

    # Extract from local disk
    print(f'  Extracting from local disk...')
    result = subprocess.run(
        ['tar', '-xzf', local_archive, '-C', LOCAL_CACHES],
        capture_output=True,
        text=True
    )
    if result.returncode != 0:
        print(f'  STDERR: {result.stderr}')
        raise RuntimeError(f'tar extraction failed with code {result.returncode}')

    # Delete local archive to save space
    print(f'  Cleaning up local archive...')
    os.remove(local_archive)

else:
    print(f'ERROR: Cache not found in Google Drive!')
    print(f'  Checked:')
    print(f'    - {gdrive_folder}/ (folder)')
    print(f'    - {gdrive_lz4} (.tar.lz4)')
    print(f'    - {gdrive_tgz} (.tgz)')
    raise FileNotFoundError(f'Cache {CACHE_NAME} not found in {GD_CACHES}')

# Verify cache exists
assert os.path.exists(cache_local_path), f'Cache not found at {cache_local_path}'
cache_files = list(Path(cache_local_path).glob('*.pt'))
print(f'\n✓ Cache ready: {len(cache_files)} .pt files in {cache_local_path}')

In [None]:
# [CELL 10: Verify gradient flow]
# ============================================================
# VERIFY GRADIENT FLOW FOR V2
# ============================================================

from qat_lora import evaluate_kd_loss

print('Verifying V2 gradient flow...')

# Find a V2 layer (use type name to avoid reload issues)
layer0 = model.model.layers[0]
test_module = None
for name, m in layer0.named_modules():
    if type(m).__name__ == 'AnemllQATLinearV2':
        test_module = m
        break

if test_module is None:
    print("ERROR: No AnemllQATLinearV2 modules found! Replacement failed.")
else:
    # V2: weight is frozen, only scales are trained
    print(f"  weight.requires_grad: {test_module.weight.requires_grad} (should be False after freeze_Q)")
    print(f"  scale_A.requires_grad: {test_module.scale_A.requires_grad}")
    print(f"  scale_B.requires_grad: {test_module.scale_B.requires_grad}")
    print(f"  rank_magnitude.requires_grad: {test_module.rank_magnitude.requires_grad}")
    print(f"  Q frozen: {test_module._Q is not None}")
    
    # Test gradient flow through scales
    test_module.scale_A.requires_grad = True
    test_module.scale_B.requires_grad = True
    test_module.rank_magnitude.requires_grad = True
    
    x = torch.randn(1, 10, test_module.in_features, device=DEVICE, dtype=DTYPE)
    y = test_module(x)
    loss = y.sum()
    try:
        loss.backward()
        if test_module.scale_A.grad is not None:
            print(f"  Gradient OK: scale_A.grad.shape = {test_module.scale_A.grad.shape}")
            print(f"               scale_B.grad.shape = {test_module.scale_B.grad.shape}")
            print(f"               rank_magnitude.grad.shape = {test_module.rank_magnitude.grad.shape}")
            # Clear for actual training
            test_module.scale_A.grad = None
            test_module.scale_B.grad = None
            test_module.rank_magnitude.grad = None
        else:
            print("  ERROR: scale_A.grad is None after backward!")
    except Exception as e:
        print(f"  ERROR during backward: {e}")

# Compute initial KD loss
print('\nComputing initial KD loss...')
initial_loss = evaluate_kd_loss(model, cache_local_path, DEVICE, num_samples=40, temperature=DISTILL_TEMP)
print(f'Initial KD Loss: {initial_loss:.4f}')

# **V2 SCALE OPTIMIZATION** (Q and Weights Frozen)

In V2, we train only the scale parameters:
- `scale_A`: [out, rank] - unit-norm columns
- `scale_B`: [rank, in] - unit-norm rows  
- `rank_magnitude`: [rank] - the ONLY magnitude

The frozen components:
- `weight`: Original FP weights (frozen by `freeze_Q()`)
- `_Q`: LUT values = `lut[indices]` (computed once by `freeze_Q()`)
- `_indices`: Quantization indices (computed once)

**Why this works:**
- Q represents the normalized quantized weights in [-1, 1]
- Scales modulate the contribution of each rank
- Training adjusts how much each rank contributes to the output

**Training strategy:** E2E training (all layers together) is more effective than layer-by-layer.

In [None]:
# [CELL 11: Save initial checkpoint]
# ============================================================
# SAVE INITIAL V2 CHECKPOINT (BEFORE TRAINING)
# ============================================================

import os

RUN_NAME = f'SR-011_{QUAL}_from_scratch'
SAVE_DIR = f'{LOCAL_RUNS}/{RUN_NAME}'

os.makedirs(SAVE_DIR, exist_ok=True)

# Save state dict
torch.save(model.state_dict(), f'{SAVE_DIR}/model_state_dict.pt')

# Save config
import json
config = {
    'model_id': MODEL_ID,
    'version': 'v2',  # Mark as V2
    'lut_size': LUT_SIZE,
    'scale_rank': SCALE_RANK,
    'attn_lut_size': ATTN_LUT_SIZE,
    'attn_scale_rank': ATTN_SCALE_RANK,
    'initial_kd_loss': initial_loss,
}
with open(f'{SAVE_DIR}/config.json', 'w') as f:
    json.dump(config, f, indent=2)

print(f'Saved V2 checkpoint to {SAVE_DIR}')

In [None]:
# [CELL 12: Sanity checks (CPU-only, before training)]
# ============================================================
# SANITY CHECKS FOR V2 INITIALIZATION
# ============================================================
# Run these checks BEFORE training to verify correct setup:
# 1. Q is frozen (indices computed)
# 2. LUT values are FP16-representable
# 3. rank_magnitude stats look reasonable

print("=" * 60)
print("SANITY CHECKS (Step 0 Verification)")
print("=" * 60)

# Check 1: Q frozen
q_frozen_count = 0
q_not_frozen = []
for name, module in model.named_modules():
    if type(module).__name__ == 'AnemllQATLinearV2':
        if module._Q is not None:
            q_frozen_count += 1
        else:
            q_not_frozen.append(name)

print(f"\n[Check 1] Q Frozen: {q_frozen_count} layers")
if q_not_frozen:
    print(f"  WARNING: {len(q_not_frozen)} layers have Q=None!")
    for n in q_not_frozen[:5]:
        print(f"    - {n}")
else:
    print("  ✓ All layers have frozen Q")

# Check 2: LUT FP16 representability
lut_issues = []
for name, module in model.named_modules():
    if type(module).__name__ == 'AnemllQATLinearV2':
        lut = module.lut.data.cpu()
        lut_fp16 = lut.half().float()
        max_diff = (lut - lut_fp16).abs().max().item()
        if max_diff > 1e-6:
            lut_issues.append((name, max_diff))
        break  # Just check one (they should all be the same)

print(f"\n[Check 2] LUT FP16 Representability:")
if lut_issues:
    print(f"  WARNING: {len(lut_issues)} layers have LUT precision issues")
else:
    print("  ✓ LUT values are FP16-representable")

# Check 3: rank_magnitude stats
mag_stats = {'min': float('inf'), 'max': float('-inf'), 'mean': 0, 'count': 0}
for name, module in model.named_modules():
    if type(module).__name__ == 'AnemllQATLinearV2':
        mag = module.rank_magnitude.data.cpu()
        mag_stats['min'] = min(mag_stats['min'], mag.min().item())
        mag_stats['max'] = max(mag_stats['max'], mag.max().item())
        mag_stats['mean'] += mag.mean().item()
        mag_stats['count'] += 1

if mag_stats['count'] > 0:
    mag_stats['mean'] /= mag_stats['count']

print(f"\n[Check 3] rank_magnitude Stats:")
print(f"  min: {mag_stats['min']:.6f}")
print(f"  max: {mag_stats['max']:.6f}")
print(f"  mean: {mag_stats['mean']:.6f}")

if mag_stats['min'] < 0:
    print("  WARNING: Negative rank_magnitude values detected!")
else:
    print("  ✓ All positive (softplus activation working)")

# Check 4: scale_A/scale_B norms
norm_issues = 0
for name, module in model.named_modules():
    if type(module).__name__ == 'AnemllQATLinearV2':
        A_norms = module.scale_A.data.norm(dim=0)  # Column norms
        B_norms = module.scale_B.data.norm(dim=1)  # Row norms
        if (A_norms - 1.0).abs().max() > 0.1:
            norm_issues += 1
        if (B_norms - 1.0).abs().max() > 0.1:
            norm_issues += 1

print(f"\n[Check 4] Scale Normalization:")
if norm_issues > 0:
    print(f"  WARNING: {norm_issues} scale matrices not unit-norm")
else:
    print("  ✓ scale_A columns and scale_B rows are unit-norm")

# Check 5: Dtype verification
print(f"\n[Check 5] Dtype Verification:")
sample_module = None
for name, module in model.named_modules():
    if type(module).__name__ == 'AnemllQATLinearV2':
        sample_module = module
        break
if sample_module:
    print(f"  weight.dtype: {sample_module.weight.dtype}")
    print(f"  scale_A.dtype: {sample_module.scale_A.dtype}")
    print(f"  _Q.dtype: {sample_module._Q.dtype if sample_module._Q is not None else 'N/A'}")
    if sample_module.weight.dtype == torch.float32:
        print("  ✓ FP32 (correct for Step 0)")
    else:
        print(f"  ⚠ Expected FP32 for Step 0, got {sample_module.weight.dtype}")

print("\n" + "=" * 60)
print("SANITY CHECKS COMPLETE - Ready for Step 1 (MLP Training)")
print("=" * 60)

In [None]:
# [CELL 13: Upload initial to Google Drive]
# ============================================================
# UPLOAD TO GOOGLE DRIVE
# ============================================================

!tar -czvf {RUN_NAME}.tgz -C {LOCAL_RUNS} {RUN_NAME}
!cp {RUN_NAME}.tgz {GD_RUNS}/
print(f'Uploaded to {GD_RUNS}/{RUN_NAME}.tgz')

# **V2 END-TO-END KD-QAT TRAINING**

This is the **primary training stage** for V2. Train all MLP scales together.

**V2 Training Strategy:**

| Stage | What's Trained | What's Frozen | Notes |
|-------|---------------|---------------|-------|
| Stage 1 | scale_A, scale_B, rank_magnitude | weight, Q, lut | Primary training |
| Stage 2 | weight + scales | lut | Optional, recomputes Q |

**Recommended approach:** Stay in Stage 1 (scales only) for most cases.

In [None]:
# [CELL 14: Phase A - MLP-only training + AutoSnap (Terminal)]
# ============================================================
# STEP 1: MLP-ONLY TRAINING + AUTO-SNAP MAGS
# ============================================================
# This cell prints the terminal command for MLP-only training.
# Run in terminal for better monitoring (nohup, logs, etc.)
#
# AutoSnap: Automatically snaps rank_magnitude to FP16 when stable,
# then freezes mags for the rest of training.
#
# IMPORTANT FLAGS:
#   --v2-checkpoint : Load V2 checkpoint (NOT --checkpoint which is for V1)
#   --dtype fp32    : FP32 for scale/mag training (no BF16 rounding errors)
#   --mlp-only      : Train only MLP layers, freeze attention
#   --auto-snap-mags: Enable auto-snap of rank_magnitude when stable

import os

# Checkpoint from Step 0
INIT_CKPT = f'{SAVE_DIR}/model_state_dict.pt'

# Training output dir
MLP_RUN_NAME = f'SR-011_{QUAL}_mlp_autosnap'
MLP_SAVE_DIR = f'{LOCAL_RUNS}/{MLP_RUN_NAME}'

# Build command
cmd = f'''python scripts/train_v2_simple.py \\
    --v2-checkpoint {INIT_CKPT} \\
    --cache-dir {cache_local_path} \\
    --output-dir {MLP_SAVE_DIR} \\
    --mlp-only \\
    --auto-snap-mags \\
    --auto-snap-target mlp \\
    --auto-snap-threshold {AUTO_SNAP_THRESHOLD} \\
    --auto-snap-patience {AUTO_SNAP_PATIENCE} \\
    --auto-snap-start-step {AUTO_SNAP_START_STEP} \\
    --max-steps 4000 \\
    --batch-size {BATCH_SIZE} \\
    --accumulation-steps {GRAD_ACCUM} \\
    --lr 5e-4 \\
    --warmup-steps 100 \\
    --save-steps 200 \\
    --eval-steps 100 \\
    --temperature {DISTILL_TEMP} \\
    --dtype fp32'''

print("=" * 70)
print("STEP 1: MLP-ONLY TRAINING + AUTO-SNAP")
print("=" * 70)
print()
print("Copy this command to terminal:")
print()
print(cmd)
print()
print("=" * 70)
print("Expected behavior:")
print("  - Trains MLP scales only (attention frozen)")
print("  - AutoSnap monitors rank_magnitude stability at save checkpoints")
print("  - When stable for 2 consecutive saves: snap mags to FP16 + freeze")
print("  - After freeze: only scale_A, scale_B trained (fewer params)")
print()
print("NOTE: Using --dtype fp32 to avoid BF16 rounding errors in mags.")
print("      AutoSnap will snap to FP16 when stable, then freeze.")
print("=" * 70)

In [None]:
# [CELL 15: GC cleanup]
import gc
gc.collect()
torch.cuda.empty_cache()

In [None]:
# [CELL 16: E2E weight tuning (optional)]
# ============================================================
# V2 OPTIONAL: E2E GENTLE WEIGHT + SCALE TUNING (Stage 2)
# ============================================================
# Only run this if scale-only training isn't sufficient
# Requires clearing _Q to recompute on-the-fly

ENABLE_E2E_WEIGHT_TUNING = False  # Set to True to enable

if ENABLE_E2E_WEIGHT_TUNING:
    from qat_lora import unfreeze_model_for_training_v2
    
    # Clear cached weights and frozen Q
    unfreeze_model_for_training_v2(model)
    
    # Unfreeze weights and clear Q
    for name, module in model.named_modules():
        if isinstance(module, AnemllQATLinearV2):
            module.weight.requires_grad = True
            module._Q = None  # Will compute Q on-the-fly
    
    print('V2 E2E weight training (Stage 2)...')
    print('WARNING: Q will be recomputed on-the-fly for each forward pass')
    
    e2e_weights_result = train_e2e(
        model=model,
        cache_dir=cache_local_path,
        device=DEVICE,
        max_steps=1000,
        batch_size=64 if torch.cuda.is_available() else 32,
        lr=1e-5,  # Very small LR for weights
        use_cosine_schedule=True,
        warmup_steps=100,
        min_lr_ratio=0.1,
        temperature=DISTILL_TEMP,
        train_weights=True,
        train_scales=True,  # Also train scales
        hard_top1_weight=0.2,
        hard_full_weight=0.0,
        logging_steps=50,
        eval_steps=500,
        verbose=True,
    )
    
    # Re-freeze Q after training
    print('Re-freezing Q...')
    freeze_Q_all(model, verbose=False)
else:
    print('E2E weight tuning disabled. Using scale-only training.')
    print('Set ENABLE_E2E_WEIGHT_TUNING = True above to enable.')

In [None]:
# [CELL 17: E2E MLP refinement]
# ============================================================
# V2 E2E: ADDITIONAL SCALE REFINEMENT (MLP ONLY)
# ============================================================
# Continue scale training for MLP layers

e2e_mlp_scales_result = train_e2e(
    model=model,
    cache_dir=cache_local_path,
    device=DEVICE,
    max_steps=1000,
    batch_size=64 if torch.cuda.is_available() else 32,
    lr=1e-4,  # Lower LR for refinement
    use_cosine_schedule=True,
    warmup_steps=100,
    min_lr_ratio=0.1,
    temperature=DISTILL_TEMP,
    train_weights=False,
    train_scales=True,
    hard_top1_weight=0.0,
    hard_full_weight=0.0,
    logging_steps=50,
    eval_steps=500,
    verbose=True,
    train_mlp_only=True,
)

# **V2 ATTENTION SCALE TRAINING**

Train attention layer scales (Q, K, V, O projections) while keeping MLP frozen.

In [None]:
# [CELL 18: Phase B - Attention training (Python API)]
# ============================================================
# STEP 2: ATTENTION SCALE TRAINING (MLP FROZEN)
# ============================================================
# NOTE: The CLI doesn't have --attn-only flag.
# Use Python API (train_e2e) with explicit attention-only config.
#
# Alternative CLI approach (trains all, MLP mags frozen):
#   python scripts/train_v2_simple.py \
#       --v2-checkpoint {MLP_BEST_CKPT} \
#       --cache-dir {cache_local_path} \
#       --output-dir {ATTN_SAVE_DIR} \
#       --freeze-mags-mlp \
#       --max-steps 2000 --lr 1e-4 --dtype fp32

from qat_lora import train_e2e, evaluate_kd_loss

# Load best MLP checkpoint
MLP_BEST_CKPT = f'{MLP_SAVE_DIR}/best_state_dict.pt'

print("=" * 70)
print("STEP 2: ATTENTION SCALE TRAINING")
print("=" * 70)

# Check if MLP checkpoint exists
import os
if os.path.exists(MLP_BEST_CKPT):
    print(f"Loading MLP checkpoint: {MLP_BEST_CKPT}")
    model.load_state_dict(torch.load(MLP_BEST_CKPT, map_location=DEVICE))
    print("Loaded successfully.")
else:
    print(f"WARNING: MLP checkpoint not found at {MLP_BEST_CKPT}")
    print("Make sure to run Step 1 (MLP training) first!")

# Freeze MLP layers, enable attention layers only
print("\nFreezing MLP layers, enabling attention only...")
for name, module in model.named_modules():
    if type(module).__name__ == 'AnemllQATLinearV2':
        is_mlp = any(p in name for p in ['gate_proj', 'up_proj', 'down_proj'])
        is_attn = any(p in name for p in ['q_proj', 'k_proj', 'v_proj', 'o_proj'])
        
        if is_mlp:
            # Freeze ALL MLP params
            module.scale_A.requires_grad = False
            module.scale_B.requires_grad = False
            module.rank_magnitude.requires_grad = False
        elif is_attn:
            # Train attention scales
            module.scale_A.requires_grad = True
            module.scale_B.requires_grad = True
            module.rank_magnitude.requires_grad = True

# Count trainable params
attn_trainable = sum(p.numel() for n, p in model.named_parameters() 
                     if p.requires_grad and any(x in n for x in ['q_proj', 'k_proj', 'v_proj', 'o_proj']))
mlp_trainable = sum(p.numel() for n, p in model.named_parameters()
                    if p.requires_grad and any(x in n for x in ['gate_proj', 'up_proj', 'down_proj']))
print(f"Trainable attention params: {attn_trainable:,}")
print(f"Trainable MLP params: {mlp_trainable:,} (should be 0)")

# Eval before attention training
loss_before = evaluate_kd_loss(model, cache_local_path, DEVICE, num_samples=40, temperature=DISTILL_TEMP)
print(f"\nKD Loss before attention training: {loss_before:.4f}")

# Train attention scales
ATTN_RUN_NAME = f'SR-011_{QUAL}_attn'
ATTN_SAVE_DIR = f'{LOCAL_RUNS}/{ATTN_RUN_NAME}'
os.makedirs(ATTN_SAVE_DIR, exist_ok=True)

print("\nTraining attention scales...")
attn_result = train_e2e(
    model=model,
    cache_dir=cache_local_path,
    device=DEVICE,
    max_steps=2000,
    batch_size=BATCH_SIZE,
    lr=1e-4,
    use_cosine_schedule=True,
    warmup_steps=100,
    min_lr_ratio=0.1,
    temperature=DISTILL_TEMP,
    train_weights=False,
    train_scales=True,
    hard_top1_weight=0.0,
    hard_full_weight=0.0,
    logging_steps=50,
    eval_steps=100,
    verbose=True,
)

# Save attention checkpoint
torch.save(model.state_dict(), f'{ATTN_SAVE_DIR}/best_state_dict.pt')
print(f"\nSaved attention checkpoint to {ATTN_SAVE_DIR}/best_state_dict.pt")

# **V2 JOINT MLP + ATTENTION FINE-TUNING**

After separate MLP and attention training, fine-tune both together for final polish.

**Training Strategy:**
- Stage 1: MLP scales only
- Stage 2: Attention scales only (MLP frozen)
- **Stage 3: Joint MLP + Attention** (this cell) ← co-adaptation

In [None]:
# [CELL 19: Joint MLP + Attention fine-tuning]
# ============================================================
# V2 E2E: JOINT MLP + ATTENTION FINE-TUNING
# ============================================================
# Train both MLP and attention scales together for final polish
# This helps the scales co-adapt after separate training

from qat_lora import unfreeze_model_for_training_v2

unfreeze_model_for_training_v2(model)

# Enable ALL scales (both MLP and attention)
for name, module in model.named_modules():
    if isinstance(module, AnemllQATLinearV2):
        if hasattr(module, 'scale_A') and module.scale_A is not None:
            module.scale_A.requires_grad = True
            module.scale_B.requires_grad = True
            module.rank_magnitude.requires_grad = True
        module.weight.requires_grad = False  # Keep weights frozen

trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Trainable params (joint MLP+Attn): {trainable:,}')

# Joint training with lower LR for fine-tuning
e2e_joint_result = train_e2e(
    model=model,
    cache_dir=cache_local_path,
    device=DEVICE,
    max_steps=1000,  # Shorter since already trained
    batch_size=32,
    lr=5e-5,         # Lower LR for fine-tuning
    use_cosine_schedule=True,
    warmup_steps=50,
    min_lr_ratio=0.1,
    temperature=DISTILL_TEMP,
    train_weights=False,
    train_scales=True,
    hard_top1_weight=0.0,
    hard_full_weight=0.0,
    logging_steps=20,
    eval_steps=100,
    verbose=True,
    train_mlp_only=False,
)

In [None]:
# [CELL 20: Save final checkpoint + upload]
# ============================================================
# SAVE FINAL V2 CHECKPOINT
# ============================================================

from qat_lora import unfreeze_model_for_training_v2

unfreeze_model_for_training_v2(model)

E2E_RUN_NAME = f'SR-011_{QUAL}_e2e'
E2E_SAVE_DIR = f'{LOCAL_RUNS}/{E2E_RUN_NAME}'

# Save with config
config = {
    'model_id': MODEL_ID,
    'version': 'v2',
    'lut_size': LUT_SIZE,
    'scale_rank': SCALE_RANK,
    'attn_lut_size': ATTN_LUT_SIZE,
    'attn_scale_rank': ATTN_SCALE_RANK,
    'e2e_scales_result': e2e_scales_result,
}

save_checkpoint(model, E2E_SAVE_DIR, config=config)

# Upload to Google Drive
!tar -czvf {E2E_RUN_NAME}.tgz -C {LOCAL_RUNS} {E2E_RUN_NAME}
!cp {E2E_RUN_NAME}.tgz {GD_RUNS}/
print(f'\nUploaded to {GD_RUNS}/{E2E_RUN_NAME}.tgz')

# **INFERENCE OPTIMIZATION**

Before running inference, freeze all layers to precompute quantized weights.
This avoids recomputing `LUT[indices] * (scale_A @ scale_B)` on every forward pass.

In [None]:
# [CELL 21: Freeze model for inference]
# ============================================================
# V2 FREEZE MODEL FOR FAST INFERENCE
# ============================================================
# Precompute full W_eff = Q * scales for all layers
# This caches the effective weight to avoid rank-by-rank computation per token

from qat_lora import freeze_model_for_inference_v2, unfreeze_model_for_training_v2

print('Freezing V2 model for inference...')
num_frozen = freeze_model_for_inference_v2(model, verbose=False)
print(f'Frozen {num_frozen} V2 layers')
print('Cached W_eff = Q * scales for fast inference')

# To resume training later:
# unfreeze_model_for_training_v2(model)

In [None]:
# [CELL 21: (empty)]


In [None]:
# [CELL 22: Test inference (greedy)]
import torch

# ============================================================
# TEST INFERENCE
# ============================================================

def run_inference(model, tokenizer, prompt, max_new_tokens=128):
    messages = [
        {'role': 'system', 'content': 'You are a helpful assistant.'},
        {'role': 'user', 'content': prompt}
    ]
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(text, return_tensors='pt').to(DEVICE)

    with torch.no_grad():
        output = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)

    return tokenizer.decode(output[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)

# List of prompts to test
prompts = [
    'What is the capital of France?',
    'What is Apple Neural Engine?',
    'Explain quantum mechanics',
    'What is speed of light'
]

model.eval() # Set model to evaluation mode once

for prompt in prompts:
    response = run_inference(model, tokenizer, prompt,max_new_tokens=1024)
    print(f'Prompt: {prompt}')
    print(f'Response: {response}')
    print('-' * 50) # Separator for readability

In [None]:
# [CELL 23: Test inference (sampling)]
# ============================================================
# TEST INFERENCE
# ============================================================

def run_inference(model, tokenizer, prompt, max_new_tokens=512):
    messages = [
        {'role': 'user', 'content': prompt}
    ]
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=True
    )
    inputs = tokenizer(text, return_tensors='pt').to(DEVICE)

    with torch.no_grad():
        output = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=0.6,
            top_p=0.9,
            repetition_penalty=1.1,
        )

    return tokenizer.decode(output[0][inputs['input_ids'].shape[1]:], skip_special_tokens=False)

# List of prompts to test
prompts = [
    'What is the capital of France?',
    'What is Apple Neural Engine?',
    'Explain quantum mechanics',
    'What is speed of light'
]

model.eval()

for prompt in prompts:
    response = run_inference(model, tokenizer, prompt, max_new_tokens=512)
    print(f'Prompt: {prompt}')
    print(f'Response: {response}')
    print('-' * 50)

## Next Steps

After E2E training, you can:

1. **Additional refinement** - Continue training with lower LR
2. **LoRA recovery** - Add LoRA adapters to recover quality
3. **Export for ANE** - Use snap_for_export() to bake normalization

# **V2 EXPORT FOR ANEMLL CONVERTER**

Export V2 model for external tools.

**V2 Export includes:**
- `_Q`: Frozen LUT values [out, in]
- `_indices`: Quantization indices [out, in]
- `scale_A`, `scale_B`, `rank_magnitude`: Scale parameters
- `lut`: LUT values

**For inference optimization:**
- Call `snap_for_export()` to bake normalization into parameters
- This eliminates runtime norm computation in the exported model

In [None]:
# [CELL 24: Export for ANEMLL converter]
# ============================================================
# V2 EXPORT FOR ANEMLL CONVERTER
# ============================================================

from qat_lora import unfreeze_model_for_training_v2

# First unfreeze to clear cached weights
unfreeze_model_for_training_v2(model)

# Export V2 model representation
print('Exporting V2 quantized model representation...')

export_dict = {}
for name, module in model.named_modules():
    if isinstance(module, AnemllQATLinearV2):
        layer_export = {
            'indices': module._indices.cpu() if module._indices is not None else None,
            'Q': module._Q.cpu() if module._Q is not None else None,
            'scale_A': module.scale_A.data.cpu(),
            'scale_B': module.scale_B.data.cpu(),
            'rank_magnitude': module.rank_magnitude.data.cpu(),
            'lut': module.lut.cpu(),
            'bias': module.bias.data.cpu() if module.bias is not None else None,
            'in_features': module.in_features,
            'out_features': module.out_features,
            'scale_rank': module.scale_rank,
            'lut_bits': module.lut_bits,
        }
        export_dict[name] = layer_export

print(f'Exported {len(export_dict)} V2 layers')

# Save export for ANEMLL converter
EXPORT_DIR = f'{LOCAL_RUNS}/{E2E_RUN_NAME}_export'
os.makedirs(EXPORT_DIR, exist_ok=True)
torch.save(export_dict, f'{EXPORT_DIR}/v2_quantized_model.pt')
print(f'\nSaved V2 export to {EXPORT_DIR}/v2_quantized_model.pt')

In [None]:
# [CELL 25: Snap for export + test]
# ============================================================
# V2 SNAP FOR EXPORT AND TEST
# ============================================================
# Bake normalization into parameters for CoreML-friendly export

print('Snapping V2 model for export (baking normalization)...')

for name, module in model.named_modules():
    if isinstance(module, AnemllQATLinearV2):
        module.snap_for_export()

print('V2 model snapped for export.')
print('Normalization baked into scale_A (no runtime norm computation)')

# Test inference with snapped model
model.eval()
print('\nTesting V2 inference with snapped weights...')
response = run_inference(model, tokenizer, 'What is 2+2?', max_new_tokens=256)
print(f'Prompt: What is 2+2?')
print(f'Response: {response}')

In [None]:
# [CELL 26: Backup checkpoint]
torch.save(model.state_dict(), '/tmp/backup_mlp_e2e_0.4613.pt')  # Local, fast

torch.save(model.state_dict(), '/tmp/backup_mlp_e2e_w_0.3824.pt')  # Local, fast

In [None]:
# [CELL 27: Backup checkpoint]
torch.save(model.state_dict(), '/tmp/backup_mlp_e4e_4_4.pt')  # Local, fast

In [None]:
# [CELL 28: Load backup]
model.load_state_dict(torch.load('/tmp/backup_initial.pt', map_location=DEVICE))

In [None]:
# [CELL 29: Load backup]
model.load_state_dict(torch.load('/tmp/backup_mlp_e2e_w_0.3824.pt', map_location=DEVICE))