In [1]:
import os
import torch
import torch.nn as nn
from tqdm.notebook import tqdm
from pathlib import Path
from torchvision import transforms
from torch.utils.data import DataLoader
import torchvision.datasets as datasets

from PlantVision import paths

In [2]:
# DATA PIPELINE
def get_dataloader(data_location: Path, img_size: int, batch_size: int=5, shuffle:bool=True):
    transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # Create a dataset
    dataset = datasets.ImageFolder(root=data_location, transform=transform)

    # For loading the datasets iteratively
    return DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=1,
        drop_last=True
    )

loader = get_dataloader(data_location=paths.DATA_DIR / "processed" / "train", img_size=224, batch_size=10)

In [3]:
img, lab = next(iter(loader))
img.shape

torch.Size([10, 3, 224, 224])

In [4]:
# MODEL DEFINITION
class Model(nn.Module):
    def __init__(self, num_classes: int):
        super(Model, self).__init__()

        # Feature Extraction
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=10, kernel_size=3, stride=1, padding=1) # --> output (10, 224, 224)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=10, out_channels=15, kernel_size=3, stride=1, padding=1) # --> output (15, 224, 224)
        self.relu2 = nn.ReLU()

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2) # --> output (15, 112, 112)

        # Classifier
        self.fc1 = nn.Linear(in_features=15*112*112, out_features=num_classes)

    def forward(self, x):
        x = self.relu1(self.conv1(x))
        x = self.relu2(self.conv2(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        return x

In [5]:
# Model Definition
model = Model(num_classes=38)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [6]:
n_epochs = 2

for epoch in range(n_epochs):
    model.train()
    running_loss = 0.0
    progress_bar = tqdm(loader, desc=f"Epoch {epoch + 1} / {n_epochs}")

    for imgs, labels in progress_bar:
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item())

    epoch_loss = running_loss / len(loader)

Epoch 1 / 2:   0%|          | 0/1084 [00:00<?, ?it/s]

Epoch 2 / 2:   0%|          | 0/1084 [00:00<?, ?it/s]

In [7]:
# Save Model Weights
torch.save(model.state_dict(), f"model.pth")

# Post-Training Dynamic Quantization

In [8]:
import torch.quantization

In [9]:
model_fp32 = Model(num_classes=38)
model_fp32.load_state_dict(torch.load("model.pth"))
model_fp32.eval()

Model(
  (conv1): Conv2d(3, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu1): ReLU()
  (conv2): Conv2d(10, 15, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu2): ReLU()
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=188160, out_features=38, bias=True)
)

In [10]:
# APPLYING DYNAMIC
"""
The weights are converted from float32 to int8,
Then during inference time, the activations are converted to int8 on the fly inside the layer just before the dot product.
The computation (matmul + bias add) is done in int8, then converted back to float32.
"""
# Specify the layers to quantize
model_int8_dynamic = torch.quantization.quantize_dynamic(
    model_fp32,
    {torch.nn.Linear}, # A set of layer types to quantize
    dtype=torch.qint8 # The target dtype
)

# Save quantized model
torch.save(model_int8_dynamic.state_dict(), f"model_int8_dynamic.pth")

# Compare Sizes
fp32_size = os.path.getsize('model.pth') / 1e6
int8_size = os.path.getsize('model_int8_dynamic.pth') / 1e6
print(fp32_size)
print(f"{int8_size}, a {100 - (int8_size/fp32_size) * 100:.2f}% reduction in size")

28.609853
7.160743, a 74.97% reduction in size


# Post-Training Static Quantization (ONNX)

In [11]:
import onnx
import time
import numpy as np
import onnxruntime
from onnxruntime.quantization import quantize_static, QuantType, CalibrationDataReader
from onnxruntime.quantization.shape_inference import quant_pre_process

class ONNXDataReader(CalibrationDataReader):
    def __init__(self, dataloader: DataLoader, onnx_input_name: str):
        """
        Initializes the data reader

        :param dataloader: A PyTorch DataLoader providing calibration data
        :param onnx_input_name: The name of the input node in the ONNX model
        """

        self.iterator = iter(dataloader)
        self.onnx_input_name = onnx_input_name

    def get_next(self):
        """
        Returns the next batch of data for calibration
        The quantizer will call this method repeatedly
        """
        try:
            images, _ = next(self.iterator)
            # The quantizer expects a dictionary mapping input names to numpy arrays
            return {self.onnx_input_name: images.numpy()}
        except StopIteration:
            # Returns None to signal the end of the calibration dataset
            return None

