# ðŸ§ª HQQ From Scratch: Half-Quadratic Quantization (2024)

[!["Open In Colab"](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adiel2012/model-size-reduction/blob/main/chronology/hqq_demo.ipynb)

## ðŸ“– The Theory: Data-Free Optimization

HQQ (Half-Quadratic Quantization) is a fast and robust quantization method that doesn't require **calibration data**. Unlike GPTQ or AWQ, which look at activations, HQQ treats quantization as a mathematical optimization problem based solely on the weights $W$.

### The Objective
We want to find a scale $S$ and quantized weights $Q$ such that:

$$\min_{S, Q} ||W - S \cdot Q||_2^2$$

### Iterative Solver
HQQ uses an iterative approach. It alternates between:
1.  **Estimating the Scale $S$**: Given $Q$, find the best $S$ that fits the weights (least squares).
2.  **Quantizing $W$**: Given $S$, map $W$ to the closest integer values in $Q$.

This converges extremely quickly (often in 2-3 iterations) to a much better solution than simple min-max rounding.

---

In [None]:
import torch

def hqq_quantize_block(W, bits=4, n_iter=5):
    """
    Simplified HQQ algorithm implementation from scratch.
    W: Weight tensor (can be a layer or block)
    """
    orig_shape = W.shape
    W = W.view(-1).float()
    
    # Initialize Scale with simple min-max
    q_max = 2**(bits-1) - 1
    scale = (torch.max(torch.abs(W)) / q_max).item()
    
    q_w = torch.round(W / scale).clamp(-q_max, q_max)
    
    print(f"Optimizing scale for {bits}-bit weight...")
    for i in range(n_iter):
        # 1. Update Scale (Least Squares: S = (W Â· Q) / (Q Â· Q))
        num = torch.dot(W, q_w)
        den = torch.dot(q_w, q_w)
        scale = (num / den).item()
        
        # 2. Update Quantized Weights
        q_w = torch.round(W / scale).clamp(-q_max, q_max)
        
        # Calculate error
        err = torch.norm(W - scale * q_w)
        if i % 2 == 0:
            print(f"  Iteration {i}: MSE = {(err**2/len(W)):.8f}")
            
    return q_w.view(orig_shape), scale

# Test implementation
W_raw = torch.randn(512, 512)
W_q, s = hqq_quantize_block(W_raw, bits=4)

final_w = W_q * s
print(f"\nFinal Quantized Scale: {s:.6f}")
print(f"Manual HQQ Reconstruction Mean Error: {torch.abs(W_raw - final_w).mean():.6f}")