In [None]:
# Install dependencies (run once in Colab / environment)
!pip install --upgrade pip
!pip install onnx onnxruntime onnxruntime-tools onnxruntime-extensions torchvision pillow tqdm matplotlib
!pip install onnxscript


Collecting onnxscript
  Downloading onnxscript-0.5.6-py3-none-any.whl.metadata (13 kB)
Collecting onnx_ir<2,>=0.1.12 (from onnxscript)
  Downloading onnx_ir-0.1.12-py3-none-any.whl.metadata (3.2 kB)
Downloading onnxscript-0.5.6-py3-none-any.whl (683 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m683.0/683.0 kB[0m [31m22.6 MB/s[0m  [33m0:00:00[0m
[?25hDownloading onnx_ir-0.1.12-py3-none-any.whl (129 kB)
Installing collected packages: onnx_ir, onnxscript
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2/2[0m [onnxscript]
[1A[2KSuccessfully installed onnx_ir-0.1.12 onnxscript-0.5.6


In [None]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("ifigotin/imagenetmini-1000")

print("Path to dataset files:", path)

Using Colab cache for faster access to the 'imagenetmini-1000' dataset.
Path to dataset files: /kaggle/input/imagenetmini-1000


In [None]:
# selective_qdq_pipeline.py
import os
import time
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import transforms, datasets, models
import onnx
import onnxruntime as ort
from onnxruntime.quantization import (
    quantize_static,
    CalibrationDataReader,
    QuantFormat,
    QuantType
)
from PIL import Image
from collections import defaultdict
import pandas as pd
import matplotlib.pyplot as plt

# ----------------------------
# CONFIG
# ----------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IMAGE_SIZE = 224
BATCH_EVAL = 16         # batch used during ONNX eval (keep small if CPU)
CALIB_BATCH = 8         # batch used by calibration reader
CALIB_SAMPLES = 256     # number of images for calibration
EVAL_SAMPLES = 512      # number of val images to evaluate
OPSET = 13

# Set this to your imagenet-mini root with train/val folders
#DATA_ROOT = "/root/.cache/kagglehub/datasets/ifigotin/imagenetmini-1000/versions/1/imagenet-mini"
DATA_ROOT = "/kaggle/input/imagenetmini-1000/"
FP32_ONNX = "resnet18_fp32.onnx"
QONNX_SEL_TEMPLATE = "resnet18_qdq_selective_k{}.onnx"
QONNX_FULL = "resnet18_qlinear_full.onnx"  # optional full qlinear path

# ----------------------------
# DATA PREPROCESSING & LOADERS
# ----------------------------
preprocess = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225])
])

def get_val_loader(split="val", batch=BATCH_EVAL, max_samples=None):
    folder = os.path.join(DATA_ROOT, split)
    if os.path.isdir(folder):
        ds = datasets.ImageFolder(folder, transform=preprocess)
        if max_samples:
            ds = torch.utils.data.Subset(ds, list(range(min(max_samples, len(ds)))))
        return DataLoader(ds, batch_size=batch, shuffle=False, num_workers=2)
    else:
        # fallback: FakeData
        from torchvision.datasets import FakeData
        ds = FakeData(size=1024, image_size=(3, IMAGE_SIZE, IMAGE_SIZE), transform=preprocess)
        if max_samples:
            ds = torch.utils.data.Subset(ds, list(range(min(max_samples, len(ds)))))
        return DataLoader(ds, batch_size=batch, shuffle=False, num_workers=2)

def collect_calib_image_paths(root=DATA_ROOT, split="val", max_files=CALIB_SAMPLES):
    folder = os.path.join(root, split)
    files = []
    if not os.path.isdir(folder):
        raise FileNotFoundError(f"Calibration folder {folder} not found: update DATA_ROOT")
    for cls in os.listdir(folder):
        cls_folder = os.path.join(folder, cls)
        if not os.path.isdir(cls_folder):
            continue
        for f in os.listdir(cls_folder):
            if f.lower().endswith((".jpg", ".jpeg", ".png")):
                files.append(os.path.join(cls_folder, f))
                if len(files) >= max_files:
                    return files
    return files

# ----------------------------
# EXPORT FP32 ONNX with dynamic axes (batch dynamic)
# ----------------------------
def export_pytorch_to_onnx(model, out_path=FP32_ONNX, opset=OPSET):
    model = model.eval().cpu()
    dummy = torch.randn(1,3,IMAGE_SIZE,IMAGE_SIZE, dtype=torch.float32)
    torch.onnx.export(
        model, dummy, out_path,
        input_names=["input"], output_names=["output"],
        dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}},
        opset_version=opset, do_constant_folding=True
    )
    print("Exported FP32 ONNX:", out_path)

# ----------------------------
# CalibrationDataReader for quantize_static (reads image files)
# ----------------------------
class ImageCalibrationReader(CalibrationDataReader):
    def __init__(self, onnx_model_path, image_files, batch_size=CALIB_BATCH):
        self.session = ort.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"])
        self.input_name = self.session.get_inputs()[0].name
        self.files = image_files
        self.batch_size = batch_size
        self.idx = 0
        self.transform = preprocess

    def get_next(self):
        if self.idx >= len(self.files):
            return None
        batch_files = self.files[self.idx:self.idx + self.batch_size]
        arrs = []
        for p in batch_files:
            img = Image.open(p).convert("RGB")
            x = self.transform(img).unsqueeze(0).numpy().astype(np.float32)
            arrs.append(x)
        self.idx += self.batch_size
        batch = np.concatenate(arrs, axis=0)
        return {self.input_name: batch}

# ----------------------------
# PROFILE: execution order, per-module latency & activation memory
# ----------------------------
def is_leaf_module(m):
    return len(list(m.children())) == 0

def tensor_bytes(t):
    return t.numel() * t.element_size()

def profile_forward_order(model, sample_input, runs=8):
    model = model.eval()
    device = next(model.parameters()).device if any(True for _ in model.parameters()) else torch.device("cpu")
    x = sample_input.to(device)

    start_times = {}
    total_time = defaultdict(float)
    total_mem = defaultdict(float)
    calls = defaultdict(int)
    exec_order = []
    name_map = {}

    def pre_hook(m, inp):
        if device.type == "cuda":
            torch.cuda.synchronize()
        start_times[id(m)] = time.perf_counter()

    def post_hook(m, inp, out):
        if device.type == "cuda":
            torch.cuda.synchronize()
        elapsed = time.perf_counter() - start_times[id(m)]
        total_time[m] += elapsed
        calls[m] += 1
        # activation bytes
        if isinstance(out, torch.Tensor):
            total_mem[m] += tensor_bytes(out.detach().cpu())
        elif isinstance(out, (list, tuple)):
            s = 0
            for o in out:
                if isinstance(o, torch.Tensor):
                    s += tensor_bytes(o.detach().cpu())
            total_mem[m] += s
        if m not in exec_order:
            exec_order.append(m)

    hooks = []
    name_map = {mod: name for name, mod in model.named_modules()}
    for name, mod in model.named_modules():
        if is_leaf_module(mod):
            hooks.append(mod.register_forward_pre_hook(pre_hook))
            hooks.append(mod.register_forward_hook(post_hook))

    with torch.no_grad():
        # warmup
        for _ in range(2):
            model(x)
        # runs
        for _ in range(runs):
            model(x)

    for h in hooks: h.remove()

    avg_time = {m: total_time[m] / calls[m] for m in total_time}
    avg_mem = {m: total_mem[m] / calls[m] for m in total_mem}
    return exec_order, avg_time, avg_mem, name_map

# ----------------------------
# SENSITIVITY: L2 output change when quantizing module weights (weights-only simulation)
# ----------------------------
def quantize_tensor_sym(x, num_bits=8):
    if num_bits >= 32:
        return x.clone()
    qmax = 2 ** (num_bits - 1) - 1
    t_abs = x.abs().max()
    if t_abs == 0:
        return x.clone()
    scale = t_abs / qmax
    q = torch.clamp((x / scale).round(), -qmax, qmax)
    return q * scale

def layer_sensitivity_l2(model, module, sample_input, num_bits=8):
    model.eval()
    device = next(model.parameters()).device if any(True for _ in model.parameters()) else torch.device("cpu")
    x = sample_input.to(device)

    orig = None
    def hook_orig(m, inp, out):
        nonlocal orig
        orig = out.detach().cpu().clone()
    h1 = module.register_forward_hook(hook_orig)
    with torch.no_grad():
        model(x)
    h1.remove()
    if orig is None:
        return float("inf")

    # backup & quantize inplace
    backup = {}
    for name, p in module.named_parameters(recurse=False):
        backup[name] = p.data.clone()
        p.data = quantize_tensor_sym(p.data, num_bits=num_bits).to(p.device)

    pert = None
    def hook_pert(m, inp, out):
        nonlocal pert
        pert = out.detach().cpu().clone()
    h2 = module.register_forward_hook(hook_pert)
    with torch.no_grad():
        model(x)
    h2.remove()

    # restore
    for name, p in module.named_parameters(recurse=False):
        p.data.copy_(backup[name])

    if pert is None:
        return float("inf")
    diff = (orig - pert).view(-1)
    l2 = torch.norm(diff).item()
    return l2 / (orig.numel() + 1e-12)

# ----------------------------
# SCORE computation & selection
# ----------------------------
def compute_scores(exec_order, avg_time, avg_mem, name_map, model, sample_input):
    candidates = []
    lat_list, mem_list, sens_list = [], [], []
    infos = []

    for m in exec_order:
        # only modules with parameters
        if any(True for _ in m.parameters(recurse=False)):
            t = float(avg_time.get(m, 0.0))
            mem = float(avg_mem.get(m, 0.0)) + sum(p.numel() * p.element_size() for p in m.parameters(recurse=False))
            sens = float(layer_sensitivity_l2(model, m, sample_input))
            lat_list.append(t); mem_list.append(mem); sens_list.append(sens)
            infos.append((m, name_map.get(m, "<noname>"), t, mem, sens))

    if len(infos) == 0:
        return []

    def norm(arr):
        mn = min(arr); mx = max(arr)
        if mx - mn < 1e-12:
            return [0.0 for _ in arr]
        return [(v - mn) / (mx - mn) for v in arr]

    lat_n = norm(lat_list); mem_n = norm(mem_list); sens_n = norm(sens_list)

    scored = []
    for i, (m, nm, t, mem, sens) in enumerate(infos):
        score = (lat_n[i] + mem_n[i]) / (sens_n[i] + 1e-9)
        scored.append((score, m, nm, t, mem, sens))
    scored.sort(reverse=True, key=lambda x: x[0])
    return scored

# ----------------------------
# MAP selected PyTorch modules -> ONNX Conv/Gemm nodes (order heuristic)
# ----------------------------
def map_modules_to_onnx_nodes(scored_list, onnx_path, max_map=None):
    model = onnx.load(onnx_path)
    onnx_nodes = [n for n in model.graph.node if n.op_type in ("Conv", "Gemm", "MatMul")]
    onnx_names = [n.name if n.name != "" else f"{n.op_type}_{i}" for i, n in enumerate(onnx_nodes)]
    # Map first N scored modules to first N ONNX conv/gemm nodes
    N = len(scored_list) if max_map is None else min(len(scored_list), max_map)
    mapped = onnx_names[:N]
    return mapped

# ----------------------------
# SELECTIVE QUANTIZATION (QDQ static)
# ----------------------------
def quantize_static_selective(fp32_onnx, out_qonnx, calib_files, nodes_to_quantize, per_channel=True):
    reader = ImageCalibrationReader(fp32_onnx, calib_files, batch_size=CALIB_BATCH)
    print(f"Quantizing (QDQ) selected nodes count={len(nodes_to_quantize)} ...")
    try:
        quantize_static(
            model_input=fp32_onnx,
            model_output=out_qonnx,
            calibration_data_reader=reader,
            quant_format=QuantFormat.QDQ,
            activation_type=QuantType.QInt8,
            weight_type=QuantType.QInt8,
            per_channel=per_channel,
            nodes_to_quantize=nodes_to_quantize
        )
        print("Saved selective QDQ model:", out_qonnx)
    except TypeError as e:
        # nodes_to_quantize arg might not exist in some ORT versions
        print("quantize_static selective failed (nodes_to_quantize not supported). Falling back to full QDQ.")
        quantize_static(
            model_input=fp32_onnx,
            model_output=out_qonnx,
            calibration_data_reader=reader,
            quant_format=QuantFormat.QDQ,
            activation_type=QuantType.QInt8,
            weight_type=QuantType.QInt8,
            per_channel=per_channel
        )
        print("Saved full QDQ model as fallback:", out_qonnx)

# ----------------------------
# EVALUATE ONNX (accuracy, latency, size)
# ----------------------------
def evaluate_onnx_model(onnx_path, loader, provider="CPUExecutionProvider", max_images=EVAL_SAMPLES):
    sess = ort.InferenceSession(onnx_path, providers=[provider])
    input_name = sess.get_inputs()[0].name
    correct = 0
    total = 0
    times = []
    for images, labels in loader:
        imgs = images.numpy().astype(np.float32)
        t0 = time.time()
        outputs = sess.run(None, {input_name: imgs})[0]
        t1 = time.time()
        preds = np.argmax(outputs, axis=1)
        correct += (preds == labels.numpy()).sum()
        total += labels.size(0)
        times.append(t1 - t0)
        if total >= max_images:
            break
    acc = correct / total if total>0 else 0.0
    latency_ms = (sum(times) / len(times)) * 1000.0 if len(times)>0 else 0.0
    size_mb = os.path.getsize(onnx_path) / (1024*1024)
    return acc, latency_ms, size_mb

# ----------------------------
# MAIN PIPELINE
# ----------------------------
def main():
    print("Loading PyTorch model (ResNet18 pretrained)...")
    pt_model = models.resnet18(weights="IMAGENET1K_V1").to("cpu").eval()

    # sample for profiling: 8 images
    val_small_loader = get_val_loader("val", batch=8, max_samples=8)
    sample_images, _ = next(iter(val_small_loader))
    sample_images = sample_images[:8]

    print("Profiling forward order / latency / activation memory...")
    exec_order, avg_time_map, avg_mem_map, name_map = profile_forward_order(pt_model, sample_images, runs=8)
    print(f"Leaf modules executed: {len(exec_order)}")

    print("Computing sensitivity & scores (this may take time)...")
    scored = compute_scores(exec_order, avg_time_map, avg_mem_map, name_map, pt_model, sample_images)
    print("Top 10 scored modules:")
    for s, m, nm, t, mem, sens in scored[:10]:
        print(f"{nm:40s} | score={s:.3e} | lat={t:.6f}s | mem={mem/1024:.1f}KB | sens={sens:.3e}")

    print("Exporting FP32 model to ONNX with dynamic batch...")
    export_pytorch_to_onnx(pt_model, FP32_ONNX)

    print("Collecting calibration images...")
    calib_files = collect_calib_image_paths(root=DATA_ROOT, split="val", max_files=CALIB_SAMPLES)
    print(f"Calibration images: {len(calib_files)}")

    # optional: do a full QLinear quant (comment if you don't want)
    # print("Running full QLinear static quant (optional)...")
    # run_full_qlinear = False
    # if run_full_qlinear:
    #     # You can use quantize_static (QOperator) to produce QONNX_FULL
    #     pass

    # prepare evaluation loader (val)
    eval_loader = get_val_loader("val", batch=BATCH_EVAL, max_samples=EVAL_SAMPLES)

    print("Evaluate FP32 ONNX baseline...")
    fp32_acc, fp32_lat, fp32_size = evaluate_onnx_model(FP32_ONNX, eval_loader)
    print(f"FP32: Acc={fp32_acc:.4f}, Latency={fp32_lat:.2f} ms, Size={fp32_size:.2f} MB")

    # sweep K (number of top-scored modules to quantize)
    max_map = min(len(scored),  len([n for n in onnx.load(FP32_ONNX).graph.node if n.op_type in ("Conv","Gemm","MatMul")]))
    ks = list(sorted(set([0, 1, 2, 5, 10, 20, max_map])))  # tune as needed
    ks = [k for k in ks if k <= max_map]

    results = []
    for k in ks:
        print("==== Sweep K =", k, "====")
        topk = scored[:k]
        # map to ONNX node names by order heuristic
        onnx_nodes_to_quantize = map_modules_to_onnx_nodes(topk, FP32_ONNX, max_map=k)
        print(f"Mapped {len(onnx_nodes_to_quantize)} ONNX nodes for quantization.")

        out_q = QONNX_SEL_TEMPLATE.format(k)
        quantize_static_selective(FP32_ONNX, out_q, calib_files, onnx_nodes_to_quantize, per_channel=True)

        acc, lat, size = evaluate_onnx_model(out_q, eval_loader)
        print(f"K={k} -> Acc={acc:.4f}, Lat={lat:.2f} ms, Size={size:.2f} MB")
        results.append({"K": k, "accuracy": acc, "latency_ms": lat, "size_mb": size})

    # include FP32 baseline at K=0 (already included if 0 in ks)
    df = pd.DataFrame(results)
    df.to_csv("selective_qdq_sweep_results.csv", index=False)
    print("Saved sweep results to selective_qdq_sweep_results.csv")
    print(df)

    # PLOT
    plt.figure(figsize=(6,4)); plt.plot(df["K"], df["accuracy"], marker='o'); plt.xlabel("K (quantized nodes)"); plt.ylabel("Top-1 Accuracy"); plt.grid(True); plt.title("Accuracy vs K"); plt.show()
    plt.figure(figsize=(6,4)); plt.plot(df["K"], df["latency_ms"], marker='o'); plt.xlabel("K (quantized nodes)"); plt.ylabel("Latency (ms)"); plt.grid(True); plt.title("Latency vs K"); plt.show()
    plt.figure(figsize=(6,4)); plt.plot(df["K"], df["size_mb"], marker='o'); plt.xlabel("K (quantized nodes)"); plt.ylabel("On-disk Size (MB)"); plt.grid(True); plt.title("Model Size vs K"); plt.show()

if __name__ == "__main__":
    main()


Loading PyTorch model (ResNet18 pretrained)...
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████| 44.7M/44.7M [00:00<00:00, 156MB/s]


Profiling forward order / latency / activation memory...
Leaf modules executed: 52
Computing sensitivity & scores (this may take time)...
Top 10 scored modules:
bn1                                      | score=1.337e+09 | lat=0.016521s | mem=25088.5KB | sens=1.152e-06
layer1.1.conv2                           | score=1.096e+02 | lat=0.031730s | mem=6416.0KB | sens=3.617e-06
layer1.0.bn2                             | score=1.084e+02 | lat=0.001205s | mem=6272.5KB | sens=1.844e-06
layer4.1.conv2                           | score=1.016e+02 | lat=0.037423s | mem=10000.0KB | sens=4.609e-06
layer4.0.conv2                           | score=9.997e+01 | lat=0.035935s | mem=10000.0KB | sens=4.572e-06
layer3.1.conv2                           | score=9.595e+01 | lat=0.031435s | mem=3872.0KB | sens=3.617e-06
layer2.1.conv2                           | score=9.500e+01 | lat=0.030656s | mem=3712.0KB | sens=3.569e-06
layer1.1.bn2                             | score=7.509e+01 | lat=0.001175s | mem=6272.5

ModuleNotFoundError: No module named 'onnxscript'

In [None]:
import os
import torch
import torch.nn as nn
from torchvision import models, datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import time
from collections import defaultdict

import onnx
import onnxruntime
from onnxruntime.quantization import quantize_static, CalibrationDataReader, QuantType

################################################################################
# 0. CONFIGURATION
################################################################################

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IMAGE_SIZE = 224
BATCH_SIZE = 1

DATA_ROOT = "/kaggle/input/imagenetmini-1000/imagenet-mini/train"
#DATA_ROOT = "/root/.cache/kagglehub/datasets/ifigotin/imagenetmini-1000/versions/1/imagenet-mini/train"

transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225])
])

