# Quantization: State-of-the-Art Methods

**Objective:** Compare simple quantization methods against production-grade GPTQ, AWQ, and bitsandbytes to understand the gap.

**Key Questions:**
1. How much better is GPTQ/AWQ (with Hessian-based error compensation) than simple uniform quantization?
2. Does NF4 (Gaussian-optimized) beat FP4 (uniform log-spacing)? This validates our Gaussian assumption.
3. Is weight MSE a good proxy for functional distortion (output MSE)?

**Expected findings:**
- GPTQ/AWQ should achieve ~0.5-1.0 bits better than simple methods due to error compensation
- NF4 should beat FP4, confirming near-Gaussian weight distributions
- Weight importance should be non-uniform (functional MSE != weight MSE)

In [None]:
# === AUTHENTICATION (required) ===
from huggingface_hub import login

# Paste your token directly as a string argument
login(token="...")

# After running successfully, DELETE this cell or clear the token string

In [None]:
# === Install dependencies ===
# Uncomment the lines below to install required packages

%pip install -q torch torchvision
%pip install -q numpy scipy scikit-learn matplotlib pandas transformers accelerate
%pip install -q bitsandbytes  # For NF4/FP4 quantization

: 

In [None]:
# Core imports
import math
import numpy as np
import torch
import matplotlib.pyplot as plt
import pandas as pd
from transformers import AutoModelForCausalLM, AutoTokenizer

In [None]:
# Load model (Llama 3.2 1B by default)
model_name = "meta-llama/Llama-3.2-1B"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float32,
    device_map="auto",
)
model.eval()

print(f"Loaded: {model_name}")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# Utility functions needed for baselines

def symmetric_quantize(x, bits, per_channel=False, axis=1):
    """
    Symmetric uniform quantization to signed integers.
    per_channel=True uses per-output-channel scaling (axis=1 for [out, in]).
    Returns dequantized array and scale(s).
    """
    x = x.astype(np.float32)
    qmax = (2 ** (bits - 1)) - 1

    if per_channel:
        max_abs = np.max(np.abs(x), axis=axis, keepdims=True) + 1e-12
        scale = max_abs / qmax
        q = np.clip(np.round(x / scale), -qmax, qmax)
        dq = q * scale
    else:
        max_abs = float(np.max(np.abs(x)))
        scale = max_abs / qmax if max_abs > 0 else 1.0
        q = np.clip(np.round(x / scale), -qmax, qmax)
        dq = q * scale

    return dq.astype(np.float32), scale


def shannon_distortion(sigma_sq, rate_bits):
    return sigma_sq / (4 ** rate_bits)


def gap_bits(mse, d_shannon):
    if mse <= 0 or d_shannon <= 0:
        return 0.0
    return 0.5 * math.log2(mse / d_shannon)


def quantize_group(weights, bits, group_size=128):
    flat = weights.flatten().astype(np.float32)
    n = len(flat)

    pad_size = (group_size - n % group_size) % group_size
    if pad_size > 0:
        flat_padded = np.concatenate([flat, np.zeros(pad_size, dtype=np.float32)])
    else:
        flat_padded = flat

    groups = flat_padded.reshape(-1, group_size)
    levels = 2 ** bits
    quantized_groups = np.zeros_like(groups)

    for i in range(groups.shape[0]):
        group = groups[i]
        max_abs = float(np.max(np.abs(group)))
        if max_abs == 0:
            quantized_groups[i] = group
            continue
        scale = (2 * max_abs) / (levels - 1)
        quantized_groups[i] = np.round(group / scale) * scale

    quantized = quantized_groups.flatten()[:n]
    mse = float(np.mean((flat - quantized) ** 2))
    effective_bits = bits + 16 / group_size
    return quantized, mse, effective_bits

## Load Pre-Quantized GPTQ/AWQ Models

Install required libraries and load quantized checkpoints.

In [None]:
# ============================================
# Load Original FP16 Model (already loaded above)
# ============================================
# We'll reuse the 'model' variable from earlier as our FP16 baseline

# ============================================
# Load AWQ 4-bit Quantized Model
# ============================================
print("Loading AWQ 4-bit model...")
try:
    model_awq = AutoModelForCausalLM.from_pretrained(
        "AMead10/Llama-3.2-1B-Instruct-AWQ",
        torch_dtype=torch.float16,
        device_map="auto",
    )

    print("✓ AWQ model loaded successfully")
    AWQ_AVAILABLE = True

