# Anemll V2 FP16 Training Pipeline

Converts V1 checkpoint to V2 format and fine-tunes in FP16 for ANE deployment.

## Pipeline:
1. Load V1 model + checkpoint
2. Create V2 model
3. Convert V1 → V2 (extract norms into rank_magnitude)
4. Convert to FP16 (LUT, scales, weights)
5. Freeze Q (indices computed in FP16 = same as ANE)
6. Train in FP16
7. Save - no snap needed!

## Key Benefits:
- **No precision mismatch**: Indices computed in FP16 = same as ANE
- **Proper V1→V2 conversion**: Preserves trained scales
- **ANE-ready**: Direct export without conversion

In [None]:
# ============================================================
# GOOGLE DRIVE PATHS (STANDARD)
# ============================================================

# Checkpoints/runs go here
GD_RUNS = '/content/drive/MyDrive/qwen3_runs'

# KD caches go here
GD_CACHES = '/content/drive/MyDrive/qwen3_caches'

# Local directories (on Colab VM)
LOCAL_RUNS = 'runs'
LOCAL_CACHES = 'caches'

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Clone repo if needed
!git clone https://github.com/anemll/qwen3_apple_style_2bit_qat_lora.git || (cd qwen3_apple_style_2bit_qat_lora && git pull)
%cd qwen3_apple_style_2bit_qat_lora
!git fetch && git pull
!git reset --hard HEAD

import sys
[sys.modules.pop(k) for k in list(sys.modules) if k.startswith('qat_lora')]

# Import both V1 and V2 modules
from qat_lora import (
    # V1 (for loading checkpoint)
    AnemllQATLinear,
    AnemllQuantConfig,
    replace_linear_with_anemll,
    # V2 (for training)
    AnemllQATLinearV2,
    AnemllQuantConfigV2,
    replace_linear_with_anemll_v2,
    freeze_Q_all,
    freeze_model_for_inference_v2,
    unfreeze_model_for_training_v2,
    convert_model_to_fp16_v2,
    evaluate_kd_loss,
    train_e2e,
    save_checkpoint,
)

In [None]:
!pip install -q transformers accelerate safetensors

In [None]:
# ============================================================
# LOAD KD CACHE FROM GOOGLE DRIVE
# ============================================================

CACHE_NAME = 'alpaca_chat_think_both_L128_K128_R1024'
CACHE_TGZ = f'{CACHE_NAME}.tgz'

!mkdir -p {LOCAL_CACHES}

import os
cache_local_path = f'{LOCAL_CACHES}/{CACHE_NAME}'
if not os.path.exists(cache_local_path):
    print(f'Extracting {CACHE_TGZ} from Google Drive...')
    !tar -xzf {GD_CACHES}/{CACHE_TGZ} -C {LOCAL_CACHES}/
else:
    print(f'Cache already exists at {cache_local_path}')

!ls -la {cache_local_path}/ | head -10

In [None]:
# ============================================================
# CONFIGURATION - V1→V2 FP16 CONVERSION
# ============================================================

import torch
import os

# Model
MODEL_ID = 'Qwen/Qwen3-0.6B'

# V1 checkpoint archive (the .tgz file in GD_RUNS)
# ⚠️ SET YOUR V1 CHECKPOINT HERE:
V1_ARCHIVE = 'anemll_q4_a4_e2e_v2_scales_only.tgz'
V1_FOLDER = 'anemll_q4_a4_e2e_v2_scales_only'

# Extract V1 checkpoint from .tgz
os.makedirs(LOCAL_RUNS, exist_ok=True)
v1_extract_path = f'{LOCAL_RUNS}/{V1_FOLDER}'
if not os.path.exists(v1_extract_path):
    print(f'Extracting {V1_ARCHIVE} from Google Drive...')
    !tar -xzf {GD_RUNS}/{V1_ARCHIVE} -C {LOCAL_RUNS}/
else:
    print(f'V1 checkpoint already extracted at {v1_extract_path}')

V1_CHECKPOINT = f'{v1_extract_path}/model_state_dict.pt'
assert os.path.exists(V1_CHECKPOINT), f'V1 checkpoint not found: {V1_CHECKPOINT}'
print(f'V1 checkpoint: {V1_CHECKPOINT}')

