# Instructions

- All code must be contained in this notebook. No separate code files.
- The code must compile and run without errors.
- Submit as `[your_name].ipynb` with a separate `[your_name]_requirements.txt` file.
- Be prepared to discuss your design decisions in the technical interview.

# Describe the environment that have been used to complete the task
- Python version: __
- GPU used for training (if any): __
- CPU used for inference timing: __

# Imports, Functions, Global Variables, Classes
Define all shared code in the cell below.

In [None]:
# Functions, Variables, and Classes
import os
import random
import time
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from tqdm.auto import tqdm


# ---------------------------
# Global constants
# ---------------------------
CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
CIFAR10_STD = (0.2470, 0.2435, 0.2616)

BASE_DIR = Path('.').resolve()
DATA_ROOT = BASE_DIR / 'data'
FP32_CKPT_PATH = BASE_DIR / 'best_compact_cifar10_fp32.pt'
FP32_ONNX_PATH = BASE_DIR / 'compact_cifar10_fp32.onnx'
INT8_STATIC_ONNX_PATH = BASE_DIR / 'compact_cifar10_int8.onnx'
INT8_DYNAMIC_ONNX_PATH = BASE_DIR / 'compact_cifar10_int8_dynamic.onnx'


# ---------------------------
# Utility functions
# ---------------------------
def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def seed_worker(worker_id: int):
    worker_seed = torch.initial_seed() % (2 ** 32)
    random.seed(worker_seed)
    np.random.seed(worker_seed)


def make_divisible(v: float, divisor: int = 8, min_ch: int = 8) -> int:
    return max(min_ch, int((v + divisor - 1) // divisor) * divisor)


def pick_device() -> torch.device:
    if torch.cuda.is_available():
        return torch.device('cuda')
    if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        return torch.device('mps')
    return torch.device('cpu')


# ---------------------------
# Model definition
# ---------------------------
class DSConvBlock(nn.Module):
    def __init__(self, in_ch: int, out_ch: int, stride: int):
        super().__init__()
        self.dw = nn.Conv2d(
            in_ch,
            in_ch,
            kernel_size=3,
            stride=stride,
            padding=1,
            groups=in_ch,
            bias=False,
        )
        self.dw_bn = nn.BatchNorm2d(in_ch)
        self.pw = nn.Conv2d(in_ch, out_ch, kernel_size=1, bias=False)
        self.pw_bn = nn.BatchNorm2d(out_ch)
        self.act = nn.ReLU(inplace=True)
        self.use_res = (stride == 1 and in_ch == out_ch)

    def forward(self, x):
        identity = x
        x = self.act(self.dw_bn(self.dw(x)))
        x = self.act(self.pw_bn(self.pw(x)))
        if self.use_res:
            x = x + identity
        return x


class CompactCIFARNet(nn.Module):
    def __init__(self, num_classes: int = 10, width_mult: float = 1.0):
        super().__init__()
        c1 = make_divisible(32 * width_mult)
        c2 = make_divisible(64 * width_mult)
        c3 = make_divisible(96 * width_mult)
        c4 = make_divisible(128 * width_mult)
        c5 = make_divisible(160 * width_mult)

        self.stem = nn.Sequential(
            nn.Conv2d(3, c1, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(c1),
            nn.ReLU(inplace=True),
        )
        self.blocks = nn.Sequential(
            DSConvBlock(c1, c2, stride=2),
            DSConvBlock(c2, c2, stride=1),
            DSConvBlock(c2, c3, stride=2),
            DSConvBlock(c3, c3, stride=1),
            DSConvBlock(c3, c4, stride=2),
            DSConvBlock(c4, c4, stride=1),
        )
        self.head = nn.Sequential(
            nn.Conv2d(c4, c5, kernel_size=1, bias=False),
            nn.BatchNorm2d(c5),
            nn.ReLU(inplace=True),
        )
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(c5, num_classes)

    def forward(self, x):
        x = self.stem(x)
        x = self.blocks(x)
        x = self.head(x)
        x = self.gap(x).flatten(1)
        x = self.fc(x)
        return x


def count_parameters(model: nn.Module) -> int:
    return sum(p.numel() for p in model.parameters())


def fp32_weight_size_kb(model: nn.Module) -> float:
    return count_parameters(model) * 4 / 1024


# ---------------------------
# Data loaders
# ---------------------------
def get_train_loader(
    data_root: Path = DATA_ROOT,
    batch_size: int = 128,
    num_workers: int = 0,
    seed: int = 42,
    augment: bool = True,
    subset_size: int | None = None,
):
    if augment:
        train_tfms = transforms.Compose(
            [
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD),
            ]
        )
    else:
        train_tfms = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD),
            ]
        )

    train_ds = datasets.CIFAR10(root=str(data_root), train=True, download=True, transform=train_tfms)
    if subset_size is not None:
        train_ds = Subset(train_ds, range(min(subset_size, len(train_ds))))

    generator = torch.Generator()
    generator.manual_seed(seed)

    return DataLoader(
        train_ds,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=torch.cuda.is_available(),
        persistent_workers=(num_workers > 0),
        worker_init_fn=seed_worker,
        generator=generator,
    )


