In [None]:
import sys
sys.path.append("../")

import torch
from torchvision.transforms import Resize

import Quantization.explicit_quant as quantize

from Dataset.dataset import build_dataloader
from Utils.model_loading import *
from Utils.classification import evaluate, finetune

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using {device}")

# Initialize quantization
quantize.quant_initialize()

# Load model
model = load_vgg_custom_from_pruned("../Pretrained/VGG-Pruned/vgg16_bn_custom_ecg_ep30_i152_p0.8.pth", 0.8, torch.rand((1, 3, 152, 152)))
model.to(device)

image_size = 152
dataloader = build_dataloader(
    train_path = "Data/mitbih_mif_train_small.h5",
    test_path  = "Data/mitbih_mif_test.h5",
    batch_size = 32,
    transform  = Resize((image_size, image_size), antialias=None)
)

# Calibrate 
quantize.calibrate(model=model, dataloader=dataloader["train"], num_batches=100)
accuracy = evaluate(model, dataloader["test"])
print(f"Accuracy: {accuracy}")

# QAT
finetune(
    model      = model,
    epochs     = 10,
    dataloader = dataloader,
    save_path  = "Quantization/Test/quant_vgg_test.pth",
)

# ONNX export
quantize.export_to_onnx(model, image_size, 1, "Quantization/Test/quant_vgg_test.onnx")