In [1]:
# Cell 1: Imports and Setup
# ==========================
import torch
import torchvision
# Import the specific quantization models module
import torchvision.models.quantization as models_quant
from torchvision.models import ResNet18_Weights # Use the modern weights API
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import os
import copy
import time
import numpy as np

print(f"PyTorch Version: {torch.__version__}")
print(f"Torchvision Version: {torchvision.__version__}")

# Check for CUDA availability (optional but good practice)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}") # Quantization primarily targets CPU

# Cell 2: Load Original FP32 Model (for weights and comparison)
# ==============================================================
def load_original_model():
    """Loads the pre-trained FP32 ResNet18 model."""
    weights = ResNet18_Weights.DEFAULT # Loads the best available weights (ImageNet V1)
    model = torchvision.models.resnet18(weights=weights)
    model.eval() # Set model to evaluation mode
    model.cpu() # Ensure it's on CPU
    print("Original FP32 ResNet18 model loaded and moved to CPU.")
    # Get the transformation pipeline associated with the weights
    preprocess = weights.transforms()
    print("Preprocessing transforms for model loaded.")
    return model, preprocess

# Load the original floating-point model
fp32_model, preprocess = load_original_model()

# Cell 3: Helper function for Model Size
# =======================================
def print_model_size(model, label=""):
    """Saves the model's state_dict temporarily and prints its size."""
    temp_file_path = f"{label}_temp_model_state.pth" # Use label in filename
    torch.save(model.state_dict(), temp_file_path)
    size_bytes = os.path.getsize(temp_file_path)
    size_mb = size_bytes / (1024 * 1024)
    print(f"{label} Model size: {size_mb:.2f} MB")
    os.remove(temp_file_path)
    return size_mb

# Check the size of the original FP32 model
print("\n--- Checking Initial Model Size ---")
fp32_model_size = print_model_size(fp32_model, "FP32_Original")


# Cell 4: Prepare Quantization-Aware Model Architecture
# ======================================================
print("\n--- Preparing Quantization-Aware Model ---")

# Create an instance of the quantization-aware ResNet18 architecture.
# `quantize=False` initially means it's FP32 but has the structure (stubs, etc.) for quantization.
# We don't pass weights here; we'll load them from our original fp32_model.
model_to_quantize = models_quant.resnet18(weights=None, quantize=False) # ***MODIFIED***

# Load the state dictionary from the original FP32 model into the quantization-aware architecture.
# This transfers the learned weights. `strict=True` is default and should work.
model_to_quantize.load_state_dict(fp32_model.state_dict()) # ***MODIFIED***
model_to_quantize.eval() # Set to evaluation mode
model_to_quantize.cpu() # Ensure it's on CPU

print("Quantization-aware ResNet18 architecture created.")
print("Loaded weights from the original FP32 model.")


# Cell 5: Configure Quantization Backend and QConfig
# ==================================================
print("\n--- Configuring Quantization ---")

q_backend = "none"
# Check for supported engines (prefer fbgemm for x86)
if 'fbgemm' in torch.backends.quantized.supported_engines:
    q_backend = "fbgemm"
elif 'qnnpack' in torch.backends.quantized.supported_engines:
    q_backend = "qnnpack"
else:
    print("Warning: Neither 'fbgemm' nor 'qnnpack' supported. Static quantization might not work well.")

qconfig = None
if q_backend != "none":
    try:
        qconfig = torch.quantization.get_default_qconfig(q_backend)
        torch.backends.quantized.engine = q_backend
        print(f"Quantization backend set to: {q_backend}")

        # Apply the qconfig to the quantization-aware model instance
        # Note: Fusion might be applied later, but qconfig needs setting first.
        model_to_quantize.qconfig = qconfig # ***MODIFIED*** (applied to model_to_quantize)
        print("Quantization configuration applied to the model.")

    except Exception as e:
        print(f"Error setting up quantization backend {q_backend}: {e}")
        qconfig = None
else:
    print("Skipping quantization configuration due to lack of supported backend.")