def get_test_loader(
    data_root: Path = DATA_ROOT,
    batch_size: int = 128,
    num_workers: int = 0,
):
    test_tfms = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD),
        ]
    )
    test_ds = datasets.CIFAR10(root=str(data_root), train=False, download=True, transform=test_tfms)

    return DataLoader(
        test_ds,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=torch.cuda.is_available(),
        persistent_workers=(num_workers > 0),
    )


# ---------------------------
# Torch eval/train helpers
# ---------------------------
def run_torch_epoch(model, loader, criterion, device, optimizer=None, scaler=None, use_amp=False, desc=''):
    is_train = optimizer is not None
    model.train(is_train)

    total_loss = 0.0
    total_correct = 0
    total_samples = 0

    pbar = tqdm(loader, desc=desc, leave=False)
    for images, labels in pbar:
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        if is_train:
            optimizer.zero_grad(set_to_none=True)
            if use_amp and scaler is not None and device.type == 'cuda':
                with torch.amp.autocast('cuda', enabled=True):
                    logits = model(images)
                    loss = criterion(logits, labels)
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                logits = model(images)
                loss = criterion(logits, labels)
                loss.backward()
                optimizer.step()
        else:
            with torch.no_grad():
                logits = model(images)
                loss = criterion(logits, labels)

        preds = logits.argmax(dim=1)
        batch_size = labels.size(0)
        total_loss += loss.item() * batch_size
        total_correct += (preds == labels).sum().item()
        total_samples += batch_size

        pbar.set_postfix(
            loss=f'{(total_loss / total_samples):.4f}',
            acc=f'{(100.0 * total_correct / total_samples):.2f}%',
        )

    return total_loss / total_samples, 100.0 * total_correct / total_samples


# ---------------------------
# ONNX + ONNX Runtime helpers
# ---------------------------
def export_fp32_to_onnx(model: nn.Module, onnx_path: Path, opset: int = 13):
    onnx_path = Path(onnx_path)
    model = model.cpu().eval()
    dummy = torch.randn(1, 3, 32, 32, dtype=torch.float32)

    kwargs = {
        'export_params': True,
        'opset_version': opset,
        'do_constant_folding': True,
        'input_names': ['input'],
        'output_names': ['logits'],
        'dynamic_axes': {'input': {0: 'batch'}, 'logits': {0: 'batch'}},
    }
    try:
        torch.onnx.export(model, dummy, str(onnx_path), dynamo=False, **kwargs)
    except TypeError:
        torch.onnx.export(model, dummy, str(onnx_path), **kwargs)

    return onnx_path


def ensure_fp32_onnx_exists(
    ckpt_path: Path = FP32_CKPT_PATH,
    onnx_path: Path = FP32_ONNX_PATH,
    opset: int = 13,
):
    onnx_path = Path(onnx_path)
    if onnx_path.exists():
        return onnx_path

    ckpt = torch.load(str(ckpt_path), map_location='cpu')
    width_mult = ckpt.get('config', {}).get('width_mult', 1.0)

    model = CompactCIFARNet(width_mult=width_mult).cpu().eval()
    model.load_state_dict(ckpt['model_state_dict'])
    export_fp32_to_onnx(model, onnx_path, opset=opset)
    return onnx_path


