# Debug LUT Training Experiment

**Rev: 1.0 (2026-01-17)**

## Goal
Investigate why LUT + scales training appears to change `_Q` tensors unexpectedly.

## Hypothesis
During training, `_Q` buffer should NOT change because:
1. `_Q` is a buffer (not a parameter) - has no gradients
2. Forward pass uses `lut[_indices]` when LUT training is enabled, bypassing `_Q`
3. `_Q` is only updated by explicit calls to `sync_q_from_indices()` or baking

## Experiment Plan
1. Snapshot initial checkpoint tensor hashes
2. Run short training (20 steps) with LUT + scales
3. Compare tensors: what actually changed?
4. Verify `_Q` buffers are unchanged

## Expected Results

| Tensor Type | Should Change? | Why |
|-------------|----------------|-----|
| `.lut` | YES | `--train-lut` enables LUT training |
| `_lut_raw_deltas` | YES | LUT training uses delta parameterization |
| `scale_A`, `scale_B` | YES | `train_scales=True` (hardcoded) |
| `rank_magnitude` | YES | Part of scale training |
| `_Q` | **NO** | Buffer, bypassed when LUT training enabled |
| `_indices` | NO | Frozen at init |

## Cell Index

| Cell | Purpose |
|------|---------|
| 1 | Google Drive paths |
| 2 | Mount Google Drive |
| 3 | Clone/update repo |
| 4 | Install dependencies |
| 5 | Configuration |
| 6 | Snapshot initial checkpoint |
| 7 | Run short training |
| 8 | Compare results |
| 9 | Categorize changes |
| 10 | Analyze _Q changes (key result) |
| 11 | Verify LUT training worked |
| 12 | Conclusion |
| 13-14 | Bonus: compare srLUT-004c input vs output |

In [None]:
# [CELL 1: Google Drive paths]
# ============================================================
# GOOGLE DRIVE PATHS (STANDARD)
# ============================================================

# Checkpoints/runs go here
GD_RUNS = '/content/drive/MyDrive/qwen3_runs'

# KD caches go here
GD_CACHES = '/content/drive/MyDrive/qwen3_caches'

# Local directories (on Colab VM)
LOCAL_RUNS = 'runs'
LOCAL_CACHES = 'caches'

In [None]:
# [CELL 2: Mount Google Drive]
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# [CELL 3: Clone/update repo]
# ============================================================
# CLONE/UPDATE REPO (re-run safe)
# ============================================================

import os

REPO_DIR = '/content/qwen3_apple_style_2bit_qat_lora'

if not os.path.exists(REPO_DIR):
    print(f"Cloning repo to {REPO_DIR}...")
    !git clone https://github.com/anemll/qwen3_apple_style_2bit_qat_lora.git {REPO_DIR}
else:
    print(f"Repo exists at {REPO_DIR}, updating...")

%cd {REPO_DIR}
!git fetch && git pull

# Clear cached imports
import sys
[sys.modules.pop(k) for k in list(sys.modules) if k.startswith('qat_lora')]

print(f"\nWorking directory: {os.getcwd()}")

In [None]:
# [CELL 4: Install dependencies]
!pip install -q transformers accelerate

In [None]:
# [CELL 5: Configuration]
# ============================================================
# EXPERIMENT CONFIGURATION
# ============================================================

import os

# Initial checkpoint (the "truth" before any training)
INITIAL_CKPT = f'{GD_RUNS}/sr011-q4a4-FP4-init/v2_initial.pt'

# KD cache for training
CACHE_NAME = 'alpaca_chat_think_both_L128_K128_R1024'
CACHE_DIR = f'{LOCAL_CACHES}/{CACHE_NAME}'

# Output directory for this experiment
OUTPUT_DIR = f'{LOCAL_RUNS}/debug_lut_training'