dataset = datasets.ImageFolder(DATA_ROOT, transform)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)

model = models.resnet18(weights="IMAGENET1K_V1").eval().to(DEVICE)

################################################################################
# 1. BASIC UTILITY FUNCTIONS
################################################################################

def is_leaf(m):
    return (len(list(m.children())) == 0)

def tensor_bytes(t):
    return t.numel() * t.element_size()

def param_bytes(m):
    return sum(p.numel() * p.element_size() for p in m.parameters(recurse=False))

################################################################################
# 2. PROFILE LAYER LATENCY + MEMORY
################################################################################
def profile_layers(model, x):
    model.eval()
    x = x.to(DEVICE)

    names = {m: n for n, m in model.named_modules()}
    lat, mem, calls = defaultdict(float), defaultdict(float), defaultdict(int)

    start_times = {}

    def pre_hook(m, inp):
        torch.cuda.synchronize() if DEVICE == "cuda" else None
        start_times[m] = time.perf_counter()

    def fwd_hook(m, inp, out):
        torch.cuda.synchronize() if DEVICE == "cuda" else None
        dt = time.perf_counter() - start_times[m]
        lat[m] += dt
        calls[m] += 1
        if isinstance(out, torch.Tensor):
            mem[m] += tensor_bytes(out.cpu())

    hooks = []
    for name, m in model.named_modules():
        if is_leaf(m):
            hooks.append(m.register_forward_pre_hook(pre_hook))
            hooks.append(m.register_forward_hook(fwd_hook))

    with torch.no_grad():
        for _ in range(5):
            model(x)

    for h in hooks:
        h.remove()

    for m in lat:
        lat[m] /= calls[m]
        mem[m] /= calls[m]

    return lat, mem, names

