In [None]:
import os
import torch
import time 
import tqdm
import numpy as np

In [None]:
model = torch.load('your_model.pth',weights_only=False)

In [None]:
for name, param in model.named_parameters():
    print(name, param.dtype)

In [None]:
from torch.ao.quantization import get_default_qconfig
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx

qconfig = get_default_qconfig("qnnpack")
qconfig_dict = {"": qconfig}

model.eval()
model.cpu()  # must be on CPU for quantization
example_inputs = (next(iter(val_loader))[0],) 
prepared_model = prepare_fx(model, qconfig_dict,example_inputs)

#calibration
with torch.inference_mode():
    for imgs, _ in val_loader:
        imgs = imgs.to('cpu')
        prepared_model(imgs) 

quantized_model = convert_fx(prepared_model)

In [None]:
for name, param in quantized_model.named_parameters():
    print(name, param.dtype)

In [None]:
def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

print_size_of_model(model)
print_size_of_model(quantized_model)

In [None]:
class Timer:
    """
    A simple timer utility for measuring elapsed time in milliseconds.

    Supports both GPU and CPU timing:
    - If CUDA is available, uses torch.cuda.Event for accurate GPU timing.
    - Otherwise, falls back to wall-clock CPU timing via time.time().

    Methods:
        start(): Start the timer.
        stop(): Stop the timer and return the elapsed time in milliseconds.
    """
    
    def __init__(self):
        self.use_cuda = torch.cuda.is_available()
        if self.use_cuda:
            self.starter = torch.cuda.Event(enable_timing=True)
            self.ender = torch.cuda.Event(enable_timing=True)

    def start(self):
        if self.use_cuda:
            self.starter.record()
        else:
            self.start_time = time.time()

    def stop(self):
        if self.use_cuda:
            self.ender.record()
            torch.cuda.synchronize()
            return self.starter.elapsed_time(self.ender)  # ms
        else:
            return (time.time() - self.start_time) * 1000  # ms

def estimate_latency(model, example_inputs, repetitions=10):
    """
    Returns avg and std inference latency (ms) over given runs.
    """
    
    timer = Timer()
    timings = np.zeros((repetitions, 1))

    # warm-up
    for _ in range(5):
        _ = model(example_inputs)

    with torch.no_grad():
        for rep in tqdm(range(repetitions), desc="Measuring latency"):
            timer.start()
            _ = model(example_inputs)
            elapsed = timer.stop()
            timings[rep] = elapsed

    return np.mean(timings), np.std(timings)

def estimate_latency_full(model, tag, skip_gpu):
    """
    Prints model latency on GPU and (optionally) CPU.
    """

    # estimate latency on CPU
    example_input = torch.rand(10, 3, 384, 384).cpu()
    model.cpu()
    latency_mu, latency_std = estimate_latency(model, example_input)
    print(f"Latency ({tag}, on CPU): {latency_mu:.2f} ± {latency_std:.2f} ms")

    # estimate latency on GPU
    if not skip_gpu and torch.cuda.is_available():
        example_input = torch.rand(128, 3, 32, 32).cuda()
        model.cuda()
        latency_mu, latency_std = estimate_latency(model, example_input)
        print(f"Latency ({tag}, on GPU): {latency_mu:.2f} ± {latency_std:.2f} ms")

In [None]:
estimate_latency_full(model, "NextViT", skip_gpu=True)
estimate_latency_full(quantized_model, "NextViT (Quantized)", skip_gpu=True)

In [None]:
# Export to ONNX
dummy_input = torch.randn(1, 3, 384, 384).to(device)

torch.onnx.export(
    model, 
    dummy_input,
    "quantized_model.onnx",         
    input_names=["input"], 
    output_names=["output"],
    dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
    opset_version=16,
    verbose=False,
)