# Dependencies

In [None]:
from model.detector import Detector
from model.loss import DetectorLoss
import torch
import yaml
from dataset import Dataset, collate_fn
import tqdm
import math
from model.evaluator import CocoDetectionEvaluator
import os
import datetime

# Initialization

In [None]:
configuration = yaml.safe_load(open("config.yaml", "r"))
epochs = configuration["epochs"]
batch_size = configuration["batch_size"]
learning_rate = configuration["learning_rate"]
milestones = configuration["milestones"]
input_size = configuration["input_size"]
num_workers = configuration["num_workers"]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = Detector(len(configuration["classes"]), True).to(device)
loss = DetectorLoss(device)
evaluator = CocoDetectionEvaluator(configuration["classes"], device)
optimizer = torch.optim.SGD(
    params=model.parameters(),
    lr=learning_rate,
    momentum=0.949,
    weight_decay=0.0005,
)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
    optimizer, milestones=milestones, gamma=0.1
)
epoch = 0
batch = 0

dummy_input = torch.rand(1, 3, input_size, input_size).to(device)

train_dataset = Dataset(configuration, augment=True)
test_dataset = Dataset(configuration, test=True)
train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    persistent_workers=True,
    collate_fn=collate_fn,
)
test_dataloader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=batch_size,
    num_workers=num_workers,
    persistent_workers=True,
    collate_fn=collate_fn,
)

# Dataset summary

In [None]:
# print(len(train_dataset), len(test_dataset))
# train_dataset.show_distribution()
# test_dataset.show_distribution()
# train_dataset.show_sample()

# Load weights

In [None]:
load = None

if load is not None and load != "":
    with open(f"{load}/metadata", "r") as file:
        epoch = int(file.readlines()[1].split(";")[0]) + 1
    model.load_state_dict(torch.load(f"{load}/weights.pt"))
    optimizer.load_state_dict(torch.load(f"{load}/optimizer.pt"))
    scheduler.load_state_dict(torch.load(f"{load}/scheduler.pt"))

# Train

In [None]:
checkpoints_path = (
    f"./checkpoints/{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
)

for epoch in range(epoch, epoch + epochs + 1):
    model.train()

    progress_bar = tqdm.tqdm(train_dataloader)
    for x, y in progress_bar:
        x = x.to(device)
        y = y.to(device)

        predictions = model(x)

        iou, obj, cls, total = loss(predictions, y)

        total.backward()
        optimizer.step()
        optimizer.zero_grad()

        for g in optimizer.param_groups:
            warmup = 5 * len(train_dataloader)
            if batch <= warmup:
                scale = math.pow(batch / warmup, 4)
                g["lr"] = learning_rate * scale
            lr = g["lr"]

        info = f"Epoch: {epoch}\t| LR: {lr}\t| IoU: {iou}\t| Obj: {obj}\t| Cls: {cls}\t| Total: {total}\t"
        progress_bar.set_description(info)
        batch += 1

    if epoch % 10 == 0 and epoch > 0:
        model.eval()
        print("Compute mAP...")
        mAP05 = evaluator.compute_map(test_dataloader, model)
        dir = f"./{checkpoints_path}/{epoch}_{mAP05}"
        os.makedirs(dir, exist_ok=True)
        torch.save(
            model.state_dict(),
            f"{dir}/weights.pt",
        )
        torch.save(
            optimizer.state_dict(),
            f"{dir}/optimizer.pt",
        )
        torch.save(
            scheduler.state_dict(),
            f"{dir}/scheduler.pt",
        )
        torch.onnx.export(
            model,
            dummy_input,
            f"{dir}/model.onnx",
            export_params=True,
            input_names=["input"],
            output_names=["output"],
            dynamic_axes={
                "input": {0: "batch", 2: "width", 3: "height"},
                "output": {0: "batch"},
            },
            do_constant_folding=True,
        )
        with open(f"{dir}/metadata", "w") as file:
            file.write(";".join(["epoch", "mAP05"]) + "\n")
            file.write(";".join([str(epoch), str(mAP05)]))
    scheduler.step()