# Anemll-Style Layer-by-Layer QAT V2 (ANE-Friendly)

This notebook implements layer-by-layer QAT training using **`AnemllQATLinearV2`** with:
- Rank-by-rank forward pass (ANE-friendly, no A @ B materialization)
- Per-rank normalization: `scale_A` (unit columns), `scale_B` (unit rows), `rank_magnitude`
- **Frozen Q**: indices computed once at init, not recomputed during training
- **Positive scale constraints**: `force_positive_scales=True` prevents sign flips with frozen Q
- KD cache for distillation

## V2 Key Differences from V1:
- `freeze_Q()` is called once to compute quantization indices
- Forward: `y = Σₖ gₖ · (aₖ ⊙ (Q (bₖ ⊙ x)))`
- Only scales are trained (weight is frozen after Q is computed)
- **Positive scales by construction** (abs/softplus on factors) replaces clamp(A@B)

## Pipeline:
1. Load model and replace linears with AnemllQATLinearV2
2. Call `freeze_Q_all()` to compute Q once
3. Layer-by-layer scale optimization (Q and weights frozen)
4. End-to-end scale refinement
5. (Optional) Stage 2: Gentle weight tuning with STE

## Training Parameters:
- `scale_A`, `scale_B`, `rank_magnitude` are trained
- `weight` and `_Q` are frozen during Stage 1
- `force_positive_scales=True` ensures S = (A_dir * g) @ B_dir >= 0

In [1]:
# ============================================================
# GOOGLE DRIVE PATHS (STANDARD)
# ============================================================

# Checkpoints/runs go here
GD_RUNS = '/content/drive/MyDrive/qwen3_runs'

# KD caches go here
GD_CACHES = '/content/drive/MyDrive/qwen3_caches'

# Local directories (on Colab VM)
LOCAL_RUNS = 'runs'
LOCAL_CACHES = 'caches'

In [2]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


### GITUB

In [None]:
# Clone repo if needed
!git clone https://github.com/anemll/qwen3_apple_style_2bit_qat_lora.git || (cd qwen3_apple_style_2bit_qat_lora && git pull)
%cd qwen3_apple_style_2bit_qat_lora
# to allow updates
!git fetch
!git pull
!git reset --hard HEAD
import sys
[sys.modules.pop(k) for k in list(sys.modules) if k.startswith('qat_lora')]

# Import V2 modules
from qat_lora import (
    # V2 ANE-friendly classes
    AnemllQATLinearV2,
    AnemllQuantConfigV2,
    replace_linear_with_anemll_v2,
    freeze_Q_all,
    freeze_model_for_inference_v2,
    unfreeze_model_for_training_v2,
    # Training utilities (shared)
    evaluate_kd_loss,
    train_all_layers,
    train_e2e,
    save_checkpoint,
    load_checkpoint,
)

In [4]:
# Install dependencies
!pip install -q transformers accelerate safetensors

In [5]:
# ============================================================
# LOAD KD CACHE FROM GOOGLE DRIVE
# ============================================================

#CACHE_NAME = 'alpaca_chat_think_both_L128_K32_R256'
#CACHE_NAME = 'alpaca_chat_think_both_L128_K64_R512'
CACHE_NAME = 'alpaca_chat_think_both_L128_K128_R1024'


CACHE_TGZ = f'{CACHE_NAME}.tgz'

!mkdir -p {LOCAL_CACHES}

# Check if cache exists locally
import os
cache_local_path = f'{LOCAL_CACHES}/{CACHE_NAME}'
if not os.path.exists(cache_local_path):
    print(f'Extracting {CACHE_TGZ} from Google Drive...')
    !tar -xzf {GD_CACHES}/{CACHE_TGZ} -C {LOCAL_CACHES}/
else:
    print(f'Cache already exists at {cache_local_path}')

!ls -la {cache_local_path}/ | head -10

Extracting alpaca_chat_think_both_L128_K128_R1024.tgz from Google Drive...
total 17157672
drwx------ 2 root root      4096 Dec 26 02:45 .
drwxr-xr-x 3 root root      4096 Dec 29 00:46 ..
-rw------- 1 root root       423 Dec 26 02:45 meta.json
-rw------- 1 root root 899550149 Dec 26 02:46 shard_00000.pt
-rw------- 1 root root 899550149 Dec 26 02:43 shard_00001.pt
-rw------- 1 root root 899550149 Dec 26 02:44 shard_00002.pt
-rw------- 1 root root 899550149 Dec 26 02:44 shard_00003.pt
-rw------- 1 root root 899550149 Dec 26 02:44 shard_00004.pt
-rw------- 1 root root 899550149 Dec 26 02:43 shard_00005.pt


