# Lab 3.2.5: AWQ Quantization - Solutions

This notebook contains solutions for all exercises in Lab 3.2.5.

In [None]:
import numpy as np
import matplotlib.pyplot as plt

## Exercise 1 Solution: Salient Weight Detection

Implement salient weight detection using activation statistics.

In [None]:
def detect_salient_weights(
    weights: np.ndarray,
    activations: np.ndarray,
    percentile: float = 99.0
) -> np.ndarray:
    """
    Detect salient (important) weights based on activation patterns.
    
    AWQ key insight: Weights that multiply large activations are important.
    
    Args:
        weights: Weight matrix (out_features, in_features)
        activations: Calibration activations (batch, in_features)
        percentile: Percentile threshold for saliency
        
    Returns:
        Saliency mask (in_features,)
    """
    # Compute activation importance per input channel
    # Higher activation magnitude = more important
    activation_importance = np.mean(np.abs(activations), axis=0)
    
    # Compute weight importance
    weight_magnitude = np.mean(np.abs(weights), axis=0)
    
    # Combined importance: channels with large weights AND large activations
    combined_importance = activation_importance * weight_magnitude
    
    # Find threshold
    threshold = np.percentile(combined_importance, percentile)
    
    # Create saliency mask
    salient_mask = combined_importance >= threshold
    
    return salient_mask, combined_importance


# Test salient weight detection
np.random.seed(42)

# Simulate weights with some important channels
weights = np.random.randn(256, 512).astype(np.float32) * 0.02
# Make some channels more important
weights[:, 100:110] *= 5  # Larger weights

# Simulate activations
activations = np.random.randn(100, 512).astype(np.float32)
# Make corresponding channels have larger activations
activations[:, 100:110] *= 3

# Detect salient weights
salient_mask, importance = detect_salient_weights(weights, activations, percentile=95)

print("Salient Weight Detection")
print("="*50)
print(f"Total channels: {len(salient_mask)}")
print(f"Salient channels: {salient_mask.sum()}")
print(f"Salient ratio: {salient_mask.mean()*100:.1f}%")

# Verify detection
detected_important = set(np.where(salient_mask)[0])
actual_important = set(range(100, 110))
overlap = detected_important & actual_important
print(f"\nActual important channels: {sorted(actual_important)}")
print(f"Detected as salient: {len(overlap)}/{len(actual_important)}")

# Visualize
plt.figure(figsize=(12, 4))
plt.bar(range(len(importance)), importance, width=1.0, alpha=0.7)
plt.axhline(np.percentile(importance, 95), color='r', linestyle='--', label='95th percentile')
plt.xlabel('Channel Index')
plt.ylabel('Importance Score')
plt.title('Channel Importance (AWQ Saliency)')
plt.legend()
plt.show()

## Exercise 2 Solution: AWQ Scale Search

Implement the AWQ scale search algorithm.

In [None]:
def awq_scale_search(
    weights: np.ndarray,
    activations: np.ndarray,
    bits: int = 4,
    n_grid: int = 20
) -> np.ndarray:
    """
    Search for optimal per-channel scales using AWQ algorithm.
    
    Args:
        weights: Weight matrix (out_features, in_features)
        activations: Calibration activations (batch, in_features)
        bits: Quantization bits
        n_grid: Grid search points
        
    Returns:
        Optimal scales per channel (in_features,)
    """
    out_features, in_features = weights.shape
    qmax = 2 ** bits - 1
    
    # Compute activation statistics
    act_scales = np.mean(np.abs(activations), axis=0)
    
    # Search range for scale factors
    alphas = np.linspace(0.1, 1.0, n_grid)
    
    best_scales = np.ones(in_features, dtype=np.float32)
    
    # Search for each channel
    for c in range(in_features):
        w_col = weights[:, c]
        a_col = activations[:, c]
        
        # Reference output
        ref_out = np.outer(a_col, w_col)  # (batch, out_features)
        
        best_error = float('inf')
        best_scale = 1.0
        
        for alpha in alphas:
            # AWQ scales weights by activation importance
            scale = act_scales[c] ** alpha if act_scales[c] > 0 else 1.0
            scale = max(scale, 1e-10)  # Avoid zero
            
            # Scale weights up, then quantize
            w_scaled = w_col * scale
            
            # Quantize
            w_min, w_max = w_scaled.min(), w_scaled.max()
            q_scale = (w_max - w_min) / qmax
            w_quant = np.clip(np.round((w_scaled - w_min) / (q_scale + 1e-10)), 0, qmax)
            w_dequant = w_quant * q_scale + w_min
            
            # Scale back down
            w_final = w_dequant / scale
            
            # Compute output error
            test_out = np.outer(a_col, w_final)
            error = np.mean((test_out - ref_out) ** 2)
            
            if error < best_error:
                best_error = error
                best_scale = scale
        
        best_scales[c] = best_scale
    
    return best_scales


# Test AWQ scale search
np.random.seed(42)

weights = np.random.randn(64, 128).astype(np.float32) * 0.02
activations = np.random.randn(100, 128).astype(np.float32)

print("AWQ Scale Search")
print("="*50)

scales = awq_scale_search(weights, activations, bits=4, n_grid=20)

