In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# # This is to copy the files from Drive to the SSD for runtime to make training quicker.
!pwd
!mkdir -p /content/cs7643_model_quantization/data
!cp -r "/content/drive/MyDrive/cs7643_model_quantization/data/gtFine_trainId" "/content/cs7643_model_quantization/data"
!cp -r "/content/drive/MyDrive/cs7643_model_quantization/data/gtFine_trainIdColorized" "/content/cs7643_model_quantization/data"
!cp -r "/content/drive/MyDrive/cs7643_model_quantization/data/leftImg8bit_trainvaltest.zip" "/content/cs7643_model_quantization/data"
!unzip -o /content/cs7643_model_quantization/data/leftImg8bit_trainvaltest.zip -d /content/cs7643_model_quantization/data/leftImg8bit_trainvaltest
!rsync -av --exclude='data' --exclude='.*' --exclude='testing.ipynb' --exclude='myvenv' /content/drive/MyDrive/cs7643_model_quantization/ /content/cs7643_model_quantization/


In [None]:
%cd /content/cs7643_model_quantization/
!pwd

In [None]:
!python -m pipeline.fine_tuning

In [None]:
!python -m src.quantization.qat

In [None]:
import os
import glob
import yaml
from PIL import Image
from tqdm import tqdm

import numpy as np
import itertools
import json
import copy

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.ao.quantization as tq
import torch.ao.quantization.quantize_fx as quantize_fx
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from torchvision.transforms import functional as F

from src.models.deeplabv3_mnv3 import get_empty_model, load_model
from src.quantization.quantization_utils import set_seed, build_qconfig
from pipeline.create_dataset import cityScapesDataset
from pipeline.metrics import calculate_miou


def plot_loss(train_losses, val_losses):
    fig, ax = plt.subplots()
    epochs = range(1, len(train_losses)+1)
    ax.plot(epochs, train_losses, marker='o', label='train')
    ax.plot(epochs, val_losses, marker='o', label='val')
    ax.set_xlabel('Epochs')
    ax.set_ylabel('Loss')
    ax.legend(loc='upper right')
    ax.set_title('Training and Validation Loss')
    ax.grid()
    return fig, ax

def plot_miou(mious):
    fig, ax = plt.subplots()
    epochs = range(1, len(mious)+1)
    ax.plot(epochs, mious, marker='o')
    ax.set_xlabel('Epochs')
    ax.set_ylabel('mIOU')
    ax.set_title('Validation mIOU')
    ax.grid()
    return fig, ax

