# Dependencies

In [None]:
from model.detector import Detector
from model.loss import DetectorLoss
import torch
import matplotlib.pyplot as plt
from dataset import Dataset, collate_fn
import tqdm
import math
from model.evaluator import CocoDetectionEvaluator
import os
import datetime
import utils
import yaml

%matplotlib tk


# def get_default_qconfig():
#     qconfig = torch.ao.quantization.QConfig(
#         activation=torch.ao.quantization.FusedMovingAvgObsFakeQuantize.with_args(
#             observer=torch.ao.quantization.MovingAverageMinMaxObserver,
#             quant_min=0,
#             quant_max=255,
#             reduce_range=True,
#         ),
#         weight=torch.ao.quantization.default_fused_per_channel_wt_fake_quant,
#     )
#     return qconfig


# def get_custom_qconfig():
#     activation_observer = torch.quantization.MovingAverageMinMaxObserver.with_args(
#         quant_min=0,
#         quant_max=255,
#         dtype=torch.quint8,
#         qscheme=torch.per_tensor_symmetric 

#     )

#     weight_observer = torch.quantization.MovingAverageMinMaxObserver.with_args(
#         quant_min=0,
#         quant_max=255,
#         dtype=torch.quint8,
#         qscheme=torch.per_tensor_symmetric 
#     )

#     qconfig = torch.quantization.QConfig(
#         activation=torch.quantization.FusedMovingAvgObsFakeQuantize.with_args(
#             observer=activation_observer
#         ),
#         weight=torch.quantization.FusedMovingAvgObsFakeQuantize.with_args(
#             observer=weight_observer
#         )
#     )

#     return qconfig


# Initialization

In [None]:
# Configuration
configuration = utils.load_configuration("config.yaml")
epochs = configuration["epochs"]
batch_size = configuration["batch_size"]
learning_rate = configuration["learning_rate"]
milestones = configuration["milestones"]
input_size = configuration["input_size"]
num_workers = configuration["num_workers"]
pretrained = configuration["pretrained"]
checkpoint = configuration["checkpoint"]

# Training device selection
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on {device}")

# Loading model
model = Detector(80, True).to(device)
if pretrained is not None:
    model.load_state_dict(torch.load("weights/weight_AP05_0.253207_280-epoch.pth"))
model.reshape_category_num(len(configuration["classes"]))

# Prepare model for Quantization Aware Training
# fusing_list = [
#     line.strip().split(";")
#     for line in open("fusing_list.csv", "r").readlines()
# ]
# torch.ao.quantization.fuse_modules_qat(
#     model,
#     fusing_list,
#     inplace=True,
# )
# model.qconfig = get_default_qconfig()
# torch.ao.quantization.prepare_qat(model, inplace=True)

# Training parameters
loss = DetectorLoss(device).to(device)
evaluator = CocoDetectionEvaluator(configuration["classes"], device)
optimizer = torch.optim.SGD(
    params=model.parameters(),
    lr=learning_rate,
    momentum=0.949,
    weight_decay=0.0005,
)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
    optimizer, milestones=milestones, gamma=0.1
)

# ONNX exportation dummy input
dummy_input = torch.rand(1, 3, input_size, input_size).to(device)

# Loading dataset
train_dataset = Dataset(configuration, augment=True)
test_dataset = Dataset(configuration, test=True)
train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    persistent_workers=True,
    collate_fn=collate_fn,
    pin_memory=True,
)
test_dataloader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=batch_size,
    num_workers=num_workers,
    persistent_workers=True,
    collate_fn=collate_fn,
    pin_memory=True,
)

# Default values
warmup = 5 * len(train_dataloader)
epoch = 0
map05 = 0
map05_int8 = 0
batch = 0
x_epoch = []
y_lr = []
y_iou = []
y_obj = []
y_cls = []
y_total = []
y_map05 = []
y_map05_int8 = []

# Dataset preview

In [None]:
# print(len(train_dataset), len(test_dataset))
# train_dataset.show_distribution()
# test_dataset.show_distribution()
# train_dataset.show_sample()

# Load checkpoint

In [None]:
if checkpoint is not None and pretrained is None:
    print(f"Loading checkpoint from {checkpoint}")
    model.load_state_dict(torch.load(f"{checkpoint}/weights.pt"))
    optimizer.load_state_dict(torch.load(f"{checkpoint}/optimizer.pt"))
    scheduler.load_state_dict(torch.load(f"{checkpoint}/scheduler.pt"))
    with open(f"{checkpoint}/history.csv", "r") as file:
        lines = file.readlines()[1:]
        for line in lines:
            line = line.split(";")
            x_epoch.append(int(line[0]))
            y_lr.append(float(line[1]))
            y_iou.append(float(line[2]))
            y_obj.append(float(line[3]))
            y_cls.append(float(line[4]))
            y_total.append(float(line[5]))
            y_map05.append(float(line[6]))
            y_map05_int8.append(float(line[7]))
    epoch = x_epoch[-1] + 1
    map05 = y_map05[-1]
    map05_int8 = y_map05_int8[-1]

