# V1 to V2 Conversion + FP16 Finetuning

This notebook:
1. Loads a trained V1 checkpoint
2. Converts V1 scales to V2 format (unit-norm + rank_magnitude)
3. Converts to **FP16** (LUT, scales, weights)
4. Finetunes the V2 model in **FP16** with scale-only training
5. Saves **ANE-ready** checkpoint (no snap needed!)

## V1 → V2 Conversion:
- V1: `scales = scale_A @ scale_B` (arbitrary magnitudes)
- V2: `scales = (g * A_dir) @ B_dir` where A_dir, B_dir are unit-norm

## FP16 Benefits:
- **No precision mismatch**: Indices computed in FP16 = same as ANE
- **ANE-ready**: Direct export without conversion

In [None]:
# ============================================================
# GOOGLE DRIVE PATHS
# ============================================================

GD_RUNS = '/content/drive/MyDrive/qwen3_runs'
GD_CACHES = '/content/drive/MyDrive/qwen3_caches'

LOCAL_RUNS = 'runs'
LOCAL_CACHES = 'caches'

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Clone/update repo
!git clone https://github.com/anemll/qwen3_apple_style_2bit_qat_lora.git || (cd qwen3_apple_style_2bit_qat_lora && git pull)
%cd qwen3_apple_style_2bit_qat_lora
!git fetch && git pull && git reset --hard HEAD

import sys
[sys.modules.pop(k) for k in list(sys.modules) if k.startswith('qat_lora')]

# Install dependencies
!pip install -q transformers accelerate safetensors

In [None]:
# ============================================================
# CONFIGURATION
# ============================================================

import torch
import os

MODEL_ID = 'Qwen/Qwen3-0.6B'

# V1 checkpoint archive name (the .tgz file in GD_RUNS)
V1_ARCHIVE = 'anemll_q4_a4_e2e_v2_scales_only.tgz'
V1_FOLDER = 'anemll_q4_a4_e2e_v2_scales_only'

# Extract V1 checkpoint from .tgz
os.makedirs(LOCAL_RUNS, exist_ok=True)
v1_extract_path = f'{LOCAL_RUNS}/{V1_FOLDER}'
if not os.path.exists(v1_extract_path):
    print(f'Extracting {V1_ARCHIVE} from Google Drive...')
    !tar -xzf {GD_RUNS}/{V1_ARCHIVE} -C {LOCAL_RUNS}/
else:
    print(f'V1 checkpoint already extracted at {v1_extract_path}')

# Path to the actual checkpoint file
V1_CHECKPOINT = f'{v1_extract_path}/model_state_dict.pt'
print(f'V1 checkpoint: {V1_CHECKPOINT}')

# Verify it exists
assert os.path.exists(V1_CHECKPOINT), f'V1 checkpoint not found at {V1_CHECKPOINT}'

# Quantization config (must match V1 checkpoint)
LUT_BITS = 4
LUT_SIZE = 2**LUT_BITS
SCALE_RANK = 4

ATTN_LUT_BITS = 4
ATTN_LUT_SIZE = 2**ATTN_LUT_BITS
ATTN_SCALE_RANK = 4

# Training
BATCH_SIZE = 32 if torch.cuda.is_available() else 4
DISTILL_TEMP = 2.0

# FP16 Training - requires CUDA
if not torch.cuda.is_available():
    raise RuntimeError('FP16 training requires CUDA!')
DEVICE = torch.device('cuda')

# FP16 settings
USE_FP16 = True
V2_DTYPE = torch.float16  # V2 model in FP16 for ANE

QUAL = f'q{LUT_BITS}_a{ATTN_LUT_BITS}_fp16'
print(f'Quality: {QUAL}')
print(f'Device: {DEVICE}')
print(f'V2 dtype: {V2_DTYPE} (FP16 training enabled)')

In [None]:
# ============================================================
# LOAD KD CACHE
# ============================================================

import os

CACHE_NAME = 'alpaca_chat_think_both_L128_K128_R1024'
CACHE_TGZ = f'{CACHE_NAME}.tgz'

!mkdir -p {LOCAL_CACHES}

cache_local_path = f'{LOCAL_CACHES}/{CACHE_NAME}'
if not os.path.exists(cache_local_path):
    print(f'Extracting {CACHE_TGZ} from Google Drive...')
    !tar -xzf {GD_CACHES}/{CACHE_TGZ} -C {LOCAL_CACHES}/
else:
    print(f'Cache already exists at {cache_local_path}')

!ls -la {cache_local_path}/ | head -5

# Step 1: Load V1 Model and Checkpoint

In [None]:
# ============================================================
# LOAD BASE MODEL FOR V1
# ============================================================