# Cell 6: Fuse Modules (Important for Quantization Performance)
# =============================================================
# Even though models_quant.resnet18 defines potential fusions,
# we typically still need to explicitly call fuse_modules after applying qconfig.
# This modifies the model_to_quantize in place.
print("\n--- Fusing Modules ---")
if qconfig:
    model_to_quantize.eval() # Fuse expects eval mode
    # The quantization-aware models often have internal flags or methods,
    # but explicit fusion is standard practice before prepare.
    # Let's attempt fusing common patterns like Conv-BN-ReLU
    try:
        # fuse_modules typically modifies the model in-place
        # We fuse the model_to_quantize *before* preparing it
        torch.quantization.fuse_modules_qat(model_to_quantize, inplace=True) # Use QAT version for safety even in PTQ workflow with standard layers
        # Or use fuse_modules if specifically targetting PTQ patterns and fuse_modules_qat causes issues
        # torch.quantization.fuse_modules(model_to_quantize, [['conv1', 'bn1', 'relu']], inplace=True) # Example specific fusion
        print("Attempted module fusion on the model.")
        # NOTE: The exact fusion list might need adjustment depending on the model
        # and desired fusion patterns. models_quant.resnet18 is designed for common fusions.
    except Exception as e:
        print(f"Could not fuse modules (might be ok if model doesn't have standard patterns): {e}")
else:
    print("Skipping fusion because qconfig was not set.")


# Cell 7: Prepare Calibration Data (No Changes Needed Here)
# =========================================================
print("\n--- Preparing Calibration Data ---")

# Use CIFAR10 dataset for calibration images
calibration_transform = preprocess # Use transforms from original model

data_dir = './data'
if not os.path.exists(data_dir):
    os.makedirs(data_dir)
    print(f"Created directory: {data_dir}")

calibration_loader = None
try:
    calibration_dataset_full = datasets.CIFAR10(
        root=data_dir, train=True, download=True, transform=calibration_transform
    )
    num_calibration_images = 500 # Use a subset
    calibration_subset_indices = list(range(num_calibration_images))
    calibration_dataset = torch.utils.data.Subset(calibration_dataset_full, calibration_subset_indices)
    calibration_loader = torch.utils.data.DataLoader(
        calibration_dataset, batch_size=32, shuffle=False, num_workers=2
    )
    print(f"Using {len(calibration_dataset)} images from CIFAR10 for calibration.")
    print(f"Calibration DataLoader created with batch size {calibration_loader.batch_size}.")
    images, _ = next(iter(calibration_loader))
    print(f"Sample batch tensor shape: {images.shape}, dtype: {images.dtype}")

except Exception as e:
    print(f"\nError loading or processing calibration data: {e}")
    calibration_loader = None


# Cell 8: Prepare Model for Static Quantization (PTQ)
# ===================================================
print("\n--- Preparing Model for Static Quantization ---")

prepared_model = None # Initialize
prepared_model_ready = False
if qconfig and calibration_loader and model_to_quantize: # Check model_to_quantize existence
    model_to_quantize.cpu().eval() # Ensure CPU and eval mode
    # Use torch.quantization.prepare for Post-Training Quantization
    # This inserts observers based on the qconfig. Operates inplace.
    torch.quantization.prepare(model_to_quantize, inplace=True) # ***MODIFIED*** (applied to model_to_quantize)
    prepared_model = model_to_quantize # Assign for clarity
    print("Model prepared for static quantization (observers inserted).")
    prepared_model_ready = True
else:
    print("Skipping model preparation: Check qconfig, calibration data, and model definition.")


# Cell 9: Calibrate the Model
# ===========================
print("\n--- Calibrating the Model ---")

calibration_done = False
if prepared_model_ready and prepared_model:
    print("Running calibration data through the prepared model...")
    prepared_model.cpu().eval() # Ensure CPU and eval mode

    with torch.no_grad():
        for i, (images, _) in enumerate(calibration_loader):
            images_cpu = images.to('cpu')
            prepared_model(images_cpu) # Feed data to the *prepared* model
            print(f"  Calibration batch {i+1}/{len(calibration_loader)} processed.", end='\r')

    print("\nCalibration finished. Activation statistics collected by observers.")
    calibration_done = True