# Quantization config (must match V1 checkpoint)
LUT_BITS = 4
LUT_SIZE = 2**LUT_BITS
SCALE_RANK = 4

ATTN_LUT_BITS = 4
ATTN_LUT_SIZE = 2**ATTN_LUT_BITS
ATTN_SCALE_RANK = 4

# Training
BATCH_SIZE = 32 if torch.cuda.is_available() else 4
DISTILL_TEMP = 2.0

# Device - CUDA required for FP16 training
if not torch.cuda.is_available():
    raise RuntimeError("FP16 training requires CUDA!")
DEVICE = torch.device('cuda')

QUAL = f'q{LUT_BITS}_a{ATTN_LUT_BITS}_fp16'

print(f'\n=== V1→V2 FP16 Pipeline ===')
print(f'Quality: {QUAL}')
print(f'Device: {DEVICE}')
print(f'Quant config: lut={LUT_SIZE}, rank={SCALE_RANK}')

In [None]:
# ============================================================
# STEP 1: LOAD V1 MODEL + CHECKPOINT
# ============================================================

from transformers import AutoModelForCausalLM, AutoTokenizer

print(f'Loading {MODEL_ID} for V1...')
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)

# Load base model (BF16 for V1 - will convert to FP16 after V2 conversion)
v1_model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
)

# Create V1 configs
v1_mlp_config = AnemllQuantConfig(
    lut_size=LUT_SIZE,
    scale_rank=SCALE_RANK,
)
v1_attn_config = AnemllQuantConfig(
    lut_size=ATTN_LUT_SIZE,
    scale_rank=ATTN_SCALE_RANK,
)

# Replace with V1 layers
print('Replacing with V1 AnemllQATLinear...')
replace_linear_with_anemll(
    v1_model,
    mlp_config=v1_mlp_config,
    attn_config=v1_attn_config,
    quantize_attn=True,
    quantize_lm_head=False,
)

# Load V1 checkpoint
print(f'Loading V1 checkpoint: {V1_CHECKPOINT}')
v1_state = torch.load(V1_CHECKPOINT, map_location='cpu')
if 'model_state_dict' in v1_state:
    v1_state = v1_state['model_state_dict']
v1_model.load_state_dict(v1_state, strict=False)
v1_model.to(DEVICE)
v1_model.eval()

print('V1 model loaded!')

# Evaluate V1 loss for reference
v1_loss = evaluate_kd_loss(v1_model, cache_local_path, DEVICE, num_samples=50, temperature=DISTILL_TEMP)
print(f'\nV1 KD Loss: {v1_loss:.4f}')

In [None]:
# ============================================================
# STEP 2: CREATE V2 MODEL
# ============================================================

# Load fresh base model for V2 (in FP16!)
print('Loading fresh base model for V2 in FP16...')
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float16,  # FP16 from the start!
    trust_remote_code=True,
)

# Create V2 configs
v2_mlp_config = AnemllQuantConfigV2(
    lut_size=LUT_SIZE,
    scale_rank=SCALE_RANK,
    force_positive_scales=True,
    positive_scale_method="abs",
    magnitude_activation="softplus",
    magnitude_eps=1e-6,
)

v2_attn_config = AnemllQuantConfigV2(
    lut_size=ATTN_LUT_SIZE,
    scale_rank=ATTN_SCALE_RANK,
    force_positive_scales=True,
    positive_scale_method="abs",
    magnitude_activation="softplus",
    magnitude_eps=1e-6,
)

# Replace with V2 layers
print('Replacing with V2 AnemllQATLinearV2...')
count = replace_linear_with_anemll_v2(
    model,
    mlp_config=v2_mlp_config,
    attn_config=v2_attn_config,
    quantize_attn=True,
    quantize_lm_head=False,
)

print(f'Replaced {count} layers with V2')
print(f'Model dtype: {next(model.parameters()).dtype}')

In [None]:
# ============================================================
# STEP 3: CONVERT V1 → V2
# ============================================================

