# # End-to-End: Build → Calibrate → (Re)Train → Compare Accuracy (16×16) → Compare Runtime (various SA)

This notebook follows the same pattern as before, but bundles the whole workflow:

1. **Build models** for two multipliers (`mul8s_acc` and `mul8s_1L2H`) and a fixed SA size (16×16).
2. **Calibrate** both models (percentile histogram, short pass).
3. **(Optional) Re-train** both models briefly (fine-tune) to reduce quantization/approx errors.
4. **Compare accuracy** between the two 16×16 models (calibrated or fine-tuned).
5. **Compare execution time** for a chosen multiplier across multiple SA sizes.


In [1]:
import os, time, timeit
import torch
import torchvision as tv
from torchvision import transforms as T
import pandas as pd
from tqdm import tqdm
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader

# Safer dataloaders in constrained environments
torch.multiprocessing.set_sharing_strategy("file_system")

# ---- Global config ----
DEVICE = "cpu"
BATCH_SIZE = 64
NUM_CALIB_BATCHES = 2        # increase to 8–32 for better INT8 quality
SA_ROWS, SA_COLS = 16, 16    # fixed for the accuracy comparison
USE_EXACT = False            # True -> force exact mult; False -> use approx variants

# Two variants to compare (16×16 accuracy)
VARIANTS = ["mul8s_acc", "mul8s_1L2H"]

# (Optional) training settings
DO_FINETUNE = False          # set True to run a brief fine-tune 
FINETUNE_EPOCHS = 1
LR = 1e-4
WD = 0.0

# SA configs for runtime comparison 
SA_CONFIGS = [(8,8), (16,16), (32,8), (8,32)]
RUNTIME_MULT = "mul8s_acc"   # which multiplier to use for the runtime sweep

pd.set_option("display.max_colwidth", 160)


## Data loaders

In [2]:
def val_dataloader(mean = (0.4914, 0.4822, 0.4465), std = (0.2471, 0.2435, 0.2616)):

    transform = T.Compose(
        [
            T.ToTensor(),
            T.Normalize(mean, std),
        ]
    )
    dataset = CIFAR10(root="datasets/cifar10_data", train=False, download=True, transform=transform)
    dataloader = DataLoader(
        dataset,
        batch_size=128,
        num_workers=0,
        drop_last=True,
        pin_memory=False,
    )
    return dataloader

transform = T.Compose(
        [
            T.RandomCrop(32, padding=4),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            T.Normalize(mean = (0.4914, 0.4822, 0.4465), std = (0.2471, 0.2435, 0.2616)),
        ]
    )

dataset = CIFAR10(root="datasets/cifar10_data", train=True, download=True, transform=transform)

evens = list(range(0, len(dataset), 10))
trainset_1 = torch.utils.data.Subset(dataset, evens)

data = val_dataloader()

# data_t is used for calibration purposes and is a subset of train-set
data_t = DataLoader(trainset_1, batch_size=128,
                                            shuffle=False, num_workers=0)

Files already downloaded and verified
Files already downloaded and verified


## Helpers (evaluate, calibration, amax, finetune)

In [3]:
from pytorch_quantization.nn.modules.tensor_quantizer import TensorQuantizer
from pytorch_quantization import calib
from adapt.approx_layers.systolic_build import precompile_systolic_extensions
from adapt.approx_layers.systolic_utils import swap_to_systolic
import torch
from tqdm import tqdm
import contextlib

@torch.no_grad()
def evaluate(model, loader, device=DEVICE, desc="Eval"):
    model.eval(); model.to(device)
    correct = 0
    total = 0

    model.eval()
    start_time = timeit.default_timer()
    with torch.no_grad():
        for iteraction, (images, labels) in tqdm(enumerate(data), total=len(data)):
            images, labels = images.to("cpu"), labels.to("cpu")
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(timeit.default_timer() - start_time)
    print('Accuracy of the network on the 10000 test images: %.4f %%' % (
    100 * correct / total))
    return 100 * correct/total

@torch.no_grad()
def init_weight_amax_from_weights(model):
    n_set, n_total = 0, 0
    for _, m in model.named_modules():
        q_w = getattr(m, "quantizer_w", None)
        if isinstance(q_w, quant_nn.TensorQuantizer):
            n_total += 1
            if getattr(q_w, "amax", None) is None:
                W = getattr(m, "weight", None)
                if W is not None:
                    q_w._amax = torch.as_tensor(W.detach().abs().max(), dtype=torch.float32)
                    n_set += 1
    print(f"[init_weight_amax_from_weights] set {n_set}/{n_total} weight amax")

