In [None]:
import torch
from torch.utils.data import DataLoader
import numpy as np
import copy
import os
from train import train_loop, evaluate
from utils import fuse_model, show_metrics, set_random_seeds, save_model, load_model,  
from models import setup_model, QuantizedModel
from datasets import create_url_df, CatCamDataset
from pruning import measure_global_sparsity, iterative_pruning_finetuning, remove_parameters
import onnx
from onnx_tf.backend import prepare
import tensorflow as tf

In [None]:
#init parameters
random_seed = 42
cuda_device = torch.device("cuda:0")
cpu_device = torch.device("cpu:0")

model_dir = "saved_models"
model_name = "mobilenet_v3_small"
model_filename = model_name + ".pt"
model_filename_prefix = "pruned_model"
pruned_model_filename = model_name + "_pruned.pt"
quantized_model_filename = model_name + "quantized.pt"
model_filepath = os.path.join(model_dir, model_filename)
pruned_model_filepath = os.path.join(model_dir, pruned_model_filename)
quantized_model_filepath = os.path.join(model_dir, quantized_model_filename)

root_dir = "/catcam" #data

In [None]:
#train parameters
epochs = 20
batch_size = 32
l1_regularization_strength = 0
l2_regularization_strength = 1e-4
learning_rate = 1e-3
learning_rate_decay = 1
early_stopping_patience = 7

train_aug_args = {
    "imgsz": 224,
    "hflip_prob": 0.5,
    "rotation_degrees": 15,
    "perspective_scale": 0.2,
    "perspective_prob": 0.3,
    "brightness_jitter": 0.15,
    "contrast_jitter": 0.1,
    "saturation_jitter": 0.2,
    "hue_jitter": 0.02,
    "grayscale_prob": 0.05,
    "gaussian_blur_kernel": (3, 3),
    "gaussian_blur_sigma": (0.1, 0.5),
    "random_erase_prob": 0.2,
    "random_erase_scale": (0.02, 0.08),
    "normalize_mean": [0.485, 0.456, 0.406],
    "normalize_std": [0.229, 0.224, 0.225]
},

val_aug_args = {
    "imgsz" : 224,
    "center_crop_size" : 224,
    "normalize_mean": [0.485, 0.456, 0.406],
    "normalize_std": [0.229, 0.224, 0.225]
}

In [None]:
set_random_seeds(random_seed)

In [None]:
#prepairing data
cat_urls_path = os.path.join(root_dir, "cat_urls.txt")
no_cat_urls_path = os.path.join(root_dir, "no_cat_urls.txt")
url_df = create_url_df(cat_urls_path, no_cat_urls_path)

train, val, test = np.split(url_df.sample(frac=1), [int(.6*len(url_df)), int(.8*len(url_df))])

train_dataset = CatCamDataset(train, root_dir, train_aug_args)
val_dataset = CatCamDataset(val, root_dir, val_aug_args)
test_dataset = CatCamDataset(test, root_dir, val_aug_args)

In [None]:
train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    drop_last=False
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    drop_last=False
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    drop_last=False
)

In [None]:
#setup model
model = setup_model("mobilenet_v3_small")

In [None]:
#evaluation
test_metrics = evaluate(model, test_dataloader, cuda_device)
show_metrics(test_metrics)
sparsity = measure_global_sparsity(model, conv2d_use_mask=True)
print(f"Sparsity: {sparsity}")

In [None]:
#pruning
print("Iterative pruning + Fine-Tuning...")

pruned_model = copy.deepcopy(model)

iterative_pruning_finetuning(
        model=pruned_model,
        train_loader=train_dataloader,
        test_loader=test_dataloader,
        device=cuda_device,
        learning_rate=learning_rate,
        learning_rate_decay=learning_rate_decay,
        l1_regularization_strength=l1_regularization_strength,
        l2_regularization_strength=l2_regularization_strength,
        conv2d_prune_amount=0.98,
        linear_prune_amount=0,
        num_iterations=1,
        num_epochs_per_iteration=500,
        model_filename_prefix=model_filename_prefix,
        model_dir=model_dir,
        grouped_pruning=True)

remove_parameters(model=pruned_model)

In [None]:
#evaluation after pruning
test_metrics = evaluate(model, test_dataloader, cuda_device)
show_metrics(test_metrics)
sparsity = measure_global_sparsity(model, conv2d_use_mask=True)
print(f"Sparsity: {sparsity}")