def convert_v1_layer_to_v2(v1_layer, v2_layer):
    """Convert V1 layer parameters to V2 format.
    
    V1: scales = A @ B (arbitrary magnitudes)
    V2: scales = (g * A_dir) @ B_dir (unit-norm + magnitude)
    """
    with torch.no_grad():
        # Copy base parameters (convert to FP16)
        v2_layer.weight.data = v1_layer.weight.data.to(torch.float16)
        if v1_layer.bias is not None and v2_layer.bias is not None:
            v2_layer.bias.data = v1_layer.bias.data.to(torch.float16)
        
        # Copy LUT (will be overwritten by convert_to_fp16 later)
        v2_layer.lut.data = v1_layer.lut.data.to(torch.float16)
        
        # Get V1 scales (handle potential padding)
        A = v1_layer.scale_A  # [out, rank]
        B_full = v1_layer.scale_B  # [rank, padded_in]
        B = B_full[:, :v1_layer.in_features]  # [rank, in]
        
        # Compute norms
        A_norms = A.norm(dim=0, keepdim=True).clamp(min=1e-8)  # [1, rank]
        B_norms = B.norm(dim=1, keepdim=True).clamp(min=1e-8)  # [rank, 1]
        
        # V2 stores unit-norm directions + magnitude
        A_dir = A / A_norms  # [out, rank] unit-norm columns
        B_dir = B / B_norms  # [rank, in] unit-norm rows
        
        # Magnitude is product of norms
        rank_magnitude = (A_norms.squeeze() * B_norms.squeeze())  # [rank]
        
        # Store in V2 layer (FP16)
        v2_layer.scale_A.data = A_dir.to(torch.float16)
        v2_layer.scale_B.data = B_dir.to(torch.float16)
        v2_layer.rank_magnitude.data = rank_magnitude.to(torch.float16)


print('Converting V1 → V2 in FP16...')
converted = 0

# Collect V1 and V2 layers
v1_layers = {name: m for name, m in v1_model.named_modules() 
             if type(m).__name__ == 'AnemllQATLinear'}
v2_layers = {name: m for name, m in model.named_modules() 
             if type(m).__name__ == 'AnemllQATLinearV2'}

print(f'Found {len(v1_layers)} V1 layers, {len(v2_layers)} V2 layers')

# Convert each layer
for name in v1_layers:
    if name in v2_layers:
        convert_v1_layer_to_v2(v1_layers[name], v2_layers[name])
        converted += 1
        if converted <= 3:
            print(f'  Converted: {name}')

print(f'\nConverted {converted} layers to V2 FP16 format')

# Free V1 model memory
del v1_model, v1_layers
import gc
gc.collect()
torch.cuda.empty_cache()
print('V1 model freed from memory')

In [None]:
# ============================================================
# STEP 4: ENSURE ALL V2 LAYERS ARE FP16
# ============================================================
# This recomputes LUT in FP16 to ensure indices will be computed in FP16

print('Ensuring all V2 layers are FP16 (recomputing LUT)...')
convert_model_to_fp16_v2(model, verbose=True)

# Verify FP16 conversion
for name, module in model.named_modules():
    if type(module).__name__ == 'AnemllQATLinearV2':
        print(f'\nVerifying {name}:')
        print(f'  weight.dtype: {module.weight.dtype}')
        print(f'  scale_A.dtype: {module.scale_A.dtype}')
        print(f'  lut.dtype: {module.lut.dtype}')
        print(f'  lut range: [{module.lut.min():.4f}, {module.lut.max():.4f}]')
        break

In [None]:
# ============================================================
# STEP 5: FREEZE Q IN FP16
# ============================================================
# Indices computed in FP16 = same precision as ANE!

model.to(DEVICE)

print('Freezing Q (computing indices in FP16)...')
freeze_Q_all(model, verbose=False)
print('Q frozen for all layers in FP16.')

# Verify Q is frozen
for name, module in model.named_modules():
    if type(module).__name__ == 'AnemllQATLinearV2':
        print(f'\nVerifying {name}:')
        print(f'  _Q.dtype: {module._Q.dtype if module._Q is not None else "None"}')
        print(f'  _Q.shape: {module._Q.shape if module._Q is not None else "None"}')
        print(f'  _indices.dtype: {module._indices.dtype if module._indices is not None else "None"}')
        break

