# 🧪 Lab 2: Hardware–Software Model Co-Design via Post-Training Quantization & Bit-Width Search

### 📚 Introduction

In **Lab 1**, you saw that compressing transformer weights down to **2 bits** reduced model size by ×16 with only a modest accuracy drop. But compression alone is a *software-centric* solution; actual deployment only succeeds when the model cooperates with the underlying silicon.

In this lab, you’ll adopt a **hardware–software co-design** perspective, treating quantization as the critical interface between the network and its deployment hardware. Quantization affects both **model accuracy** and **execution efficiency**, making it the ideal lever for co-design.

Specifically, you will:

- **Wrap every `nn.Linear` in a quantized integer-only `QLinear` module**  
- **Post-quantize both weights and activations layerwise**, selecting precision from **8 → 2 bits**  
- **Measure performance of each quantization configuration** using **KL-divergence** and **memory consumption**  
- **Perform automated layerwise bit-width search** to optimize a hardware-aware objective function

> **Why co-design matters:**  
> A model that looks efficient in software may still bottleneck on real hardware due to memory access patterns, compute throughput, or unsupported bit-widths. Hardware–software co-design ensures the model structure aligns with hardware constraints, enabling deployment that is both **accurate** and **efficient** on edge devices.

---

### 🎯 Lab Objectives

1. **Implement `QLinear`**, a simulated integer GEMM layer with scale-offset dequantization, compatible with PyTorch CPU kernels  
2. **Post-quantize a pretrained model checkpoint** using per-layer {8, 4, 2}-bit precision, and export metadata for downstream hardware cost modeling  
3. **Profile** model size and **KL-divergence from the FP32 teacher model**  
4. **Run a non-linear optimization algorithm** to identify a per-layer quantization configuration that minimizes a joint objective (accuracy vs. efficiency)

---

By the end, you'll produce a deployment-ready language model with a **per-layer optimal quantization configuration**—striking the best trade-off between hardware efficiency and model fidelity.



In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from typing import Dict, Any, Tuple

import pytorch_lightning as pl
from hyperopt import fmin, tpe, hp, Trials, space_eval, STATUS_OK

from src.lab1.shakespeare_trainer import ShakespeareModule



 ## 1️⃣ Building `QLinear`

 **Symmetric uniform quantization** maps a float tensor to signed integers
 in the range [ −(2ᵇ⁻¹−1), …, + (2ᵇ⁻¹−1) ] with a single scale factor **s**.

 Forward pass outline:

 1. **Quantize** incoming activations to ints.
 2. **Integer GEMM** with pre-quantized weights.
 3. **De-quantize** the accumulator by multiplying with the two scales.
 4. Add bias (still Floating Point).

 The class below is written for clarity rather than raw speed