# Training config
MAX_STEPS = 20      # Short run to verify behavior
BATCH_SIZE = 2
ACCUMULATION = 1
WARMUP_STEPS = 5
LR = 2e-5
LUT_LR = 1e-5
LUT_MAX_ABS = 2.0

# Verify paths
print(f"=== Debug LUT Training Experiment ===")
print(f"")
print(f"Initial checkpoint: {INITIAL_CKPT}")
print(f"  Exists: {os.path.exists(INITIAL_CKPT)}")
print(f"")
print(f"Cache dir: {CACHE_DIR}")
print(f"  Exists: {os.path.exists(CACHE_DIR)}")
print(f"")
print(f"Output dir: {OUTPUT_DIR}")

In [None]:
# [CELL 5b: Download cache if needed]
# ============================================================
# DOWNLOAD KD CACHE (if not already present)
# ============================================================

if not os.path.exists(CACHE_DIR):
    print(f"Cache not found locally, copying from Google Drive...")
    !mkdir -p {LOCAL_CACHES}
    
    # Try tgz first
    tgz_path = f"{GD_CACHES}/{CACHE_NAME}.tgz"
    if os.path.exists(tgz_path):
        print(f"Extracting {tgz_path}...")
        !tar -xzf "{tgz_path}" -C {LOCAL_CACHES}
    else:
        # Try direct folder copy
        folder_path = f"{GD_CACHES}/{CACHE_NAME}"
        if os.path.exists(folder_path):
            print(f"Copying {folder_path}...")
            !cp -r "{folder_path}" {CACHE_DIR}
        else:
            print(f"ERROR: Cache not found in Google Drive!")
            print(f"  Tried: {tgz_path}")
            print(f"  Tried: {folder_path}")
else:
    print(f"Cache already exists: {CACHE_DIR}")

# Verify
if os.path.exists(CACHE_DIR):
    !ls -la {CACHE_DIR} | head -5
    print(f"...")
    !ls {CACHE_DIR} | wc -l
    print(f"files total")

## Snapshot Initial Checkpoint

In [None]:
# [CELL 6: Snapshot initial checkpoint]
# ============================================================
# LOAD AND ANALYZE INITIAL CHECKPOINT
# ============================================================

import torch
import hashlib

def tensor_hash(t):
    """Quick hash for tensor comparison."""
    return hashlib.md5(t.cpu().numpy().tobytes()).hexdigest()[:16]

print(f"Loading initial checkpoint: {INITIAL_CKPT}")
initial_sd = torch.load(INITIAL_CKPT, map_location='cpu', weights_only=False)

# Categorize keys
lut_keys = [k for k in initial_sd if '.lut' in k and '_lut' not in k]
q_keys = [k for k in initial_sd if '._Q' in k]
indices_keys = [k for k in initial_sd if '._indices' in k]
scale_keys = [k for k in initial_sd if 'scale_A' in k or 'scale_B' in k or 'rank_magnitude' in k]
delta_keys = [k for k in initial_sd if '_lut_raw_deltas' in k]

print(f"\n=== Initial Checkpoint Structure ===")
print(f"Total keys: {len(initial_sd)}")
print(f"")
print(f"  LUT keys (.lut):        {len(lut_keys)}")
print(f"  _Q keys:                {len(q_keys)}")
print(f"  _indices keys:          {len(indices_keys)}")
print(f"  Scale keys:             {len(scale_keys)}")
print(f"  _lut_raw_deltas keys:   {len(delta_keys)}")

# Save hashes for comparison
initial_q_hashes = {k: tensor_hash(initial_sd[k]) for k in q_keys}
initial_lut_hashes = {k: tensor_hash(initial_sd[k]) for k in lut_keys}

print(f"\nCaptured {len(initial_q_hashes)} _Q hashes")
print(f"Captured {len(initial_lut_hashes)} LUT hashes")

# Show sample
print(f"\nSample _Q tensor:")
sample_q = q_keys[0] if q_keys else None
if sample_q:
    t = initial_sd[sample_q]
    print(f"  {sample_q}")
    print(f"  shape: {t.shape}, dtype: {t.dtype}")
    print(f"  hash: {initial_q_hashes[sample_q]}")