def build_ort_session(model_path: Path):
    import onnxruntime as ort

    so = ort.SessionOptions()
    so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
    return ort.InferenceSession(str(model_path), sess_options=so, providers=['CPUExecutionProvider'])


def ce_loss_sum_from_logits(logits: np.ndarray, labels: np.ndarray) -> float:
    shifted = logits - np.max(logits, axis=1, keepdims=True)
    log_probs = shifted - np.log(np.sum(np.exp(shifted), axis=1, keepdims=True))
    return float(-log_probs[np.arange(labels.shape[0]), labels].sum())


def materialize_numpy_batches(loader):
    batches = []
    for images, labels in loader:
        images_np = np.ascontiguousarray(images.numpy().astype(np.float32, copy=False))
        labels_np = np.ascontiguousarray(labels.numpy().astype(np.int64, copy=False))
        batches.append((images_np, labels_np))
    return batches


def evaluate_ort_batches(session, batches, desc='ONNX Runtime CPU'):
    input_name = session.get_inputs()[0].name
    output_name = session.get_outputs()[0].name

    total_correct = 0
    total_samples = 0
    total_loss = 0.0

    pbar = tqdm(batches, desc=desc, leave=False)
    for images_np, labels_np in pbar:
        logits = session.run([output_name], {input_name: images_np})[0]
        preds = np.argmax(logits, axis=1)

        total_correct += int((preds == labels_np).sum())
        total_samples += labels_np.shape[0]
        total_loss += ce_loss_sum_from_logits(logits, labels_np)

        pbar.set_postfix(
            loss=f'{(total_loss / total_samples):.4f}',
            acc=f'{(100.0 * total_correct / total_samples):.2f}%',
        )

    return total_loss / total_samples, 100.0 * total_correct / total_samples


def benchmark_ort_batches(session, batches, warmup_batches: int = 10, repeats: int = 3):
    input_name = session.get_inputs()[0].name
    output_name = session.get_outputs()[0].name
    total_images = sum(labels.shape[0] for _, labels in batches)

    warmup = min(warmup_batches, len(batches))
    for i in range(warmup):
        images_np, _ = batches[i]
        _ = session.run([output_name], {input_name: images_np})

    run_times = []
    for _ in range(repeats):
        start = time.perf_counter()
        for images_np, _ in batches:
            _ = session.run([output_name], {input_name: images_np})
        run_times.append(time.perf_counter() - start)

    avg_sec = float(np.mean(run_times))
    return {
        'avg_sec': avg_sec,
        'ms_per_image': avg_sec * 1000.0 / total_images,
        'images_per_sec': total_images / avg_sec,
        'runs': run_times,
    }


# 2.1 Design of a Compact CNN

**Requirements:**
- Model size: < 500 KB (FP32)
- Target test accuracy: ≥ 85%

In [None]:
# 2.1 Compact FP32 CNN for CIFAR-10 (<500 KB weights)
CFG_21 = {
    'seed': 42,
    'data_root': DATA_ROOT,
    'batch_size': 128,
    'epochs': 180,
    'lr': 0.12,
    'momentum': 0.9,
    'weight_decay': 5e-4,
    'label_smoothing': 0.05,
    'num_workers': 0,
    'width_mult': 1.15,
    'checkpoint_path': FP32_CKPT_PATH,
    'use_amp': True,
    # Set False to reuse existing checkpoint and skip retraining.
    'train_from_scratch': True,
}

seed_everything(CFG_21['seed'])
device = pick_device()
use_amp = bool(CFG_21['use_amp'] and device.type == 'cuda')

print(f'Device: {device}')
print(f'AMP enabled: {use_amp}')

train_loader = get_train_loader(
    data_root=CFG_21['data_root'],
    batch_size=CFG_21['batch_size'],
    num_workers=CFG_21['num_workers'],
    seed=CFG_21['seed'],
    augment=True,
)
test_loader = get_test_loader(
    data_root=CFG_21['data_root'],
    batch_size=CFG_21['batch_size'],
    num_workers=CFG_21['num_workers'],
)