except Exception as e:
    print(f"✗ AWQ model loading failed: {e}")
    print("  Try upgrading transformers or verifying the checkpoint supports HF quantized weights")
    model_awq = None
    AWQ_AVAILABLE = False

# ============================================
# Load GPTQ 4-bit Quantized Model
# ============================================
print("Loading GPTQ 4-bit model...")
try:
    model_gptq = AutoModelForCausalLM.from_pretrained(
        "clowman/Llama-3.2-1B-Instruct-GPTQ-Int4",
        torch_dtype=torch.float16,
        device_map="auto",
    )

    print("✓ GPTQ model loaded successfully")
    GPTQ_AVAILABLE = True

except Exception as e:
    print(f"✗ GPTQ model loading failed: {e}")
    print("  Try upgrading transformers or verifying the checkpoint supports HF quantized weights")
    model_gptq = None
    GPTQ_AVAILABLE = False

print(f"\nModel availability summary:")
print(f"  FP16 baseline: ✓ (already loaded)")
print(f"  AWQ 4-bit: {'✓' if AWQ_AVAILABLE else '✗'}")
print(f"  GPTQ 4-bit: {'✓' if GPTQ_AVAILABLE else '✗'}")

In [None]:
# ============================================
# Extract and Compare Weights Across Models
# ============================================

# Target layer for comparison
target_layer = 8
target_proj = "mlp.down_proj"

# Extract FP16 baseline weights (convert to fp32 for comparison)
w_fp16 = model.model.layers[target_layer].mlp.down_proj.weight.detach().cpu().float().numpy()

print(f"Analyzing layer {target_layer}, projection: {target_proj}")
print(f"Weight shape: {w_fp16.shape}")
print(f"FP16 variance: {np.var(w_fp16):.6e}")
print(f"FP16 kurtosis: {np.mean(((w_fp16 - w_fp16.mean()) / w_fp16.std())**4):.2f}")
print()

# Container for results
sota_results = {
    "method": ["FP16 (baseline)"],
    "mse": [0.0],
    "bits_per_weight": [16.0],
    "sqnr_db": [float('inf')],
}

# Extract AWQ weights if available
if AWQ_AVAILABLE and model_awq is not None:
    try:
        w_awq = model_awq.model.layers[target_layer].mlp.down_proj.weight.detach().cpu().float().numpy()
        mse_awq = np.mean((w_fp16 - w_awq) ** 2)
        sqnr_awq = 10 * np.log10(np.mean(w_fp16 ** 2) / mse_awq) if mse_awq > 0 else float('inf')
        
        sota_results["method"].append("AWQ 4-bit")
        sota_results["mse"].append(mse_awq)
        sota_results["bits_per_weight"].append(4.0)
        sota_results["sqnr_db"].append(sqnr_awq)
        
        print(f"AWQ 4-bit:")
        print(f"  MSE:     {mse_awq:.2e}")
        print(f"  SQNR:    {sqnr_awq:.2f} dB")
        print(f"  Max err: {np.max(np.abs(w_fp16 - w_awq)):.2e}")
        print()
    except Exception as e:
        print(f"Failed to extract AWQ weights: {e}\n")

# Extract GPTQ weights if available
if GPTQ_AVAILABLE and model_gptq is not None:
    try:
        w_gptq = model_gptq.model.layers[target_layer].mlp.down_proj.weight.detach().cpu().float().numpy()
        mse_gptq = np.mean((w_fp16 - w_gptq) ** 2)
        sqnr_gptq = 10 * np.log10(np.mean(w_fp16 ** 2) / mse_gptq) if mse_gptq > 0 else float('inf')
        
        sota_results["method"].append("GPTQ 4-bit")
        sota_results["mse"].append(mse_gptq)
        sota_results["bits_per_weight"].append(4.0)
        sota_results["sqnr_db"].append(sqnr_gptq)
        
        print(f"GPTQ 4-bit:")
        print(f"  MSE:     {mse_gptq:.2e}")
        print(f"  SQNR:    {sqnr_gptq:.2f} dB")
        print(f"  Max err: {np.max(np.abs(w_fp16 - w_gptq)):.2e}")
        print()
    except Exception as e:
        print(f"Failed to extract GPTQ weights: {e}\n")