# 2) A small helper to attach pre-hooks that "touch" quantizers during calib
def _make_calib_pre_hook(mod):
    @torch.no_grad()
    def _pre(mod_, inputs):
        # inputs is a tuple; x is first positional input tensor
        if not inputs:
            return
        x = inputs[0]
        # Call activation quantizer to collect stats
        q = getattr(mod_, "quantizer", None)
        if isinstance(q, quant_nn.TensorQuantizer) and getattr(q, "_calibrator", None) is not None:
            q(x)
        # Touch weights as well so weight calibrator (if any) can record
        q_w = getattr(mod_, "quantizer_w", None)
        if isinstance(q_w, quant_nn.TensorQuantizer) and getattr(q_w, "_calibrator", None) is not None:
            W = getattr(mod_, "weight", None)
            if W is not None:
                q_w(W)
    return _pre

@contextlib.contextmanager
def attach_calibration_hooks(model):
    """Attach forward-pre-hooks to every systolic layer so quantizers see tensors."""
    hooks = []
    for _, m in model.named_modules():
        # We only hook layers that have our quantizers AND a forward
        has_any_q = isinstance(getattr(m, "quantizer", None), quant_nn.TensorQuantizer) or \
                    isinstance(getattr(m, "quantizer_w", None), quant_nn.TensorQuantizer)
        if has_any_q:
            try:
                h = m.register_forward_pre_hook(_make_calib_pre_hook(m))
                hooks.append(h)
            except Exception:
                pass
    try:
        yield
    finally:
        for h in hooks:
            try:
                h.remove()
            except Exception:
                pass

# 3) Fixed collect_stats using the hooks
def collect_stats(model, data_loader, num_batches=10, device="cpu"):
    """Collect activation histograms. Uses pre-hooks to ensure quantizers see tensors."""
    model.eval()
    model.to(device)

    # Enable calibration (disable quantization) so quantizers record histograms
    for _, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                module.disable_quant()
                module.enable_calib()
            else:
                module.disable()

    with torch.no_grad(), attach_calibration_hooks(model):
        for i, (image, _) in enumerate(data_loader):
            image = image.to(device, non_blocking=True)
            _ = model(image)  # pre-hooks will call quantizers here
            if i >= num_batches - 1:
                break

    # Disable calibration (enable quantization for inference)
    for _, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                module.enable_quant()
                module.disable_calib()
            else:
                module.enable()

    print("Calibration data collection complete.")

# 4) Compute and sanitize amax 
def compute_amax(model, method="percentile", percentile=99.99, strict=False, fallback=1.0):
    n_loaded, n_fixed = 0, 0
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                try:
                    if isinstance(module._calibrator, calib.MaxCalibrator):
                        module.load_calib_amax(strict=strict)
                    else:
                        module.load_calib_amax(method=method, percentile=percentile, strict=strict)
                    n_loaded += 1
                except RuntimeError:
                    pass
            # sanitize
            amax = getattr(module, "amax", None)
            if (amax is None) or (torch.isnan(amax)) or (float(amax) == 0.0):
                module._amax = torch.tensor(float(fallback), dtype=torch.float32)
                n_fixed += 1
            print(F"{name:40}: {module}")
    print(f"Loaded calibrated amax values. loaded={n_loaded}, sanitized={n_fixed}")
    model.cpu()

In [4]:
from pytorch_quantization import nn as quant_nn
from pytorch_quantization import calib
from adapt.references.classification.train import train_one_epoch, load_data
def finetune_one_epoch(model, loader, device=DEVICE, lr=1e-4, wd=0.0):
    from pytorch_quantization import nn as quant_nn
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)
    
    # finetune the model for one epoch based on data_t subset 
    train_one_epoch(model, criterion, optimizer, loader, device, 0, 1)

## Accuracy comparison ( systolic array size 16×16) for two multipliers

In [5]:
# === Accuracy comparison @16x16 using resnet50_systolic ===
rows = []

try:
    from models.resnet_systolic import resnet50_systolic
except Exception:
    from examples.models.resnet_systolic import resnet50_systolic

