In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import matplotlib.patches as patches

from model import YOLOv3
from dataset import YOLODataset
from loss import YoloLoss
from utils import (
    cells_to_bboxes,
    iou_width_height as iou,
    non_max_suppression as nms,
    plot_image,
    get_evaluation_bboxes,
    mean_average_precision,
    load_checkpoint,
    get_loaders,
    plot_couple_examples,
)
import config

# Load the dataset
train_loader, test_loader, train_eval_loader = get_loaders(
    config.DATASET + "/train.csv",
    config.DATASET + "/test.csv",
)


# Define the model and loss function
model = YOLOv3(num_classes=config.NUM_CLASSES).to(config.DEVICE)
loss_fn = YoloLoss()

# Set up the training loop
def train_loop(model, optimizer, loss_fn, dataloader, device):
    model.train()
    loop = tqdm(dataloader, leave=True)
    mean_loss = []

    for batch_idx, (x, y) in enumerate(loop):
        x, y = x.to(device), tuple(y_tensor.to(device) for y_tensor in y)
        out = model(x)
        loss = loss_fn(out, y)
        mean_loss.append(loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Update the progress bar
        mean_loss_value = np.mean(mean_loss[-50:])
        loop.set_postfix(loss=mean_loss_value)

# Set up the evaluation loop
def evaluate_model(model, dataloader, device):
    model.eval()
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for x, y in tqdm(dataloader, leave=False):
            x, y = x.to(device), tuple(y_tensor.to(device) for y_tensor in y)
            out = model(x)
            all_preds.append([item.to("cpu") for item in out])
            all_targets.append([item.to("cpu") for item in y])

    all_preds = torch.cat(all_preds, dim=0)
    all_targets = torch.cat(all_targets, dim=0)

    boxes, targets = [], []
    for batch_idx in range(all_preds[0].shape[0]):
        boxes.append(get_evaluation_bboxes(
            all_preds[:, batch_idx],
            iou_threshold=config.IOU_THRESHOLD,
            anchors=config.ANCHORS,
            threshold=config.THRESHOLD,
            box_format="midpoint",
            device=config.DEVICE,
        ))
        targets.append(cells_to_bboxes(
            all_targets[0][batch_idx],
            anchors=config.ANCHORS,
            device=config.DEVICE,
        ))

    return mean_average_precision(boxes, targets)

# Set up the visualization
def visualize_predictions(model, dataloader, device):
    model.eval()
    x, y = next(iter(dataloader))
    x, y = x.to(device), tuple(y_tensor.to(device) for y_tensor in y)

    with torch.no_grad():
        out = model(x)
        bboxes = [[] for _ in range(x.shape[0])]
        for i in range(3):
            batch_size, A, S, _, _ = out[i].shape
            anchors = config.ANCHORS[config.ANCHORS_PER_SCALE * i: config.ANCHORS_PER_SCALE * (i + 1)]
            boxes_scale_i = cells_to_bboxes(
                out[i].reshape(batch_size, A, S, S, 5 + config.NUM_CLASSES),
                anchors,
                device,
            )
            for idx, (box) in enumerate(boxes_scale_i):
                bboxes[idx] += box

        model.eval()
        for idx in range(batch_size):
            nms_boxes = non_max_suppression(
                bboxes[idx],
                iou_threshold=config.IOU_THRESHOLD,
                threshold=config.THRESHOLD,
                box_format="midpoint",
            )
            plot_image(x[idx].permute(1, 2, 0).byte().cpu().numpy(), nms_boxes)

# Training pipeline
optimizer = optim.Adam(model.parameters(), lr=config.LEARNING_RATE)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, verbose=True)

# Implement the training and evaluation pipeline
for epoch in range(config.NUM_EPOCHS):
    train_loop(model, optimizer, loss_fn, train_loader, config.DEVICE)

    # Evaluate the model at regular intervals
    if epoch % 5 == 0:
        mAP = evaluate_model(model, train_eval_loader, config.DEVICE)
        print(f"Epoch {epoch}, mAP: {mAP:.4f}")
        lr_scheduler.step(mAP)

    # Save the model checkpoint
    if epoch % 10 == 0:
        checkpoint = {
            "state_dict": model.state_dict(),
            "optimizer": optimizer.state_dict(),
        }
        torch.save(checkpoint, f"checkpoints/yolov3_epoch_{epoch}.pth")

# Visualize the predictions
visualize_predictions(model, test_loader, config.DEVICE)


[(32, 3, 1), (64, 3, 2), ['B', 1], (128, 3, 2), ['B', 2], (256, 3, 2), ['B', 8], (512, 3, 2), ['B', 8], (1024, 3, 2), ['B', 4], (512, 1, 1), (1024, 3, 1), 'S', (256, 1, 1), 'U', (256, 1, 1), (512, 3, 1), 'S', (128, 1, 1), 'U', (128, 1, 1), (256, 3, 1), 'S']


  Expected `Union[float, tuple[float, float]]` but got `list` - serialized value may not be as expected
  Expected `Union[float, tuple[float, float]]` but got `list` - serialized value may not be as expected
  Expected `Union[float, tuple[float, float]]` but got `list` - serialized value may not be as expected
  Expected `Union[float, tuple[float, float]]` but got `list` - serialized value may not be as expected
  return self.__pydantic_serializer__.to_python(
  0%|          | 0/2 [32:10<?, ?it/s]


KeyboardInterrupt: 