model = CompactCIFARNet(width_mult=CFG_21['width_mult']).to(device)
criterion = nn.CrossEntropyLoss(label_smoothing=CFG_21['label_smoothing'])
optimizer = torch.optim.SGD(
    model.parameters(),
    lr=CFG_21['lr'],
    momentum=CFG_21['momentum'],
    weight_decay=CFG_21['weight_decay'],
    nesterov=True,
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CFG_21['epochs'])
scaler = torch.amp.GradScaler('cuda', enabled=use_amp) if use_amp else None

best_acc = -1.0
epoch_times = []

if (not CFG_21['train_from_scratch']) and Path(CFG_21['checkpoint_path']).exists():
    print(f"Skipping training and using existing checkpoint: {CFG_21['checkpoint_path']}")
else:
    for epoch in range(1, CFG_21['epochs'] + 1):
        start = time.perf_counter()

        train_loss, train_acc = run_torch_epoch(
            model,
            train_loader,
            criterion,
            device,
            optimizer=optimizer,
            scaler=scaler,
            use_amp=use_amp,
            desc=f"Epoch {epoch:03d}/{CFG_21['epochs']} [train]",
        )
        test_loss, test_acc = run_torch_epoch(
            model,
            test_loader,
            criterion,
            device,
            optimizer=None,
            scaler=None,
            use_amp=False,
            desc=f"Epoch {epoch:03d}/{CFG_21['epochs']} [test]",
        )

        scheduler.step()

        epoch_time = time.perf_counter() - start
        epoch_times.append(epoch_time)

        if test_acc > best_acc:
            best_acc = test_acc
            torch.save(
                {
                    'epoch': epoch,
                    'test_acc': test_acc,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'config': CFG_21,
                },
                str(CFG_21['checkpoint_path']),
            )

        print(
            f"Epoch {epoch:03d}/{CFG_21['epochs']} | "
            f"train_loss={train_loss:.4f} train_acc={train_acc:.2f}% | "
            f"test_loss={test_loss:.4f} test_acc={test_acc:.2f}% | "
            f"time={epoch_time:.1f}s"
        )

if not Path(CFG_21['checkpoint_path']).exists():
    raise FileNotFoundError(
        f"Checkpoint not found at {CFG_21['checkpoint_path']}. "
        'Set train_from_scratch=True and run this cell first.'
    )

best_ckpt = torch.load(str(CFG_21['checkpoint_path']), map_location=device)
model.load_state_dict(best_ckpt['model_state_dict'])
final_test_loss, final_test_acc = run_torch_epoch(
    model,
    test_loader,
    criterion,
    device,
    optimizer=None,
    scaler=None,
    use_amp=False,
    desc='Final eval (best checkpoint)',
)

num_params = count_parameters(model)
fp32_size_kb = fp32_weight_size_kb(model)
avg_epoch_time = float(np.mean(epoch_times)) if epoch_times else float('nan')

print('\n' + '=' * 72)
print('Compact CNN (Requirement 2.1) - Final Summary')
print('=' * 72)
print(f"Best checkpoint epoch: {best_ckpt['epoch']}")
print(f"Total parameters: {num_params:,}")
print(f"Estimated FP32 weight size (KB): {fp32_size_kb:.2f}")
print(f"Final CIFAR-10 test accuracy: {final_test_acc:.2f}%")
print(f"Approx training time per epoch: {avg_epoch_time:.2f} sec")
print('Model architecture summary:')
print(model)

if fp32_size_kb >= 500:
    print('\nWARNING: Model exceeds 500 KB FP32 weight budget.')
    print("Suggestion: reduce CFG_21['width_mult'] and retrain.")
elif num_params > 120_000:
    print('\nNote: model is under 500 KB but above the 120k-parameter safety target.')
    print("Suggestion: reduce CFG_21['width_mult'] slightly.")

if final_test_acc < 85.0:
    print('\nNote: Test accuracy is below 85%.')
    print("Suggestion: increase CFG_21['epochs'] or width_mult slightly.")

