# 4. Exporting Models to ONNX and Running Inference

This notebook demonstrates how to export the trained/optimized PyTorch models to the ONNX (Open Neural Network Exchange) format. It also shows how to run inference using ONNX Runtime on the exported models.

We will export three specific model versions:
1.  **Baseline FP32 Model**: The original MobileNetV2 model adapted for CIFAR-10.
3.  **Baseline QAT INT8 Model**: The baseline model quantized to INT8 using Quantization-Aware Training.

In [25]:
import os
import torch
import numpy as np
import onnxruntime as ort

from nnopt.model.export import export_model_to_onnx
from nnopt.recipes.mobilenetv2_cifar10 import load_mobilenetv2_cifar10_model
from nnopt.model.prune import remove_pruning_reparameterization
from nnopt.model.const import BASE_MODEL_DIR, DEVICE

# Ensure the logger in export is configured (if not already by its import)
import logging
logger = logging.getLogger("nnopt.model.export")
if not logger.hasHandlers():
    logging.basicConfig(level=logging.INFO,
                        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
                        handlers=[logging.StreamHandler()])

print(f"PyTorch version: {torch.__version__}")
print(f"ONNX Runtime version: {ort.__version__}")
print(f"Using device: {DEVICE}")

PyTorch version: 2.7.1+cu126
ONNX Runtime version: 1.22.0
Using device: cuda


## Configuration and Dummy Input

In [None]:
# Define model versions to be exported
baseline_fp32_version = "mobilenetv2_cifar10/fp32/baseline"
pruned_fp32_version = "mobilenetv2_cifar10/fp32/l1_struct_prune_0.5"
qat_int8_version = "mobilenetv2_cifar10/int8/qat_baseline"


# Directory to save ONNX models
ONNX_EXPORT_DIR = os.path.join(BASE_MODEL_DIR, "onnx_exports")
os.makedirs(ONNX_EXPORT_DIR, exist_ok=True)
print(f"ONNX models will be saved in: {ONNX_EXPORT_DIR}")

# Create a dummy input tensor (batch_size, channels, height, width)
# MobileNetV2 typically expects 224x224 images.
# CIFAR-10 images are 32x32, but the model adapts them or uses a standard input size.
# Using 224x224 as per common MobileNetV2 usage and example in pruning notebook.
dummy_input_shape = (1, 3, 224, 224)
dummy_input = torch.randn(dummy_input_shape, device='cpu') # ONNX export prefers CPU dummy input
print(f"Dummy input shape: {dummy_input.shape}")

