In [None]:
from model.detector import Detector
from model.loss import DetectorLoss
import torch
import yaml
from dataset import Dataset
import cv2
import utils
import tqdm
import math
from model.evaluator import CocoDetectionEvaluator
import os

In [None]:
configuration = yaml.safe_load(open("config.yaml", "r"))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = Detector(len(configuration["classes"]), True).to(device)

In [None]:
train_dataset = Dataset(configuration, augment=True)
test_dataset = Dataset(configuration, test=True)
print(len(train_dataset), len(test_dataset))

In [None]:
x, y = train_dataset[0]
x = x.numpy()
y = y.numpy()
x = x.transpose((1, 2, 0))
x = cv2.cvtColor(x, cv2.COLOR_RGB2BGR)
y = [value[1:] for value in y]
x = utils.draw_annotation(x, y, configuration["classes"])
cv2.imshow("image", x)
cv2.waitKey(0)
cv2.destroyAllWindows()

In [None]:
epochs = configuration["epochs"]
batch_size = configuration["batch_size"]
loss = DetectorLoss(device)
learning_rate = configuration["learning_rate"]
optimizer = torch.optim.SGD(
    params=model.parameters(),
    lr=learning_rate,
    momentum=0.949,
    weight_decay=0.0005,
)
milestones = configuration["milestones"]
scheduler = torch.optim.lr_scheduler.MultiStepLR(
    optimizer, milestones=milestones, gamma=0.1
)

evaluator = CocoDetectionEvaluator(configuration["classes"], device)


def collate_fn(batch):
    img, label = zip(*batch)
    for i, l in enumerate(label):
        if l.shape[0] > 0:
            l[:, 0] = i
    return torch.stack(img), torch.cat(label, 0)


num_workers = configuration["num_workers"]
train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    persistent_workers=True,
    collate_fn=collate_fn,
)
test_dataloader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=batch_size,
    num_workers=num_workers,
    persistent_workers=True,
    collate_fn=collate_fn,
)

batch_num = 0
for epoch in range(epochs + 1):
    model.train()
    progress_bar = tqdm.tqdm(train_dataloader)
    for x, y in progress_bar:
        x = x.to(device)
        y = y.to(device)

        predictions = model(x)

        iou, obj, cls, total = loss(predictions, y)

        total.backward()
        optimizer.step()
        optimizer.zero_grad()

        for g in optimizer.param_groups:
            warmup_num = 5 * len(train_dataloader)
            if batch_num <= warmup_num:
                scale = math.pow(batch_num / warmup_num, 4)
                g["lr"] = learning_rate * scale
            lr = g["lr"]

        info = f"Epoch: {epoch}\tLR: {lr}\tIOU: {iou}\tObj: {obj}\tCls: {cls}\tTotal: {total}"
        progress_bar.set_description(info)
        batch_num += 1

    if epoch % 10 == 0 and epoch > 0:
        model.eval()
        print("compute mAP...")
        mAP05 = evaluator.compute_map(test_dataloader, model)
        os.makedirs("./checkpoints", exist_ok=True)
        torch.save(
            model.state_dict(),
            f"./checkpoints/weight_AP05:{mAP05}_{epoch}-epoch.pt",
        )

    scheduler.step()