# Q2_A4 V1 to V2 Conversion + STE-FP16 Finetuning

**Configuration:**
- MLP: 2-bit (LUT=4), rank=32
- Attention: 4-bit (LUT=16), rank=8

This notebook:
1. Loads a trained Q2_A4 V1 checkpoint (loss ~0.38)
2. Converts V1 scales to V2 format (unit-norm + rank_magnitude)
3. Finetunes the V2 model with **STE-FP16** (FP32 master weights + FP16 forward)
4. Exports **FP16 checkpoint** (ANE-ready!)

## STE-FP16 Training:
- **FP32 master weights**: Stable gradients, no underflow
- **FP16 forward pass** (via STE): Matches ANE behavior exactly
- **Export to FP16**: Direct conversion for ANE deployment

In [None]:
# ============================================================
# GOOGLE DRIVE PATHS
# ============================================================

GD_RUNS = '/content/drive/MyDrive/qwen3_runs'
GD_CACHES = '/content/drive/MyDrive/qwen3_caches'

LOCAL_RUNS = 'runs'
LOCAL_CACHES = 'caches'

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Clone/update repo
!git clone https://github.com/anemll/qwen3_apple_style_2bit_qat_lora.git || (cd qwen3_apple_style_2bit_qat_lora && git pull)
%cd qwen3_apple_style_2bit_qat_lora
!git fetch && git pull && git reset --hard HEAD

import sys
[sys.modules.pop(k) for k in list(sys.modules) if k.startswith('qat_lora')]

# Install dependencies
!pip install -q transformers accelerate safetensors

In [None]:
# ============================================================
# EXTRACT CHECKPOINT + CONFIGURATION
# ============================================================

import torch
import os
import gc
import glob

MODEL_ID = 'Qwen/Qwen3-0.6B'

# Paths
DRIVE_RUNS = '/content/drive/MyDrive/qwen3_runs'

# Extract Q2 checkpoint from Drive (Colab storage is ephemeral!)
os.makedirs(LOCAL_RUNS, exist_ok=True)
archive_path = f'{DRIVE_RUNS}/q2_pt_good1.tgz'

if os.path.exists(archive_path):
    print(f'Extracting {archive_path}...')
    !tar -xzf {archive_path} -C {LOCAL_RUNS}/
    print('Extraction complete.')
else:
    print(f'WARNING: Archive not found at {archive_path}')
    print('Upload q2_pt_good1.tgz to Google Drive qwen3_runs folder')

# Find checkpoint
V1_CHECKPOINT = None
search_paths = [
    f'{LOCAL_RUNS}/tmp/backup_mlp_e2e_w_0.3824.pt',
    f'{LOCAL_RUNS}/q2_pt_good1/backup_mlp_e2e_w_0.3824.pt',
    f'{LOCAL_RUNS}/backup_mlp_e2e_w_0.3824.pt',
]

for p in search_paths:
    if os.path.exists(p):
        V1_CHECKPOINT = p
        break

# Fallback: glob search
if V1_CHECKPOINT is None:
    matches = glob.glob(f'{LOCAL_RUNS}/**/backup_mlp_e2e_w_*.pt', recursive=True)
    if matches:
        V1_CHECKPOINT = matches[0]

assert V1_CHECKPOINT and os.path.exists(V1_CHECKPOINT), \
    f'Checkpoint not found! Found files:\n' + \
    '\n'.join(glob.glob(f'{LOCAL_RUNS}/**/*.pt', recursive=True)[:10])

print(f'V1 checkpoint: {V1_CHECKPOINT}')

# ============================================================
# Q2_A4 Quantization Config (MUST MATCH V1 CHECKPOINT!)
# ============================================================
# MLP: 2-bit (lut_size=4), scale_rank=32
LUT_BITS = 2
LUT_SIZE = 2**LUT_BITS  # 4
SCALE_RANK = 32

# Attention: 4-bit (lut_size=16), scale_rank=8
ATTN_LUT_BITS = 4
ATTN_LUT_SIZE = 2**ATTN_LUT_BITS  # 16
ATTN_SCALE_RANK = 8

# Training - SMALLER BATCH SIZE for Q2_A4 (rank=32 uses more memory with STE)
BATCH_SIZE = 4  # Reduced for high rank + STE memory usage
DISTILL_TEMP = 2.0

if not torch.cuda.is_available():
    raise RuntimeError('Training requires CUDA!')
DEVICE = torch.device('cuda')