################################################################################
# 3. SENSITIVITY (L2 output difference from quant-dequant weights)
################################################################################
def quant_dequant(w, num_bits=8):
    qmax = 2**(num_bits-1) - 1
    scale = w.abs().max() / qmax
    return (w/scale).round().clamp(-qmax, qmax) * scale

def layer_sensitivity(model, module, sample):
    model.eval()
    sample = sample.to(DEVICE)

    orig_out = None
    def h1(m, i, o):
        nonlocal orig_out
        orig_out = o.detach().cpu()

    hook1 = module.register_forward_hook(h1)
    with torch.no_grad(): model(sample)
    hook1.remove()

    if orig_out is None:
        return float("inf")

    backup = {n: p.data.clone() for n, p in module.named_parameters(recurse=False)}

    for n, p in module.named_parameters(recurse=False):
        p.data = quant_dequant(p.data)

    pert_out = None
    def h2(m, i, o):
        nonlocal pert_out
        pert_out = o.detach().cpu()

    hook2 = module.register_forward_hook(h2)
    with torch.no_grad(): model(sample)
    hook2.remove()

    # restore
    for n, p in module.named_parameters(recurse=False):
        p.data.copy_(backup[n])

    if pert_out is None:
        return float("inf")

    diff = (orig_out - pert_out).view(-1)
    return torch.norm(diff).item() / orig_out.numel()

