# ðŸ§ª HQQ From Scratch: Half-Quadratic Quantization (2024)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adiel2012/model-size-reduction/blob/main/chronology/hqq_demo.ipynb)

## ðŸ“– Theory: Data-Free Optimization

HQQ quantizes without **calibration data**. Instead of inspecting activations
(as GPTQ or AWQ do), it treats quantization as a pure weight-space optimization problem.

### Objective Function

Find a scale $S$ and quantized integer weights $Q$ such that:

$$\min_{S,\,Q}\; \|W - S \cdot Q\|_2^2$$

This is a **bilinear** problem, so HQQ alternates between two sub-problems
(coordinate descent / alternating least squares).

### Iterative Solver

**Step 1 -- Update Scale** (given $Q$, find best $S$ by least squares):

$$S^* = \frac{W \cdot Q}{Q \cdot Q} = \frac{\langle W, Q\rangle}{\|Q\|^2}$$

**Step 2 -- Update $Q$** (given $S$, round-and-clamp):

$$Q = \text{clamp}\!\left(\text{round}\!\left(\frac{W}{S}\right),\; -2^{b-1},\; 2^{b-1}-1\right)$$

Convergence is **extremely fast** (2-5 iterations) because the optimal scale
has a closed form and rounding is idempotent once stable.

### Half-Quadratic Origin

The name comes from *half-quadratic splitting* in image processing (Geman & Yang, 1995),
which regularises non-smooth problems by introducing an auxiliary variable.
Here $Q$ plays the role of that auxiliary variable, decoupling the integer
constraint from the least-squares optimisation.

### HQQ vs GPTQ vs Naive INT4

| Property | Naive INT4 | GPTQ | HQQ |
|---|---|---|---|
| Calibration data | No | Yes | **No** |
| Weight MSE | Highest | Low | Low |
| Speed | Instant | Slow | **Fast** |
| Activation-aware | No | Yes | No |

### Limitations
* Ignores the downstream loss -- only minimises weight reconstruction MSE.
* No activation-aware scaling; can miss outlier weight channels.

---

In [None]:
import torch

def hqq_quantize_block(W, bits=4, n_iter=5):
    """
    Simplified HQQ algorithm implementation from scratch.
    W: Weight tensor (can be a layer or block)
    """
    orig_shape = W.shape
    W = W.view(-1).float()
    
    # Initialize Scale with simple min-max
    q_max = 2**(bits-1) - 1
    scale = (torch.max(torch.abs(W)) / q_max).item()
    
    q_w = torch.round(W / scale).clamp(-q_max, q_max)
    
    print(f"Optimizing scale for {bits}-bit weight...")
    for i in range(n_iter):
        # 1. Update Scale (Least Squares: S = (W Â· Q) / (Q Â· Q))
        num = torch.dot(W, q_w)
        den = torch.dot(q_w, q_w)
        scale = (num / den).item()
        
        # 2. Update Quantized Weights
        q_w = torch.round(W / scale).clamp(-q_max, q_max)
        
        # Calculate error
        err = torch.norm(W - scale * q_w)
        if i % 2 == 0:
            print(f"  Iteration {i}: MSE = {(err**2/len(W)):.8f}")
            
    return q_w.view(orig_shape), scale

# Test implementation
W_raw = torch.randn(512, 512)
W_q, s = hqq_quantize_block(W_raw, bits=4)

final_w = W_q * s
print(f"\nFinal Quantized Scale: {s:.6f}")
print(f"Manual HQQ Reconstruction Mean Error: {torch.abs(W_raw - final_w).mean():.6f}")

## ðŸ”¢ Worked Example with Numbers

Before the full implementation, letâ€™s trace through the math with a tiny, hand-traceable example.

In [None]:
# Tiny example: HQQ iterative optimization on a 2Ã—2 matrix using 2-bit
import torch

W_small = torch.tensor([[0.8, -0.3],
                        [0.1, -0.9]])
bits  = 2          # 4 levels: -1, 0, 1  (q_max = 1)
q_max = 2**(bits-1) - 1
W     = W_small.view(-1).float()

print(f"Original weights (flattened): {W.tolist()}")
print(f"Bits={bits}  â†’  q_max={q_max}  â†’  levels {{{-q_max}â€¦{q_max}}}")

# Initial min-max scale
scale = (W.abs().max() / q_max).item()
q_w   = W.div(scale).round().clamp(-q_max, q_max)
print(f"\n{'Iter':>5}  {'scale':>8}  {'q_w':>20}  {'MSE':>12}")
print(f"{'init':>5}  {scale:8.4f}  {str(q_w.tolist()):>20}  {((W - q_w*scale)**2).mean().item():12.6f}")

for i in range(1, 4):
    # Least-squares scale update
    scale = (W.dot(q_w) / q_w.dot(q_w)).item()
    q_w   = W.div(scale).round().clamp(-q_max, q_max)
    mse   = ((W - q_w * scale)**2).mean().item()
    print(f"{i:>5}  {scale:8.4f}  {str(q_w.tolist()):>20}  {mse:12.6f}")

recon = (q_w * scale).view(2, 2)
print(f"\nOriginal   :\n{W_small.numpy()}")
print(f"Reconstructed:\n{recon.numpy().round(4)}")
print(f"Element-wise error:\n{(W_small - recon).abs().numpy().round(4)}")


## ðŸ§ª GPT-2 Evaluation

Apply the method to all 2D weight matrices of GPT-2 and compare perplexity before and after quantization.

In [None]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch, copy, io, contextlib

model_id = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_id)
model = GPT2LMHeadModel.from_pretrained(model_id).eval()

text = "The quick brown fox jumps over the lazy dog. Transformers are powerful sequence models."
inputs = tokenizer(text, return_tensors="pt")

def perplexity(mdl, inputs):
    with torch.no_grad():
        loss = mdl(**inputs, labels=inputs["input_ids"]).loss
    return torch.exp(loss).item()

baseline_ppl = perplexity(model, inputs)
print(f"Baseline GPT-2 Perplexity: {baseline_ppl:.2f}")

model_q = copy.deepcopy(model)
for name, param in model_q.named_parameters():
    if param.dim() == 2:
        with contextlib.redirect_stdout(io.StringIO()):
            q_w, scale = hqq_quantize_block(param.data, bits=4, n_iter=3)
        param.data = q_w * scale

quant_ppl = perplexity(model_q, inputs)
print(f"HQQ GPT-2 Perplexity:      {quant_ppl:.2f}")
print(f"Delta:                     {quant_ppl - baseline_ppl:+.2f}")


## ðŸ“š References

1. **Badri, H. & Shaji, A.** (2023).  
   *Half-Quadratic Quantization of Large Machine Learning Models.*  
   [arXiv:2401.14112](https://arxiv.org/abs/2401.14112)

2. **Geman, D. & Yang, C.** (1995).  
   *Nonlinear Image Recovery with Half-Quadratic Regularization.*  
   IEEE Transactions on Image Processing, 4(7), 932-946.

3. **Frantar, E., Ashkboos, S., Hoefler, T., & Alistarh, D.** (2022).  
   *GPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers.* ICLR 2023.  
   [arXiv:2210.17323](https://arxiv.org/abs/2210.17323)