# Evaluate converted V2 model
model.eval()
v2_converted_loss = evaluate_kd_loss(model, cache_local_path, DEVICE, num_samples=50, temperature=DISTILL_TEMP)

print(f'\n=== Conversion Results ===')
print(f'V1 Loss: {v1_loss:.4f}')
print(f'V2 Loss (after FP16 conversion): {v2_converted_loss:.4f}')
print(f'Difference: {abs(v2_converted_loss - v1_loss):.4f}')

if abs(v2_converted_loss - v1_loss) < 0.5:
    print('Conversion successful - losses are close!')
else:
    print('Note: Some difference expected due to FP16 precision')

# FP16 Training

Training in pure FP16 (model already in FP16):
- Uses `torch.amp.autocast` for FP16 forward pass
- No GradScaler needed (model is already FP16)

In [None]:
# ============================================================
# FP16 E2E TRAINING - MLP SCALES
# ============================================================

print('=== FP16 Training with GradScaler ===')

e2e_mlp_result = train_e2e(
    model=model,
    cache_dir=cache_local_path,
    device=DEVICE,
    max_steps=4000,
    batch_size=BATCH_SIZE,
    lr=5e-4,
    use_cosine_schedule=True,
    warmup_steps=100,
    min_lr_ratio=0.1,
    temperature=DISTILL_TEMP,
    train_weights=False,
    train_scales=True,
    hard_top1_weight=0.0,
    hard_full_weight=0.0,
    logging_steps=20,
    eval_steps=100,
    verbose=True,
    train_mlp_only=True,
    use_fp16=True,  # FP16 with GradScaler!
)

print(f'\nMLP Training Result:')
print(f'  Initial: {e2e_mlp_result["initial_loss"]:.4f}')
print(f'  Final: {e2e_mlp_result["final_loss"]:.4f}')
print(f'  Best: {e2e_mlp_result["best_loss"]:.4f}')

In [None]:
# ============================================================
# FP16 E2E TRAINING - ATTENTION SCALES
# ============================================================

unfreeze_model_for_training_v2(model)

# Freeze MLP scales, only train attention
for name, module in model.named_modules():
    if type(module).__name__ == 'AnemllQATLinearV2':
        is_attn = any(x in name for x in ['q_proj', 'k_proj', 'v_proj', 'o_proj'])
        if hasattr(module, 'scale_A') and module.scale_A is not None:
            module.scale_A.requires_grad = is_attn
            module.scale_B.requires_grad = is_attn
            module.rank_magnitude.requires_grad = is_attn
        module.weight.requires_grad = False

e2e_attn_result = train_e2e(
    model=model,
    cache_dir=cache_local_path,
    device=DEVICE,
    max_steps=2000,
    batch_size=BATCH_SIZE,
    lr=1e-4,
    use_cosine_schedule=True,
    warmup_steps=100,
    min_lr_ratio=0.1,
    temperature=DISTILL_TEMP,
    train_weights=False,
    train_scales=True,
    hard_top1_weight=0.0,
    hard_full_weight=0.0,
    logging_steps=20,
    eval_steps=100,
    verbose=True,
    train_mlp_only=False,
    use_fp16=True,
)

print(f'\nAttention Training Result:')
print(f'  Initial: {e2e_attn_result["initial_loss"]:.4f}')
print(f'  Final: {e2e_attn_result["final_loss"]:.4f}')

In [None]:
# ============================================================
# FP16 E2E TRAINING - JOINT MLP + ATTENTION
# ============================================================

unfreeze_model_for_training_v2(model)

# Enable ALL scales
for name, module in model.named_modules():
    if type(module).__name__ == 'AnemllQATLinearV2':
        if hasattr(module, 'scale_A') and module.scale_A is not None:
            module.scale_A.requires_grad = True
            module.scale_B.requires_grad = True
            module.rank_magnitude.requires_grad = True
        module.weight.requires_grad = False

e2e_joint_result = train_e2e(
    model=model,
    cache_dir=cache_local_path,
    device=DEVICE,
    max_steps=1000,
    batch_size=32,
    lr=5e-5,
    use_cosine_schedule=True,
    warmup_steps=50,
    min_lr_ratio=0.1,
    temperature=DISTILL_TEMP,
    train_weights=False,
    train_scales=True,
    hard_top1_weight=0.0,
    hard_full_weight=0.0,
    logging_steps=20,
    eval_steps=100,
    verbose=True,
    train_mlp_only=False,
    use_fp16=True,
)