# Define dynamic axes for batch size flexibility
dynamic_axes = {'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}

ONNX models will be saved in: /home/pbeuran/repos/nnopt/models/onnx_exports
Dummy input shape: torch.Size([1, 3, 224, 224])


## 1. Baseline FP32 Model Export & Inference

In [27]:
# Load the baseline FP32 model
print(f"Loading baseline FP32 model from version: {baseline_fp32_version}")
baseline_fp32_model, _ = load_mobilenetv2_cifar10_model(
    version=baseline_fp32_version,
    mode="jit_trace"
)
baseline_fp32_model.eval()
baseline_fp32_model.to('cpu') # Move model to CPU for export

# Define ONNX path
onnx_path_baseline_fp32 = os.path.join(ONNX_EXPORT_DIR, "mobilenetv2_cifar10_baseline_fp32.onnx")

# Export to ONNX
print(f"Exporting baseline FP32 model to {onnx_path_baseline_fp32}...")
success_fp32 = export_model_to_onnx(
    model=baseline_fp32_model,
    dummy_input=dummy_input,
    onnx_path=onnx_path_baseline_fp32,
    dynamic_axes=dynamic_axes,
    opset_version=13
)

if success_fp32:
    print("Baseline FP32 model exported successfully.")
    # Run inference with ONNX Runtime
    try:
        ort_session_fp32 = ort.InferenceSession(onnx_path_baseline_fp32, providers=['CPUExecutionProvider'])
        input_name_fp32 = ort_session_fp32.get_inputs()[0].name
        output_name_fp32 = ort_session_fp32.get_outputs()[0].name
        
        ort_inputs_fp32 = {input_name_fp32: dummy_input.cpu().numpy()}
        ort_outputs_fp32 = ort_session_fp32.run([output_name_fp32], ort_inputs_fp32)
        print(f"ONNX Runtime (FP32 Baseline) output shape: {ort_outputs_fp32[0].shape}")
        # print(f"ONNX Runtime (FP32 Baseline) output sample: {ort_outputs_fp32[0][0,:5]}") # Print first 5 logits
    except Exception as e:
        print(f"Error running ONNX Runtime for FP32 baseline model: {e}")
else:
    print("Baseline FP32 model export failed.")

2025-06-12 13:58:44,386 - nnopt.recipes.mobilenetv2_cifar10 - INFO - Loading MobileNetV2 model for CIFAR-10 from version: mobilenetv2_cifar10/fp32/baseline at /home/pbeuran/repos/nnopt/models
2025-06-12 13:58:44,387 - nnopt.recipes.mobilenetv2_cifar10 - INFO - Loaded metadata: {'metrics_values': {'val_metrics': {'accuracy': 0.9254, 'avg_loss': 0.21455251355171204, 'samples_per_second': 9594.238581494608, 'avg_time_per_batch': 0.006596786050597328, 'avg_time_per_sample': 0.00010422921959943779, 'params_stats': {'int_weight_params': 0, 'float_weight_params': 2202560, 'float_bias_params': 10, 'bn_param_params': 34112, 'other_float_params': 0, 'total_params': 2236682, 'approx_memory_mb_for_params': 8.532264709472656}}, 'test_metrics': {'accuracy': 0.9288, 'avg_loss': 0.20640371625423432, 'samples_per_second': 9116.575566348814, 'avg_time_per_batch': 0.006986643948966146, 'avg_time_per_sample': 0.00010969030999876849, 'params_stats': {'int_weight_params': 0, 'float_weight_params': 2202560, 

Loading baseline FP32 model from version: mobilenetv2_cifar10/fp32/baseline
Exporting baseline FP32 model to /home/pbeuran/repos/nnopt/models/onnx_exports/mobilenetv2_cifar10_baseline_fp32.onnx...


2025-06-12 13:58:44,705 - nnopt.model.export - INFO - Model successfully exported to /home/pbeuran/repos/nnopt/models/onnx_exports/mobilenetv2_cifar10_baseline_fp32.onnx


Baseline FP32 model exported successfully.
ONNX Runtime (FP32 Baseline) output shape: (1, 10)


### Evaluate Baseline FP32 ONNX Model

In [28]:
# Import necessary functions and data
from nnopt.model.eval import eval_onnx_model
from nnopt.recipes.mobilenetv2_cifar10 import get_cifar10_datasets

# Load CIFAR-10 test dataset
# Assuming the data is already downloaded and preprocessed as in other notebooks
# Adjust data_dir if your CIFAR-10 data is located elsewhere
DATA_DIR = os.path.join(os.getcwd(), '..', 'data', 'image', 'cifar10') 
_, test_dataset, _ = get_cifar10_datasets() # We only need test_dataset

if success_fp32: # Only proceed if the ONNX model was exported successfully
    print("\n--- Evaluating Baseline FP32 ONNX Model on CPU ---")
    onnx_metrics_cpu = eval_onnx_model(
        onnx_model_path=onnx_path_baseline_fp32,
        test_dataset=test_dataset,
        batch_size=32, # Adjust as needed
        device="cpu",
        num_warmup_batches=2 # Smaller warmup for quicker testing
    )
    print(f"CPU ONNX Metrics: {onnx_metrics_cpu}")

    if torch.cuda.is_available() and ort.get_device() == 'GPU':
        print("\n--- Evaluating Baseline FP32 ONNX Model on GPU ---")
        onnx_metrics_gpu = eval_onnx_model(
            onnx_model_path=onnx_path_baseline_fp32,
            test_dataset=test_dataset,
            batch_size=32, # Adjust as needed
            device="cuda",
            num_warmup_batches=2
        )
        print(f"GPU ONNX Metrics: {onnx_metrics_gpu}")
    else:
        print("\nSkipping GPU ONNX evaluation as CUDA is not available or ONNX Runtime GPU provider is not set up.")
else:
    print("\nSkipping ONNX model evaluation as the export failed.")

2025-06-12 13:58:47,085 - nnopt.recipes.mobilenetv2_cifar10 - INFO - Loading existing training and validation datasets...
2025-06-12 13:58:51,942 - nnopt.recipes.mobilenetv2_cifar10 - INFO - Loading existing test dataset...
2025-06-12 13:58:52,097 - nnopt.model.eval - INFO - Starting ONNX model evaluation for: /home/pbeuran/repos/nnopt/models/onnx_exports/mobilenetv2_cifar10_baseline_fp32.onnx
2025-06-12 13:58:52,097 - nnopt.model.eval - INFO - Evaluation on PyTorch device: cpu, batch size: 32
2025-06-12 13:58:52,098 - nnopt.model.eval - INFO - Using ONNX Runtime providers: ['CPUExecutionProvider']
2025-06-12 13:58:52,124 - nnopt.model.eval - INFO - ONNX Model Input Name: input, Output Name: output
2025-06-12 13:58:52,124 - nnopt.model.eval - INFO - Starting warmup for 2 batches...



--- Evaluating Baseline FP32 ONNX Model on CPU ---


[ONNX Warmup]: 100%|██████████| 2/2 [00:00<00:00,  6.34it/s]
2025-06-12 13:58:52,503 - nnopt.model.eval - INFO - Warmup complete.
2025-06-12 13:58:52,503 - nnopt.model.eval - INFO - Starting ONNX model evaluation pass...
[ONNX Evaluation]: 100%|██████████| 157/157 [00:13<00:00, 11.43it/s]


ONNX Evaluation Complete: Avg Loss: 0.2173, Accuracy: 0.9248
Throughput: 465.50 samples/sec | Avg Batch Time: 68.42 ms | Avg Sample Time: 2.15 ms
System Stats (PyTorch side): CPU Usage: 4.90% | RAM Usage: 9.2/30.9GB (38.7%)
CPU ONNX Metrics: {'accuracy': 0.9248, 'avg_loss': 0.21728210688829422, 'samples_per_second': 465.4980484405136, 'avg_time_per_batch': 0.06841518211441343, 'avg_time_per_sample': 0.0021482367183925815}

Skipping GPU ONNX evaluation as CUDA is not available or ONNX Runtime GPU provider is not set up.


## 2. Baseline QAT INT8 Model Export & Inference

In [29]:
# Load the Baseline QAT INT8 model
print(f"Loading Baseline QAT INT8 model from version: {qat_int8_version}")
# The QAT model is saved after torch.quantization.convert, so it's already an INT8 model.
# We use quantized=True in get_mobilenetv2_cifar10_model to load the correct model architecture
# (e.g., torchvision.models.quantization.mobilenet_v2)
qat_int8_model, _ = load_mobilenetv2_cifar10_model(
    version=qat_int8_version,
    device='cpu', # QAT models are typically exported on CPU
    mode="jit_trace"
)
qat_int8_model.eval()
qat_int8_model.to('cpu') # Quantized models run on CPU. ONNX export also expects CPU model.

# Define ONNX path
onnx_path_qat_int8 = os.path.join(ONNX_EXPORT_DIR, "mobilenetv2_cifar10_qat_int8.onnx")

# Export to ONNX
# Opset version 13+ is generally recommended for better support of quantized operators.
print(f"Exporting Baseline QAT INT8 model to {onnx_path_qat_int8}...")
success_qat_int8 = export_model_to_onnx(
    model=qat_int8_model,
    dummy_input=dummy_input, # Dummy input should be FP32 for QAT model export
    onnx_path=onnx_path_qat_int8,
    dynamic_axes=dynamic_axes,
    opset_version=13 # Use opset 13 or higher for QAT models
)

if success_qat_int8:
    print("Baseline QAT INT8 model exported successfully.")
    # Run inference with ONNX Runtime
    try:
        ort_session_qat_int8 = ort.InferenceSession(onnx_path_qat_int8, providers=['CPUExecutionProvider'])
        input_name_qat_int8 = ort_session_qat_int8.get_inputs()[0].name
        output_name_qat_int8 = ort_session_qat_int8.get_outputs()[0].name
        
        # Input to ONNX Runtime for QAT model is also FP32
        ort_inputs_qat_int8 = {input_name_qat_int8: dummy_input.cpu().numpy()}
        ort_outputs_qat_int8 = ort_session_qat_int8.run([output_name_qat_int8], ort_inputs_qat_int8)
        print(f"ONNX Runtime (QAT INT8) output shape: {ort_outputs_qat_int8[0].shape}")
    except Exception as e:
        print(f"Error running ONNX Runtime for QAT INT8 model: {e}")
else:
    print("Baseline QAT INT8 model export failed.")

2025-06-12 13:59:06,287 - nnopt.recipes.mobilenetv2_cifar10 - INFO - Loading MobileNetV2 model for CIFAR-10 from version: mobilenetv2_cifar10/int8/qat_baseline at /home/pbeuran/repos/nnopt/models
2025-06-12 13:59:06,288 - nnopt.recipes.mobilenetv2_cifar10 - INFO - Loaded metadata: {'metrics_values': {'accuracy': 0.8008, 'avg_loss': 0.6162318737983704, 'samples_per_second': 781.2681917092245, 'avg_time_per_batch': 0.04076338201903243, 'avg_time_per_sample': 0.0012799701953976183, 'params_stats': {'int_weight_params': 2202560, 'float_weight_params': 0, 'float_bias_params': 17066, 'bn_param_params': 0, 'other_float_params': 0, 'total_params': 2219626, 'approx_memory_mb_for_params': 2.1656265258789062}}}
2025-06-12 13:59:06,288 - nnopt.recipes.mobilenetv2_cifar10 - INFO - Loading JIT traced model from /home/pbeuran/repos/nnopt/models/mobilenetv2_cifar10/int8/qat_baseline/jit_trace.pt
2025-06-12 13:59:06,437 - nnopt.recipes.mobilenetv2_cifar10 - INFO - Successfully loaded JIT traced model f

Loading Baseline QAT INT8 model from version: mobilenetv2_cifar10/int8/qat_baseline
Exporting Baseline QAT INT8 model to /home/pbeuran/repos/nnopt/models/onnx_exports/mobilenetv2_cifar10_qat_int8.onnx...


2025-06-12 13:59:06,817 - nnopt.model.export - INFO - Model successfully exported to /home/pbeuran/repos/nnopt/models/onnx_exports/mobilenetv2_cifar10_qat_int8.onnx


Baseline QAT INT8 model exported successfully.
ONNX Runtime (QAT INT8) output shape: (1, 10)


### Evaluate Baseline QAT INT8 ONNX Model

In [30]:
# Ensure test_dataset is loaded

if success_qat_int8: # Only proceed if the ONNX model was exported successfully
    print("\n--- Evaluating Baseline QAT INT8 ONNX Model on CPU ---")
    onnx_metrics_qat_int8_cpu = eval_onnx_model(
        onnx_model_path=onnx_path_qat_int8,
        test_dataset=test_dataset,
        batch_size=32, 
        device="cpu", # QAT models are typically evaluated on CPU
        num_warmup_batches=2
    )
    print(f"CPU ONNX Metrics (QAT INT8): {onnx_metrics_qat_int8_cpu}")

    # Optional: Test QAT INT8 on GPU if supported and desired
    # Note: GPU support for INT8 ONNX models can be more complex and might require specific ONNX opset versions
    # or specific GPU capabilities and ONNX Runtime build options.
    if torch.cuda.is_available() and ort.get_device() == 'GPU':
        print("\n--- Evaluating Baseline QAT INT8 ONNX Model on GPU (Experimental) ---")
        try:
            onnx_metrics_qat_int8_gpu = eval_onnx_model(
                onnx_model_path=onnx_path_qat_int8,
                test_dataset=test_dataset,
                batch_size=32, 
                device="cuda",
                num_warmup_batches=2
            )
            print(f"GPU ONNX Metrics (QAT INT8): {onnx_metrics_qat_int8_gpu}")
        except Exception as e:
            print(f"Could not run QAT INT8 ONNX model on GPU: {e}")
            print("This might be due to operator support or other configuration issues.")
    else:
        print("\nSkipping GPU ONNX evaluation for QAT INT8 model as CUDA is not available or ONNX Runtime GPU provider is not set up.")
else:
    print("\nSkipping ONNX model evaluation for QAT INT8 model as the export failed.")

2025-06-12 13:59:06,868 - nnopt.model.eval - INFO - Starting ONNX model evaluation for: /home/pbeuran/repos/nnopt/models/onnx_exports/mobilenetv2_cifar10_qat_int8.onnx
2025-06-12 13:59:06,870 - nnopt.model.eval - INFO - Evaluation on PyTorch device: cpu, batch size: 32
2025-06-12 13:59:06,871 - nnopt.model.eval - INFO - Using ONNX Runtime providers: ['CPUExecutionProvider']
2025-06-12 13:59:06,902 - nnopt.model.eval - INFO - ONNX Model Input Name: input, Output Name: output
2025-06-12 13:59:06,903 - nnopt.model.eval - INFO - Starting warmup for 2 batches...



--- Evaluating Baseline QAT INT8 ONNX Model on CPU ---


[ONNX Warmup]: 100%|██████████| 2/2 [00:00<00:00,  7.54it/s]
2025-06-12 13:59:07,233 - nnopt.model.eval - INFO - Warmup complete.
2025-06-12 13:59:07,234 - nnopt.model.eval - INFO - Starting ONNX model evaluation pass...
[ONNX Evaluation]: 100%|██████████| 157/157 [00:08<00:00, 18.90it/s]

ONNX Evaluation Complete: Avg Loss: 0.6166, Accuracy: 0.7980
Throughput: 926.23 samples/sec | Avg Batch Time: 34.38 ms | Avg Sample Time: 1.08 ms
System Stats (PyTorch side): CPU Usage: 89.00% | RAM Usage: 9.0/30.9GB (38.0%)
CPU ONNX Metrics (QAT INT8): {'accuracy': 0.798, 'avg_loss': 0.6165598166465759, 'samples_per_second': 926.2332375917704, 'avg_time_per_batch': 0.03438349269430789, 'avg_time_per_sample': 0.0010796416706012678}

Skipping GPU ONNX evaluation for QAT INT8 model as CUDA is not available or ONNX Runtime GPU provider is not set up.





# 3. L1-structured Pruning Model Export & Inference

In [33]:
# Load the l1 structured pruned FP32 model 
print(f"Loading l1 structured pruned FP32 model from version: {pruned_fp32_version}")
pruned_fp32_model, _ = load_mobilenetv2_cifar10_model(
    version=pruned_fp32_version,
    mode="jit_trace"
)
pruned_fp32_model.eval()
pruned_fp32_model.to('cpu') # Move model to CPU for export

# Define ONNX path
onnx_path_pruned_fp32 = os.path.join(ONNX_EXPORT_DIR, "mobilenetv2_cifar10_pruned_fp32.onnx")

# Export to ONNX
print(f"Exporting pruned FP32 model to {onnx_path_pruned_fp32}...")
success_fp32 = export_model_to_onnx(
    model=pruned_fp32_model,
    dummy_input=dummy_input,
    onnx_path=onnx_path_pruned_fp32,
    dynamic_axes=dynamic_axes,
    opset_version=13
)

if success_fp32:
    print("Pruned FP32 model exported successfully.")
    # Run inference with ONNX Runtime
    try:
        ort_session_fp32 = ort.InferenceSession(onnx_path_pruned_fp32, providers=['CPUExecutionProvider'])
        input_name_fp32 = ort_session_fp32.get_inputs()[0].name
        output_name_fp32 = ort_session_fp32.get_outputs()[0].name
        
        ort_inputs_fp32 = {input_name_fp32: dummy_input.cpu().numpy()}
        ort_outputs_fp32 = ort_session_fp32.run([output_name_fp32], ort_inputs_fp32)
        print(f"ONNX Runtime (FP32 Pruned) output shape: {ort_outputs_fp32[0].shape}")
        # print(f"ONNX Runtime (FP32 Pruned) output sample: {ort_outputs_fp32[0][0,:5]}") # Print first 5 logits
    except Exception as e:
        print(f"Error running ONNX Runtime for FP32 pruned model: {e}")
else:
    print("Pruned FP32 model export failed.")

2025-06-12 14:00:08,666 - nnopt.recipes.mobilenetv2_cifar10 - INFO - Loading MobileNetV2 model for CIFAR-10 from version: mobilenetv2_cifar10/fp32/l1_struct_prune_0.1 at /home/pbeuran/repos/nnopt/models
2025-06-12 14:00:08,667 - nnopt.recipes.mobilenetv2_cifar10 - INFO - Loaded metadata: {'metrics_values': {'val_metrics': {'accuracy': 0.8706, 'avg_loss': 0.3692050254821777, 'samples_per_second': 7585.646960704619, 'avg_time_per_batch': 0.008343538734186928, 'avg_time_per_sample': 0.00013182791200015345, 'params_stats': {'int_weight_params': 0, 'float_weight_params': 1456640, 'float_bias_params': 10, 'bn_param_params': 27802, 'other_float_params': 0, 'total_params': 1484452, 'approx_memory_mb_for_params': 5.6627349853515625}}, 'test_metrics': {'accuracy': 0.8649, 'avg_loss': 0.3818217437744141, 'samples_per_second': 7764.16992856416, 'avg_time_per_batch': 0.00820361585358844, 'avg_time_per_sample': 0.00012879676890133852, 'params_stats': {'int_weight_params': 0, 'float_weight_params': 1

Loading l1 structured pruned FP32 model from version: mobilenetv2_cifar10/fp32/l1_struct_prune_0.1
Exporting pruned FP32 model to /home/pbeuran/repos/nnopt/models/onnx_exports/mobilenetv2_cifar10_pruned_fp32.onnx...


2025-06-12 14:00:08,980 - nnopt.model.export - INFO - Model successfully exported to /home/pbeuran/repos/nnopt/models/onnx_exports/mobilenetv2_cifar10_pruned_fp32.onnx


Pruned FP32 model exported successfully.
ONNX Runtime (FP32 Pruned) output shape: (1, 10)


### Evaluate Baseline FP32 ONNX Model

In [34]:
if success_fp32: # Only proceed if the ONNX model was exported successfully
    print("\n--- Evaluating Pruned FP32 ONNX Model on CPU ---")
    onnx_metrics_cpu = eval_onnx_model(
        onnx_model_path=onnx_path_pruned_fp32,
        test_dataset=test_dataset,
        batch_size=32, # Adjust as needed
        device="cpu",
        num_warmup_batches=2 # Smaller warmup for quicker testing
    )
    print(f"CPU ONNX Metrics: {onnx_metrics_cpu}")

    if torch.cuda.is_available() and ort.get_device() == 'GPU':
        print("\n--- Evaluating Pruned FP32 ONNX Model on GPU ---")
        onnx_metrics_gpu = eval_onnx_model(
            onnx_model_path=onnx_path_baseline_fp32,
            test_dataset=test_dataset,
            batch_size=32, # Adjust as needed
            device="cuda",
            num_warmup_batches=2
        )
        print(f"GPU ONNX Metrics: {onnx_metrics_gpu}")
    else:
        print("\nSkipping GPU ONNX evaluation as CUDA is not available or ONNX Runtime GPU provider is not set up.")
else:
    print("\nSkipping ONNX model evaluation as the export failed.")

2025-06-12 14:03:10,363 - nnopt.model.eval - INFO - Starting ONNX model evaluation for: /home/pbeuran/repos/nnopt/models/onnx_exports/mobilenetv2_cifar10_pruned_fp32.onnx
2025-06-12 14:03:10,364 - nnopt.model.eval - INFO - Evaluation on PyTorch device: cpu, batch size: 32
2025-06-12 14:03:10,364 - nnopt.model.eval - INFO - Using ONNX Runtime providers: ['CPUExecutionProvider']
2025-06-12 14:03:10,381 - nnopt.model.eval - INFO - ONNX Model Input Name: input, Output Name: output
2025-06-12 14:03:10,382 - nnopt.model.eval - INFO - Starting warmup for 2 batches...



--- Evaluating Pruned FP32 ONNX Model on CPU ---


[ONNX Warmup]: 100%|██████████| 2/2 [00:00<00:00,  3.49it/s]
2025-06-12 14:03:11,019 - nnopt.model.eval - INFO - Warmup complete.
2025-06-12 14:03:11,019 - nnopt.model.eval - INFO - Starting ONNX model evaluation pass...
[ONNX Evaluation]:  36%|███▋      | 57/157 [00:11<00:20,  4.90it/s]


KeyboardInterrupt: 

# Analysis

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
from nnopt.model.eval import eval_model # For PyTorch model evaluation
from nnopt.recipes.mobilenetv2_cifar10 import init_mobilenetv2_cifar10_model, get_cifar10_datasets # To load models and dataset

# Ensure test_dataset is loaded (it should be from earlier cells, e.g., cell d52e6d98)
# If not, uncomment and run:
# DATA_DIR = os.path.join(os.getcwd(), '..', 'data', 'image', 'cifar10')
# _, _, test_dataset = get_cifar10_datasets(data_dir=DATA_DIR) # Or however you load it

# --- 1. PyTorch FP32 Baseline Model Evaluation (CPU) ---
print("Evaluating PyTorch FP32 Baseline Model on CPU for analysis...")
pytorch_fp32_model, _ = init_mobilenetv2_cifar10_model(
    version=baseline_fp32_version, # Defined in cell f411f025
    device='cpu'
)
pytorch_fp32_model.eval()
pytorch_fp32_metrics_cpu = eval_model(
    model=pytorch_fp32_model,
    test_dataset=test_dataset,
    device="cpu",
    use_amp=False, # No AMP for CPU FP32
    dtype=torch.float32,
    batch_size=32, # Consistent batch size
    num_warmup_batches=2
)
print(f"PyTorch FP32 CPU Metrics: {pytorch_fp32_metrics_cpu}")

# --- 2. PyTorch QAT INT8 Baseline Model Evaluation (CPU) ---
print("\nEvaluating PyTorch QAT INT8 Baseline Model on CPU for analysis...")
# qat_int8_version is defined in cell f411f025
pytorch_qat_int8_model, _ = init_mobilenetv2_cifar10_model(
    version=qat_int8_version,
    device='cpu'
)
pytorch_qat_int8_model.eval()
pytorch_qat_int8_metrics_cpu = eval_model(
    model=pytorch_qat_int8_model,
    test_dataset=test_dataset,
    device="cpu",
    use_amp=False, # QAT models are typically run with FP32 interface, actual ops are INT8
    dtype=torch.float32,
    batch_size=32,
    num_warmup_batches=2
)
print(f"PyTorch QAT INT8 CPU Metrics: {pytorch_qat_int8_metrics_cpu}")

# --- 3. Retrieve ONNX Model Metrics (Assumed to be available from previous cells) ---
# Ensure 'onnx_metrics_cpu' and 'onnx_metrics_qat_int8_cpu' are populated from earlier cells
# These should contain keys like 'accuracy', 'avg_time_per_sample'
print(f"\nUsing pre-calculated ONNX FP32 CPU Metrics: {onnx_metrics_cpu}")
print(f"Using pre-calculated ONNX QAT INT8 CPU Metrics: {onnx_metrics_qat_int8_cpu}")


# --- 4. Model Sizes ---
# PyTorch model sizes (from parameters, does not include quantization overhead directly but reflects param precision)
# For a more direct comparison with ONNX file size, you could save the PyTorch models and get file size.
# Here, we use the parameter-based approximation for PyTorch models.
pytorch_fp32_size_mb = pytorch_fp32_metrics_cpu['params_stats']['total_params'] * pytorch_fp32_metrics_cpu['params_stats']['approx_memory_mb_for_params'] / pytorch_fp32_metrics_cpu['params_stats']['total_params'] if pytorch_fp32_metrics_cpu['params_stats']['total_params'] > 0 else 0
# For QAT INT8, the parameters are still stored in FP32 for training, but effective size is smaller.
# The 'params_stats' from eval_model for a quantized model might still reflect FP32 storage if not careful.
# A better measure for PyTorch quantized model size is to save it and check file size, or estimate based on INT8.
# For simplicity, we'll use the reported size from eval_model, but acknowledge it might be an overestimate for QAT.
# A more accurate way for PyTorch INT8 model size:
torch.save(pytorch_qat_int8_model.state_dict(), "temp_qat_int8_model.pth")
pytorch_qat_int8_size_mb = os.path.getsize("temp_qat_int8_model.pth") / (1024 * 1024)
os.remove("temp_qat_int8_model.pth")
print(f"PyTorch QAT INT8 Model Size (saved state_dict): {pytorch_qat_int8_size_mb:.2f} MB")


# ONNX model file sizes
# Ensure 'onnx_path_baseline_fp32' and 'onnx_path_qat_int8' are defined (cell f411f025 and 5b5fb4cd, 2170bb29)
onnx_fp32_size_mb = os.path.getsize(onnx_path_baseline_fp32) / (1024 * 1024) if os.path.exists(onnx_path_baseline_fp32) else 0
onnx_qat_int8_size_mb = os.path.getsize(onnx_path_qat_int8) / (1024 * 1024) if os.path.exists(onnx_path_qat_int8) else 0
print(f"ONNX FP32 Model Size: {onnx_fp32_size_mb:.2f} MB")
print(f"ONNX QAT INT8 Model Size: {onnx_qat_int8_size_mb:.2f} MB")


# --- 5. Prepare data for plotting ---
model_labels = [
    "PyTorch FP32",
    "PyTorch QAT INT8",
    "ONNX FP32",
    "ONNX QAT INT8"
]

# Using test accuracies
accuracies = [
    pytorch_fp32_metrics_cpu['accuracy'],
    pytorch_qat_int8_metrics_cpu['accuracy'],
    onnx_metrics_cpu['accuracy'] if 'onnx_metrics_cpu' in locals() and onnx_metrics_cpu else 0, # from cell d52e6d98
    onnx_metrics_qat_int8_cpu['accuracy'] if 'onnx_metrics_qat_int8_cpu' in locals() and onnx_metrics_qat_int8_cpu else 0 # from cell 8b57290a
]

# CPU inference time per sample (test set)
cpu_time_per_sample = [
    pytorch_fp32_metrics_cpu['avg_time_per_sample'],
    pytorch_qat_int8_metrics_cpu['avg_time_per_sample'],
    onnx_metrics_cpu['avg_time_per_sample'] if 'onnx_metrics_cpu' in locals() and onnx_metrics_cpu else float('inf'),
    onnx_metrics_qat_int8_cpu['avg_time_per_sample'] if 'onnx_metrics_qat_int8_cpu' in locals() and onnx_metrics_qat_int8_cpu else float('inf')
]

# Model sizes in MB
model_sizes_mb = [
    pytorch_fp32_size_mb,
    pytorch_qat_int8_size_mb, # Using saved state_dict size
    onnx_fp32_size_mb,
    onnx_qat_int8_size_mb
]

print("\nData for plotting:")
print(f"Labels: {model_labels}")
print(f"Accuracies: {accuracies}")
print(f"CPU Time/Sample (s): {cpu_time_per_sample}")
print(f"Model Sizes (MB): {model_sizes_mb}")

# Check if all ONNX metrics were loaded correctly
if not ('onnx_metrics_cpu' in locals() and onnx_metrics_cpu and \
        'onnx_metrics_qat_int8_cpu' in locals() and onnx_metrics_qat_int8_cpu):
    print("\nWARNING: ONNX metrics might not be fully loaded. Plots might be incomplete or show zero/infinity values.")
    print("Please ensure the cells evaluating ONNX models (d52e6d98, 8b57290a) have been run successfully.")


In [None]:
# Accuracy Bar Plot (Test Set on CPU)
x = np.arange(len(model_labels))
width = 0.5 # Single bar for test accuracy

fig, ax = plt.subplots(figsize=(10, 6))
rects = ax.bar(x, accuracies, width, label='Test Accuracy (CPU)')

ax.set_ylabel('Accuracy')
ax.set_title('Model Test Accuracy Comparison (CPU)')
ax.set_xticks(x)
ax.set_xticklabels(model_labels, rotation=45, ha="right")
ax.legend()
ax.grid(True, linestyle='--', alpha=0.7)
ax.set_ylim(min(accuracies) * 0.9 if min(accuracies) > 0 else 0, max(accuracies) * 1.1 if max(accuracies) > 0 else 1) # Adjust y-lim dynamically

def autolabel(rects_to_label, ax_to_use):
    """Attach a text label above each bar in *rects*, displaying its height."""
    for rect in rects_to_label:
        height = rect.get_height()
        ax_to_use.annotate(f'{height:.4f}',
                    xy=(rect.get_x() + rect.get_width() / 2, height),
                    xytext=(0, 3),  # 3 points vertical offset
                    textcoords="offset points",
                    ha='center', va='bottom', fontsize=9)

autolabel(rects, ax)

fig.tight_layout()
plt.show()

In [None]:
# CPU Inference Time Comparison Plot (Time per Sample)
fig, ax = plt.subplots(figsize=(10, 6))
# Convert times to milliseconds for better readability if they are very small
cpu_time_per_sample_ms = [t * 1000 for t in cpu_time_per_sample]
rects = ax.bar(x, cpu_time_per_sample_ms, width, label='CPU Time/Sample (ms)')

ax.set_ylabel('CPU Time/Sample (milliseconds)')
ax.set_title('Model Inference Time Comparison (CPU)')
ax.set_xticks(x)
ax.set_xticklabels(model_labels, rotation=45, ha="right")
ax.legend()
ax.grid(True, linestyle='--', alpha=0.7)
# ax.set_yscale('log') # Use log scale if times vary greatly

def autolabel_time(rects_to_label, ax_to_use):
    for rect in rects_to_label:
        height = rect.get_height()
        ax_to_use.annotate(f'{height:.3f}',
                    xy=(rect.get_x() + rect.get_width() / 2, height),
                    xytext=(0, 3),
                    textcoords="offset points",
                    ha='center', va='bottom', fontsize=9)

autolabel_time(rects, ax)

fig.tight_layout()
plt.show()

In [None]:
# Model Size Comparison Plot
fig, ax = plt.subplots(figsize=(10, 6))
rects = ax.bar(x, model_sizes_mb, width, label='Model Size (MB)')

ax.set_ylabel('Model Size (MB)')
ax.set_title('Model Size Comparison')
ax.set_xticks(x)
ax.set_xticklabels(model_labels, rotation=45, ha="right")
ax.legend()
ax.grid(True, linestyle='--', alpha=0.7)

def autolabel_size(rects_to_label, ax_to_use):
    for rect in rects_to_label:
        height = rect.get_height()
        ax_to_use.annotate(f'{height:.2f}',
                    xy=(rect.get_x() + rect.get_width() / 2, height),
                    xytext=(0, 3),
                    textcoords="offset points",
                    ha='center', va='bottom', fontsize=9)

autolabel_size(rects, ax)

fig.tight_layout()
plt.show()