## Run Short Training

In [None]:
# [CELL 7: Run short training]
# ============================================================
# RUN SHORT TRAINING (LUT + SCALES)
# ============================================================

# Clean output directory
!rm -rf {OUTPUT_DIR}
!mkdir -p {OUTPUT_DIR}

print(f"Starting training...")
print(f"  Checkpoint: {INITIAL_CKPT}")
print(f"  Cache: {CACHE_DIR}")
print(f"  Output: {OUTPUT_DIR}")
print(f"  Steps: {MAX_STEPS}")
print(f"")

In [None]:
%%time
# [CELL 7b: Execute training]
# Training command with LUT + scales

!python scripts/train_v2_simple.py \
    --tpu --mixed-precision \
    --config q4_r32 \
    --v2-checkpoint "{INITIAL_CKPT}" \
    --cache-dir {CACHE_DIR} \
    --output-dir {OUTPUT_DIR} \
    --max-steps {MAX_STEPS} \
    --batch-size {BATCH_SIZE} \
    --accumulation-steps {ACCUMULATION} \
    --warmup-steps {WARMUP_STEPS} \
    --train-lut --lut-scope all --lut-max-abs {LUT_MAX_ABS} --lut-lr {LUT_LR} \
    --lr {LR} \
    --hard-top1 0.0 --hard-full 0.0 \
    --save-steps {MAX_STEPS} \
    --eval-steps 0

## Compare Results

In [None]:
# [CELL 8: Load trained checkpoint and compare]
# ============================================================
# COMPARE INITIAL VS TRAINED
# ============================================================

import glob

# Find checkpoint
ckpt_files = glob.glob(f"{OUTPUT_DIR}/checkpoint_step*.pt") + glob.glob(f"{OUTPUT_DIR}/best_state_dict.pt")
print(f"Found checkpoints: {ckpt_files}")

if not ckpt_files:
    print("ERROR: No checkpoint found! Training may have failed.")
    trained_sd = None
else:
    trained_path = ckpt_files[0]
    print(f"\nLoading trained checkpoint: {trained_path}")
    trained_sd = torch.load(trained_path, map_location='cpu', weights_only=False)
    print(f"Loaded {len(trained_sd)} keys")

In [None]:
# [CELL 8b: Compare all tensors]
# ============================================================
# TENSOR-BY-TENSOR COMPARISON
# ============================================================

if trained_sd:
    changed = []
    unchanged = []
    
    common_keys = set(initial_sd.keys()) & set(trained_sd.keys())
    only_initial = set(initial_sd.keys()) - set(trained_sd.keys())
    only_trained = set(trained_sd.keys()) - set(initial_sd.keys())
    
    for k in sorted(common_keys):
        t1 = initial_sd[k]
        t2 = trained_sd[k]
        
        if t1.shape != t2.shape:
            changed.append((k, 'shape_mismatch', None, None))
        elif not torch.equal(t1, t2):
            diff = (t1.float() - t2.float()).abs()
            changed.append((k, diff.max().item(), diff.mean().item(), diff.sum().item()))
        else:
            unchanged.append(k)
    
    print(f"=== COMPARISON RESULTS ===")
    print(f"")
    print(f"Common keys:      {len(common_keys)}")
    print(f"Only in initial:  {len(only_initial)}")
    print(f"Only in trained:  {len(only_trained)}")
    print(f"")
    print(f"Unchanged:        {len(unchanged)} tensors")
    print(f"Changed:          {len(changed)} tensors")
    
    if only_trained:
        print(f"\nNew keys in trained (first 5):")
        for k in sorted(only_trained)[:5]:
            print(f"  {k}")

In [None]:
# [CELL 9: Categorize changes]
# ============================================================
# CATEGORIZE CHANGES BY TENSOR TYPE
# ============================================================