################################################################################
# 4. SELECT BEST LAYERS FOR QUANTIZATION
################################################################################
def select_layers(model, lat, mem, names, sample, K=10):
    scores = []

    for m in lat:
        if any(True for _ in m.parameters(recurse=False)):
            sens = layer_sensitivity(model, m, sample)
            score = (lat[m] + mem[m]) / (sens + 1e-9)
            scores.append((score, names[m]))

    scores.sort(reverse=True)
    selected = [name for _, name in scores[:K]]
    return selected, scores

################################################################################
# 5. EXPORT ONNX
################################################################################
def export_onnx(model):
    dummy = torch.randn(1,3,224,224).to(DEVICE)
    torch.onnx.export(model, dummy, "model.onnx",
                      input_names=["input"],
                      output_names=["output"],
                      opset_version=13,
                      dynamic_axes={"input": {0:"batch"},
                                    "output": {0:"batch"}})
    print("Exported model.onnx")

################################################################################
# 6. DUMMY CALIBRATION FOR STATIC QUANTIZATION
################################################################################
class DummyDataReader(CalibrationDataReader):
    def __init__(self):
        self.batch = {"input": np.random.randn(1,3,224,224).astype(np.float32)}
        self.done = False
    def get_next(self):
        if self.done:
            return None
        self.done = True
        return self.batch