def run_qat(idx, config):
    set_seed()
    print("--- Running QAT Script ---")
    print("Loading Configuration ...")

    device = 'mps' if torch.backends.mps.is_available() else ('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using Device: {device}")

    print(f"Loading Model from checkpoint: {config['model_checkpoint']} ...")
    # Get empty model and load checkpoint weights.
    model = get_empty_model()
    model = load_model(model, config["model_checkpoint"], device=device)
    model.eval()
    print(f"Baseline Model Size (MB): {os.path.getsize(config['model_checkpoint']) / 1e6:.2f}")

    # Get Quantization Configuration
    print(f"Building QConfig with mode: {config['mode']}...")
    qconfig_mapping = build_qconfig("qat", config)
    print(f"qconfig_mapping global config: {qconfig_mapping.global_qconfig}")
    print(f"qconfig_mapping weight config: {qconfig_mapping.global_qconfig.weight}")
    print(f"qconfig_mapping activation config: {qconfig_mapping.global_qconfig.activation}")

    train_img_path = "data/leftImg8bit_trainvaltest/leftImg8bit/train"
    train_label_path = "data/gtFine_trainId/gtFine/train"
    val_img_path = "data/leftImg8bit_trainvaltest/leftImg8bit/val"
    val_label_path = "data/gtFine_trainId/gtFine/val"

    cal_dataset = cityScapesDataset(train_img_path, train_label_path, config['training']['train_transforms'])
    train_dataset = cityScapesDataset(train_img_path, train_label_path, config['training']['train_transforms'])
    val_dataset = cityScapesDataset(val_img_path, val_label_path, config['training']['val_transforms'])

    cal_dataloader = DataLoader(cal_dataset, batch_size=2, shuffle=True)
    train_dataloader = DataLoader(train_dataset, batch_size=config['training']['batch_size'], shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=config['training']['batch_size'], shuffle=False)

    example_image, _ = next(iter(cal_dataloader))
    example_image = example_image.to(device)
    prepared_model = quantize_fx.prepare_qat_fx(model, qconfig_mapping, (example_image,))
    prepared_model = prepared_model.to(device)
    prepared_model.eval()

    loss_function = nn.CrossEntropyLoss(ignore_index=255)
    optimizer = optim.Adam(prepared_model.parameters(), lr=float(config['training']['learning_rate']),
                                                        weight_decay=float(config['training']['weight_decay']))

    print("Start Calibration...")
    if config['calibration']['enabled']:
        with torch.no_grad():
            for i, (image, _) in enumerate(cal_dataloader):
                image = image.to(device, non_blocking=True)

                prepared_model(image)
                if (i % 10 == 0 and i > 0):
                    print(f"  Calibrated {i} batches ...")

                if (i >= config['calibration']['steps'] - 1):
                    print(f"  Completed {config['calibration']['steps']} calibration steps.")
                    break


    print("Starting QAT...")
    train_losses = []
    val_losses = []
    val_mious = []

    epochs = config['training']['epochs']
    for epoch in range(epochs):
        prepared_model.train()
        training_loss = 0
        for image, label in tqdm(train_dataloader, desc=f"Training Epoch {epoch}"):
            image = image.to(device, non_blocking=True)
            label = label.to(device, non_blocking=True)

            optimizer.zero_grad()
            out = prepared_model(image)['out']
            loss = loss_function(out, label)
            training_loss += loss.item()
            loss.backward()
            optimizer.step()

        prepared_model.eval()
        validation_loss = 0
        val_miou = 0
        with torch.no_grad():
            for image, label in tqdm(val_dataloader, desc=f"Validation Epoch {epoch}"):
                image = image.to(device, non_blocking=True)
                label = label.to(device, non_blocking=True)

                out = prepared_model(image)['out']
                pred = out.argmax(dim=1)
                loss = loss_function(out, label)
                validation_loss += loss.item()
                val_miou += calculate_miou(pred, label)

        average_training_loss = training_loss / len(train_dataloader)
        average_validation_loss = validation_loss / len(val_dataloader)
        average_val_miou = val_miou / len(val_dataloader)

        train_losses.append(average_training_loss)
        val_losses.append(average_validation_loss)
        val_mious.append(average_val_miou)

        print(f"Epoch: {epoch}, Training Loss: {average_training_loss}, Validation Loss: {average_validation_loss}, mIOU: {average_val_miou}")

    print("Convert QAT model ...")
    # must move model to CPU to convert, else it errors!
    prepared_model = prepared_model.cpu()
    quantized_model = quantize_fx.convert_fx(prepared_model.eval())

    # print("Saving QAT Model ...")
    # torch.save(quantized_model.state_dict(), f"models/qat_quantized_model{idx}.pth")
    # print(f"QAT Model Size (MB): {os.path.getsize(f"models/qat_quantized_model{idx}.pth") / 1e6:.2f}")

    # save all results
    result = {
      "idx": idx,
      "config": config,
      "train_losses": [round(loss, 2) for loss in train_losses],
      "val_losses": [round(loss, 2) for loss in val_losses],
      "val_mious": [round(miou, 2) for miou in val_mious],
      "final_train_loss": train_losses[-1],
      "final_val_loss": val_losses[-1],
      "final_val_miou": val_mious[-1]
    }

    json_path = os.path.join(results_dir, "qat_results.json")
    if os.path.exists(json_path):
      with open(json_path, "r") as f:
        results = json.load(f)
    else:
      results = []

    results.append(result)

    with open(json_path, "w") as f:
      json.dump(results, f, indent=2)

    # save all results
    fig, ax = plot_loss(train_losses, val_losses)
    plot_path = os.path.join(results_dir, f"qat_loss_{idx}.png")
    fig.savefig(plot_path)

    fig, ax = plot_miou(val_mious)
    plot_path = os.path.join(results_dir, f"miou_{idx}.png")
    fig.savefig(plot_path)

### start here... ##################################################################################
results_dir = "results"
os.makedirs(results_dir, exist_ok=True)

qat_config = {
    "model_checkpoint": "models/baseline_init_model.pth",
    "mode": ['int8','int6', 'int4'],

    "weights": {
        "dtype": "qint8",            # Keep constant.
        "granularity": "per_channel" # "per_channel" or "per_tensor"
    },

    "activations": {
        "dtype": "quint8",        # Keep constant.
        "granularity": "per_tensor",  # Keep constant
        "observer": ["histogram", "minmax"]
    },

    "calibration": {
        "enabled": True,
        "steps": 10
    },

    "training": {
        "epochs": 5,
        "batch_size": [8],
        "learning_rate": [1e-4, 1e-3],
        "weight_decay": 1e-5,
        "train_transforms": {
            "crop": True,
            "resize": True,
            "flip": True
        },
        "val_transforms": {
            "crop": False,
            "resize": False,
            "flip": False
        }
    },
    "skip_aspp": [True, False]
}

modes = qat_config['mode']
act_observers = qat_config["activations"]["observer"]
training_batch_sizes = qat_config["training"]["batch_size"]
training_learning_rates = qat_config["training"]["learning_rate"]
skip_aspps = qat_config["skip_aspp"]

for i, (mode, act_observer, training_batch_size, training_learning_rate, skip_aspp) in enumerate(itertools.product(modes, act_observers, training_batch_sizes, training_learning_rates, skip_aspps)):
  current_config = copy.deepcopy(qat_config)
  current_config['mode'] = mode
  current_config["activations"]["observer"] = act_observer
  current_config["training"]["batch_size"] = training_batch_size
  current_config["training"]["learning_rate"] = training_learning_rate
  current_config["skip_aspp"] = skip_aspp

  print(f"Run {i} with config: {current_config}")

  run_qat(i, current_config)


