In [1]:
import os  
import sys  

project_root = os.path.abspath("..")
src_path = os.path.join(project_root, "src")
 
if src_path not in sys.path:  
    sys.path.append(src_path)  

In [2]:
import os
import copy
import torch
from torch.utils.data import DataLoader
from train import train_loop, evaluate
from utils import show_metrics, set_random_seeds, save_model, load_model
from models import setup_model
from quantization import QuantizedModel, fuse_model
from datasets import CatCamDataset
from pruning import measure_global_sparsity, iterative_pruning_finetuning, remove_parameters

In [3]:
#init parameters
random_seed = 42
cuda_device = torch.device("cuda:0")
cpu_device = torch.device("cpu:0")

model_dir = "/content/saved_models"
model_name = "efficientnet_b0"
model_filename = model_name + ".pt"
pruned_model_filename = model_name + "_pruned.pt"
quantized_model_filename = model_name + "_quantized.pt"
onnx_converted_model_filename = model_name + ".onnx"

model_filepath = os.path.join(model_dir, model_filename)
pruned_model_filepath = os.path.join(model_dir, pruned_model_filename)
quantized_model_filepath = os.path.join(model_dir, quantized_model_filename)
onnx_converted_model_filepath = os.path.join(model_dir, onnx_converted_model_filename)

In [4]:
set_random_seeds(random_seed)

# Data preparation

In [None]:
root_dir = "../data"
batch_size = 8

train_aug_args = {
    "imgsz": 224,

    "random_scale_limit": (-0.5, -0.2),
    "random_scale_prob": 0.6,
    "pad_if_needed": True,
    "pad_position": "random",
    "pad_border_mode": "reflect",
    
    "hflip_prob": 0.5,
    "rotation_degrees": 10,
    "rotation_prob": 0.4,
    "perspective_scale": 0.1,
    "perspective_prob": 0.2,
    
    "brightness_jitter": 0.2,
    "contrast_jitter": 0.15,
    "saturation_jitter": 0.2,
    "hue_jitter": 0.03,
    "color_jitter_prob": 0.7,
    "grayscale_prob": 0.03,
    
    "gaussian_blur_kernel": (3, 5),
    "gaussian_blur_sigma": (0.3, 1.2), 
    "gaussian_blur_prob": 0.3,
    "gaussian_noise_var": (10.0, 50.0),
    "gaussian_noise_prob": 0.25,

    "motion_blur_kernel_size": (3, 5),
    "motion_blur_prob": 0.15,
    
    "jpeg_compression_quality": (40, 90),
    "jpeg_compression_prob": 0.3,
    
    "coarse_dropout_max_holes": 6,
    "coarse_dropout_max_height": 30,
    "coarse_dropout_max_width": 30,
    "coarse_dropout_prob": 0.15,

    "normalize_mean": [0.485, 0.456, 0.406],
    "normalize_std": [0.229, 0.224, 0.225]
}

val_aug_args = {
    "imgsz": 224,
    "normalize_mean": [0.485, 0.456, 0.406],
    "normalize_std": [0.229, 0.224, 0.225]
}


In [9]:
train_dataset = CatCamDataset(root_dir, train_aug_args)
val_dataset = CatCamDataset(root_dir, val_aug_args, mode="val")
test_dataset = CatCamDataset(root_dir, val_aug_args, mode="test")

In [None]:
train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    drop_last=False
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    drop_last=False,
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    drop_last=False
)

# Fine-tuning

In [None]:
#setup model
model = setup_model(model_name)

In [None]:
model.to(cuda_device)

In [None]:
#evaluation
val_metrics = evaluate(model, val_dataloader, cuda_device)
show_metrics(val_metrics)
sparsity = measure_global_sparsity(
    model,
    weight=True,
    conv2d_use_mask=True,
    linear_use_mask=True)
print(f"Sparsity: {sparsity}")

In [None]:
#train parameters
epochs = 1
batch_size = 32
l1_reg_strength = 0
l2_reg_strength = 1e-4
lr = 1e-3
lr_decay = 1
early_stopping_patience = 7

In [None]:
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(
            model.parameters(),
            lr=lr,
            weight_decay=l2_reg_strength
        )
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3)

In [None]:
fine_tuning_history = train_loop(
    model,
    train_dataloader,
    val_dataloader,
    criterion,
    optimizer,
    scheduler,
    cuda_device,
    epochs,
    early_stopping_patience
)