if trained_sd and changed:
    # Filter out shape mismatches for stats
    valid_changes = [(k, mx, mn, sm) for k, mx, mn, sm in changed if mx != 'shape_mismatch']
    
    lut_changed = [x for x in valid_changes if '.lut' in x[0] and '_lut' not in x[0]]
    q_changed = [x for x in valid_changes if '._Q' in x[0]]
    scale_changed = [x for x in valid_changes if 'scale_A' in x[0] or 'scale_B' in x[0] or 'rank_magnitude' in x[0]]
    delta_changed = [x for x in valid_changes if '_lut_raw_deltas' in x[0]]
    indices_changed = [x for x in valid_changes if '._indices' in x[0]]
    other_changed = [x for x in valid_changes 
                    if '.lut' not in x[0] and '._Q' not in x[0] 
                    and 'scale_' not in x[0] and 'rank_magnitude' not in x[0]
                    and '_lut_raw_deltas' not in x[0] and '._indices' not in x[0]]
    
    print(f"=== CHANGES BY CATEGORY ===")
    print(f"")
    print(f"EXPECTED CHANGES:")
    print(f"  LUT (.lut):           {len(lut_changed):3d}  {'✓' if lut_changed else '✗'}")
    print(f"  _lut_raw_deltas:      {len(delta_changed):3d}  {'✓' if delta_changed else '✗'}")
    print(f"  Scales (A/B/mag):     {len(scale_changed):3d}  {'✓' if scale_changed else '✗'}")
    print(f"")
    print(f"SHOULD NOT CHANGE:")
    print(f"  _Q buffers:           {len(q_changed):3d}  {'✗ BUG!' if q_changed else '✓ OK'}")
    print(f"  _indices:             {len(indices_changed):3d}  {'✗ BUG!' if indices_changed else '✓ OK'}")
    print(f"  Other:                {len(other_changed):3d}  {'?' if other_changed else '✓ OK'}")

In [None]:
# [CELL 10: Analyze _Q changes - THE KEY RESULT]
# ============================================================
# DETAILED _Q ANALYSIS (THIS IS WHAT WE'RE TESTING)
# ============================================================

if trained_sd:
    if q_changed:
        print(f"!!! _Q BUFFERS CHANGED (UNEXPECTED) !!!")
        print(f"")
        print(f"Found {len(q_changed)} _Q tensors that changed during training.")
        print(f"This indicates a bug in the training code.")
        print(f"")
        print(f"Top 20 changes:")
        for k, mx, mn, sm in sorted(q_changed, key=lambda x: x[1], reverse=True)[:20]:
            print(f"  max={mx:.6e}  mean={mn:.6e}  {k}")
        
        # Analyze one in detail
        sample_key = q_changed[0][0]
        t1 = initial_sd[sample_key]
        t2 = trained_sd[sample_key]
        
        print(f"")
        print(f"=== Detailed analysis: {sample_key} ===")
        print(f"  Shape: {t1.shape}")
        print(f"  dtype: {t1.dtype} -> {t2.dtype}")
        
        diff = (t1.float() - t2.float())
        changed_mask = diff != 0
        n_changed = changed_mask.sum().item()
        
        print(f"  Elements changed: {n_changed} / {t1.numel()} ({100*n_changed/t1.numel():.2f}%)")
        if n_changed > 0 and n_changed < 100:
            print(f"  Change values: {diff[changed_mask].unique().tolist()}")
    else:
        print(f"=== _Q BUFFERS UNCHANGED (EXPECTED) ===")
        print(f"")
        print(f"All {len([k for k in unchanged if '._Q' in k])} _Q tensors remained unchanged.")
        print(f"")
        print(f"This confirms:")
        print(f"  - _Q is a buffer (not trained)")
        print(f"  - Forward pass uses lut[_indices] when LUT training enabled")
        print(f"  - Any _Q changes in production runs come from BAKING, not training")