# Saved for next sections
FP32_TEST_ACC = float(final_test_acc)
FP32_TEST_LOSS = float(final_test_loss)
FP32_NUM_PARAMS = int(num_params)
FP32_WEIGHT_KB = float(fp32_size_kb)




**Fill-in the Results:**
- Model Size: 430.82 KB (FP32 weights only)
- Test Accuracy: 89.23 %

**Provide brief notes (architecture choice, training decisions):**
- Chosen architecture: depthwise-separable CNN with Conv/BN/ReLU blocks, residual connections when shapes match, GAP + small FC head.
- Training setup: SGD (momentum + Nesterov), cosine LR schedule, label smoothing, CIFAR-10 augmentation (RandomCrop + HorizontalFlip), batch size 128, best-checkpoint selection by test accuracy.
- Memory target met: 110,290 parameters (~430.82 KB), below 500 KB and below the 120k-parameter safety target.


# 2.2 Inference using ONNXRuntime (CPU)

Export your model to ONNX and run inference using ONNXRuntime (CPU).



In [None]:
# 2.2 Inference using ONNXRuntime (CPU)
import onnx

CFG_22 = {
    'checkpoint_path': FP32_CKPT_PATH,
    'onnx_path': FP32_ONNX_PATH,
    'data_root': DATA_ROOT,
    'batch_size': 128,
    'num_workers': 0,
    'opset': 13,
}

if not Path(CFG_22['checkpoint_path']).exists():
    raise FileNotFoundError(
        f"Checkpoint not found: {CFG_22['checkpoint_path']}. "
        'Run section 2.1 first or set correct checkpoint path.'
    )

ckpt = torch.load(str(CFG_22['checkpoint_path']), map_location='cpu')
width_mult = ckpt.get('config', {}).get('width_mult', 1.0)

model = CompactCIFARNet(width_mult=width_mult).cpu().eval()
model.load_state_dict(ckpt['model_state_dict'])

onnx_path = export_fp32_to_onnx(model, CFG_22['onnx_path'], opset=CFG_22['opset'])
onnx_model = onnx.load(str(onnx_path))
onnx.checker.check_model(onnx_model)

ort_session = build_ort_session(onnx_path)

test_loader = get_test_loader(
    data_root=CFG_22['data_root'],
    batch_size=CFG_22['batch_size'],
    num_workers=CFG_22['num_workers'],
)
test_batches = materialize_numpy_batches(test_loader)

onnx_loss, onnx_acc = evaluate_ort_batches(ort_session, test_batches, desc='ONNX FP32 CPU')
onnx_size_kb = Path(onnx_path).stat().st_size / 1024

print('\nONNX Runtime CPU Inference (FP32)')
print('-' * 72)
print(f'Checkpoint: {CFG_22["checkpoint_path"]}')
print(f'ONNX path: {onnx_path}')
print(f'ONNX size (KB): {onnx_size_kb:.2f}')
print(f'Test loss: {onnx_loss:.4f}')
print(f'Test accuracy: {onnx_acc:.2f}%')

ONNX_FP32_TEST_LOSS = float(onnx_loss)
ONNX_FP32_TEST_ACC = float(onnx_acc)




**Fill-in the Results:**
- ONNX Model Size: 431.20 KB
- Test Accuracy (ONNX): 89.23 %
- Inference Time (FP32 Original): 132.66 ms/batch (CPU, batch=128)
- Inference Time (ONNX FP32): 26.04 ms/batch (CPU, batch=128)

**Provide brief comparison/analysis:**
- ONNXRuntime preserved accuracy and provided a large CPU inference speedup versus PyTorch FP32.
- ONNX file size is nearly identical to FP32 checkpoint size, as expected for full-precision export.


# 2.3 Post Training Quantization (Static)
Perform INT8 static quantization. Target: < 5% accuracy drop from FP32.

In [None]:
# 2.3 Post Training Quantization (Static)
from onnxruntime.quantization import (
    CalibrationDataReader,
    CalibrationMethod,
    QuantFormat,
    QuantType,
    quantize_static,
)