In [6]:
# ============================================================
# CONFIGURATION
# ============================================================

import torch

# Model
MODEL_ID = 'Qwen/Qwen3-0.6B'

# Quantization config (4-bit with groupwise LUT)
LUT_BITS = 4
LUT_SIZE = 2**LUT_BITS
GROUP_SIZE = 16      # Group size for scales
SCALE_RANK = 4       # Low-rank for A @ B scales

# Attention quantization (same params)
ATTN_LUT_BITS = 4
ATTN_LUT_SIZE = 2**ATTN_LUT_BITS
ATTN_GROUP_SIZE = 16
ATTN_SCALE_RANK = 4

# Training
BATCH_SIZE = 4
GRAD_ACCUM = 4

if torch.cuda.is_available():
    BATCH_SIZE=32
    GRAD_ACCUM=1

LR = 2e-5
EPOCHS_PER_LAYER = 1

# KD / Distillation params
DISTILL_TEMP = 2.0
HARD_TOP1_WEIGHT = 0.2    # Hard label top-1 loss (helps convergence)
HARD_FULL_WEIGHT = 0.00005    # Hard label full vocab loss (optional)

# Device
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DTYPE = torch.bfloat16


QUAL = f'q{LUT_BITS}_a{ATTN_LUT_BITS}'

print(f'Quality: {QUAL}')

print(f'Device: {DEVICE}, dtype: {DTYPE}')
print(f'Quant config: lut={LUT_SIZE}, group={GROUP_SIZE}, rank={SCALE_RANK}')
print(f'Distillation: temp={DISTILL_TEMP}, hard_top1={HARD_TOP1_WEIGHT}, hard_full={HARD_FULL_WEIGHT}')

Quality: q4_a4
Device: cuda, dtype: torch.bfloat16
Quant config: lut=16, group=16, rank=4
Distillation: temp=2.0, hard_top1=0.2, hard_full=5e-05


In [7]:
# ============================================================
# Extracting LOCAL CACHE
# ============================================================

import os
from pathlib import Path

# Verify drive is mounted and cache exists
if not os.path.exists('/content/drive/MyDrive'):
    print('Google Drive not mounted! Mounting now...')
    from google.colab import drive
    drive.mount('/content/drive')

if not os.path.exists(cache_local_path):
    print(f'Cache not found at {cache_local_path}')
    print(f'Extracting from Google Drive...')
    os.makedirs(LOCAL_CACHES, exist_ok=True)
    !tar -xzf {GD_CACHES}/{CACHE_TGZ} -C {LOCAL_CACHES}/

# Verify cache exists now
assert os.path.exists(cache_local_path), f'Cache still not found at {cache_local_path}'
cache_files = list(Path(cache_local_path).glob('*.pt'))
print(f'Cache ready: {len(cache_files)} files in {cache_local_path}')

Cache ready: 20 files in caches/alpaca_chat_think_both_L128_K128_R1024


In [8]:
# ============================================================
# LOAD MODEL
# ============================================================

from transformers import AutoModelForCausalLM, AutoTokenizer

print(f'Loading {MODEL_ID}...')
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=DTYPE,
    trust_remote_code=True,
)
model.to(DEVICE)
model.eval()
print(f'Loaded. Parameters: {sum(p.numel() for p in model.parameters()):,}')

Loading Qwen/Qwen3-0.6B...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/726 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors:   0%|          | 0.00/1.50G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

Loaded. Parameters: 596,049,920


In [None]:
# ============================================================
# REPLACE LINEARS WITH AnemllQATLinearV2
# ============================================================

import sys
sys.path.insert(0, '.')

# Force reimport to get latest code
import importlib
import qat_lora
importlib.reload(qat_lora)
import qat_lora.ane_qat_linear_v2 as ane_module_v2
importlib.reload(ane_module_v2)
import qat_lora.layer_qat as layer_module
importlib.reload(layer_module)

from qat_lora import AnemllQuantConfigV2, replace_linear_with_anemll_v2, freeze_Q_all

# Debug: Check what modules exist in the model
print("Checking model structure...")
import torch.nn as nn
linear_count = 0
for name, m in model.named_modules():
    if isinstance(m, nn.Linear):
        linear_count += 1
        if linear_count <= 5:
            print(f"  Found Linear: {name}")