################################################################################
# 7. RUN STATIC QUANTIZATION (WITH SELECTIVE LAYER LIST)
################################################################################
def selective_quantize_static(selected_nodes, output="model_int8.onnx"):
    dr = DummyDataReader()  # no need to load val dataset

    quantize_static(
        model_input="model.onnx",
        model_output=output,
        calibration_data_reader=dr,
        quant_format="QLinearOps",
        activation_type=QuantType.QUInt8,
        weight_type=QuantType.QInt8,
        nodes_to_quantize=selected_nodes
    )
    quantize_static(
    model_input="model.onnx",
    model_output="model_int8.onnx",
    calibration_data_reader=dr,
    quant_format=QuantFormat.QOperator,        # <--- IMPORTANT
    activation_type=QuantType.QUInt8,
    weight_type=QuantType.QInt8,
)

    print("Saved quantized model as:", output)

################################################################################
# 8. ONNX INFERENCE
################################################################################
def eval_onnx(path, loader):
    sess = onnxruntime.InferenceSession(path, providers=["CPUExecutionProvider"])
    name = sess.get_inputs()[0].name

    correct, total, lat = 0, 0, 0

    for x, y in loader:
        x = x.numpy()

        t0 = time.perf_counter()
        out = sess.run(None, {name: x})[0]
        lat += time.perf_counter() - t0

        pred = out.argmax(1)
        correct += (pred == y.numpy()).sum()
        total += len(y)

        if total > 200:  # speed up
            break

    return correct/total, lat/total