In [12]:
# Some helper functions
def print_onnx_model_size(model_path, label):
    """Prints the size of the ONNX model file in MB"""
    size = os.path.getsize(model_path) / 1e6
    print(f"Size of {label} ONNX model: {size:.2f} MB")

def benchmark_onnx_inference(model_path, label):
    """Runs a quick benchmark to measure ONNX model inference latency"""
    # Create an ONNX Runtime inference session
    session = onnxruntime.InferenceSession(model_path)
    input_name = session.get_inputs()[0].name

    # Use a typical image input size
    dummy_input = np.random.randn(1, 3, 224, 224).astype(np.float32)

    # Warm-up runs
    for _ in range(10):
        _ = session.run(None, {input_name: dummy_input})

    # Timed runs
    timings = []
    for _ in range(100):
        start_time = time.time()
        _ = session.run(None, {input_name: dummy_input})
        end_time = time.time()
        timings.append(end_time - start_time)

    avg_latency = np.mean(timings) * 1000
    print(f"Average latency for {label} ONNX model: {avg_latency:.3f} ms")

In [13]:
fp32_onnx_path = "temp_model_fp32.onnx"
preprocessed_onnx_path = "temp_model_preprocessed.onnx"
int8_onnx_path = "temp_model_int8.onnx"


# 1. ARRANGE: Get a trained FP32 model and calibration data
fp32_model = Model(num_classes=38)
fp32_model.load_state_dict(torch.load("model.pth"))
fp32_model.eval()

# calibration_dataset = datasets.FakeData(
#     size=200, image_size=(3, 224, 224), transform=transforms.ToTensor()
# )
# calibration_loader = DataLoader(calibration_dataset, batch_size=16)

calibration_loader = get_dataloader(data_location=paths.DATA_DIR / "processed" / "train", img_size=224, batch_size=10)


# 2. EXPORT: Convert the PyTorch model to FP32 ONNX format
# Create a dummy tensor, required by ONNX exporter to trace the model's computation graph
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(
    fp32_model,
    (dummy_input, ),
    fp32_onnx_path,
    opset_version=13,
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={
        'input': {0: 'batch_size'},
        'output': {0: 'batch_size'},
    }
)

# Call helper function
print_onnx_model_size(fp32_onnx_path, "Original FP32")

# 3. PRE-PROCESS: Run graph optimizations and shape inference
quant_pre_process(
    input_model_path=fp32_onnx_path,
    output_model_path=preprocessed_onnx_path
)

Size of Original FP32 ONNX model: 28.61 MB


In [14]:
# 4. QUANTIZE: Perform Static Quantization on the pre-processed model

# a. Create the Data Reader
# First, get the input name from the ONNX model graph
onnx_model = onnx.load(preprocessed_onnx_path)
input_name = onnx_model.graph.input[0].name

calibration_data_reader = ONNXDataReader(
    dataloader=calibration_loader,
    onnx_input_name=input_name
)

# b. Call the quantization function
quantize_static(
    model_input=preprocessed_onnx_path,
    model_output=int8_onnx_path,
    calibration_data_reader=calibration_data_reader,
    quant_format='QDQ', # Quantize-Dequantize format
    per_channel=True, # Use per-channel quantization for conv layers
    weight_type=QuantType.QUInt8,
    activation_type=QuantType.QUInt8,
)

print_onnx_model_size(int8_onnx_path, "INT8 (Static)")

Size of INT8 (Static) ONNX model: 7.16 MB


In [15]:
# Benchmark
benchmark_onnx_inference(fp32_onnx_path, "FP32")
benchmark_onnx_inference(int8_onnx_path, "INT8 (Static)")

Average latency for FP32 ONNX model: 3.226 ms
Average latency for INT8 (Static) ONNX model: 2.828 ms


In [16]:
# 6. CLEAN UP
print("\n--- 6. Cleaning up temporary files ---")
os.remove(fp32_onnx_path)
os.remove(preprocessed_onnx_path)
os.remove(int8_onnx_path)
print("Done.")


--- 6. Cleaning up temporary files ---
Done.