print(f"Total Linear modules: {linear_count}")

# Create V2 configs with POSITIVE SCALE CONSTRAINTS
# This prevents scale sign flips when Q is frozen (key fix!)
mlp_config = AnemllQuantConfigV2(
    lut_size=LUT_SIZE,
    scale_rank=SCALE_RANK,
    learnable_lut=False,
    # Positive scale constraints (prevents sign flips with frozen Q)
    force_positive_scales=True,
    positive_scale_method="abs",      # "abs" or "softplus"
    magnitude_activation="softplus",  # keeps g >= 0
    magnitude_eps=1e-6,
)

attn_config = AnemllQuantConfigV2(
    lut_size=ATTN_LUT_SIZE,
    scale_rank=ATTN_SCALE_RANK,
    learnable_lut=False,
    # Positive scale constraints
    force_positive_scales=True,
    positive_scale_method="abs",
    magnitude_activation="softplus",
    magnitude_eps=1e-6,
)

print(f'\nV2 Config: force_positive_scales={mlp_config.force_positive_scales}')
print(f'           positive_scale_method={mlp_config.positive_scale_method}')
print(f'           magnitude_activation={mlp_config.magnitude_activation}')

print('\nReplacing linear layers with V2...')
count = replace_linear_with_anemll_v2(
    model,
    mlp_config=mlp_config,
    attn_config=attn_config,
    quantize_attn=True,
    quantize_lm_head=False,
)

# Verify replacement worked (use type name to avoid reload issues)
qat_count = sum(1 for _, m in model.named_modules() if type(m).__name__ == 'AnemllQATLinearV2')
print(f"\nVerification: {qat_count} AnemllQATLinearV2 modules in model")

# ============================================================
# FREEZE Q - CRITICAL STEP FOR V2
# ============================================================
# Compute Q = lut[indices] once. After this, only scales are trained.
print('\nFreezing Q (computing indices once)...')
freeze_Q_all(model, verbose=False)
print('Q frozen for all layers. Training will only update scale_A, scale_B, rank_magnitude.')

In [None]:
# ============================================================
# VERIFY GRADIENT FLOW FOR V2
# ============================================================

from qat_lora import evaluate_kd_loss

print('Verifying V2 gradient flow...')

# Find a V2 layer (use type name to avoid reload issues)
layer0 = model.model.layers[0]
test_module = None
for name, m in layer0.named_modules():
    if type(m).__name__ == 'AnemllQATLinearV2':
        test_module = m
        break

if test_module is None:
    print("ERROR: No AnemllQATLinearV2 modules found! Replacement failed.")
else:
    # V2: weight is frozen, only scales are trained
    print(f"  weight.requires_grad: {test_module.weight.requires_grad} (should be False after freeze_Q)")
    print(f"  scale_A.requires_grad: {test_module.scale_A.requires_grad}")
    print(f"  scale_B.requires_grad: {test_module.scale_B.requires_grad}")
    print(f"  rank_magnitude.requires_grad: {test_module.rank_magnitude.requires_grad}")
    print(f"  Q frozen: {test_module._Q is not None}")
    
    # Test gradient flow through scales
    test_module.scale_A.requires_grad = True
    test_module.scale_B.requires_grad = True
    test_module.rank_magnitude.requires_grad = True
    
    x = torch.randn(1, 10, test_module.in_features, device=DEVICE, dtype=DTYPE)
    y = test_module(x)
    loss = y.sum()
    try:
        loss.backward()
        if test_module.scale_A.grad is not None:
            print(f"  Gradient OK: scale_A.grad.shape = {test_module.scale_A.grad.shape}")
            print(f"               scale_B.grad.shape = {test_module.scale_B.grad.shape}")
            print(f"               rank_magnitude.grad.shape = {test_module.rank_magnitude.grad.shape}")
            # Clear for actual training
            test_module.scale_A.grad = None
            test_module.scale_B.grad = None
            test_module.rank_magnitude.grad = None
        else:
            print("  ERROR: scale_A.grad is None after backward!")
    except Exception as e:
        print(f"  ERROR during backward: {e}")

# Compute initial KD loss
print('\nComputing initial KD loss...')
initial_loss = evaluate_kd_loss(model, cache_local_path, DEVICE, num_samples=40, temperature=DISTILL_TEMP)
print(f'Initial KD Loss: {initial_loss:.4f}')

# **V2 SCALE OPTIMIZATION** (Q and Weights Frozen)