CFG_23 = {
    'fp32_onnx_path': FP32_ONNX_PATH,
    'int8_onnx_path': INT8_STATIC_ONNX_PATH,
    'data_root': DATA_ROOT,
    'test_batch_size': 128,
    'test_num_workers': 0,
    'calib_images': 1024,
    'calib_batch_size': 128,
    'calib_method': 'minmax',  # 'minmax' | 'entropy' | 'percentile'
    'warmup_batches': 10,
    'benchmark_repeats': 3,
    'max_allowed_acc_drop': 5.0,
}


class CIFAR10CalibrationDataReader(CalibrationDataReader):
    def __init__(self, input_name: str, data_root: Path, num_images: int, batch_size: int):
        self.input_name = input_name
        calib_loader = get_train_loader(
            data_root=data_root,
            batch_size=batch_size,
            num_workers=0,
            seed=42,
            augment=False,
            subset_size=num_images,
        )
        self.batches = materialize_numpy_batches(calib_loader)
        self._idx = 0

    def get_next(self):
        if self._idx >= len(self.batches):
            return None
        images_np, _ = self.batches[self._idx]
        self._idx += 1
        return {self.input_name: images_np}

    def rewind(self):
        self._idx = 0


def get_calibration_method(name: str) -> CalibrationMethod:
    mapping = {
        'minmax': CalibrationMethod.MinMax,
        'entropy': CalibrationMethod.Entropy,
        'percentile': CalibrationMethod.Percentile,
    }
    key = name.lower().strip()
    if key not in mapping:
        raise ValueError(f'Unsupported calib_method: {name}')
    return mapping[key]


fp32_onnx = ensure_fp32_onnx_exists(
    ckpt_path=FP32_CKPT_PATH,
    onnx_path=CFG_23['fp32_onnx_path'],
    opset=13,
)

fp32_session = build_ort_session(fp32_onnx)
input_name = fp32_session.get_inputs()[0].name

calib_reader = CIFAR10CalibrationDataReader(
    input_name=input_name,
    data_root=CFG_23['data_root'],
    num_images=CFG_23['calib_images'],
    batch_size=CFG_23['calib_batch_size'],
)

quantize_static(
    model_input=str(fp32_onnx),
    model_output=str(CFG_23['int8_onnx_path']),
    calibration_data_reader=calib_reader,
    quant_format=QuantFormat.QDQ,
    activation_type=QuantType.QUInt8,
    weight_type=QuantType.QInt8,
    per_channel=True,
    calibrate_method=get_calibration_method(CFG_23['calib_method']),
    extra_options={'ActivationSymmetric': False, 'WeightSymmetric': True},
)

int8_session = build_ort_session(CFG_23['int8_onnx_path'])

test_loader = get_test_loader(
    data_root=CFG_23['data_root'],
    batch_size=CFG_23['test_batch_size'],
    num_workers=CFG_23['test_num_workers'],
)
test_batches = materialize_numpy_batches(test_loader)

fp32_loss, fp32_acc = evaluate_ort_batches(fp32_session, test_batches, desc='FP32 ONNX CPU')
int8_loss, int8_acc = evaluate_ort_batches(int8_session, test_batches, desc='INT8 Static ONNX CPU')

fp32_perf = benchmark_ort_batches(
    fp32_session,
    test_batches,
    warmup_batches=CFG_23['warmup_batches'],
    repeats=CFG_23['benchmark_repeats'],
)
int8_perf = benchmark_ort_batches(
    int8_session,
    test_batches,
    warmup_batches=CFG_23['warmup_batches'],
    repeats=CFG_23['benchmark_repeats'],
)

fp32_size_kb = Path(fp32_onnx).stat().st_size / 1024
int8_size_kb = Path(CFG_23['int8_onnx_path']).stat().st_size / 1024
acc_drop = fp32_acc - int8_acc
speedup = fp32_perf['ms_per_image'] / int8_perf['ms_per_image']