else:
    print("Skipping calibration step because model was not prepared successfully.")


# Cell 10: Convert Model to Quantized INT8
# =======================================
print("\n--- Converting the Model to Quantized INT8 ---")

int8_model = None # Initialize variable
conversion_done = False

if calibration_done and prepared_model:
    prepared_model.cpu().eval() # Ensure CPU and eval mode before conversion
    try:
        # Convert the calibrated model. Operates inplace by default.
        # Assign to int8_model for clarity, although it modifies prepared_model.
        torch.quantization.convert(prepared_model, inplace=True) # ***MODIFIED*** (applied to prepared_model)
        int8_model = prepared_model # prepared_model is now the converted int8 model
        print("Model successfully converted to INT8 quantized format.")
        conversion_done = True
        # Optional: Print the structure
        # print("\nStructure of the INT8 Model:")
        # print(int8_model)
    except Exception as e:
        print(f"Error during model conversion: {e}")
        conversion_done = False
else:
    print("Skipping conversion because calibration was not completed successfully or prepared model missing.")


# Cell 11: Compare Model Sizes
# ============================
print("\n--- Comparing Model Sizes ---")

if conversion_done and int8_model is not None:
    print("Original FP32 Model:")
    # Use the size calculated earlier or recalculate
    fp32_model_size_check = print_model_size(fp32_model, "FP32_Original") # Use original fp32 model

    print("\nQuantized INT8 Model:")
    int8_model_size = print_model_size(int8_model, "INT8") # Use the converted int8_model

    if int8_model_size > 0 and fp32_model_size_check > 0:
      size_reduction = fp32_model_size_check / int8_model_size
      print(f"\nSize reduction factor: {size_reduction:.2f}x")
      print(f"Model size reduced from {fp32_model_size_check:.2f} MB to {int8_model_size:.2f} MB.")
    else:
      print("\nCould not calculate size reduction (one or both model sizes are zero or invalid).")
else:
    print("Skipping size comparison because conversion step failed or was skipped.")


# Cell 12: Compare Inference Speed (CPU) - UPDATED
# ================================================
print("\n--- Comparing Inference Speed (CPU) ---")

# --- Explicitly set quantization backend --- (Should match what was used for qconfig)
try:
    current_backend = torch.backends.quantized.engine
    print(f"Using quantization backend: {current_backend}")
except Exception as e:
    print(f"Warning: Could not verify quantization backend. Error: {e}")
    current_backend = q_backend # Fallback to the intended backend
    print(f"Assuming backend is: {current_backend}")