In V2, we train only the scale parameters:
- `scale_A`: [out, rank] - unit-norm columns
- `scale_B`: [rank, in] - unit-norm rows  
- `rank_magnitude`: [rank] - the ONLY magnitude

The frozen components:
- `weight`: Original FP weights (frozen by `freeze_Q()`)
- `_Q`: LUT values = `lut[indices]` (computed once by `freeze_Q()`)
- `_indices`: Quantization indices (computed once)

**Why this works:**
- Q represents the normalized quantized weights in [-1, 1]
- Scales modulate the contribution of each rank
- Training adjusts how much each rank contributes to the output

In [None]:
# ============================================================
# V2 LAYER-BY-LAYER SCALE OPTIMIZATION
# ============================================================
# Q is frozen, only train scale_A, scale_B, rank_magnitude
# Higher LR since fewer parameters and per-rank normalization ensures stability

SCALE_LR = 1e-3  # Higher LR for scales (fewer params)
SCALE_EPOCHS = 2  # More epochs since scales have less capacity

print('Starting V2 scale-only layer-by-layer optimization...')
print(f'LR: {SCALE_LR}, Epochs per layer: {SCALE_EPOCHS}')
print('Training: scale_A, scale_B, rank_magnitude (Q and weights frozen)')

# Get loss before scale optimization
pre_scale_loss = evaluate_kd_loss(model, cache_local_path, DEVICE, num_samples=40)
print(f'KD Loss before scale optimization: {pre_scale_loss:.4f}')

# Train scales layer-by-layer
# Note: In V2, train_weights=False means weight AND Q are frozen
#       train_scales=True means scale_A, scale_B, rank_magnitude are trained
scale_losses = train_all_layers(
    model=model,
    cache_dir=cache_local_path,
    device=DEVICE,
    batch_size=BATCH_SIZE,
    lr=SCALE_LR,
    epochs_per_layer=SCALE_EPOCHS,
    grad_accum=GRAD_ACCUM,
    temperature=DISTILL_TEMP,
    train_weights=False,  # Keep Q and weights frozen
    train_scales=True,    # Train scale_A, scale_B, rank_magnitude
    local_weight=0.5,
    global_weight=0.5,
    hard_top1_weight=0.0,  # Not needed for scale optimization
    hard_full_weight=0.0,
    verbose=True,
    steps_per_layer=100,
)

# Evaluate after scale optimization
post_scale_loss = evaluate_kd_loss(model, cache_local_path, DEVICE, num_samples=40)
print(f'\n=== V2 Scale Optimization Results ===')
print(f'Before: {pre_scale_loss:.4f}')
print(f'After:  {post_scale_loss:.4f}')
print(f'Improvement: {pre_scale_loss - post_scale_loss:.4f}')

# **OPTIONAL: GENTLE WEIGHT+SCALE TUNING**

After scale-only optimization, you can optionally do gentle weight+scale tuning.

**Important for V2:**
- When `train_weights=True`, Q needs to be recomputed periodically
- Use very small weight LR (10-100x smaller than scale LR)
- This is Stage 2 in the V2 training strategy

**Note:** For most cases, scale-only training (Stage 1) is sufficient.

In [None]:
# ============================================================
# V2 OPTIONAL: GENTLE WEIGHT + SCALE TUNING (Stage 2)
# ============================================================
# Unfreeze weights with very small LR, continue training scales
# Q will be recomputed based on updated weights

# WARNING: This requires recomputing Q. For V2, we need to:
# 1. Unfreeze weights
# 2. Recompute Q periodically during training
# 3. Or use the V1 path which does fake_quant on-the-fly

# For now, let's skip this and rely on scale-only training
# If you want to train weights, consider using V1 or implementing STE

ENABLE_WEIGHT_TUNING = False  # Set to True to enable