# Add simple methods for comparison (4-bit)
print("Comparing to simple quantization methods:")

# Per-tensor symmetric 4-bit
dq_per_tensor, _ = symmetric_quantize(w_fp16, 4, per_channel=False)
mse_per_tensor = np.mean((w_fp16 - dq_per_tensor) ** 2)
sqnr_per_tensor = 10 * np.log10(np.mean(w_fp16 ** 2) / mse_per_tensor)
sota_results["method"].append("Simple per-tensor")
sota_results["mse"].append(mse_per_tensor)
sota_results["bits_per_weight"].append(4.0)
sota_results["sqnr_db"].append(sqnr_per_tensor)
print(f"  Simple per-tensor 4-bit: MSE={mse_per_tensor:.2e}, SQNR={sqnr_per_tensor:.2f} dB")

# Per-channel symmetric 4-bit
dq_per_channel, _ = symmetric_quantize(w_fp16, 4, per_channel=True, axis=1)
mse_per_channel = np.mean((w_fp16 - dq_per_channel) ** 2)
sqnr_per_channel = 10 * np.log10(np.mean(w_fp16 ** 2) / mse_per_channel)
sota_results["method"].append("Simple per-channel")
sota_results["mse"].append(mse_per_channel)
sota_results["bits_per_weight"].append(4.0)
sota_results["sqnr_db"].append(sqnr_per_channel)
print(f"  Simple per-channel 4-bit: MSE={mse_per_channel:.2e}, SQNR={sqnr_per_channel:.2f} dB")

# Group quantization (g=128)
_, mse_group128, eff_bits_g128 = quantize_group(w_fp16, 4, group_size=128)
sqnr_group128 = 10 * np.log10(np.mean(w_fp16 ** 2) / mse_group128)
sota_results["method"].append("Group g=128")
sota_results["mse"].append(mse_group128)
sota_results["bits_per_weight"].append(eff_bits_g128)
sota_results["sqnr_db"].append(sqnr_group128)
print(f"  Group g=128 (4-bit): MSE={mse_group128:.2e}, SQNR={sqnr_group128:.2f} dB, eff_bits={eff_bits_g128:.3f}")

# Shannon bound (Gaussian assumption)
sigma_sq = np.var(w_fp16)
d_shannon_4bit = shannon_distortion(sigma_sq, 4.0)
sqnr_shannon = 10 * np.log10(np.mean(w_fp16 ** 2) / d_shannon_4bit)
sota_results["method"].append("Shannon bound (4-bit)")
sota_results["mse"].append(d_shannon_4bit)
sota_results["bits_per_weight"].append(4.0)
sota_results["sqnr_db"].append(sqnr_shannon)
print(f"  Shannon bound (4-bit): MSE={d_shannon_4bit:.2e}, SQNR={sqnr_shannon:.2f} dB")
print()

# Create summary DataFrame
df_sota = pd.DataFrame(sota_results)
df_sota = df_sota.sort_values("mse", ascending=True)
print("=" * 70)
print("SUMMARY TABLE (sorted by MSE, lower is better)")
print("=" * 70)
print(df_sota.to_string(index=False))
print("=" * 70)

In [None]:
# ============================================
# Visualize SOTA Comparison
# ============================================

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Plot 1: MSE comparison (bar chart)
methods_plot = [m for m in df_sota["method"] if m != "FP16 (baseline)"]
mses_plot = [df_sota[df_sota["method"] == m]["mse"].values[0] for m in methods_plot]

colors = []
for m in methods_plot:
    if "Shannon" in m:
        colors.append('black')
    elif "AWQ" in m or "GPTQ" in m:
        colors.append('green')
    elif "Group" in m:
        colors.append('blue')
    else:
        colors.append('orange')

ax1.bar(range(len(methods_plot)), mses_plot, color=colors, alpha=0.7)
ax1.set_xticks(range(len(methods_plot)))
ax1.set_xticklabels(methods_plot, rotation=45, ha='right')
ax1.set_yscale('log')
ax1.set_ylabel('MSE (log scale)')
ax1.set_title(f'Weight MSE Comparison (Layer {target_layer})')
ax1.grid(True, alpha=0.3, axis='y')

# Plot 2: SQNR comparison (higher is better)
sqnrs_plot = [df_sota[df_sota["method"] == m]["sqnr_db"].values[0] for m in methods_plot]
sqnrs_plot = [s if s != float('inf') else 100 for s in sqnrs_plot]