from transformers import AutoModelForCausalLM, AutoTokenizer

print(f'Loading {MODEL_ID}...')
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)

# Load V1 model in BF16 (will convert to FP16 during V1→V2 conversion)
v1_model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,  # V1 in BF16
    trust_remote_code=True,
)
print(f'Base model loaded: {sum(p.numel() for p in v1_model.parameters()):,} params')

In [None]:
# ============================================================
# REPLACE WITH V1 LAYERS AND LOAD CHECKPOINT
# ============================================================

import sys
sys.path.insert(0, '.')

import importlib
import qat_lora
importlib.reload(qat_lora)

from qat_lora import (
    AnemllQATLinear,
    AnemllQuantConfig,
    replace_linear_with_anemll,
    load_checkpoint,
)

# Create V1 configs
v1_mlp_config = AnemllQuantConfig(
    lut_size=LUT_SIZE,
    scale_rank=SCALE_RANK,
)
v1_attn_config = AnemllQuantConfig(
    lut_size=ATTN_LUT_SIZE,
    scale_rank=ATTN_SCALE_RANK,
)

# Replace with V1 layers
print('Replacing with V1 AnemllQATLinear...')
replace_linear_with_anemll(
    v1_model,
    mlp_config=v1_mlp_config,
    attn_config=v1_attn_config,
    quantize_attn=True,
    quantize_lm_head=False,
)

# Load V1 checkpoint
print(f'\nLoading V1 checkpoint from {V1_CHECKPOINT}...')
v1_model.load_state_dict(torch.load(V1_CHECKPOINT, map_location='cpu'), strict=False)
v1_model.to(DEVICE)
print('V1 checkpoint loaded!')

In [None]:
# ============================================================
# EVALUATE V1 MODEL
# ============================================================

from qat_lora import evaluate_kd_loss

v1_model.eval()
v1_loss = evaluate_kd_loss(v1_model, cache_local_path, DEVICE, num_samples=50, temperature=DISTILL_TEMP)
print(f'V1 KD Loss: {v1_loss:.4f}')

# Step 2: Create V2 Model and Convert

In [None]:
# ============================================================
# CREATE V2 MODEL (FP16)
# ============================================================

from qat_lora import (
    AnemllQATLinearV2,
    AnemllQuantConfigV2,
    replace_linear_with_anemll_v2,
    freeze_Q_all,
    convert_model_to_fp16_v2,  # FP16 conversion
)

# Load fresh base model for V2 in FP16
print('Loading fresh base model for V2 in FP16...')
v2_model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=V2_DTYPE,  # FP16!
    trust_remote_code=True,
)

# Create V2 configs
v2_mlp_config = AnemllQuantConfigV2(
    lut_size=LUT_SIZE,
    scale_rank=SCALE_RANK,
    force_positive_scales=True,
    positive_scale_method='abs',
    magnitude_activation='softplus',
    magnitude_eps=1e-6,
)
v2_attn_config = AnemllQuantConfigV2(
    lut_size=ATTN_LUT_SIZE,
    scale_rank=ATTN_SCALE_RANK,
    force_positive_scales=True,
    positive_scale_method='abs',
    magnitude_activation='softplus',
    magnitude_eps=1e-6,
)

# Replace with V2 layers
print('Replacing with V2 AnemllQATLinearV2...')
replace_linear_with_anemll_v2(
    v2_model,
    mlp_config=v2_mlp_config,
    attn_config=v2_attn_config,
)

v2_model.to(DEVICE)
print(f'V2 model created in {V2_DTYPE}!')

In [None]:
# ============================================================
# CONVERT V1 → V2 (FP16)
# ============================================================

import torch.nn.functional as F

def inverse_softplus(y, beta=1.0, threshold=20.0):
    """Inverse of softplus: x = log(exp(y) - 1) / beta"""
    # For numerical stability, use y directly when y > threshold
    return torch.where(
        y * beta > threshold,
        y,
        torch.log(torch.expm1(y * beta).clamp(min=1e-8)) / beta
    )

