# FinCompress — Notebook 03: Quantization

**RUN ON: Colab/GPU or Local/CPU**

## Section A: Post-Training Quantization (PTQ)

### How PTQ Works

Quantization reduces weight precision from FP32 (32-bit float, ~10⁷ values per weight) to INT8 (8-bit integer, 256 values per weight). This gives:
- **4× size reduction** (4 bytes → 1 byte per weight)
- **2-4× latency improvement** on x86 CPUs with INT8 SIMD instructions

**How observers work:** PyTorch's `prepare()` inserts observer modules that record activation ranges (min/max or histogram) during a calibration pass. These ranges determine the quantization *scale* and *zero-point*: the two values that map INT8 [−128, 127] to the observed FP32 range.

**Why fbgemm?** fbgemm is optimized for x86 CPUs (Intel/AMD). qnnpack is for ARM mobile (Android/iOS). Using the wrong backend gives misleadingly slow benchmarks.

**Why keep first/last layers in FP32?**
- Embedding layers: Quantization error propagates multiplicatively through every subsequent layer
- Final classifier: Only 3 outputs — even tiny rounding can change the argmax

In [None]:
import os, sys
PROJECT_PATH = '/content/drive/MyDrive/fincompress'  # adjust if running locally
os.chdir(PROJECT_PATH)
sys.path.insert(0, PROJECT_PATH)

In [None]:
# PTQ runs on CPU — CUDA does not support static quantization
!python -m fincompress.quantization.ptq

In [None]:
import json
from pathlib import Path

inter_info = json.loads(Path('fincompress/checkpoints/student_intermediate_kd/checkpoint_info.json').read_text())
ptq_info   = json.loads(Path('fincompress/checkpoints/student_ptq/checkpoint_info.json').read_text())

print('PTQ Results')
print('=' * 48)
print(f'{"Metric":<25} {"FP32 Student":>10} {"INT8 PTQ":>10}')
print('-' * 48)
print(f'{"Val Macro F1":<25} {inter_info["val_macro_f1"]:>10.4f} {ptq_info["val_macro_f1"]:>10.4f}')
print(f'{"Size (MB)":<25} {inter_info["size_mb"]:>10.1f} {ptq_info["size_mb"]:>10.1f}')
ratio = inter_info['size_mb'] / ptq_info['size_mb']
f1_drop = inter_info['val_macro_f1'] - ptq_info['val_macro_f1']
print('-' * 48)
print(f'{"Compression ratio":<25} {"1.0×":>10} {ratio:.1f}×')
print(f'{"F1 drop":<25} {"0.0000":>10} {f1_drop:>+10.4f}')

## Section B: Quantization-Aware Training (QAT)

### PTQ vs. QAT: Conceptual Difference

| | PTQ | QAT |
|-|-----|-----|
| When applied | After training | During training |
| Quantization in forward pass | No (only during calibration) | Yes (fake-quant nodes) |
| Backward pass through quant | N/A | Straight-through estimator |
| Optimizer awareness | No | Yes — weights adapt to quantization noise |
| Training cost | None | 3 additional fine-tuning epochs |
| Typical accuracy recovery | Baseline → −2-5% | Baseline → −0.5-2% |

**Straight-through estimator (STE):** The rounding function has zero gradient almost everywhere. STE bypasses this by using the identity function as the gradient during the backward pass — pretending quantization is a linear operation. This allows gradient-based optimization through a fundamentally discrete operation.

**When is QAT worth the extra training?** When PTQ drops >2% accuracy and you have training data. If PTQ gives acceptable accuracy, skip QAT to save compute.

In [None]:
# QAT training (runs on GPU in Colab)
!python -m fincompress.quantization.qat

In [None]:
import pandas as pd, matplotlib.pyplot as plt

log_df = pd.read_csv('fincompress/logs/qat_training.csv')
fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(log_df['epoch'], log_df['train_loss'], 'o-', color='steelblue', label='Train Loss')
ax2 = ax.twinx()
ax2.plot(log_df['epoch'], log_df['val_f1'], 's-', color='coral', label='Val F1')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss', color='steelblue')
ax2.set_ylabel('Val Macro F1', color='coral')
ax.set_title('QAT Training Curve')
lines1, labels1 = ax.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax.legend(lines1 + lines2, labels1 + labels2, loc='upper right')
plt.tight_layout()
plt.show()

In [None]:
import json
from pathlib import Path

inter_info = json.loads(Path('fincompress/checkpoints/student_intermediate_kd/checkpoint_info.json').read_text())
ptq_info   = json.loads(Path('fincompress/checkpoints/student_ptq/checkpoint_info.json').read_text())
qat_info   = json.loads(Path('fincompress/checkpoints/student_qat/checkpoint_info.json').read_text())

print('Three-Way Quantization Comparison')
print('=' * 60)
print(f'{"Metric":<25} {"FP32":>10} {"INT8 PTQ":>10} {"INT8 QAT":>10}')
print('-' * 60)
for info, name in [(inter_info, 'FP32'), (ptq_info, 'PTQ'), (qat_info, 'QAT')]:
    pass  # populated below
rows = [
    ('Val Macro F1', 'val_macro_f1', '.4f'),
    ('Size (MB)',    'size_mb',      '.1f'),
]
for label, key, fmt in rows:
    vals = [inter_info[key], ptq_info[key], qat_info[key]]
    print(f'{label:<25} {vals[0]:>10{fmt}} {vals[1]:>10{fmt}} {vals[2]:>10{fmt}}')
print('=' * 60)

## Section C: Production Trade-off Analysis

**Which quantization variant would you deploy in production?**

| Constraint | Deploy | Reason |
|-----------|--------|--------|
| Latency SLA ≤ 5ms, accuracy loss ≤2% | INT8 QAT | Best accuracy at INT8 precision |
| Zero additional training budget | INT8 PTQ | No retraining required |
| Accuracy drop >2% in PTQ | INT8 QAT | 3 epochs recovers most PTQ loss |
| Serving on ARM mobile (iOS/Android) | INT8 PTQ with qnnpack backend | Different SIMD target |

**Rule of thumb:** Start with PTQ. If accuracy drop is <2%, ship it. If >2%, run QAT.