if ENABLE_WEIGHT_TUNING:
    print('WARNING: V2 weight tuning requires special handling.')
    print('Q will be recomputed to match updated weights.')
    
    # Unfreeze weights for all V2 layers
    for name, module in model.named_modules():
        if isinstance(module, AnemllQATLinearV2):
            module.weight.requires_grad = True
            module._Q = None  # Clear frozen Q, will use on-the-fly computation
    
    # Use small LR for weights, normal LR for scales
    WEIGHT_LR = 1e-5  # 10x smaller than scale LR
    
    print(f'Starting gentle weight+scale tuning...')
    print(f'Weight LR: {WEIGHT_LR}')
    
    layer_losses = train_all_layers(
        model=model,
        cache_dir=cache_local_path,
        device=DEVICE,
        batch_size=BATCH_SIZE,
        lr=WEIGHT_LR,  # Small LR for weights
        epochs_per_layer=1,
        grad_accum=GRAD_ACCUM,
        temperature=DISTILL_TEMP,
        train_weights=True,   # Train weights with small LR
        train_scales=True,    # Also train scales
        local_weight=0.5,
        global_weight=0.5,
        hard_top1_weight=HARD_TOP1_WEIGHT,
        hard_full_weight=HARD_FULL_WEIGHT,
        verbose=True,
        steps_per_layer=50,  # Fewer steps for gentle tuning
    )
    
    # Re-freeze Q after weight tuning
    print('Re-freezing Q after weight tuning...')
    freeze_Q_all(model, verbose=False)
else:
    print('Weight tuning disabled. Using scale-only training (recommended for V2).')
    print('Set ENABLE_WEIGHT_TUNING = True above to enable weight tuning.')

In [None]:
# ============================================================
# EVALUATE AFTER V2 LAYER-BY-LAYER
# ============================================================

model.eval()
post_layer_loss = evaluate_kd_loss(model, cache_local_path, DEVICE, num_samples=40)
print(f'Initial KD Loss: {initial_loss:.4f}')
print(f'After V2 Layer-by-Layer: {post_layer_loss:.4f}')
print(f'Improvement: {initial_loss - post_layer_loss:.4f}')

In [None]:
# ============================================================
# SAVE V2 CHECKPOINT
# ============================================================

import os

RUN_NAME = f'anemll_v2_{QUAL}_layer_by_layer'
SAVE_DIR = f'{LOCAL_RUNS}/{RUN_NAME}'

os.makedirs(SAVE_DIR, exist_ok=True)

# Save state dict
torch.save(model.state_dict(), f'{SAVE_DIR}/model_state_dict.pt')

# Save config
import json
config = {
    'model_id': MODEL_ID,
    'version': 'v2',  # Mark as V2
    'lut_size': LUT_SIZE,
    'scale_rank': SCALE_RANK,
    'attn_lut_size': ATTN_LUT_SIZE,
    'attn_scale_rank': ATTN_SCALE_RANK,
    'initial_kd_loss': initial_loss,
    'post_layer_loss': post_layer_loss,
    'scale_losses': scale_losses,
}
with open(f'{SAVE_DIR}/config.json', 'w') as f:
    json.dump(config, f, indent=2)

print(f'Saved V2 checkpoint to {SAVE_DIR}')

In [None]:
# ============================================================
# UPLOAD TO GOOGLE DRIVE
# ============================================================

!tar -czvf {RUN_NAME}.tgz -C {LOCAL_RUNS} {RUN_NAME}
!cp {RUN_NAME}.tgz {GD_RUNS}/
print(f'Uploaded to {GD_RUNS}/{RUN_NAME}.tgz')

# **V2 END-TO-END KD-QAT REFINEMENT**

After layer-by-layer training, refine the model with all layers unfrozen.

**V2 Training Strategy:**

| Stage | What's Trained | What's Frozen | Notes |
|-------|---------------|---------------|-------|
| Stage 1 | scale_A, scale_B, rank_magnitude | weight, Q, lut | Primary training |
| Stage 2 | weight + scales | lut | Optional, recomputes Q |

**Recommended approach:** Stay in Stage 1 (scales only) for most cases.

In [None]:
# ============================================================
# V2 END-TO-END: TRAIN SCALES (Q AND WEIGHTS FROZEN)
# ============================================================
# This is the primary V2 training mode

# Train scales (Q and weights frozen) - higher LR since fewer params
e2e_scales_result = train_e2e(
    model=model,
    cache_dir=cache_local_path,
    device=DEVICE,
    max_steps=4000,
    batch_size=64 if torch.cuda.is_available() else 32,
    lr=5e-4,  # Higher LR for scales
    use_cosine_schedule=True,
    warmup_steps=100,
    min_lr_ratio=0.1,
    temperature=DISTILL_TEMP,
    train_weights=False,  # Keep Q and weights frozen
    train_scales=True,    # Train scale_A, scale_B, rank_magnitude
    hard_top1_weight=0.0,
    hard_full_weight=0.0,
    logging_steps=20,
    eval_steps=100,
    verbose=True,
    train_mlp_only=True,  # Focus on MLP (more bits needed)
)

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [None]:
# ============================================================
# V2 OPTIONAL: E2E GENTLE WEIGHT + SCALE TUNING (Stage 2)
# ============================================================
# Only run this if scale-only training isn't sufficient
# Requires clearing _Q to recompute on-the-fly