# STE-FP16 settings
V2_DTYPE = torch.float32  # FP32 master weights
USE_STE_FP16 = True
USE_FP16 = False

QUAL = f'q{LUT_BITS}_a{ATTN_LUT_BITS}_ste_fp16'
print(f'\n=== Q2_A4 Configuration ===')
print(f'MLP:  {LUT_SIZE} LUT ({LUT_BITS}-bit), rank={SCALE_RANK}')
print(f'Attn: {ATTN_LUT_SIZE} LUT ({ATTN_LUT_BITS}-bit), rank={ATTN_SCALE_RANK}')
print(f'Batch size: {BATCH_SIZE} (reduced for high rank + STE)')
print(f'STE-FP16: {USE_STE_FP16}')

In [None]:
# ============================================================
# LOAD KD CACHE
# ============================================================

import os

CACHE_NAME = 'alpaca_chat_think_both_L128_K128_R1024'
CACHE_TGZ = f'{CACHE_NAME}.tgz'

!mkdir -p {LOCAL_CACHES}

cache_local_path = f'{LOCAL_CACHES}/{CACHE_NAME}'
if not os.path.exists(cache_local_path):
    print(f'Extracting {CACHE_TGZ} from Google Drive...')
    !tar -xzf {GD_CACHES}/{CACHE_TGZ} -C {LOCAL_CACHES}/
else:
    print(f'Cache already exists at {cache_local_path}')

!ls -la {cache_local_path}/ | head -5

# Step 1: Load V1 Model and Checkpoint

In [None]:
# ============================================================
# LOAD BASE MODEL FOR V1
# ============================================================

from transformers import AutoModelForCausalLM, AutoTokenizer

print(f'Loading {MODEL_ID}...')
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)

# Load V1 model in BF16 (will convert to FP16 during V1→V2 conversion)
v1_model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,  # V1 in BF16
    trust_remote_code=True,
)
print(f'Base model loaded: {sum(p.numel() for p in v1_model.parameters()):,} params')

In [None]:
# ============================================================
# REPLACE WITH V1 LAYERS AND LOAD CHECKPOINT
# ============================================================

import sys
sys.path.insert(0, '.')

import importlib
import qat_lora
importlib.reload(qat_lora)

from qat_lora import (
    AnemllQATLinear,
    AnemllQuantConfig,
    replace_linear_with_anemll,
    load_checkpoint,
)

# Create V1 configs
v1_mlp_config = AnemllQuantConfig(
    lut_size=LUT_SIZE,
    scale_rank=SCALE_RANK,
)
v1_attn_config = AnemllQuantConfig(
    lut_size=ATTN_LUT_SIZE,
    scale_rank=ATTN_SCALE_RANK,
)

# Replace with V1 layers
print('Replacing with V1 AnemllQATLinear...')
replace_linear_with_anemll(
    v1_model,
    mlp_config=v1_mlp_config,
    attn_config=v1_attn_config,
    quantize_attn=True,
    quantize_lm_head=False,
)

# Load V1 checkpoint
print(f'\nLoading V1 checkpoint from {V1_CHECKPOINT}...')
v1_model.load_state_dict(torch.load(V1_CHECKPOINT, map_location='cpu'), strict=False)
v1_model.to(DEVICE)
print('V1 checkpoint loaded!')

In [None]:
# ============================================================
# EVALUATE V1 MODEL
# ============================================================

from qat_lora import evaluate_kd_loss

v1_model.eval()
v1_loss = evaluate_kd_loss(v1_model, cache_local_path, DEVICE, num_samples=50, temperature=DISTILL_TEMP)
print(f'V1 KD Loss: {v1_loss:.4f}')

# Step 2: Create V2 Model and Convert

In [None]:
# ============================================================
# CREATE V2 MODEL (FP32 master weights + STE-FP16 forward)
# ============================================================

from qat_lora import (
    AnemllQATLinearV2,
    AnemllQuantConfigV2,
    replace_linear_with_anemll_v2,
    freeze_Q_all,
)

# Load fresh base model for V2 in FP32 (master weights)
print(f'Loading fresh base model for V2 in {V2_DTYPE}...')
v2_model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=V2_DTYPE,  # FP32 master weights!
    trust_remote_code=True,
)