to visuzalize fine-tuning process in term of metrics & loss

In [None]:
#saving fine-tuned model
save_model(model, model_dir, model_filename=model_filename)

# Pruning

In [None]:
#prune parameters
prune_lr = 1e-3
prune_l1_reg_strength = 1e-3
prune_l2_reg_strength = 1e-3
conv2d_prune_amount = 0.7
linear_prune_amount = 0.7
num_iterations = 3
num_epochs_per_iteration =  5
grouped_pruning = True

In [None]:
#pruning
print("Iterative pruning + Fine-Tuning...")

pruned_model = copy.deepcopy(model).to(cuda_device)

prune_history = iterative_pruning_finetuning(
        model=pruned_model,
        train_dataloader=train_dataloader,
        val_dataloader=test_dataloader,
        device=cuda_device,
        learning_rate=prune_lr,
        l1_regularization_strength=prune_l1_reg_strength,
        l2_regularization_strength=prune_l2_reg_strength,
        conv2d_prune_amount=conv2d_prune_amount,
        linear_prune_amount=linear_prune_amount,
        num_iterations=num_iterations,
        num_epochs_per_iteration=num_epochs_per_iteration,
        model_dir=model_dir,
        grouped_pruning=grouped_pruning)

to visualize pruning process in term of metrics & loss

In [None]:
#evaluation after pruning
val_metrics = evaluate(pruned_model, val_dataloader, cuda_device)
show_metrics(val_metrics)
sparsity = measure_global_sparsity(
    pruned_model,
    weight=True,
    bias=False,
    conv2d_use_mask=True,
    linear_use_mask=True)#, conv2d_use_mask=True)
print(f"Sparsity: {sparsity}")

In [None]:
pruned_model = remove_parameters(pruned_model)

In [None]:
#evaluation after pruning
val_metrics = evaluate(pruned_model, val_dataloader, cuda_device)
show_metrics(val_metrics)
sparsity = measure_global_sparsity(
    pruned_model,
    weight=True,
    bias=False,
    conv2d_use_mask=True,
    linear_use_mask=True)#, conv2d_use_mask=True)
print(f"Sparsity: {sparsity}")

In [None]:
#saving pruned model
save_model(pruned_model, model_dir, model_filename=pruned_model_filename)

# Quantization

In [None]:
#quantization parameters
quantization_lr = 1e-3
quantization_wd = 4e-5

In [None]:
#fusing model
pruned_model.to(cpu_device)
fused_model = fuse_model(model, model_name)

pruned_model.eval()
fused_model.eval()

In [None]:
#model quantization (QAT)
quantized_model = QuantizedModel(model_fp32=fused_model.to(cuda_device))
quantization_config = torch.quantization.get_default_qconfig("qnnpack") #or fbgemm
quantized_model.qconfig = quantization_config
torch.quantization.prepare_qat(quantized_model, inplace=True)

#calibration
with torch.no_grad():
    for _ in range(15):
        calibr = torch.rand((8, 3, 224, 224))
        quantized_model(calibr.to(cuda_device))

quantized_model.to(cpu_device).train()
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(
        quantized_model.parameters(),
        lr=quantization_lr,
        weight_decay=quantization_wd
    )
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3)

print("Training QAT Model...")

quantized_model.train()
qat_history = train_loop(
    quantized_model,
    train_dataloader,
    val_dataloader,
    criterion,
    optimizer,
    scheduler,
    cpu_device,
    epochs,
    early_stopping_patience) #to visualize

In [None]:
quantized_model.eval().to(cpu_device)
dummy_input = torch.randn(1, 3, 224, 224)

with torch.no_grad():
    torch.onnx.export(
        quantized_model,
        dummy_input,
        onnx_converted_model_filepath,
        opset_version=13,
        input_names=["input"],
        output_names=["output"],
        dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}},
        do_constant_folding=True,
        training=torch.onnx.TrainingMode.EVAL,
        keep_initializers_as_inputs=True,
    )

In [None]:
#only after onnx-convertation
quantized_model = torch.quantization.convert(quantized_model, inplace=True)

In [None]:
#saving qat model
save_model(quantized_model, model_dir, model_filename=quantized_model_filename)

onnx2tf -i model_name.onnx -o model_tflite