ENABLE_E2E_WEIGHT_TUNING = False  # Set to True to enable

if ENABLE_E2E_WEIGHT_TUNING:
    from qat_lora import unfreeze_model_for_training_v2
    
    # Clear cached weights and frozen Q
    unfreeze_model_for_training_v2(model)
    
    # Unfreeze weights and clear Q
    for name, module in model.named_modules():
        if isinstance(module, AnemllQATLinearV2):
            module.weight.requires_grad = True
            module._Q = None  # Will compute Q on-the-fly
    
    print('V2 E2E weight training (Stage 2)...')
    print('WARNING: Q will be recomputed on-the-fly for each forward pass')
    
    e2e_weights_result = train_e2e(
        model=model,
        cache_dir=cache_local_path,
        device=DEVICE,
        max_steps=1000,
        batch_size=64 if torch.cuda.is_available() else 32,
        lr=1e-5,  # Very small LR for weights
        use_cosine_schedule=True,
        warmup_steps=100,
        min_lr_ratio=0.1,
        temperature=DISTILL_TEMP,
        train_weights=True,
        train_scales=True,  # Also train scales
        hard_top1_weight=0.2,
        hard_full_weight=0.0,
        logging_steps=50,
        eval_steps=500,
        verbose=True,
    )
    
    # Re-freeze Q after training
    print('Re-freezing Q...')
    freeze_Q_all(model, verbose=False)
else:
    print('E2E weight tuning disabled. Using scale-only training.')
    print('Set ENABLE_E2E_WEIGHT_TUNING = True above to enable.')

In [None]:
# ============================================================
# V2 E2E: ADDITIONAL SCALE REFINEMENT (MLP ONLY)
# ============================================================
# Continue scale training for MLP layers

e2e_mlp_scales_result = train_e2e(
    model=model,
    cache_dir=cache_local_path,
    device=DEVICE,
    max_steps=1000,
    batch_size=64 if torch.cuda.is_available() else 32,
    lr=1e-4,  # Lower LR for refinement
    use_cosine_schedule=True,
    warmup_steps=100,
    min_lr_ratio=0.1,
    temperature=DISTILL_TEMP,
    train_weights=False,
    train_scales=True,
    hard_top1_weight=0.0,
    hard_full_weight=0.0,
    logging_steps=50,
    eval_steps=500,
    verbose=True,
    train_mlp_only=True,
)

# **V2 ATTENTION SCALE TRAINING**

Train attention layer scales (Q, K, V, O projections) while keeping MLP frozen.

In [None]:
# ============================================================
# V2 E2E: ATTENTION SCALE TRAINING (MLP FROZEN)
# ============================================================
# Train attention scales while keeping MLP scales frozen

from qat_lora import unfreeze_model_for_training_v2

unfreeze_model_for_training_v2(model)

# Freeze MLP scales, only train attention scales
for name, module in model.named_modules():
    if isinstance(module, AnemllQATLinearV2):
        is_attn = any(x in name for x in ['q_proj', 'k_proj', 'v_proj', 'o_proj'])
        if hasattr(module, 'scale_A') and module.scale_A is not None:
            module.scale_A.requires_grad = is_attn
            module.scale_B.requires_grad = is_attn
            module.rank_magnitude.requires_grad = is_attn
        module.weight.requires_grad = False  # Keep weights frozen

# Train attention scales
e2e_attn_scales_result = train_e2e(
    model=model,
    cache_dir=cache_local_path,
    device=DEVICE,
    max_steps=2000,
    batch_size=64 if torch.cuda.is_available() else 32,
    lr=1e-4,
    use_cosine_schedule=True,
    warmup_steps=100,
    min_lr_ratio=0.1,
    temperature=DISTILL_TEMP,
    train_weights=False,
    train_scales=True,
    hard_top1_weight=0.0,
    hard_full_weight=0.0,
    logging_steps=20,
    eval_steps=100,
    verbose=True,
    train_mlp_only=False,  # Train all (but MLP scales are frozen above)
)

In [None]:
# ============================================================
# SAVE FINAL V2 CHECKPOINT
# ============================================================

from qat_lora import unfreeze_model_for_training_v2