# Create V2 configs with STE-FP16 enabled
# - force_positive_scales=False: V1 scales can be negative
# - magnitude_activation='identity': Store raw magnitude  
# - use_ste_fp16=True: Forward simulates FP16 via STE
v2_mlp_config = AnemllQuantConfigV2(
    lut_size=LUT_SIZE,
    scale_rank=SCALE_RANK,
    force_positive_scales=False,  # V1 has negative scales!
    magnitude_activation='identity',  # No softplus - store raw magnitude
    use_ste_fp16=USE_STE_FP16,  # Enable STE-FP16 forward!
)
v2_attn_config = AnemllQuantConfigV2(
    lut_size=ATTN_LUT_SIZE,
    scale_rank=ATTN_SCALE_RANK,
    force_positive_scales=False,  # V1 has negative scales!
    magnitude_activation='identity',  # No softplus - store raw magnitude
    use_ste_fp16=USE_STE_FP16,  # Enable STE-FP16 forward!
)

# Replace with V2 layers
print('Replacing with V2 AnemllQATLinearV2...')
replace_linear_with_anemll_v2(
    v2_model,
    mlp_config=v2_mlp_config,
    attn_config=v2_attn_config,
)

v2_model.to(DEVICE)
print(f'V2 model created in {V2_DTYPE}!')
print(f'STE-FP16 enabled: {USE_STE_FP16}')
print('Note: force_positive_scales=False, magnitude_activation=identity for V1 compatibility')

In [None]:
# ============================================================
# CONVERT V1 → V2 (FP32 master weights)
# ============================================================

def convert_v1_layer_to_v2(v1_layer, v2_layer, target_dtype=torch.float32):
    """Convert V1 layer parameters to V2 format.
    
    V1: scales = A @ B (arbitrary, can be negative)
    V2: scales = (g * A_dir) @ B_dir
    
    With magnitude_activation='identity', we store raw magnitude directly.
    With force_positive_scales=False, we preserve signs in A_dir/B_dir.
    """
    with torch.no_grad():
        # Copy base parameters
        v2_layer.weight.data = v1_layer.weight.data.to(target_dtype)
        if v1_layer.bias is not None and v2_layer.bias is not None:
            v2_layer.bias.data = v1_layer.bias.data.to(target_dtype)
        v2_layer.lut.data = v1_layer.lut.data.to(target_dtype)
        
        # Get V1 scales (handle potential padding)
        A = v1_layer.scale_A  # [out, rank]
        B_full = v1_layer.scale_B  # [rank, padded_in]
        B = B_full[:, :v1_layer.in_features]  # [rank, in]
        
        # Compute norms (keep signs!)
        A_norms = A.norm(dim=0, keepdim=True).clamp(min=1e-6)  # [1, rank]
        B_norms = B.norm(dim=1, keepdim=True).clamp(min=1e-6)  # [rank, 1]
        
        # Unit-norm directions (preserving signs)
        A_dir = A / A_norms  # [out, rank]
        B_dir = B / B_norms  # [rank, in]
        
        # Magnitude is product of norms (always positive)
        magnitude = (A_norms.squeeze() * B_norms.squeeze())  # [rank]
        
        # Store directly - no inverse_softplus needed with identity activation
        v2_layer.scale_A.data = A_dir.to(target_dtype)
        v2_layer.scale_B.data = B_dir.to(target_dtype)
        v2_layer.rank_magnitude.data = magnitude.to(target_dtype)


print(f'Converting V1 → V2 in {V2_DTYPE}...')
converted = 0

# Collect V1 and V2 layers
v1_layers = {name: m for name, m in v1_model.named_modules() 
             if type(m).__name__ == 'AnemllQATLinear'}
v2_layers = {name: m for name, m in v2_model.named_modules() 
             if type(m).__name__ == 'AnemllQATLinearV2'}

print(f'Found {len(v1_layers)} V1 layers, {len(v2_layers)} V2 layers')

# Convert each layer
for name in v1_layers:
    if name in v2_layers:
        convert_v1_layer_to_v2(v1_layers[name], v2_layers[name], target_dtype=V2_DTYPE)
        converted += 1
        if converted <= 3:
            v2 = v2_layers[name]
            print(f'  {name}: mag=[{v2.rank_magnitude.min():.3f}, {v2.rank_magnitude.max():.3f}]')

print(f'\nConverted {converted} layers to V2 {V2_DTYPE}')

# Free V1 model memory
del v1_model
import gc
gc.collect()
torch.cuda.empty_cache()
print('V1 model freed from memory')

In [None]:
# ============================================================
# FREEZE Q (compute quantization indices)
# ============================================================

# With STE-FP16, we keep master weights in FP32 but forward simulates FP16.
# No need to convert to FP16 here - just freeze Q.