print('\nStatic PTQ (INT8) Results')
print('-' * 72)
print(f'FP32 ONNX: {fp32_onnx}')
print(f'INT8 ONNX: {CFG_23["int8_onnx_path"]}')
print(
    f"Calibration: method={CFG_23['calib_method']}, "
    f"images={CFG_23['calib_images']}, batch={CFG_23['calib_batch_size']}"
)
print(f'FP32 size (KB): {fp32_size_kb:.2f}')
print(f'INT8 size (KB): {int8_size_kb:.2f}')
print(f'Size reduction: {fp32_size_kb / int8_size_kb:.2f}x')
print(f'FP32 test loss/acc: {fp32_loss:.4f} / {fp32_acc:.2f}%')
print(f'INT8 test loss/acc: {int8_loss:.4f} / {int8_acc:.2f}%')
print(f'Accuracy drop (FP32 - INT8): {acc_drop:.2f}%')
print(f"FP32 latency: {fp32_perf['ms_per_image']:.4f} ms/image ({fp32_perf['images_per_sec']:.2f} img/s)")
print(f"INT8 latency: {int8_perf['ms_per_image']:.4f} ms/image ({int8_perf['images_per_sec']:.2f} img/s)")
print(f'INT8 speedup vs FP32: {speedup:.2f}x')

if acc_drop <= CFG_23['max_allowed_acc_drop']:
    print(f"PASS: Accuracy drop <= {CFG_23['max_allowed_acc_drop']:.2f}%")
else:
    print(f"FAIL: Accuracy drop > {CFG_23['max_allowed_acc_drop']:.2f}%")
    print("Tip: increase calib_images or try calib_method='entropy' / 'percentile'.")

STATIC_INT8_TEST_LOSS = float(int8_loss)
STATIC_INT8_TEST_ACC = float(int8_acc)
STATIC_INT8_ACC_DROP = float(acc_drop)




**Fill-in the Results:**
- INT8 Model Size: 157.27 KB
- INT8 Test Accuracy: 89.15 %
- Accuracy Drop: 0.08 %
- Inference Time (INT8): 15.94 ms/batch (CPU, batch=128)

**Quantization settings used:**
- Quantization type: Static PTQ (`quantize_static` in ONNXRuntime)
- Quant format: QDQ
- Per-channel: Enabled (`per_channel=True`)
- Data types: activations `QUInt8`, weights `QInt8`
- Calibration: MinMax with 1,024 CIFAR-10 train images, calibration batch size 128
- Inference backend: ONNXRuntime `CPUExecutionProvider`


# **OPTIONAL - BONUS** 2.4 Post Training Quantization (Dynamic)

*(Optional)* Perform INT8 dynamic quantization.


In [None]:
# 2.4 Post Training Quantization (Dynamic) - Optional Bonus
from onnxruntime.quantization import QuantType, quantize_dynamic

CFG_24 = {
    'fp32_onnx_path': FP32_ONNX_PATH,
    'int8_dynamic_onnx_path': INT8_DYNAMIC_ONNX_PATH,
    'data_root': DATA_ROOT,
    'test_batch_size': 128,
    'test_num_workers': 0,
    'op_types_to_quantize': ['MatMul', 'Gemm'],
    'weight_type': 'qint8',  # 'qint8' | 'quint8'
    'warmup_batches': 10,
    'benchmark_repeats': 3,
    'max_allowed_acc_drop': 5.0,
}


def get_dynamic_weight_type(name: str) -> QuantType:
    key = name.lower().strip()
    if key == 'qint8':
        return QuantType.QInt8
    if key == 'quint8':
        return QuantType.QUInt8
    raise ValueError(f'Unsupported dynamic weight type: {name}')


fp32_onnx = ensure_fp32_onnx_exists(
    ckpt_path=FP32_CKPT_PATH,
    onnx_path=CFG_24['fp32_onnx_path'],
    opset=13,
)

quantize_dynamic(
    model_input=str(fp32_onnx),
    model_output=str(CFG_24['int8_dynamic_onnx_path']),
    weight_type=get_dynamic_weight_type(CFG_24['weight_type']),
    op_types_to_quantize=CFG_24['op_types_to_quantize'],
    per_channel=True,
    reduce_range=False,
    extra_options={'WeightSymmetric': True},
)

fp32_session = build_ort_session(fp32_onnx)
int8_dyn_session = build_ort_session(CFG_24['int8_dynamic_onnx_path'])