# Training loop

In [None]:
# Plotting initialization
plt.ion()
plt.rcParams["keymap.quit"].clear()
plt.rcParams["keymap.save"].clear()
plt.rcParams["toolbar"] = "None"
fig, ax = plt.subplots()
fig.canvas.manager.set_window_title("Training losses")
ax.set_title("Training losses")
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
(lr_plot,) = ax.plot(x_epoch, y_lr, ".-", label="LR")
(iou_plot,) = ax.plot(x_epoch, y_iou, ".-", label="IoU")
(obj_plot,) = ax.plot(x_epoch, y_obj, ".-", label="Obj")
(cls_plot,) = ax.plot(x_epoch, y_cls, ".-", label="Cls")
(total_plot,) = ax.plot(x_epoch, y_total, ".-", label="Total")
(map05_plot,) = ax.plot(x_epoch, y_map05, ".-", label="mAP@0.5")
(map05_int8_plot,) = ax.plot(x_epoch, y_map05_int8, ".-", label="mAP@0.5 int8")
ax.legend()
plt.draw()
plt.pause(0.1)

# Checkpoints path
checkpoints_path = (
    f"./checkpoints/{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
)

# Training loop
for epoch in range(epoch, epoch + epochs + 1):
    model.train()

    progress_bar = tqdm.tqdm(train_dataloader)
    for x, y in progress_bar:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        predictions = model(x)

        iou, obj, cls, total = loss(predictions, y)

        total.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch <= warmup:
            if batch == 0:
                target_lr = optimizer.param_groups[0]["lr"]
            scale = math.pow(batch / warmup, 4)
            optimizer.param_groups[0]["lr"] = target_lr * scale
        lr = optimizer.param_groups[0]["lr"]

        info = f"Epoch: {epoch}\t| LR: {lr}\t| IoU: {iou}\t| Obj: {obj}\t| Cls: {cls}\t| Total: {total}\t"
        progress_bar.set_description(info)
        batch += 1

    # Update plotting data
    x_epoch.append(epoch)
    y_lr.append(float(lr))
    y_iou.append(float(iou))
    y_obj.append(float(obj))
    y_cls.append(float(cls))
    y_total.append(float(total))
    lr_plot.set_data(x_epoch, y_lr)
    iou_plot.set_data(x_epoch, y_iou)
    obj_plot.set_data(x_epoch, y_obj)
    cls_plot.set_data(x_epoch, y_cls)
    total_plot.set_data(x_epoch, y_total)

    # mAP:05 evaluation
    if epoch % 10 == 0 and epoch > 0:
        print("Compute mAP...")
        model.eval()
        map05 = evaluator.compute_map(test_dataloader, model)
        # model_int8 = torch.ao.quantization.convert(model.eval())
        # model_int8.eval()
        # map05_int8 = evaluator.compute_map(test_dataloader, model_int8)

    # Update plotting
    y_map05.append(float(map05))
    y_map05_int8.append(float(map05_int8))
    map05_plot.set_data(x_epoch, y_map05)
    ax.relim()
    ax.autoscale_view()
    plt.draw()
    plt.pause(0.1)

    # Saving weights and training history and exporting to ONNX
    if epoch % 10 == 0 and epoch > 0:
        dir = f"./{checkpoints_path}/{epoch}_{map05}"
        os.makedirs(dir, exist_ok=True)
        torch.save(
            model.state_dict(),
            f"{dir}/weights.pt",
        )

        torch.save(
            optimizer.state_dict(),
            f"{dir}/optimizer.pt",
        )
        torch.save(
            scheduler.state_dict(),
            f"{dir}/scheduler.pt",
        )
        torch.onnx.export(
            model,
            dummy_input,
            f"{dir}/model.onnx",
            export_params=True,
            input_names=["input"],
            output_names=["output"],
            do_constant_folding=True,
        )
        # torch.onnx.export(
        #     model_int8,
        #     dummy_input,
        #     f"{dir}/model_int8.onnx",
        #     export_params=True,
        #     input_names=["input"],
        #     output_names=["output"],
        #     do_constant_folding=True,
        # )
        with open(f"{dir}/history.csv", "w") as file:
            file.write("epoch;lr;iou;obj;cls;total;map05;map05_int8\n")
            for i in range(len(x_epoch)):
                file.write(
                    f"{x_epoch[i]};{y_lr[i]};{y_iou[i]};{y_obj[i]};{y_cls[i]};{y_total[i]};{y_map05[i]};{y_map05_int8[i]}\n"
                )
        with open(f"{dir}/config.yaml", "w") as file:
            yaml.dump(configuration, file)
        plt.savefig(f"{dir}/losses.png")
    scheduler.step()