print('Freezing Q (computing indices)...')
freeze_Q_all(v2_model, verbose=False)
print('Q frozen for all V2 layers.')

# Verify dtype (should be FP32 master weights)
for name, m in v2_model.named_modules():
    if type(m).__name__ == 'AnemllQATLinearV2':
        print(f'\nVerified {name}:')
        print(f'  weight.dtype: {m.weight.dtype} (master weights)')
        print(f'  lut.dtype: {m.lut.dtype}')
        print(f'  _Q.dtype: {m._Q.dtype if m._Q is not None else None}')
        print(f'  STE-FP16: {m.config.use_ste_fp16} (forward simulates FP16)')
        break

In [None]:
# ============================================================
# EVALUATE CONVERTED V2 MODEL
# ============================================================

v2_model.eval()
v2_converted_loss = evaluate_kd_loss(v2_model, cache_local_path, DEVICE, num_samples=50, temperature=DISTILL_TEMP)
print(f'\n=== Conversion Results ===')
print(f'V1 Loss: {v1_loss:.4f}')
print(f'V2 Loss (after conversion): {v2_converted_loss:.4f}')
print(f'Difference: {abs(v2_converted_loss - v1_loss):.4f}')

if abs(v2_converted_loss - v1_loss) < 0.1:
    print('Conversion successful - losses are close!')
else:
    print('Note: Some difference expected due to different forward implementations')

# Step 3: Finetune V2 Model with STE-FP16

Training with:
- **FP32 master weights**: Stable gradients, no underflow
- **STE-FP16 forward**: Each operation rounded to FP16 precision (matches ANE)
- **No autocast**: STE handles the FP16 simulation internally

In [None]:
# ============================================================
# MEMORY CLEANUP
# ============================================================

import gc
import torch

# Clear CUDA cache and collect garbage before training
torch.cuda.empty_cache()
gc.collect()

print(f'GPU Memory: {torch.cuda.memory_allocated()/1e9:.2f}GB allocated')
print(f'GPU Memory: {torch.cuda.memory_reserved()/1e9:.2f}GB reserved')
!nvidia-smi --query-gpu=memory.used,memory.free --format=csv

In [None]:
# ============================================================
# PASS 1: ALL SCALES (MLP + ATTENTION) - 1000 steps
# ============================================================

from qat_lora import train_e2e, unfreeze_model_for_training_v2

# Clear memory before training
torch.cuda.empty_cache()
gc.collect()

# Unfreeze scales for training
unfreeze_model_for_training_v2(v2_model)

# Enable all scales
for name, module in v2_model.named_modules():
    if type(module).__name__ == 'AnemllQATLinearV2':
        if hasattr(module, 'scale_A') and module.scale_A is not None:
            module.scale_A.requires_grad = True
            module.scale_B.requires_grad = True
            module.rank_magnitude.requires_grad = True
        module.weight.requires_grad = False  # Keep weights frozen

trainable = sum(p.numel() for p in v2_model.parameters() if p.requires_grad)
print(f'Trainable params: {trainable:,}')

# Pass 1: Train all scales
print('\n=== PASS 1: All Scales (MLP + Attention) ===')
e2e_result = train_e2e(
    model=v2_model,
    cache_dir=cache_local_path,
    device=DEVICE,
    max_steps=1000,  # First pass: 1000 steps
    batch_size=BATCH_SIZE,
    lr=1e-4,
    use_cosine_schedule=True,
    warmup_steps=100,
    min_lr_ratio=0.1,
    temperature=DISTILL_TEMP,
    train_weights=False,
    train_scales=True,
    hard_top1_weight=0.0,
    hard_full_weight=0.0,
    logging_steps=20,
    eval_steps=100,
    verbose=True,
    train_mlp_only=False,  # Train all scales
    use_fp16=USE_FP16,
)
print(f'\nPass 1 complete! Final loss: {e2e_result["final_loss"]:.4f}')

In [None]:
# ============================================================
# PASS 2: MLP-ONLY REFINEMENT - 1000 steps
# ============================================================

from qat_lora import train_e2e, unfreeze_model_for_training_v2

# Clear memory before training
torch.cuda.empty_cache()
gc.collect()

# Unfreeze and set MLP-only
unfreeze_model_for_training_v2(v2_model)