print(f"Scale range: [{scales.min():.4f}, {scales.max():.4f}]")
print(f"Mean scale: {scales.mean():.4f}")

# Compare with and without AWQ scales
def quantize_naive(w, bits=4):
    qmax = 2**bits - 1
    scale = (w.max() - w.min()) / qmax
    q = np.clip(np.round((w - w.min()) / (scale + 1e-10)), 0, qmax)
    return q * scale + w.min()

def quantize_awq(w, awq_scales, bits=4):
    # Scale up
    w_scaled = w * awq_scales
    # Quantize
    w_quant = quantize_naive(w_scaled, bits)
    # Scale back
    return w_quant / awq_scales

w_naive = quantize_naive(weights)
w_awq = quantize_awq(weights, scales)

# Compute output errors
ref_output = activations @ weights.T
naive_output = activations @ w_naive.T
awq_output = activations @ w_awq.T

naive_error = np.mean((naive_output - ref_output) ** 2)
awq_error = np.mean((awq_output - ref_output) ** 2)

print(f"\nOutput Error Comparison:")
print(f"  Naive quantization: {naive_error:.6f}")
print(f"  AWQ quantization:   {awq_error:.6f}")
print(f"  Improvement:        {naive_error/awq_error:.2f}x")

## Exercise 3 Solution: AWQ vs GPTQ Comparison

Systematically compare AWQ and GPTQ on the same model.

In [None]:
def compare_awq_gptq(
    weights: np.ndarray,
    activations: np.ndarray,
    bits: int = 4
) -> dict:
    """
    Compare AWQ and GPTQ on the same layer.
    
    Returns:
        Comparison metrics
    """
    in_features = weights.shape[1]
    
    # Reference output
    ref_output = activations @ weights.T
    
    # Naive quantization
    w_naive = quantize_naive(weights, bits)
    naive_output = activations @ w_naive.T
    naive_mse = np.mean((naive_output - ref_output) ** 2)
    
    # AWQ quantization
    awq_scales = awq_scale_search(weights, activations, bits=bits, n_grid=20)
    w_awq = quantize_awq(weights, awq_scales, bits)
    awq_output = activations @ w_awq.T
    awq_mse = np.mean((awq_output - ref_output) ** 2)
    
    # Simplified GPTQ (using Hessian approximation)
    hessian = (activations.T @ activations) / len(activations)
    hessian += 0.01 * np.eye(in_features) * np.mean(np.diag(hessian))
    
    # Simple GPTQ-style quantization with error compensation
    w_gptq = weights.copy()
    qmax = 2**bits - 1
    
    for i in range(in_features):
        col = w_gptq[:, i]
        scale = (col.max() - col.min()) / qmax
        q = np.clip(np.round((col - col.min()) / (scale + 1e-10)), 0, qmax)
        w_q = q * scale + col.min()
        error = col - w_q
        w_gptq[:, i] = w_q
        
        # Compensate error in remaining columns (simplified)
        if i + 1 < in_features:
            h_ii = hessian[i, i]
            if h_ii > 1e-10:
                correction = np.outer(error, hessian[i, i+1:] / h_ii)
                w_gptq[:, i+1:] += correction
    
    gptq_output = activations @ w_gptq.T
    gptq_mse = np.mean((gptq_output - ref_output) ** 2)
    
    return {
        'naive_mse': naive_mse,
        'awq_mse': awq_mse,
        'gptq_mse': gptq_mse,
        'awq_improvement': naive_mse / awq_mse,
        'gptq_improvement': naive_mse / gptq_mse,
        'awq_vs_gptq': gptq_mse / awq_mse
    }


# Run comparison across multiple layers
np.random.seed(42)

print("AWQ vs GPTQ Comparison")
print("="*60)

results = []
for trial in range(5):
    weights = np.random.randn(128, 256).astype(np.float32) * 0.02
    activations = np.random.randn(100, 256).astype(np.float32)
    
    result = compare_awq_gptq(weights, activations, bits=4)
    results.append(result)

# Average results
avg_results = {k: np.mean([r[k] for r in results]) for k in results[0].keys()}

print(f"\nAverage over 5 trials:")
print(f"  Naive MSE:       {avg_results['naive_mse']:.6f}")
print(f"  AWQ MSE:         {avg_results['awq_mse']:.6f}")
print(f"  GPTQ MSE:        {avg_results['gptq_mse']:.6f}")
print(f"")
print(f"  AWQ vs Naive:    {avg_results['awq_improvement']:.2f}x better")
print(f"  GPTQ vs Naive:   {avg_results['gptq_improvement']:.2f}x better")
print(f"  AWQ vs GPTQ:     {avg_results['awq_vs_gptq']:.2f}x")

if avg_results['awq_vs_gptq'] > 1:
    print(f"\nConclusion: AWQ achieves better quality than GPTQ")
else:
    print(f"\nConclusion: GPTQ achieves better quality than AWQ")

## Summary

Key findings:

1. **Salient weight detection** correctly identifies important channels
2. **AWQ scale search** finds per-channel scales that minimize output error
3. **AWQ vs GPTQ**: AWQ is simpler (no Hessian inversion) but both achieve similar quality
4. **AWQ is faster** to quantize since it doesn't require sequential column processing