ax2.bar(range(len(methods_plot)), sqnrs_plot, color=colors, alpha=0.7)
ax2.set_xticks(range(len(methods_plot)))
ax2.set_xticklabels(methods_plot, rotation=45, ha='right')
ax2.set_ylabel('SQNR (dB)')
ax2.set_title(f'Signal-to-Quantization-Noise Ratio (Layer {target_layer})')
ax2.grid(True, alpha=0.3, axis='y')
ax2.axhline(y=sqnr_shannon, color='black', linestyle='--', linewidth=1, alpha=0.5, label='Shannon bound')

plt.tight_layout()
plt.show()

# Calculate and display gaps
print("\nGap analysis (bits from Shannon bound at 4-bit):")
print("=" * 60)
for method in methods_plot:
    if "Shannon" in method:
        continue
    row = df_sota[df_sota["method"] == method].iloc[0]
    mse = row["mse"]
    gap = gap_bits(mse, d_shannon_4bit)
    print(f"{method:25s}: {gap:+.3f} bits gap")
print("=" * 60)

## Compare bitsandbytes NF4 vs FP4

**Hypothesis:** If weights are near-Gaussian, NF4 (optimized for Gaussian) should outperform FP4 (uniform log-spacing).

In [None]:
# ============================================
# Load with NF4 (NormalFloat4 - optimized for Gaussian)
# ============================================
print("Loading NF4 (Gaussian-optimized) quantized model...")
try:
    from transformers import BitsAndBytesConfig
    
    config_nf4 = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=False
    )
    
    model_nf4 = AutoModelForCausalLM.from_pretrained(
        "meta-llama/Llama-3.2-1B",
        quantization_config=config_nf4,
        device_map="auto",
        torch_dtype=torch.float16
    )
    
    print("✓ NF4 model loaded successfully")
    NF4_AVAILABLE = True
    
except Exception as e:
    print(f"✗ NF4 loading failed: {e}")
    print("  Install with: pip install bitsandbytes")
    model_nf4 = None
    NF4_AVAILABLE = False

# ============================================
# Load with FP4 (uniform log-spacing)
# ============================================
print("\nLoading FP4 (uniform log-spacing) quantized model...")
try:
    config_fp4 = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="fp4",
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=False
    )
    
    model_fp4 = AutoModelForCausalLM.from_pretrained(
        "meta-llama/Llama-3.2-1B",
        quantization_config=config_fp4,
        device_map="auto",
        torch_dtype=torch.float16
    )
    
    print("✓ FP4 model loaded successfully")
    FP4_AVAILABLE = True
    
except Exception as e:
    print(f"✗ FP4 loading failed: {e}")
    model_fp4 = None
    FP4_AVAILABLE = False

print(f"\nbitsandbytes availability:")
print(f"  NF4 (Gaussian-optimized): {'✓' if NF4_AVAILABLE else '✗'}")
print(f"  FP4 (uniform log):        {'✓' if FP4_AVAILABLE else '✗'}")

In [None]:
# ============================================
# Extract and Compare NF4 vs FP4 Weights
# ============================================

bnb_results = {
    "method": [],
    "mse": [],
    "sqnr_db": [],
}

# We already have FP16 baseline from earlier
w_fp16_ref = model.model.layers[target_layer].mlp.down_proj.weight.detach().cpu().float().numpy()

print(f"Comparing NF4 vs FP4 on layer {target_layer}, {target_proj}")
print(f"FP16 reference shape: {w_fp16_ref.shape}")
print()

# Extract NF4 weights
if NF4_AVAILABLE and model_nf4 is not None:
    try:
        w_nf4_tensor = model_nf4.model.layers[target_layer].mlp.down_proj.weight
        w_nf4 = w_nf4_tensor.detach().cpu().float().numpy()
        
        mse_nf4 = np.mean((w_fp16_ref - w_nf4) ** 2)
        sqnr_nf4 = 10 * np.log10(np.mean(w_fp16_ref ** 2) / mse_nf4) if mse_nf4 > 0 else float('inf')
        
        bnb_results["method"].append("NF4 (Gaussian-optimized)")
        bnb_results["mse"].append(mse_nf4)
        bnb_results["sqnr_db"].append(sqnr_nf4)
        
        print(f"NF4 (Gaussian-optimized):")
        print(f"  MSE:     {mse_nf4:.2e}")
        print(f"  SQNR:    {sqnr_nf4:.2f} dB")
        print(f"  Max err: {np.max(np.abs(w_fp16_ref - w_nf4)):.2e}")
        print()
        
    except Exception as e:
        print(f"Failed to extract NF4 weights: {e}\n")