# Freeze attention, enable MLP scales only
for name, module in v2_model.named_modules():
    if type(module).__name__ == 'AnemllQATLinearV2':
        is_mlp = '.mlp.' in name
        if hasattr(module, 'scale_A') and module.scale_A is not None:
            module.scale_A.requires_grad = is_mlp
            module.scale_B.requires_grad = is_mlp
            module.rank_magnitude.requires_grad = is_mlp
        module.weight.requires_grad = False

mlp_trainable = sum(p.numel() for p in v2_model.parameters() if p.requires_grad)
print(f'MLP-only trainable params: {mlp_trainable:,}')

# Pass 2: MLP-only refinement with lower LR
print('\n=== PASS 2: MLP-Only Refinement ===')
mlp_result = train_e2e(
    model=v2_model,
    cache_dir=cache_local_path,
    device=DEVICE,
    max_steps=1000,  # Second pass: 1000 steps
    batch_size=BATCH_SIZE,
    lr=5e-5,  # Lower LR for refinement
    use_cosine_schedule=True,
    warmup_steps=50,
    min_lr_ratio=0.1,
    temperature=DISTILL_TEMP,
    train_weights=False,
    train_scales=True,
    hard_top1_weight=0.0,
    hard_full_weight=0.0,
    logging_steps=20,
    eval_steps=100,
    verbose=True,
    train_mlp_only=True,  # MLP only!
    use_fp16=USE_FP16,
)

print(f'\nPass 2 complete! Final loss: {mlp_result["final_loss"]:.4f}')

In [None]:
# ============================================================
# EXPORT TO FP16 (ANE-READY)
# ============================================================

import os
import json

V2_RUN_NAME = f'anemll_v2_{QUAL}_from_v1'
V2_SAVE_DIR = f'{LOCAL_RUNS}/{V2_RUN_NAME}'
os.makedirs(V2_SAVE_DIR, exist_ok=True)

# Convert to FP16 for ANE export
# Training was in FP32 (master weights), now we snap to FP16
print('Converting model to FP16 for ANE export...')
v2_model.half()  # Convert all params to FP16

# Disable STE (not needed in FP16)
for name, m in v2_model.named_modules():
    if type(m).__name__ == 'AnemllQATLinearV2':
        m.config.use_ste_fp16 = False

# Save checkpoint
torch.save(v2_model.state_dict(), f'{V2_SAVE_DIR}/model_state_dict.pt')

# Get final loss from whichever pass was run
final_loss_value = mlp_result['final_loss'] if 'mlp_result' in dir() else e2e_result.get('final_loss', 0.0) if 'e2e_result' in dir() else 0.0

# Save config with ALL Q2_A4 parameters
config = {
    'model_id': MODEL_ID,
    'version': 'v2',
    'precision': 'fp16',
    'training': 'ste_fp16',
    # MLP config
    'mlp_lut_bits': LUT_BITS,
    'mlp_lut_size': LUT_SIZE,
    'mlp_scale_rank': SCALE_RANK,
    # Attention config
    'attn_lut_bits': ATTN_LUT_BITS,
    'attn_lut_size': ATTN_LUT_SIZE,
    'attn_scale_rank': ATTN_SCALE_RANK,
    # Losses
    'v1_loss': v1_loss if 'v1_loss' in dir() else 0.0,
    'v2_converted_loss': v2_converted_loss if 'v2_converted_loss' in dir() else 0.0,
    'final_loss': final_loss_value,
}
with open(f'{V2_SAVE_DIR}/config.json', 'w') as f:
    json.dump(config, f, indent=2)

# Also save to tmp for quick access
torch.save(v2_model.state_dict(), '/tmp/v2_q2a4_ste_fp16.pt')

print(f'Saved FP16 V2 checkpoint to {V2_SAVE_DIR}')
print(f'Also saved to /tmp/v2_q2a4_ste_fp16.pt')
print(f'Final loss: {final_loss_value:.4f}')
print(f'ANE-ready - trained with STE-FP16, exported to FP16!')
print(f'\nConfig: MLP={LUT_BITS}b/r{SCALE_RANK}, Attn={ATTN_LUT_BITS}b/r{ATTN_SCALE_RANK}')

In [None]:
# ============================================================
# FINAL EVALUATION
# ============================================================

v2_model.eval()
final_loss = evaluate_kd_loss(v2_model, cache_local_path, DEVICE, num_samples=50, temperature=DISTILL_TEMP)