print(f'\nJoint Training Result:')
print(f'  Initial: {e2e_joint_result["initial_loss"]:.4f}')
print(f'  Final: {e2e_joint_result["final_loss"]:.4f}')

In [None]:
# ============================================================
# SAVE FP16 CHECKPOINT (NO SNAP NEEDED!)
# ============================================================
# Model is already in FP16, indices computed in FP16
# No snap_for_ane() needed - model is ANE-ready!

import json
import os

RUN_NAME = f'anemll_v2_{QUAL}_fp16_trained'
SAVE_DIR = f'{LOCAL_RUNS}/{RUN_NAME}'
os.makedirs(SAVE_DIR, exist_ok=True)

# Save state dict
torch.save(model.state_dict(), f'{SAVE_DIR}/model_state_dict.pt')

# Save config
config = {
    'model_id': MODEL_ID,
    'version': 'v2',
    'precision': 'fp16',  # FP16 trained!
    'lut_bits': LUT_BITS,
    'attn_lut_bits': ATTN_LUT_BITS,
    'scale_rank': SCALE_RANK,
    'attn_scale_rank': ATTN_SCALE_RANK,
    'v1_loss': v1_loss,
    'v2_converted_loss': v2_converted_loss,
    'final_loss': e2e_joint_result['final_loss'],
    'training_mode': 'fp16_pure',
}
with open(f'{SAVE_DIR}/config.json', 'w') as f:
    json.dump(config, f, indent=2)

print(f'\n=== FP16 Checkpoint Saved ===')
print(f'Path: {SAVE_DIR}')
print(f'Precision: FP16 (ANE-ready, no snap needed)')
print(f'V1 Loss: {v1_loss:.4f}')
print(f'V2 Converted Loss: {v2_converted_loss:.4f}')
print(f'Final Loss: {e2e_joint_result["final_loss"]:.4f}')

In [None]:
# Upload to Google Drive
!tar -czvf {RUN_NAME}.tgz -C {LOCAL_RUNS} {RUN_NAME}
!cp {RUN_NAME}.tgz {GD_RUNS}/
print(f'Uploaded to {GD_RUNS}/{RUN_NAME}.tgz')

In [None]:
# ============================================================
# TEST FP16 INFERENCE
# ============================================================

# Freeze for fast inference
freeze_model_for_inference_v2(model, verbose=False)

def run_inference(model, tokenizer, prompt, max_new_tokens=256):
    messages = [{'role': 'user', 'content': prompt}]
    text = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True, enable_thinking=True
    )
    inputs = tokenizer(text, return_tensors='pt').to(DEVICE)
    
    with torch.no_grad():
        output = model.generate(
            **inputs, max_new_tokens=max_new_tokens,
            do_sample=True, temperature=0.6, top_p=0.9,
        )
    return tokenizer.decode(output[0][inputs['input_ids'].shape[1]:], skip_special_tokens=False)

model.eval()

prompts = [
    'What is the capital of France?',
    'What is 2+2?',
    'Explain quantum mechanics briefly.',
]

print('=== FP16 Inference Test ===')
for prompt in prompts:
    response = run_inference(model, tokenizer, prompt)
    print(f'\nPrompt: {prompt}')
    print(f'Response: {response}')
    print('-' * 50)

# Summary

## V1→V2 FP16 Conversion Complete!

Pipeline executed:
1. ✅ Loaded V1 checkpoint
2. ✅ Converted V1 → V2 (extracted norms into rank_magnitude)
3. ✅ All tensors in FP16 (LUT, scales, weights)
4. ✅ Indices computed in FP16 (same as ANE)
5. ✅ Fine-tuned in FP16
6. ✅ Saved - no snap needed!

## Key Benefits:
- **Preserves V1 training**: Converts trained scales to V2 format
- **No precision mismatch**: Indices computed in FP16 = same as ANE
- **ANE-ready**: Direct export without conversion

## Next Steps:
1. Convert to CoreML with ANEMLL converter
2. Deploy to ANE
3. No additional snapping needed!