In [None]:
# [CELL 11: Verify LUT training worked]
# ============================================================
# VERIFY LUT TRAINING IS WORKING
# ============================================================

if trained_sd:
    print(f"=== LUT TRAINING VERIFICATION ===")
    print(f"")
    
    # Check delta values
    delta_keys_trained = [k for k in trained_sd if '_lut_raw_deltas' in k]
    if delta_keys_trained:
        print(f"Found {len(delta_keys_trained)} _lut_raw_deltas tensors")
        
        # Check if any have non-zero values
        nonzero_deltas = 0
        for k in delta_keys_trained[:5]:
            delta = trained_sd[k]
            if delta.abs().max() > 1e-8:
                nonzero_deltas += 1
                print(f"")
                print(f"  {k}:")
                print(f"    values: {delta.tolist()}")
                print(f"    max_abs: {delta.abs().max().item():.6f}")
        
        print(f"")
        if nonzero_deltas > 0:
            print(f"✓ LUT deltas are non-zero - LUT training is working!")
        else:
            print(f"✗ WARNING: LUT deltas are all zero - LUT may not be training!")
    else:
        print(f"No _lut_raw_deltas found in trained checkpoint")
        print(f"This might mean LUT training wasn't enabled properly.")
    
    # Check LUT value changes
    print(f"")
    print(f"=== LUT VALUE CHANGES ===")
    if lut_changed:
        print(f"")
        for k, mx, mn, sm in sorted(lut_changed, key=lambda x: x[1], reverse=True)[:3]:
            print(f"{k}:")
            print(f"  max_diff: {mx:.6f}")
            print(f"  initial:  {initial_sd[k].tolist()}")
            print(f"  trained:  {trained_sd[k].tolist()}")
            print(f"")
    else:
        print(f"No LUT values changed (they might be computed on-the-fly from deltas)")

In [None]:
# [CELL 12: Conclusion]
# ============================================================
# EXPERIMENT CONCLUSION
# ============================================================

if trained_sd:
    print(f"="*60)
    print(f"EXPERIMENT SUMMARY")
    print(f"="*60)
    print(f"")
    print(f"Training: {MAX_STEPS} steps, LUT + scales, from v2_initial.pt")
    print(f"")
    print(f"Expected changes:")
    print(f"  - LUT values:     {len(lut_changed)} changed")
    print(f"  - LUT deltas:     {len(delta_changed)} changed")
    print(f"  - Scale params:   {len(scale_changed)} changed")
    print(f"")
    print(f"Unexpected changes:")
    print(f"  - _Q buffers:     {len(q_changed)} changed")
    print(f"  - _indices:       {len(indices_changed)} changed")
    print(f"  - Other:          {len(other_changed)} changed")
    print(f"")
    
    if len(q_changed) == 0:
        print(f">>> RESULT: _Q buffers correctly UNCHANGED during training.")
        print(f"")
        print(f"    The _Q changes observed in srLUT-004c vs v2_initial")
        print(f"    must have come from:")
        print(f"      1. The PREVIOUS run's LUT training + baking")
        print(f"      2. The bake_lut.py script that refreshes _Q from lut[_indices]")
        print(f"")
        print(f"    This is EXPECTED behavior - baking updates _Q to reflect trained LUT.")
    else:
        print(f">>> BUG FOUND: _Q buffers changed during training!")
        print(f"")
        print(f"    This indicates a bug in the training code.")
        print(f"    _Q should be a frozen buffer that only changes via:")
        print(f"      - freeze_Q() at init")
        print(f"      - sync_q_from_indices() explicitly")
        print(f"      - bake_lut.py")

## Bonus: Compare srLUT-004c Input vs Output

In [None]:
# [CELL 13: Compare baked input to srLUT-004c output]
# ============================================================
# VERIFY: _Q unchanged during srLUT-004c training itself
# ============================================================
# This compares the INPUT checkpoint to srLUT-004c vs its OUTPUT
# to confirm that training doesn't modify _Q.