print(f'\n=== STE-FP16 Training Results ===')
print(f'V1 Original:           {v1_loss:.4f}')
print(f'V2 After Convert:      {v2_converted_loss:.4f}')
print(f'V2 After STE-FP16 Train: {final_loss:.4f}')
print(f'Total Improvement:     {v1_loss - final_loss:.4f}')
print(f'\nModel exported to FP16 - ANE-ready!')
print(f'Training used STE-FP16: FP32 master weights + FP16 forward simulation')

In [None]:
# ============================================================
# UPLOAD TO GOOGLE DRIVE
# ============================================================

!tar -czvf {V2_RUN_NAME}.tgz -C {LOCAL_RUNS} {V2_RUN_NAME}
!cp {V2_RUN_NAME}.tgz {GD_RUNS}/
print(f'\nUploaded to {GD_RUNS}/{V2_RUN_NAME}.tgz')

# Load V2 Checkpoint (for inference from saved model)

This section shows how to properly load a V2 checkpoint that includes `_Q` and `_indices`.
Use `load_v2_checkpoint()` instead of `model.load_state_dict()` to ensure all buffers are loaded correctly.

In [None]:
# ============================================================
# LOAD V2 CHECKPOINT (for inference from saved model)
# ============================================================
# Use this cell to load a previously saved V2 checkpoint.
# The load_v2_checkpoint() function properly handles _Q and _indices buffers.

from qat_lora import (
    AnemllQuantConfigV2,
    replace_linear_with_anemll_v2,
    load_v2_checkpoint,
    freeze_model_for_inference_v2,
)

# Path to saved V2 checkpoint
V2_CHECKPOINT = f'{LOCAL_RUNS}/{V2_RUN_NAME}/model_state_dict.pt'
# Or load from a specific path:
# V2_CHECKPOINT = '/path/to/anemll_v2_q2_a4_ste_fp16_from_v1/model_state_dict.pt'

# Load fresh base model
print(f'Loading base model for inference...')
inference_model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float16,
    trust_remote_code=True,
)

# Create V2 configs (must match training config - Q2_A4)
v2_mlp_config = AnemllQuantConfigV2(
    lut_size=LUT_SIZE,  # 4 (2-bit)
    scale_rank=SCALE_RANK,  # 32
    force_positive_scales=False,
    magnitude_activation='identity',
    use_ste_fp16=False,
)
v2_attn_config = AnemllQuantConfigV2(
    lut_size=ATTN_LUT_SIZE,  # 16 (4-bit)
    scale_rank=ATTN_SCALE_RANK,  # 8
    force_positive_scales=False,
    magnitude_activation='identity',
    use_ste_fp16=False,
)

# Replace with V2 layers
print('Replacing with V2 layers...')
print(f'  MLP: {LUT_SIZE} LUT, rank={SCALE_RANK}')
print(f'  Attn: {ATTN_LUT_SIZE} LUT, rank={ATTN_SCALE_RANK}')
replace_linear_with_anemll_v2(
    inference_model,
    mlp_config=v2_mlp_config,
    attn_config=v2_attn_config,
    quantize_attn=True,
)

# Load checkpoint with proper _Q and _indices handling
print(f'\nLoading V2 checkpoint from {V2_CHECKPOINT}...')
stats = load_v2_checkpoint(
    inference_model,
    V2_CHECKPOINT,
    device=DEVICE,
    verbose=True,
)

# Freeze for inference (should be no-op if _Q already loaded)
freeze_model_for_inference_v2(inference_model, verbose=False)
print('\nModel ready for inference!')

In [None]:
# ============================================================
# FREEZE FOR INFERENCE
# ============================================================

from qat_lora import freeze_model_for_inference_v2

print('Freezing V2 model for inference...')
freeze_model_for_inference_v2(v2_model, verbose=False)
print('Ready for inference!')

In [None]:
# ============================================================
# TEST INFERENCE
# ============================================================

def run_inference(model, tokenizer, prompt, max_new_tokens=256):
    messages = [
        {'role': 'system', 'content': 'You are a helpful assistant.'},
        {'role': 'user', 'content': prompt}
    ]
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(text, return_tensors='pt').to(DEVICE)
    
    with torch.no_grad():
        output = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
    
    return tokenizer.decode(output[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)

prompts = [
    'What is the capital of France?',
    'Explain quantum mechanics briefly.',
    'What is 2+2?',
    'What is Apple Neural Engine',
    'What is History of Alibaba Group',
]

v2_model.eval()
for prompt in prompts:
    response = run_inference(v2_model, tokenizer, prompt)
    print(f'Prompt: {prompt}')
    print(f'Response: {response}')
    print('-' * 50)