In [None]:
import os
import numpy as np
import torch
import tensorflow as tf
from torch.utils.data import DataLoader
from sklearn.linear_model import LogisticRegression
from train import evaluate
from models import setup_model
from datasets import CatCamDataset
from utils import set_random_seeds, load_model, measure_pytorch_time, measure_tflite_time, calculate_metrics, show_metrics, inference_tflite_model

In [None]:
"""
!pip install onnx
!pip install onnx2tf
!pip install ai_edge_litert
!pip install sng4onnx
!pip install onnx_graphsurgeon
!pip install tflite-runtime

!pip uninstall protobuf
!pip install protobuf==4.25.3

!pip uninstall numpy -y
!pip install numpy==1.26.4

!pip uninstall tensorflow tensorflow-cpu tensorflow-gpu -y
!pip install tensorflow==2.15.0
"""

conversion onnx->tf using onnx2tf

In [None]:
random_state = 42
cpu_device = torch.device("cpu:0")

model_dir = "/content/saved_models"
model_name = "efficientnet_b0"
model_filename = model_name + ".pt"
pruned_model_filename = model_name + "_pruned.pt"
quantized_model_filename = model_name + "_quantized.pt"
tflite_model_filename = model_name + ".tflite"

original_model_filepath = os.path.join(model_dir, model_filename)
pruned_model_filepath = os.path.join(model_dir, pruned_model_filename)
quantized_model_filepath = os.path.join(model_dir, quantized_model_filename)
tflite_model_filepath = os.path.join(model_dir, tflite_model_filename)

In [None]:
root_dir = "/content/catcam"
batch_size = 1

#remove center crop??
val_aug_args = {
    "imgsz" : 224,
    "center_crop_size" : 224,
    "normalize_mean": [0.485, 0.456, 0.406],
    "normalize_std": [0.229, 0.224, 0.225]
}

val_dataset = CatCamDataset(root_dir, val_aug_args, mode="val")

val_dataloader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    drop_last=False,
)

test_dataset = CatCamDataset(root_dir, val_aug_args, mode="test")

test_dataloader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    drop_last=False,
)

In [None]:
set_random_seeds(random_state)

# Conversion to tflite

In [None]:
#change to val data
def representative_dataset():
    global val_dataloader
    for image, label in val_dataloader:
        yield [image.permute(0, 2, 3, 1).numpy().astype(np.float32)]

In [None]:
#model that was converted from onnx to tf using onnx2tf
converter = tf.lite.TFLiteConverter.from_saved_model("model_qat")

converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.allow_custom_ops = False

tflite_model = converter.convert()

In [None]:
#saving
with open(tflite_model_filepath, "wb") as f:
    f.write(tflite_model)

print("Model was succesfully converted to int8!")

# Calibration

In [None]:
calibration_model = LogisticRegression()
tflite_logits, labels = inference_tflite_model(tflite_model_filepath=tflite_model_filepath)
calibration_model.fit(tflite_logits, labels)

#calibration coefs
#logit -> sigmoid(a*logit+b)
a = calibration_model.coef_[0][0]
b = calibration_model.intercept_[0]

# Speed comparsion

In [None]:
#loading torch models
init_model = setup_model(model_name, pretrained=False)

original_model = load_model(init_model, original_model_filepath, cpu_device)
pruned_model = load_model(init_model, pruned_model_filepath, cpu_device)
quantized_model = load_model(init_model, quantized_model_filepath, cpu_device)

In [None]:
num_runs = 150

original_time = measure_pytorch_time(original_model, num_runs=num_runs)
pruned_time = measure_pytorch_time(pruned_model, num_runs=num_runs)
quantized_time = measure_pytorch_time(quantized_model, num_runs=num_runs)

tflite_time = measure_tflite_time(tflite_model, num_runs=num_runs)

In [None]:
print("Torchscript: ")
print(f"Original model: {original_time:6.2f}")
print(f"Pruned_model: {pruned_time:6.2f}")
print(f"QAT model: {quantized_time:6.2f}")

print("\nTensorflow Lite:")
print(f"TFLite model: {tflite_time:6.2f}")

print("=" * 45)
improvement = (original_time - tflite_time) / original_time * 100
print(f"Improvement over original: {improvement:.1f}%")

# Accuracy metrics comparsion

In [None]:
#torch test metric calculation
original_model_metrics = evaluate(original_model, test_dataloader, cpu_device)
pruned_model_metrics = evaluate(pruned_model, test_dataloader, cpu_device)
quantized_model_metrics = evaluate(quantized_model, test_dataloader, cpu_device)

In [None]:
tflite_logits, labels = inference_tflite_model(tflite_model_filepath, test_dataloader, a, b)
tflite_model_metrics = calculate_metrics(tflite_logits, labels)

In [None]:
#add calibration??
show_metrics(original_model, "Original model")
show_metrics(pruned_model_metrics, "Pruned model")
show_metrics(quantized_model_metrics, "Quantized model")
show_metrics(tflite_model_metrics, "TFLite model")