for axx_mult in VARIANTS:  # z.B. ["mul8s_acc", "mul8s_1L2H"]
    print(f"=== resnet50_systolic @16x16: {axx_mult} ===")

    # 1) Precompile 
    precompile_systolic_extensions(
        axx_mult=axx_mult,
        use_exact_variants=(USE_EXACT,),
        sa_rows=16, sa_cols=16,
        verbose=False
    )

    model = resnet50_systolic(pretrained=True, axx_mult=axx_mult, use_exact=USE_EXACT)
    model.eval()  

    # 3) calibration
    with torch.no_grad():
        #init_weight_amax_from_weights(model)  
        stats = collect_stats(model, data_t, num_batches=2)
        amax = compute_amax(model, method="percentile", percentile=99.99)
    
    # optional - test different calibration methods
    #amax = compute_amax(model, method="mse")
    #amax = compute_amax(model, method="entropy")


    acc_cal = evaluate(model, data, DEVICE, desc=f"Eval resnet50_systolic ({axx_mult}) calibrated")
    rows.append({"Variant": axx_mult, "Type": "calibrated", "Accuracy %": acc_cal})

    import torch.nn as nn

    if DO_FINETUNE:
        
        for epoch in range(FINETUNE_EPOCHS):
            finetune_one_epoch(model, data_t, lr=LR, wd=WD)
        model.eval()
        acc_ft = evaluate(model, data, 'cpu',desc=f"Eval resnet50_systolic ({axx_mult}) finetuned")
        rows.append({"Variant": axx_mult, "Type": "finetuned", "Accuracy %": acc_ft})

df_acc = pd.DataFrame(rows).sort_values(by=["Variant", "Type"]).reset_index(drop=True)
df_acc


=== resnet50_systolic @16x16: mul8s_acc ===
Pre-compiling systolic extensions for mul8s_acc...
  Mode: approx
    • linear (r16×c16)
    • conv2d (r16×c16)
Pre-compilation complete! Models will now load instantly.
✓ Loaded systolic conv2d kernel: mul8s_acc, exact=False, SA=16x16
✓ Loaded systolic conv2d kernel: mul8s_acc, exact=False, SA=16x16
✓ Loaded systolic conv2d kernel: mul8s_acc, exact=False, SA=16x16
✓ Loaded systolic conv2d kernel: mul8s_acc, exact=False, SA=16x16
✓ Loaded systolic conv2d kernel: mul8s_acc, exact=False, SA=16x16
✓ Loaded systolic conv2d kernel: mul8s_acc, exact=False, SA=16x16
✓ Loaded systolic conv2d kernel: mul8s_acc, exact=False, SA=16x16
✓ Loaded systolic conv2d kernel: mul8s_acc, exact=False, SA=16x16
✓ Loaded systolic conv2d kernel: mul8s_acc, exact=False, SA=16x16
✓ Loaded systolic conv2d kernel: mul8s_acc, exact=False, SA=16x16
✓ Loaded systolic conv2d kernel: mul8s_acc, exact=False, SA=16x16
✓ Loaded systolic conv2d kernel: mul8s_acc, exact=False, SA=

W1111 10:17:08.398038 140134759839552 tensor_quantizer.py:173] Disable HistogramCalibrator
W1111 10:17:08.398667 140134759839552 tensor_quantizer.py:173] Disable HistogramCalibrator
W1111 10:17:08.399150 140134759839552 tensor_quantizer.py:173] Disable HistogramCalibrator
W1111 10:17:08.399617 140134759839552 tensor_quantizer.py:173] Disable HistogramCalibrator
W1111 10:17:08.400077 140134759839552 tensor_quantizer.py:173] Disable HistogramCalibrator
W1111 10:17:08.400535 140134759839552 tensor_quantizer.py:173] Disable HistogramCalibrator
W1111 10:17:08.401030 140134759839552 tensor_quantizer.py:173] Disable HistogramCalibrator
W1111 10:17:08.401548 140134759839552 tensor_quantizer.py:173] Disable HistogramCalibrator
W1111 10:17:08.402094 140134759839552 tensor_quantizer.py:173] Disable HistogramCalibrator
W1111 10:17:08.402627 140134759839552 tensor_quantizer.py:173] Disable HistogramCalibrator
W1111 10:17:08.404590 140134759839552 tensor_quantizer.py:173] Disable HistogramCalibrator