################################################################################
# 9. MAIN
################################################################################
def main():
    sample, _ = next(iter(dataloader))
    sample = sample[:8]

    print("Profiling model...")
    lat, mem, names = profile_layers(model, sample)

    print("Selecting layers...")
    selected, scores = select_layers(model, lat, mem, names, sample, K=12)
    print("Selected:", selected," total layers:",len(scores))

    print("Exporting ONNX...")
    export_onnx(model)

    print("Quantizing selectively (INT8 QLinear)...")
    selective_quantize_static(selected, "model_int8.onnx")

    print("Evaluating FP32...")
    acc_fp32, lat_fp32 = eval_onnx("model.onnx", dataloader)

    print("Evaluating INT8...")
    acc_int8, lat_int8 = eval_onnx("model_int8.onnx", dataloader)

    print("\n==============================")
    print(" FP32  : Acc =", acc_fp32, ", Lat =", lat_fp32*1000, "ms/img")
    print(" INT8  : Acc =", acc_int8, ", Lat =", lat_int8*1000, "ms/img")
    print("Size FP32 :", os.path.getsize("model.onnx")/1e6, "MB")
    print("Size INT8 :", os.path.getsize("model_int8.onnx")/1e6, "MB")
    print("==============================")

main()


Profiling model...
Selecting layers...
Selected: ['bn1', 'layer1.1.bn1', 'layer1.0.bn1', 'layer1.0.bn2', 'layer1.1.bn2', 'layer2.0.bn1', 'layer2.0.bn2', 'layer2.1.bn1', 'layer1.1.conv2', 'layer2.0.downsample.1', 'layer2.1.bn2', 'conv1']  total layers: 41
Exporting ONNX...


  torch.onnx.export(model, dummy, "model.onnx",
W1128 19:53:13.354000 193 torch/onnx/_internal/exporter/_compat.py:114] Setting ONNX exporter to use operator set version 18 because the requested opset_version 13 is a lower version than we have implementations for. Automatic version conversion will be performed, which may not be successful at converting to the requested version. If version conversion is unsuccessful, the opset version of the exported model will be kept at 18. Please consider setting opset_version >=18 to leverage latest ONNX features


[torch.onnx] Obtain model graph for `ResNet([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `ResNet([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...


Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/onnxscript/version_converter/__init__.py", line 127, in call
    converted_proto = _c_api_utils.call_onnx_api(
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/onnxscript/version_converter/_c_api_utils.py", line 65, in call_onnx_api
    result = func(proto)
             ^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/onnxscript/version_converter/__init__.py", line 122, in _partial_convert_version
    return onnx.version_converter.convert_version(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/onnx/version_converter.py", line 39, in convert_version
    converted_model_str = C.convert_version(model_str, target_version)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: /github/workspace/onnx/version_converter/adapters/axes_input_to_attribute.h:65: adapt: Asserti

[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
Applied 40 of general pattern rewrite rules.




Exported model.onnx
Quantizing selectively (INT8 QLinear)...




ValueError: No data is collected.