unfreeze_model_for_training_v2(model)

E2E_RUN_NAME = f'anemll_v2_{QUAL}_e2e_scales_only'
E2E_SAVE_DIR = f'{LOCAL_RUNS}/{E2E_RUN_NAME}'

# Save with config
config = {
    'model_id': MODEL_ID,
    'version': 'v2',
    'lut_size': LUT_SIZE,
    'scale_rank': SCALE_RANK,
    'attn_lut_size': ATTN_LUT_SIZE,
    'attn_scale_rank': ATTN_SCALE_RANK,
    'e2e_scales_result': e2e_scales_result,
}

save_checkpoint(model, E2E_SAVE_DIR, config=config)

# Upload to Google Drive
!tar -czvf {E2E_RUN_NAME}.tgz -C {LOCAL_RUNS} {E2E_RUN_NAME}
!cp {E2E_RUN_NAME}.tgz {GD_RUNS}/
print(f'\nUploaded to {GD_RUNS}/{E2E_RUN_NAME}.tgz')

# **INFERENCE OPTIMIZATION**

Before running inference, freeze all layers to precompute quantized weights.
This avoids recomputing `LUT[indices] * (scale_A @ scale_B)` on every forward pass.

In [None]:
# ============================================================
# V2 FREEZE MODEL FOR FAST INFERENCE
# ============================================================
# Precompute full W_eff = Q * scales for all layers
# This caches the effective weight to avoid rank-by-rank computation per token

from qat_lora import freeze_model_for_inference_v2, unfreeze_model_for_training_v2

print('Freezing V2 model for inference...')
num_frozen = freeze_model_for_inference_v2(model, verbose=False)
print(f'Frozen {num_frozen} V2 layers')
print('Cached W_eff = Q * scales for fast inference')

# To resume training later:
# unfreeze_model_for_training_v2(model)

In [19]:
import torch

# ============================================================
# TEST INFERENCE
# ============================================================

def run_inference(model, tokenizer, prompt, max_new_tokens=128):
    messages = [
        {'role': 'system', 'content': 'You are a helpful assistant.'},
        {'role': 'user', 'content': prompt}
    ]
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(text, return_tensors='pt').to(DEVICE)

    with torch.no_grad():
        output = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)

    return tokenizer.decode(output[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)

# List of prompts to test
prompts = [
    'What is the capital of France?',
    'What is Apple Neural Engine?',
    'Explain quantum mechanics',
    'What is speed of light'
]

model.eval() # Set model to evaluation mode once

for prompt in prompts:
    response = run_inference(model, tokenizer, prompt,max_new_tokens=1024)
    print(f'Prompt: {prompt}')
    print(f'Response: {response}')
    print('-' * 50) # Separator for readability


The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Prompt: What is the capital of France?
Response: <think>
<think>
</think>

The capital of France is **Paris**.
--------------------------------------------------
Prompt: What is Apple Neural Engine?
Response: <think>
<think>
</think>

The **Apple Neural Engine** is a powerful computing platform developed by Apple Inc. It is designed to run on Apple devices and is used for various applications, including AI and machine learning. It is a key component of Apple's ecosystem and is known for its performance and efficiency in handling complex tasks.
--------------------------------------------------
Prompt: Explain quantum mechanics
Response: <think>
<think>
</think>

Quantum mechanics is a fundamental theory of physics that describes the behavior of particles at the smallest scales, such as atoms, molecules, and even subatomic particles. It is a revolutionary theory that challenges our classical understanding of the physical world, which is based on the principles of classical mechanics and

In [None]:
# ============================================================
# TEST INFERENCE
# ============================================================

def run_inference(model, tokenizer, prompt, max_new_tokens=512):
    messages = [
        {'role': 'user', 'content': prompt}
    ]
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=True
    )
    inputs = tokenizer(text, return_tensors='pt').to(DEVICE)

    with torch.no_grad():
        output = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=0.6,
            top_p=0.9,
            repetition_penalty=1.1,
        )

    return tokenizer.decode(output[0][inputs['input_ids'].shape[1]:], skip_special_tokens=False)

# List of prompts to test
prompts = [
    'What is the capital of France?',
    'What is Apple Neural Engine?',
    'Explain quantum mechanics',
    'What is speed of light'
]

model.eval()

for prompt in prompts:
    response = run_inference(model, tokenizer, prompt, max_new_tokens=512)
    print(f'Prompt: {prompt}')
    print(f'Response: {response}')
    print('-' * 50)