if conversion_done and int8_model is not None:
    # Ensure both models are on CPU and in eval mode
    fp32_model.cpu().eval() # Original FP32 model
    int8_model.cpu().eval() # Converted INT8 model

    # Create a sample input tensor (using one batch from the calibration loader)
    try:
        # Re-create iterator in case it was exhausted
        calib_iter = iter(calibration_loader)
        sample_input, _ = next(calib_iter)
        sample_input_cpu = sample_input.to('cpu')
        print(f"Using sample input batch of shape: {sample_input_cpu.shape} on CPU for timing.")

        # Helper function to time inference runs accurately
        def time_model_inference(model, input_tensor, num_runs=50, warm_up=10):
            model.eval()
            model.to('cpu')
            input_tensor = input_tensor.to('cpu')
            times = []

            with torch.no_grad():
                # Warm-up runs
                print(f"  Performing {warm_up} warm-up runs...")
                for _ in range(warm_up):
                    _ = model(input_tensor)

                # Timed runs
                print(f"  Performing {num_runs} timed runs...")
                for _ in range(num_runs):
                    start_time = time.time()
                    _ = model(input_tensor)
                    end_time = time.time()
                    times.append((end_time - start_time) * 1000) # Store time in milliseconds

            avg_time_ms = np.mean(times)
            std_dev_ms = np.std(times)
            print(f"  Avg time: {avg_time_ms:.3f} ms, Std Dev: {std_dev_ms:.3f} ms")
            return avg_time_ms

        # --- Time FP32 model inference ---
        print("\nTiming Original FP32 model inference...")
        fp32_avg_time = time_model_inference(fp32_model, sample_input_cpu)
        print(f"Average FP32 inference time: {fp32_avg_time:.3f} ms per batch")

        # --- Time INT8 model inference ---
        print("\nTiming INT8 model inference...")
        # Ensure PyTorch threading is set for optimal performance (often helps INT8)
        # torch.set_num_threads(1) # Optional: test single-thread performance
        int8_avg_time = time_model_inference(int8_model, sample_input_cpu)
        # torch.set_num_threads(torch.get_num_threads()) # Reset if changed
        print(f"Average INT8 inference time: {int8_avg_time:.3f} ms per batch")

        # --- Calculate and print speedup ---
        if int8_avg_time > 0:
            speedup_factor = fp32_avg_time / int8_avg_time
            print(f"\nInference speedup factor (INT8 vs FP32 on CPU): {speedup_factor:.2f}x")
        else:
            print("\nCould not calculate speedup factor (INT8 average time was zero or invalid).")

    except StopIteration:
        print("\nError: Could not get a batch from calibration_loader. Was it exhausted?")
        # Try re-initializing the iterator if needed
        # (Add code here to re-create calibration_loader if necessary)

    except RuntimeError as e_runtime:
        # This is the error we were trying to fix. If it still occurs,
        # there might be deeper issues (backend install, unsupported ops).
        print(f"\nRuntimeError during inference timing: {e_runtime}")
        print("This might indicate backend incompatibility or missing kernels even with stubs.")
        print(f"Verify backend '{current_backend}' support in your PyTorch installation.")
        print("Check if all operations in the model are supported for quantization with this backend.")

    except Exception as e:
        print(f"\nAn unexpected error occurred during inference timing: {e}")

else:
    print("Skipping inference speed comparison because conversion step failed or was skipped.")

PyTorch Version: 2.6.0+cu124
Torchvision Version: 0.21.0+cu124
Using device: cuda
Original FP32 ResNet18 model loaded and moved to CPU.
Preprocessing transforms for model loaded.

--- Checking Initial Model Size ---
FP32_Original Model size: 44.67 MB

--- Preparing Quantization-Aware Model ---
Quantization-aware ResNet18 architecture created.
Loaded weights from the original FP32 model.

--- Configuring Quantization ---
Quantization backend set to: fbgemm
Quantization configuration applied to the model.

--- Fusing Modules ---
Could not fuse modules (might be ok if model doesn't have standard patterns): module 'torch.quantization' has no attribute 'fuse_modules_qat'

--- Preparing Calibration Data ---
Using 500 images from CIFAR10 for calibration.
Calibration DataLoader created with batch size 32.
Sample batch tensor shape: torch.Size([32, 3, 224, 224]), dtype: torch.float32

--- Preparing Model for Static Quantization ---
Model prepared for static quantization (observers inserted).

-



  Calibration batch 16/16 processed.
Calibration finished. Activation statistics collected by observers.

--- Converting the Model to Quantized INT8 ---
Model successfully converted to INT8 quantized format.

--- Comparing Model Sizes ---
Original FP32 Model:
FP32_Original Model size: 44.67 MB

Quantized INT8 Model:
INT8 Model size: 11.38 MB

Size reduction factor: 3.92x
Model size reduced from 44.67 MB to 11.38 MB.

--- Comparing Inference Speed (CPU) ---
Using quantization backend: fbgemm
Using sample input batch of shape: torch.Size([32, 3, 224, 224]) on CPU for timing.

Timing Original FP32 model inference...
  Performing 10 warm-up runs...
  Performing 50 timed runs...
  Avg time: 6754.067 ms, Std Dev: 9240.614 ms
Average FP32 inference time: 6754.067 ms per batch

Timing INT8 model inference...
  Performing 10 warm-up runs...
  Performing 50 timed runs...
  Avg time: 2569.394 ms, Std Dev: 239.379 ms
Average INT8 inference time: 2569.394 ms per batch

Inference speedup factor (INT