W1111 10:17:08.462644 140134759839552 tensor_quantizer.py:173] Disable HistogramCalibrator
W1111 10:17:08.463077 140134759839552 tensor_quantizer.py:173] Disable HistogramCalibrator
W1111 10:17:08.463559 140134759839552 tensor_quantizer.py:173] Disable HistogramCalibrator
W1111 10:17:08.464105 140134759839552 tensor_quantizer.py:173] Disable HistogramCalibrator
W1111 10:17:08.464596 140134759839552 tensor_quantizer.py:173] Disable HistogramCalibrator
W1111 10:17:08.465112 140134759839552 tensor_quantizer.py:173] Disable HistogramCalibrator
W1111 10:17:08.465624 140134759839552 tensor_quantizer.py:173] Disable HistogramCalibrator
W1111 10:17:08.466152 140134759839552 tensor_quantizer.py:173] Disable HistogramCalibrator
W1111 10:17:08.466735 140134759839552 tensor_quantizer.py:173] Disable HistogramCalibrator
W1111 10:17:08.467173 140134759839552 tensor_quantizer.py:173] Disable HistogramCalibrator
W1111 10:17:08.467830 140134759839552 tensor_quantizer.py:173] Disable HistogramCalibrator

W1111 10:17:08.571635 140134759839552 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W1111 10:17:08.573474 140134759839552 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W1111 10:17:08.575646 140134759839552 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W1111 10:17:08.579190 140134759839552 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W1111 10:17:08.581215 140134759839552 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W1111 10:17:08.583300 140134759839552 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W1111 10:17:08.584754 140134759839552 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W1111 10:17:08.586052 140134759839552 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W1111 10:17:08.587328 140134759839552 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W1111 10:17:08.588169 140134759839552

Calibration data collection complete.
conv1.quantizer                         : TensorQuantizer(8bit per-tensor amax=2.1255 calibrator=HistogramCalibrator quant)
conv1.quantizer_w                       : TensorQuantizer(8bit per-tensor amax=0.1625 calibrator=HistogramCalibrator quant)
layer1.0.conv1.quantizer                : TensorQuantizer(8bit per-tensor amax=0.6109 calibrator=HistogramCalibrator quant)
layer1.0.conv1.quantizer_w              : TensorQuantizer(8bit per-tensor amax=0.0804 calibrator=HistogramCalibrator quant)
layer1.0.conv2.quantizer                : TensorQuantizer(8bit per-tensor amax=0.2126 calibrator=HistogramCalibrator quant)
layer1.0.conv2.quantizer_w              : TensorQuantizer(8bit per-tensor amax=0.0355 calibrator=HistogramCalibrator quant)
layer1.0.conv3.quantizer                : TensorQuantizer(8bit per-tensor amax=0.2418 calibrator=HistogramCalibrator quant)
layer1.0.conv3.quantizer_w              : TensorQuantizer(8bit per-tensor amax=0.0544 calibrat

 97%|█████████████████████████████████████████▉ | 76/78 [41:09<01:04, 32.49s/it]


KeyboardInterrupt: 

## Runtime comparison across SA sizes (same multiplier)

In [None]:
# We compare elapsed time for a short pass (10 batches) and full evaluation.
results = []

for (r, c) in SA_CONFIGS:
    print(f"=== Runtime: {RUNTIME_MULT} @ SA {r}x{c} ===")
    precompile_systolic_extensions(axx_mult=RUNTIME_MULT, use_exact_variants=(USE_EXACT,),
                                   sa_rows=r, sa_cols=c, verbose=False)
    
    model = resnet50_systolic(pretrained=True, axx_mult=axx_mult, use_exact=USE_EXACT,sa_rows=r, sa_cols=c)
    model.eval() 

    # Calibrate quickly (reuse same approach for fair comparison)
    with torch.no_grad():
        stats = collect_stats(model, data_t, num_batches=2)
        amax = compute_amax(model, method="percentile", percentile=99.99)    

    # Warmup
    xb, yb = next(iter(data))
    _ = model(xb)

    # Short run (10 batches)
    iters = 10
    start = timeit.default_timer()
    with torch.no_grad():
        for i, (x, _) in enumerate(data):
            _ = model(x)
            if i >= iters - 1:
                break
    t_small = timeit.default_timer() - start

    # Full eval timing
    start = timeit.default_timer()
    acc = evaluate(model, data, desc=f"Eval runtime SA {r}x{c}")
    t_full = timeit.default_timer() - start

    results.append({
        "sa_rows": r,
        "sa_cols": c,
        "accuracy %": acc,
        "time_10_batches_sec": t_small,
        "time_full_eval_sec": t_full,
    })

df_rt = pd.DataFrame(results).sort_values(by=["sa_rows","sa_cols"]).reset_index(drop=True)
df_rt