BAKED_INPUT = f"{GD_RUNS}/srLUT-004_all_alpaca_M2p0_lr1e5_from_best/baked_step200.pt"
SRLUT_004C_OUTPUT = f"{GD_RUNS}/srLUT-004c_all_alpaca_from_baked200_lut_plus_scales/best_state_dict.pt"

if os.path.exists(BAKED_INPUT) and os.path.exists(SRLUT_004C_OUTPUT):
    print(f"Comparing srLUT-004c INPUT vs OUTPUT...")
    print(f"")
    print(f"Input:  {BAKED_INPUT}")
    print(f"Output: {SRLUT_004C_OUTPUT}")
    print(f"")
    
    baked_sd = torch.load(BAKED_INPUT, map_location='cpu', weights_only=False)
    output_sd = torch.load(SRLUT_004C_OUTPUT, map_location='cpu', weights_only=False)
    
    q_diff_in_run = 0
    q_same_in_run = 0
    
    for k in sorted(baked_sd.keys()):
        if '._Q' in k and k in output_sd:
            if torch.equal(baked_sd[k], output_sd[k]):
                q_same_in_run += 1
            else:
                q_diff_in_run += 1
    
    print(f"_Q comparison (baked_step200 vs best_state_dict):")
    print(f"  Unchanged: {q_same_in_run}")
    print(f"  Changed:   {q_diff_in_run}")
    print(f"")
    
    if q_diff_in_run == 0:
        print(f">>> _Q unchanged during srLUT-004c training (as expected)!")
        print(f"    All _Q changes vs v2_initial came from the PREVIOUS run's baking.")
    else:
        print(f">>> _Q changed during srLUT-004c training!")
        print(f"    This is unexpected - need to investigate further.")
else:
    missing = []
    if not os.path.exists(BAKED_INPUT):
        missing.append(BAKED_INPUT)
    if not os.path.exists(SRLUT_004C_OUTPUT):
        missing.append(SRLUT_004C_OUTPUT)
    print(f"Cannot compare - missing files:")
    for m in missing:
        print(f"  {m}")

In [None]:
# [CELL 14: Compare v2_initial vs baked_step200]
# ============================================================
# TRACE THE SOURCE OF _Q CHANGES
# ============================================================
# This shows where the _Q changes actually came from.

if os.path.exists(BAKED_INPUT):
    print(f"Comparing v2_initial vs baked_step200 (source of _Q changes)...")
    print(f"")
    
    if 'baked_sd' not in dir():
        baked_sd = torch.load(BAKED_INPUT, map_location='cpu', weights_only=False)
    
    q_diff_from_init = 0
    q_same_from_init = 0
    
    for k in sorted(initial_sd.keys()):
        if '._Q' in k and k in baked_sd:
            if torch.equal(initial_sd[k], baked_sd[k]):
                q_same_from_init += 1
            else:
                q_diff_from_init += 1
    
    print(f"_Q comparison (v2_initial vs baked_step200):")
    print(f"  Unchanged: {q_same_from_init}")
    print(f"  Changed:   {q_diff_from_init}")
    print(f"")
    
    if q_diff_from_init > 0:
        print(f">>> The baked_step200 checkpoint already has {q_diff_from_init} different _Q values!")
        print(f"")
        print(f"This confirms the chain of events:")
        print(f"  1. v2_initial.pt created with original _Q")
        print(f"  2. srLUT-004 trained LUTs (LUT values changed, _Q unchanged)")
        print(f"  3. bake_lut.py refreshed _Q from trained lut[_indices]")
        print(f"  4. baked_step200.pt saved with NEW _Q values")
        print(f"  5. srLUT-004c started from baked_step200 (already has new _Q)")
        print(f"  6. srLUT-004c training did NOT change _Q further")
    else:
        print(f"Interesting: _Q values are the same between v2_initial and baked_step200")