In [None]:
#saving pruned model
save_model(model, model_dir, model_filename=pruned_model_filename)
model = load_model(model, model_filepath, cuda_device)

In [None]:
#fusing model
model.to(cpu_device)
fused_model = fuse_model(model, model_name)
fused_model = copy.deepcopy(model)

model.train()
fused_model.train()

model.eval()
fused_model.eval()

In [None]:
#model quantization (QAT)
quantized_model = QuantizedModel(model_fp32=fused_model)
quantization_config = torch.quantization.get_default_qconfig("qnnpack")
quantized_model.qconfig = quantization_config
torch.quantization.prepare_qat(quantized_model, inplace=True)

criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(
        model.parameters(),
        lr=learning_rate,
        weight_decay=l2_regularization_strength
    )
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3)

print("Training QAT Model...")

quantized_model.train()
qat_history = train_loop(model, train_dataloader, val_dataloader, criterion, optimizer, scheduler, cuda_device, epochs, early_stopping_patience) #to visualize
quantized_model.to(cpu_device)

quantized_model = torch.quantization.convert(quantized_model, inplace=True)

In [None]:
#converting to onnx
quantized_model.eval()

dummy_input = torch.rand(1, 3, 224, 224)

#saving
torch.onnx.export(
    model,
    dummy_input,
    quantized_model_filepath,
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}},
    opset_version=12,
)

In [None]:
#loading and converting to tf
onnx_model = onnx.load("efficientnet_b0.onnx")

tf_rep = prepare(onnx_model)

tf_rep.export_graph("efficientnet_b0_savedmodel")

#converting to tflite
converter = tf.lite.TFLiteConverter.from_saved_model("efficientnet_b0_savedmodel")

#converter.optimizations = [tf.lite.Optimize.DEFAULT]
#converter.target_spec.supported_types = [tf.float16]

tflite_model = converter.convert()
#saving
with open("efficientnet_b0_quant.tflite", "wb") as f:
    f.write(tflite_model)

In [None]:


#load tflite model
#evaluate
#calculate & show sparsity, inference latency (fp32, int8 on cpu and cuda)

#тут торчскрипт но нам нужен tflite (onnx?)
"""save_torchscript_model(model=quantized_model, model_dir=model_dir, model_filename=quantized_model_filename)

quantized_jit_model = load_torchscript_model(model_filepath=quantized_model_filepath, device=cpu_device)

test_metrics = evaluate(model, test_dataloader, cuda_device)
show_metrics(test_metrics)
sparsity = measure_global_sparsity(model, conv2d_use_mask=True)
print(f"Sparsity: {sparsity}")

fp32_cpu_inference_latency = measure_inference_latency(model=model, device=cpu_device, input_size=(1,3,32,32), num_samples=100)
int8_cpu_inference_latency = measure_inference_latency(model=quantized_model, device=cpu_device, input_size=(1,3,32,32), num_samples=100)
int8_jit_cpu_inference_latency = measure_inference_latency(model=quantized_jit_model, device=cpu_device, input_size=(1,3,32,32), num_samples=100)
fp32_gpu_inference_latency = measure_inference_latency(model=model, device=cuda_device, input_size=(1,3,32,32), num_samples=100)

print("FP32 CPU Inference Latency: {:.2f} ms / sample".format(fp32_cpu_inference_latency * 1000))
print("FP32 CUDA Inference Latency: {:.2f} ms / sample".format(fp32_gpu_inference_latency * 1000))
print("INT8 CPU Inference Latency: {:.2f} ms / sample".format(int8_cpu_inference_latency * 1000))
print("INT8 JIT CPU Inference Latency: {:.2f} ms / sample".format(int8_jit_cpu_inference_latency * 1000))"""

In [None]:
"""
Steps to QAT:
pretrained model -> transfering to cpu ->
-> fused model -> some prepairings (???) -> (pseudo) quntazied model ->
-> some prepairings -> training quantized model -> transfering quantized model to cpu ->
-> convertartion (quantization) -> saving model

Steps to pruning:
Just prune during the training basing on rule

How to combine them?
We should use pruning while training (pseudo) quantized model
"""

'\nSteps to QAT:\n'