In [21]:
class QLinear(nn.Module):
    """
    A fully-connected layer with symmetric uniform quantization for weights and activations.
    """
    def __init__(
        self,
        in_features: int,
        out_features: int,
        bias: bool = True,
        weight_bitwidth: int = 8,
        act_bitwidth: int = 8,
    ) -> None:
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight_bitwidth = weight_bitwidth
        self.act_bitwidth = act_bitwidth

        # Buffers to hold quantized weight and quantization scale
        self.register_buffer(
            "qweight",
            torch.zeros(out_features, in_features, dtype=torch.float32),
        )
        self.register_buffer("weight_scale", torch.ones(1))

        # Optional bias stored in float32
        if bias:
            self.register_buffer("bias", torch.zeros(out_features, dtype=torch.float32))
        else:
            self.bias = None

    @staticmethod
    def _quantize_tensor(
        x: torch.Tensor, bitwidth: int
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Quantize a tensor to signed integers in [-2^(b-1), 2^(b-1)-1].
        Returns (quantized_tensor, scale).
        """
        qmax = 2 ** (bitwidth - 1) - 1
        rmax = x.abs().max()
        scale = rmax / qmax if rmax > 0 else torch.tensor(1.0, device=x.device)
        q = torch.clamp(torch.round(x / scale), -qmax, qmax)
        return q, scale

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # 1. Quantize activations
        qx, act_scale = self._quantize_tensor(x, self.act_bitwidth)

        # 2. Integer GEMM
        qx = qx.to(self.qweight.dtype)
        acc = qx.matmul(self.qweight.t())

        # 3. Dequantize
        y = acc * act_scale * self.weight_scale

        # 4. Add bias if present
        if self.bias is not None:
            y = y + self.bias

        return y

    def __repr__(self) -> str:
        return (
            f"{self.__class__.__name__}("
            f"in={self.in_features}, out={self.out_features}, "
            f"w_bits={self.weight_bitwidth}, a_bits={self.act_bitwidth})"
        )

 ### 👉 Quick sanity check

 Run the next cell to quantize a random matrix at 2-, 4- and 8-bit and
 print the reconstruction error.


In [22]:
# %% 
torch.manual_seed(0)
sample = torch.randn(1000)
for b in (2, 4, 8):
    q, s = QLinear._quantize_tensor(sample, b)
    err = (sample - q * s).abs().mean().item()
    print(f"{b}-bit | mean-abs-error: {err:.6f}")


2-bit | mean-abs-error: 0.761839
4-bit | mean-abs-error: 0.144003
8-bit | mean-abs-error: 0.008115



 ## 2️⃣ Swapping Layers in-place

 We’ll walk the model, collect string paths to every `nn.Linear`, and replace
 each with a `QLinear` whose bit-width comes from a **qconfig** dictionary:

 ```text
   {
     "model.transformer_blocks.0.attn.q_proj": 4,
     "model.transformer_blocks.0.attn.k_proj": 2,
     …
   }
 ```

 If you hand an 8-bit default config to students, they can tweak individual
 layers and re-evaluate within seconds.


In [14]:
# %%  — utilities for model patching
def quantize_linear(layer: nn.Linear, weight_bitwidth=8, act_bitwidth=8):
    qlayer = QLinear(layer.in_features, layer.out_features,
                     bias=layer.bias is not None,
                     weight_bitwidth=weight_bitwidth,
                     act_bitwidth=act_bitwidth)
    q_w, w_s = QLinear._quantize_tensor(layer.weight.data, weight_bitwidth)
    qlayer.qweight.copy_(q_w)
    qlayer.weight_scale.copy_(w_s)
    if layer.bias is not None:
        qlayer.bias.copy_(layer.bias.data)
    return qlayer


def quantize_model(root: nn.Module, qconfig: Dict[str, int]):
    for path, bw in qconfig.items():
        parent_path, _, attr = path.rpartition('.')
        parent = root if not parent_path else root.get_submodule(parent_path)
        setattr(parent, attr,
                quantize_linear(getattr(parent, attr),
                                weight_bitwidth=bw, act_bitwidth=bw))
    return root


def default_qconfig(model: ShakespeareModule, bitwidth=8):
    cfg = {}
    for name, mod in model.model.transformer_blocks.named_modules():
        if isinstance(mod, nn.Linear):
            cfg[f"model.transformer_blocks.{name}"] = bitwidth
    return cfg



 ## 3️⃣ Accuracy Metric: KL Divergence

 We can’t rely on new training loss because we haven’t re-trained.  
 Instead we measure how far the quantized logits’ softmax is from the float
 model’s softmax on held-out batches.


In [15]:
# %%  — KL computation
def compute_kl_divergence(full, quant, batches, batch_size, device):
    dl = full.test_dataloader()
    total = 0.0
    full.eval(); quant.eval()
    with torch.no_grad():
        for i, (x, _) in enumerate(dl):
            if i >= batches: break
            x = x.to(device)
            f_logits = full.model(x)
            q_logits = quant.model(x)
            kl = F.kl_div(F.log_softmax(q_logits, dim=-1),
                          F.softmax(f_logits, dim=-1),
                          reduction='batchmean')
            total += kl.item()
    return total / batches

 ## 4️⃣ Memory Metric

 The helper below counts *every* parameter:

 * Quantized weights → bitwidth from `qconfig`.
 * Everything else (biases, embeddings, layer-norm) → 32 bits.


In [16]:
# %% 
def compute_model_size_bytes(root: nn.Module, qconfig: Dict[str, int]):
    total_bits = 0
    for path, bw in qconfig.items():
        parent_path, _, attr = path.rpartition('.')
        parent = root if not parent_path else root.get_submodule(parent_path)
        lin: QLinear = getattr(parent, attr)
        total_bits += lin.qweight.numel() * bw
        if lin.bias is not None:
            total_bits += lin.bias.numel() * 32
    for name, param in root.named_parameters():
        if name.endswith('bias') or 'weight' not in name:
            continue
        param_module = name.rsplit('.', 1)[0]
        if any(path.startswith(param_module) for path in qconfig):
            continue
        total_bits += param.numel() * 32
    return total_bits // 8


 ## 5️⃣ HyperOpt Objective

 We scalarise two goals—**keep KL tiny, shrink size huge**—with a single
 loss:  
 `loss = α * KL + (1-α) * Size_MB`.

 Feel free to experiment with α = 0.2, 0.5, 0.8

In [17]:
# %%  — objective
def objective(qconfig, batches, batch_size, alpha):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    fp = ShakespeareModule.load_from_checkpoint(
        'src/lab1/checkpoints/float-best.ckpt', batch_size=batch_size).to(device)
    fp.setup('test')
    q = ShakespeareModule.load_from_checkpoint(
        'src/lab1/checkpoints/float-best.ckpt', batch_size=batch_size)
    q.setup('test')
    quantize_model(q, qconfig).to(device)

    kl = compute_kl_divergence(fp, q, batches, batch_size, device)
    size_mb = compute_model_size_bytes(q, qconfig) / (1024 ** 2)
    loss = alpha * kl + (1 - alpha) * size_mb
    #print(f"loss={loss:.4f} | KL={kl:.4f} | size={size_mb:.2f} MB")
    return {'loss': loss, 'status': STATUS_OK}


 ## 6️⃣ Search Space & Driver

In [18]:
# %% 
def hyperopt_search(init_cfg, max_evals=200, batches=10, batch_size=6, alpha=1e-7):
    space = {k: hp.choice(k, [2, 4, 8]) for k in init_cfg}
    trials = Trials()
    fn = partial(objective, batches=batches, batch_size=batch_size, alpha=alpha)
    best = fmin(fn, space, algo=tpe.suggest, max_evals=max_evals, trials=trials)
    return space_eval(space, best)

 ## 7️⃣ Main Entry

In [19]:
# %% 
def main():
    base = ShakespeareModule.load_from_checkpoint('src/lab1/checkpoints/float-best.ckpt')
    start_cfg = default_qconfig(base, bitwidth=8)
    best_cfg = hyperopt_search(start_cfg)
    print("Best per-layer bit-widths:", best_cfg)

if __name__ == "__main__":
    main()


loss=6.6328 | KL=25.2938 | size=6.63 MB                
loss=7.0391 | KL=6.0643 | size=7.04 MB                                           
loss=6.6016 | KL=37.4991 | size=6.60 MB                                          
loss=6.4453 | KL=37.7337 | size=6.45 MB                                          
loss=6.8672 | KL=39.4967 | size=6.87 MB                                         
loss=6.1797 | KL=46.4083 | size=6.18 MB                                         
loss=7.1953 | KL=16.4744 | size=7.20 MB                                          
loss=6.7734 | KL=28.7283 | size=6.77 MB                                          
loss=6.7109 | KL=28.9668 | size=6.71 MB                                          
loss=6.8203 | KL=16.9455 | size=6.82 MB                                          
loss=6.4609 | KL=29.5076 | size=6.46 MB                                           
loss=7.1797 | KL=15.7862 | size=7.18 MB                                           
loss=5.9141 | KL=47.3078 | size=5.91 MB   


 ## 🔄 Try This

 1. **Aggressive compression** – set the initial cfg to 4 bits and limit the
    search to {2,4}.  How low can the KL stay?
 2. **Latency vs throughput** – time one forward pass before and after
    quantization on CPU.
 3. **Text generation side-by-side** – sample a Shakespeare sonnet with both
    models; can you spot the quantized one?

 Post your findings on the course forum—screenshots, metrics, or even the
 strangest quantization artefacts you encounter.

 ---

 🏁 **End of Lab 2** — you now have a fully automated post-training
 quantization pipeline and a taste of multi-objective search.  
 Next stop: **quantization-aware training** and custom int kernels!