def convert_v1_layer_to_v2(v1_layer, v2_layer):
    """Convert V1 layer parameters to V2 format in FP16.
    
    V1: scales = A @ B (arbitrary magnitudes)
    V2: scales = (g * A_dir) @ B_dir (unit-norm + magnitude)
    
    Note: V2 applies softplus to rank_magnitude, so we need inverse_softplus!
    """
    with torch.no_grad():
        # Copy base parameters to FP16
        v2_layer.weight.data = v1_layer.weight.data.to(torch.float16)
        if v1_layer.bias is not None and v2_layer.bias is not None:
            v2_layer.bias.data = v1_layer.bias.data.to(torch.float16)
        v2_layer.lut.data = v1_layer.lut.data.to(torch.float16)
        
        # Get V1 scales (handle potential padding)
        A = v1_layer.scale_A  # [out, rank]
        B_full = v1_layer.scale_B  # [rank, padded_in]
        B = B_full[:, :v1_layer.in_features]  # [rank, in]
        
        # Compute norms
        A_norms = A.norm(dim=0, keepdim=True).clamp(min=1e-8)  # [1, rank]
        B_norms = B.norm(dim=1, keepdim=True).clamp(min=1e-8)  # [rank, 1]
        
        # V2 stores unit-norm directions + magnitude
        A_dir = A / A_norms  # [out, rank] unit-norm columns
        B_dir = B / B_norms  # [rank, in] unit-norm rows
        
        # Actual magnitude is product of norms
        actual_magnitude = (A_norms.squeeze() * B_norms.squeeze())  # [rank]
        
        # V2 applies softplus to rank_magnitude, so we need inverse!
        # rank_magnitude_raw such that softplus(rank_magnitude_raw) = actual_magnitude
        rank_magnitude_raw = inverse_softplus(actual_magnitude.float()).clamp(-10, 10)
        
        # Store in V2 layer (FP16)
        v2_layer.scale_A.data = A_dir.to(torch.float16)
        v2_layer.scale_B.data = B_dir.to(torch.float16)
        v2_layer.rank_magnitude.data = rank_magnitude_raw.to(torch.float16)


print('Converting V1 → V2 in FP16...')
converted = 0

# Collect V1 and V2 layers
v1_layers = {name: m for name, m in v1_model.named_modules() 
             if type(m).__name__ == 'AnemllQATLinear'}
v2_layers = {name: m for name, m in v2_model.named_modules() 
             if type(m).__name__ == 'AnemllQATLinearV2'}

print(f'Found {len(v1_layers)} V1 layers, {len(v2_layers)} V2 layers')

# Convert each layer
for name in v1_layers:
    if name in v2_layers:
        convert_v1_layer_to_v2(v1_layers[name], v2_layers[name])
        converted += 1
        if converted <= 3:
            # Verify conversion
            v2 = v2_layers[name]
            g_raw = v2.rank_magnitude
            g_eff = F.softplus(g_raw.float())
            print(f'  {name}: rank_mag raw=[{g_raw.min():.2f}, {g_raw.max():.2f}] -> eff=[{g_eff.min():.2f}, {g_eff.max():.2f}]')

print(f'\nConverted {converted} layers to V2 FP16')

# Free V1 model memory
del v1_model
import gc
gc.collect()
torch.cuda.empty_cache()
print('V1 model freed from memory')

In [None]:
# ============================================================
# ENSURE FP16 LUT + FREEZE Q
# ============================================================

# Recompute LUT in FP16 to ensure indices are computed in FP16
print('Ensuring LUT is FP16...')
convert_model_to_fp16_v2(v2_model, verbose=True)

# Freeze Q - indices computed in FP16 = same as ANE!
print('\nFreezing Q (computing indices in FP16)...')
freeze_Q_all(v2_model, verbose=False)
print('Q frozen for all V2 layers in FP16.')

# Verify
for name, m in v2_model.named_modules():
    if type(m).__name__ == 'AnemllQATLinearV2':
        print(f'\nVerified {name}:')
        print(f'  weight.dtype: {m.weight.dtype}')
        print(f'  lut.dtype: {m.lut.dtype}')
        print(f'  _Q.dtype: {m._Q.dtype if m._Q is not None else None}')
        break

In [None]:
# ============================================================
# EVALUATE CONVERTED V2 MODEL
# ============================================================

v2_model.eval()
v2_converted_loss = evaluate_kd_loss(v2_model, cache_local_path, DEVICE, num_samples=50, temperature=DISTILL_TEMP)
print(f'\n=== Conversion Results ===')
print(f'V1 Loss: {v1_loss:.4f}')
print(f'V2 Loss (after conversion): {v2_converted_loss:.4f}')
print(f'Difference: {abs(v2_converted_loss - v1_loss):.4f}')

if abs(v2_converted_loss - v1_loss) < 0.1:
    print('Conversion successful - losses are close!')
else:
    print('Note: Some difference expected due to different forward implementations')

# Step 3: Finetune V2 Model

In [None]:
# ============================================================
# V2 SCALE FINETUNING IN FP16 (MLP + ATTENTION)
# ============================================================

from qat_lora import train_e2e, unfreeze_model_for_training_v2

# Unfreeze scales for training
unfreeze_model_for_training_v2(v2_model)