# Extract FP4 weights
if FP4_AVAILABLE and model_fp4 is not None:
    try:
        w_fp4_tensor = model_fp4.model.layers[target_layer].mlp.down_proj.weight
        w_fp4 = w_fp4_tensor.detach().cpu().float().numpy()
        
        mse_fp4 = np.mean((w_fp16_ref - w_fp4) ** 2)
        sqnr_fp4 = 10 * np.log10(np.mean(w_fp16_ref ** 2) / mse_fp4) if mse_fp4 > 0 else float('inf')
        
        bnb_results["method"].append("FP4 (uniform log)")
        bnb_results["mse"].append(mse_fp4)
        bnb_results["sqnr_db"].append(sqnr_fp4)
        
        print(f"FP4 (uniform log-spacing):")
        print(f"  MSE:     {mse_fp4:.2e}")
        print(f"  SQNR:    {sqnr_fp4:.2f} dB")
        print(f"  Max err: {np.max(np.abs(w_fp16_ref - w_fp4)):.2e}")
        print()
        
    except Exception as e:
        print(f"Failed to extract FP4 weights: {e}\n")

# Add Shannon bound for reference
bnb_results["method"].append("Shannon bound")
bnb_results["mse"].append(d_shannon_4bit)
bnb_results["sqnr_db"].append(sqnr_shannon)

# Display comparison
if len(bnb_results["method"]) > 1:
    df_bnb = pd.DataFrame(bnb_results)
    df_bnb = df_bnb.sort_values("mse", ascending=True)
    print("=" * 60)
    print("NF4 vs FP4 COMPARISON")
    print("=" * 60)
    print(df_bnb.to_string(index=False))
    print("=" * 60)
    
    # Interpretation
    if NF4_AVAILABLE and FP4_AVAILABLE:
        improvement = (mse_fp4 - mse_nf4) / mse_fp4 * 100 if mse_fp4 > 0 else 0
        print(f"\nNF4 achieves {improvement:.1f}% lower MSE than FP4")
        if improvement > 5:
            print("✓ VALIDATION: NF4 significantly outperforms FP4")
            print("  → Confirms weights are approximately Gaussian")
        else:
            print("⚠ UNEXPECTED: NF4 and FP4 perform similarly")
            print("  → Weights may be more heavy-tailed than expected")
else:
    print("⚠ Could not load bitsandbytes models for comparison")

In [None]:
# ============================================
# Visualize NF4 vs FP4 Comparison
# ============================================

if len(bnb_results["method"]) > 1:
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    methods = df_bnb["method"].tolist()
    mses = df_bnb["mse"].tolist()
    sqnrs = df_bnb["sqnr_db"].tolist()
    
    colors = ['green' if 'NF4' in m else 'orange' if 'FP4' in m else 'black' for m in methods]
    
    # MSE comparison
    ax1.bar(range(len(methods)), mses, color=colors, alpha=0.7)
    ax1.set_xticks(range(len(methods)))
    ax1.set_xticklabels(methods, rotation=45, ha='right')
    ax1.set_yscale('log')
    ax1.set_ylabel('MSE (log scale)')
    ax1.set_title('NF4 vs FP4: Weight MSE')
    ax1.grid(True, alpha=0.3, axis='y')
    
    # SQNR comparison
    sqnrs_plot = [s if s != float('inf') else 100 for s in sqnrs]
    ax2.bar(range(len(methods)), sqnrs_plot, color=colors, alpha=0.7)
    ax2.set_xticks(range(len(methods)))
    ax2.set_xticklabels(methods, rotation=45, ha='right')
    ax2.set_ylabel('SQNR (dB)')
    ax2.set_title('NF4 vs FP4: Signal Quality')
    ax2.grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.show()
    
    # Gap analysis
    print("\nGap from Shannon bound:")
    print("=" * 50)
    for method, mse in zip(methods, mses):
        if "Shannon" not in method:
            gap = gap_bits(mse, d_shannon_4bit)
            print(f"{method:25s}: {gap:+.3f} bits")
    print("=" * 50)