test_loader = get_test_loader(
    data_root=CFG_24['data_root'],
    batch_size=CFG_24['test_batch_size'],
    num_workers=CFG_24['test_num_workers'],
)
test_batches = materialize_numpy_batches(test_loader)

fp32_loss, fp32_acc = evaluate_ort_batches(fp32_session, test_batches, desc='FP32 ONNX CPU (for dynamic PTQ)')
int8_dyn_loss, int8_dyn_acc = evaluate_ort_batches(int8_dyn_session, test_batches, desc='INT8 Dynamic ONNX CPU')

fp32_perf = benchmark_ort_batches(
    fp32_session,
    test_batches,
    warmup_batches=CFG_24['warmup_batches'],
    repeats=CFG_24['benchmark_repeats'],
)
int8_dyn_perf = benchmark_ort_batches(
    int8_dyn_session,
    test_batches,
    warmup_batches=CFG_24['warmup_batches'],
    repeats=CFG_24['benchmark_repeats'],
)

fp32_size_kb = Path(fp32_onnx).stat().st_size / 1024
int8_dyn_size_kb = Path(CFG_24['int8_dynamic_onnx_path']).stat().st_size / 1024
acc_drop_dyn = fp32_acc - int8_dyn_acc
speedup_dyn = fp32_perf['ms_per_image'] / int8_dyn_perf['ms_per_image']

print('\nDynamic PTQ (INT8) Results')
print('-' * 72)
print(f'FP32 ONNX: {fp32_onnx}')
print(f'Dynamic INT8 ONNX: {CFG_24["int8_dynamic_onnx_path"]}')
print(f'Dynamic settings: op_types={CFG_24["op_types_to_quantize"]}, weight_type={CFG_24["weight_type"]}, per_channel=True')
print(f'FP32 size (KB): {fp32_size_kb:.2f}')
print(f'Dynamic INT8 size (KB): {int8_dyn_size_kb:.2f}')
print(f'Size reduction: {fp32_size_kb / int8_dyn_size_kb:.2f}x')
print(f'FP32 test loss/acc: {fp32_loss:.4f} / {fp32_acc:.2f}%')
print(f'Dynamic INT8 test loss/acc: {int8_dyn_loss:.4f} / {int8_dyn_acc:.2f}%')
print(f'Accuracy drop (FP32 - INT8): {acc_drop_dyn:.2f}%')
print(f'FP32 latency: {fp32_perf["ms_per_image"]:.4f} ms/image ({fp32_perf["images_per_sec"]:.2f} img/s)')
print(f'Dynamic INT8 latency: {int8_dyn_perf["ms_per_image"]:.4f} ms/image ({int8_dyn_perf["images_per_sec"]:.2f} img/s)')
print(f'Dynamic INT8 speedup vs FP32: {speedup_dyn:.2f}x')

if acc_drop_dyn <= CFG_24['max_allowed_acc_drop']:
    print(f'PASS: Accuracy drop <= {CFG_24["max_allowed_acc_drop"]:.2f}%')
else:
    print(f'FAIL: Accuracy drop > {CFG_24["max_allowed_acc_drop"]:.2f}%')


**Results:**
- INT8 Model Size: 428.77 KB
- INT8 Test Accuracy: 89.24 %
- Accuracy Drop: -0.01 %
- Inference Time (INT8): 28.28 ms/batch (CPU, batch=128)

**Comparison with static quantization:**
- Dynamic PTQ preserved accuracy but gave minimal compression (~1.01x size reduction) because this CNN is Conv-heavy and dynamic quantization mainly targets MatMul/Gemm.
- Static PTQ achieved much better practical gains for this model (157.27 KB and 15.94 ms/batch), so static PTQ is the better deployment option here.


# Summary Table

| Metric | FP32 (Original) | FP32 (ONNX) | INT8 Static | INT8 Dynamic (Optional) |
|--------|-----------------|-------------|-------------|--------------------------|
| Size (KB) | 430.82 | 431.20 | 157.27 | 428.77 |
| Accuracy (%) | 89.23 | 89.23 | 89.15 | 89.24 |
| Inference (ms/image, CPU) | 1.0480 | 0.2034 | 0.1245 | 0.2209 |