# Enable all scales
for name, module in v2_model.named_modules():
    if type(module).__name__ == 'AnemllQATLinearV2':
        if hasattr(module, 'scale_A') and module.scale_A is not None:
            module.scale_A.requires_grad = True
            module.scale_B.requires_grad = True
            module.rank_magnitude.requires_grad = True
        module.weight.requires_grad = False  # Keep weights frozen

trainable = sum(p.numel() for p in v2_model.parameters() if p.requires_grad)
print(f'Trainable params: {trainable:,}')

# Train all scales in FP16
print('\nStarting V2 scale finetuning in FP16...')
e2e_result = train_e2e(
    model=v2_model,
    cache_dir=cache_local_path,
    device=DEVICE,
    max_steps=2000,
    batch_size=BATCH_SIZE,
    lr=1e-4,
    use_cosine_schedule=True,
    warmup_steps=100,
    min_lr_ratio=0.1,
    temperature=DISTILL_TEMP,
    train_weights=False,
    train_scales=True,
    hard_top1_weight=0.0,
    hard_full_weight=0.0,
    logging_steps=20,
    eval_steps=100,
    verbose=True,
    train_mlp_only=False,  # Train all scales
    use_fp16=USE_FP16,  # FP16 training!
)

In [None]:
# ============================================================
# SAVE FP16 CHECKPOINT (ANE-READY, NO SNAP NEEDED!)
# ============================================================

import os
import json

V2_RUN_NAME = f'anemll_v2_{QUAL}_from_v1_fp16'
V2_SAVE_DIR = f'{LOCAL_RUNS}/{V2_RUN_NAME}'
os.makedirs(V2_SAVE_DIR, exist_ok=True)

# Save checkpoint
torch.save(v2_model.state_dict(), f'{V2_SAVE_DIR}/model_state_dict.pt')

# Save config
config = {
    'model_id': MODEL_ID,
    'version': 'v2',
    'precision': 'fp16',  # FP16!
    'lut_bits': LUT_BITS,
    'scale_rank': SCALE_RANK,
    'v1_loss': v1_loss,
    'v2_converted_loss': v2_converted_loss,
    'final_loss': e2e_result['final_loss'],
}
with open(f'{V2_SAVE_DIR}/config.json', 'w') as f:
    json.dump(config, f, indent=2)

# Also save to tmp for quick access
torch.save(v2_model.state_dict(), '/tmp/v2_from_v1_fp16.pt')

print(f'Saved FP16 V2 checkpoint to {V2_SAVE_DIR}')
print(f'Also saved to /tmp/v2_from_v1_fp16.pt')
print(f'ANE-ready - no snap needed!')

In [None]:
# ============================================================
# FINAL EVALUATION
# ============================================================

v2_model.eval()
final_loss = evaluate_kd_loss(v2_model, cache_local_path, DEVICE, num_samples=50, temperature=DISTILL_TEMP)

print(f'\n=== FP16 Training Results ===')
print(f'V1 Original:        {v1_loss:.4f}')
print(f'V2 After Convert:   {v2_converted_loss:.4f}')
print(f'V2 After FP16 Train:{final_loss:.4f}')
print(f'Total Improvement:  {v1_loss - final_loss:.4f}')
print(f'\nModel is in FP16 - ANE-ready!')

In [None]:
# ============================================================
# UPLOAD TO GOOGLE DRIVE
# ============================================================

!tar -czvf {V2_RUN_NAME}.tgz -C {LOCAL_RUNS} {V2_RUN_NAME}
!cp {V2_RUN_NAME}.tgz {GD_RUNS}/
print(f'\nUploaded to {GD_RUNS}/{V2_RUN_NAME}.tgz')

# Test Inference

In [None]:
# ============================================================
# FREEZE FOR INFERENCE
# ============================================================

from qat_lora import freeze_model_for_inference_v2

print('Freezing V2 model for inference...')
freeze_model_for_inference_v2(v2_model, verbose=False)
print('Ready for inference!')

In [None]:
# ============================================================
# TEST INFERENCE
# ============================================================

def run_inference(model, tokenizer, prompt, max_new_tokens=256):
    messages = [
        {'role': 'system', 'content': 'You are a helpful assistant.'},
        {'role': 'user', 'content': prompt}
    ]
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(text, return_tensors='pt').to(DEVICE)
    
    with torch.no_grad():
        output = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
    
    return tokenizer.decode(output[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)

prompts = [
    'What is the capital of France?',
    'Explain quantum mechanics briefly.',
    'What is 2+2?',
]

v2_model.eval()
for prompt in prompts:
    response = run_inference(v2_model, tokenizer, prompt)
    print(f'Prompt: {prompt}')
    print(f'Response: {response}')
    print('-' * 50)