else:
    print("Visualization skipped - bitsandbytes models not available")

## Measure Functional Distortion

**Key Question:** Is weight MSE a good proxy for output distortion?

AWQ's insight: Some weights matter more than others. If weight importance is non-uniform, we should see that weight MSE doesn't correlate perfectly with functional MSE (output logits error).

In [None]:
# ============================================
# Prepare Calibration Data
# ============================================

print("Preparing calibration dataset for functional distortion measurement...")

calibration_texts = [
    "The quick brown fox jumps over the lazy dog.",
    "Machine learning is a subset of artificial intelligence.",
    "The weather today is sunny with a chance of rain.",
    "Python is a popular programming language for data science.",
    "Transformers have revolutionized natural language processing.",
    "Climate change is one of the biggest challenges facing humanity.",
    "The stock market experienced volatility in recent months.",
    "Quantum computing promises to solve complex problems faster.",
    "Renewable energy sources include solar, wind, and hydroelectric power.",
    "The human brain contains approximately 86 billion neurons.",
]

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

calibration_inputs = tokenizer(
    calibration_texts,
    return_tensors="pt",
    padding=True,
    truncation=True,
    max_length=128
)

print(f"Calibration dataset: {len(calibration_texts)} samples")
print(f"Input shape: {calibration_inputs['input_ids'].shape}")
print()

In [None]:
# ============================================
# Compute Functional Distortion (Logits MSE)
# ============================================

def compute_functional_mse(model1, model2, inputs, device="cuda"):
    """
    Compute MSE between output logits of two models.
    """
    model1.eval()
    model2.eval()
    
    input_ids = inputs["input_ids"].to(device)
    attention_mask = inputs["attention_mask"].to(device)
    
    with torch.no_grad():
        outputs1 = model1(input_ids=input_ids, attention_mask=attention_mask)
        outputs2 = model2(input_ids=input_ids, attention_mask=attention_mask)
        
        logits1 = outputs1.logits.cpu().float()
        logits2 = outputs2.logits.cpu().float()
        
        functional_mse = torch.mean((logits1 - logits2) ** 2).item()
        per_token_mse = torch.mean((logits1 - logits2) ** 2, dim=(0, 2)).numpy()
        
    return functional_mse, per_token_mse


device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
print()

functional_results = {
    "method": [],
    "weight_mse": [],
    "functional_mse": [],
    "ratio": [],
}

print("Computing functional distortion...")
print("=" * 60)

if device == "cuda":
    try:
        model = model.to(device)
    except:
        pass

models_to_compare = []

if AWQ_AVAILABLE and model_awq is not None:
    models_to_compare.append(("AWQ 4-bit", model_awq, mse_awq if 'mse_awq' in locals() else None))

if GPTQ_AVAILABLE and model_gptq is not None:
    models_to_compare.append(("GPTQ 4-bit", model_gptq, mse_gptq if 'mse_gptq' in locals() else None))

if NF4_AVAILABLE and model_nf4 is not None:
    models_to_compare.append(("NF4 4-bit", model_nf4, mse_nf4 if 'mse_nf4' in locals() else None))

if FP4_AVAILABLE and model_fp4 is not None:
    models_to_compare.append(("FP4 4-bit", model_fp4, mse_fp4 if 'mse_fp4' in locals() else None))

for method_name, quant_model, weight_mse in models_to_compare:
    try:
        print(f"Computing functional MSE for {method_name}...")
        func_mse, per_token = compute_functional_mse(model, quant_model, calibration_inputs, device=device)
        
        functional_results["method"].append(method_name)
        functional_results["weight_mse"].append(weight_mse if weight_mse is not None else np.nan)
        functional_results["functional_mse"].append(func_mse)
        functional_results["ratio"].append(func_mse / weight_mse if weight_mse and weight_mse > 0 else np.nan)
        
        print(f"  Weight MSE:     {weight_mse:.2e}" if weight_mse else "  Weight MSE:     N/A")
        print(f"  Functional MSE: {func_mse:.2e}")
        print(f"  Ratio (F/W):    {func_mse / weight_mse:.2e}" if weight_mse and weight_mse > 0 else "  Ratio (F/W):    N/A")
        print()
        
    except Exception as e:
        print(f"  ✗ Failed: {e}\n")