Prompt: What is the capital of France?
Response: <think>
<think>
</think>

The capital of France is **Paris**, located in the western part of the country.<|im_end|>
--------------------------------------------------
Prompt: What is Apple Neural Engine?
Response: <think>
<think>
</think>

The **Apple Neural Engine** is a key component of Apple's operating system, specifically the **iOS** operating system. It is developed by Apple and is part of the iOS ecosystem. The neural engine is responsible for handling various tasks such as image processing, speech recognition, and machine learning, which are essential for applications like apps, games, and other services. It plays a crucial role in enabling the development and performance of Apple products.<|im_end|>
--------------------------------------------------
Prompt: Explain quantum mechanics
Response: <think>
<think>
</think>

Quantum Mechanics is a fundamental field of physics that describes the behavior of particles at the smallest pos

## Next Steps

After layer-by-layer training, you can:

1. **End-to-end refinement** - Unfreeze all layers and train together
2. **Train scales (A, B)** - Unfreeze scale_A, scale_B parameters
3. **LoRA recovery** - Add LoRA adapters to recover quality

# **V2 EXPORT FOR ANEMLL CONVERTER**

Export V2 model for external tools.

**V2 Export includes:**
- `_Q`: Frozen LUT values [out, in]
- `_indices`: Quantization indices [out, in]
- `scale_A`, `scale_B`, `rank_magnitude`: Scale parameters
- `lut`: LUT values

**For inference optimization:**
- Call `snap_for_export()` to bake normalization into parameters
- This eliminates runtime norm computation in the exported model

In [None]:
# ============================================================
# V2 EXPORT FOR ANEMLL CONVERTER
# ============================================================

from qat_lora import unfreeze_model_for_training_v2

# First unfreeze to clear cached weights
unfreeze_model_for_training_v2(model)

# Export V2 model representation
print('Exporting V2 quantized model representation...')

export_dict = {}
for name, module in model.named_modules():
    if isinstance(module, AnemllQATLinearV2):
        layer_export = {
            'indices': module._indices.cpu() if module._indices is not None else None,
            'Q': module._Q.cpu() if module._Q is not None else None,
            'scale_A': module.scale_A.data.cpu(),
            'scale_B': module.scale_B.data.cpu(),
            'rank_magnitude': module.rank_magnitude.data.cpu(),
            'lut': module.lut.cpu(),
            'bias': module.bias.data.cpu() if module.bias is not None else None,
            'in_features': module.in_features,
            'out_features': module.out_features,
            'scale_rank': module.scale_rank,
            'lut_bits': module.lut_bits,
        }
        export_dict[name] = layer_export

print(f'Exported {len(export_dict)} V2 layers')

# Save export for ANEMLL converter
EXPORT_DIR = f'{LOCAL_RUNS}/{E2E_RUN_NAME}_export'
os.makedirs(EXPORT_DIR, exist_ok=True)
torch.save(export_dict, f'{EXPORT_DIR}/v2_quantized_model.pt')
print(f'\nSaved V2 export to {EXPORT_DIR}/v2_quantized_model.pt')

In [None]:
# ============================================================
# V2 SNAP FOR EXPORT AND TEST
# ============================================================
# Bake normalization into parameters for CoreML-friendly export

print('Snapping V2 model for export (baking normalization)...')

for name, module in model.named_modules():
    if isinstance(module, AnemllQATLinearV2):
        module.snap_for_export()

print('V2 model snapped for export.')
print('Normalization baked into scale_A (no runtime norm computation)')

# Test inference with snapped model
model.eval()
print('\nTesting V2 inference with snapped weights...')
response = run_inference(model, tokenizer, 'What is 2+2?', max_new_tokens=256)
print(f'Prompt: What is 2+2?')
print(f'Response: {response}')

In [None]:
torch.save(model.state_dict(), '/tmp/backup_mlp_e2e_0.4613.pt')  # Local, fast

torch.save(model.state_dict(), '/tmp/backup_mlp_e2e_w_0.3824.pt')  # Local, fast

In [15]:
torch.save(model.state_dict(), '/tmp/backup_mlp_e4e_4_4.pt')  # Local, fast

In [None]:
model.load_state_dict(torch.load('/tmp/backup_initial.pt', map_location=DEVICE))

<All keys matched successfully>

In [None]:
model.load_state_dict(torch.load('/tmp/backup_mlp_e2e_w_0.3824.pt', map_location=DEVICE))

<All keys matched successfully>