if len(functional_results["method"]) > 0:
    df_functional = pd.DataFrame(functional_results)
    print("=" * 70)
    print("FUNCTIONAL DISTORTION SUMMARY")
    print("=" * 70)
    print(df_functional.to_string(index=False))
    print("=" * 70)
else:
    print("⚠ No models available for functional distortion comparison")

In [None]:
# ============================================
# Visualize Weight MSE vs Functional MSE
# ============================================

if len(functional_results["method"]) > 0:
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    
    valid_mask = ~np.isnan(df_functional["weight_mse"]) & ~np.isnan(df_functional["functional_mse"])
    df_valid = df_functional[valid_mask]
    
    if len(df_valid) > 0:
        ax1.scatter(df_valid["weight_mse"], df_valid["functional_mse"], s=100, alpha=0.7)
        
        for i, row in df_valid.iterrows():
            ax1.annotate(row["method"], 
                        (row["weight_mse"], row["functional_mse"]),
                        xytext=(5, 5), textcoords='offset points', fontsize=9)
        
        ax1.set_xlabel('Weight MSE (log scale)')
        ax1.set_ylabel('Functional MSE (log scale)')
        ax1.set_xscale('log')
        ax1.set_yscale('log')
        ax1.set_title('Weight MSE vs Functional MSE')
        ax1.grid(True, alpha=0.3)
        
        xlim = ax1.get_xlim()
        ylim = ax1.get_ylim()
        min_val = max(xlim[0], ylim[0])
        max_val = min(xlim[1], ylim[1])
        ax1.plot([min_val, max_val], [min_val, max_val], 'k--', alpha=0.3, linewidth=1, label='y=x')
        ax1.legend()
    
    methods = df_functional["method"].tolist()
    func_mses = df_functional["functional_mse"].tolist()
    
    colors = ['green' if 'AWQ' in m or 'GPTQ' in m else 'blue' if 'NF4' in m else 'orange' for m in methods]
    
    ax2.bar(range(len(methods)), func_mses, color=colors, alpha=0.7)
    ax2.set_xticks(range(len(methods)))
    ax2.set_xticklabels(methods, rotation=45, ha='right')
    ax2.set_yscale('log')
    ax2.set_ylabel('Functional MSE (log scale)')
    ax2.set_title('Output Logits Error by Method')
    ax2.grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.show()
    
    if len(df_valid) > 1:
        from scipy.stats import spearmanr, pearsonr
        
        corr_pearson, p_pearson = pearsonr(np.log(df_valid["weight_mse"]), np.log(df_valid["functional_mse"]))
        corr_spearman, p_spearman = spearmanr(df_valid["weight_mse"], df_valid["functional_mse"])
        
        print("\nCorrelation Analysis:")
        print("=" * 60)
        print(f"Pearson correlation (log-log):  r={corr_pearson:.3f}, p={p_pearson:.4f}")
        print(f"Spearman correlation (rank):    ρ={corr_spearman:.3f}, p={p_spearman:.4f}")
        print("=" * 60)
        
        if corr_spearman > 0.8:
            print("✓ Strong correlation: Weight MSE is a good proxy for functional MSE")
        elif corr_spearman > 0.5:
            print("⚠ Moderate correlation: Weight MSE partially predicts functional MSE")
        else:
            print("✗ Weak correlation: Weight importance is highly non-uniform")
            print("  → AWQ's per-channel importance weighting is crucial")
else:
    print("Visualization skipped - no functional distortion data available")

## Summary: Key Findings from SOTA Comparison

**Expected Results:**

1. **GPTQ/AWQ vs Simple Methods:**
   - GPTQ/AWQ should achieve 0.5-1.0 bits better compression at same distortion
   - This gap comes from Hessian-based error compensation

2. **NF4 vs FP4:**
   - If NF4 >> FP4: Confirms weights are approximately Gaussian
   - If NF4 ≈ FP4: Weights may be heavier-tailed than expected

3. **Weight MSE vs Functional MSE:**
   - Strong correlation (ρ > 0.8): Weight MSE is a good proxy, uniform importance
   - Weak correlation (ρ < 0.5): Weight importance is non-uniform, per-channel schemes critical

**Research Opportunities:**
- If GPTQ/AWQ gap is large: Implement Hessian-based importance weighting
- If NF4 wins significantly: Explore optimal quantization levels for near-Gaussian distributions
- If functional MSE deviates: Investigate per-layer and per-channel importance patterns