### TODO
* implement intersection of union (done)
* non-max suppression
  * takes in bounding boxes and two threshold values
  * returns bounding boxes after filtering
* mean average precision
* Load data
* Train
* Test

### 1. Setup

In [1]:
# https://stackoverflow.com/questions/1254370/reimport-a-module-while-interactive
%load_ext autoreload
%autoreload 2

import torch
from torch import nn as nn
from torch.optim.lr_scheduler import StepLR
from torchvision import transforms
from utils import (
    intersection_over_union,
    get_bboxes,
    non_max_suppression,
    mean_average_precision,
)
import numpy as np
import tqdm

  from .autonotebook import tqdm as notebook_tqdm


### 2. Define Loss Function

In [2]:
"""
Implementation of Yolo Loss Function from the original yolo paper
"""

import torch
import torch.nn as nn
from utils import intersection_over_union


class YoloLoss(nn.Module):
    """
    Calculate the loss for yolo (v1) model
    """

    def __init__(self, S=7, B=2, C=20):
        super(YoloLoss, self).__init__()
        self.mse = nn.MSELoss(reduction="sum")

        """
        S is split size of image (in paper 7),
        B is number of boxes (in paper 2),
        C is number of classes (in paper and VOC dataset is 20),
        """
        self.S = S
        self.B = B
        self.C = C

        # These are from Yolo paper, signifying how much we should
        # pay loss for no object (noobj) and the box coordinates (coord)
        self.lambda_noobj = 0.5
        self.lambda_coord = 5

    def forward(self, predictions, target):
        # predictions are shaped (BATCH_SIZE, S*S(C+B*5) when inputted
        predictions = predictions.reshape(-1, self.S, self.S, self.C + self.B * 5)

        # Calculate IoU for the two predicted bounding boxes with target bbox
        iou_b1 = intersection_over_union(predictions[..., 21:25], target[..., 21:25])
        iou_b2 = intersection_over_union(predictions[..., 26:30], target[..., 21:25])
        ious = torch.cat([iou_b1.unsqueeze(0), iou_b2.unsqueeze(0)], dim=0)

        # Take the box with highest IoU out of the two prediction
        # Note that bestbox will be indices of 0, 1 for which bbox was best
        iou_maxes, bestbox = torch.max(ious, dim=0)
        exists_box = target[..., 20].unsqueeze(3)  # in paper this is Iobj_i

        # ======================== #
        #   FOR BOX COORDINATES    #
        # ======================== #

        # Set boxes with no object in them to 0. We only take out one of the two 
        # predictions, which is the one with highest Iou calculated previously.
        box_predictions = exists_box * (
            (
                bestbox * predictions[..., 26:30]
                + (1 - bestbox) * predictions[..., 21:25]
            )
        )

        box_targets = exists_box * target[..., 21:25]

        # Take sqrt of width, height of boxes to ensure that
        box_predictions[..., 2:4] = torch.sign(box_predictions[..., 2:4]) * torch.sqrt(
            torch.abs(box_predictions[..., 2:4] + 1e-6)
        )
        box_targets[..., 2:4] = torch.sqrt(box_targets[..., 2:4])

        box_loss = self.mse(
            torch.flatten(box_predictions, end_dim=-2),
            torch.flatten(box_targets, end_dim=-2),
        )

        # ==================== #
        #   FOR OBJECT LOSS    #
        # ==================== #

        # pred_box is the confidence score for the bbox with highest IoU
        pred_box = (
            bestbox * predictions[..., 25:26] + (1 - bestbox) * predictions[..., 20:21]
        )

        object_loss = self.mse(
            torch.flatten(exists_box * pred_box),
            torch.flatten(exists_box * target[..., 20:21]),
        )

        # ======================= #
        #   FOR NO OBJECT LOSS    #
        # ======================= #

        #max_no_obj = torch.max(predictions[..., 20:21], predictions[..., 25:26])
        #no_object_loss = self.mse(
        #    torch.flatten((1 - exists_box) * max_no_obj, start_dim=1),
        #    torch.flatten((1 - exists_box) * target[..., 20:21], start_dim=1),
        #)

        no_object_loss = self.mse(
            torch.flatten((1 - exists_box) * predictions[..., 20:21], start_dim=1),
            torch.flatten((1 - exists_box) * target[..., 20:21], start_dim=1),
        )

        no_object_loss += self.mse(
            torch.flatten((1 - exists_box) * predictions[..., 25:26], start_dim=1),
            torch.flatten((1 - exists_box) * target[..., 20:21], start_dim=1)
        )

        # ================== #
        #   FOR CLASS LOSS   #
        # ================== #

        class_loss = self.mse(
            torch.flatten(exists_box * predictions[..., :20], end_dim=-2,),
            torch.flatten(exists_box * target[..., :20], end_dim=-2,),
        )

        loss = (
            self.lambda_coord * box_loss  # first two rows in paper
            + object_loss  # third row in paper
            + self.lambda_noobj * no_object_loss  # forth row
            + class_loss  # fifth row
        )

        return loss

### 3. Define Model

In [3]:
model_architectures = {
    # https://arxiv.org/pdf/1506.02640.pdf
    "yolov1": [
        # (kernel_width, kernel_height filters, stride)
        (7, 64, 2, 3),
        # maxpooling
        "M",
        (3, 192, 1, 1),
        "M",
        (1, 128, 1, 0),
        (3, 256, 1, 1),
        (1, 256, 1, 0),
        (3, 512, 1, 1),
        "M",
        # repeats
        [(1, 256, 1, 0), (3, 512, 1, 1), 4],
        (1, 512, 1, 0),
        (3, 1024, 1, 1),
        "M",
        [(1, 512, 1, 0), (3, 1024, 1, 1), 2],
        (3, 1024, 1, 1),
        (3, 1024, 2, 1),
        (3, 1024, 1, 1),
        (3, 1024, 1, 1),
    ]
}

class CNNBlock(nn.Module):
    def __init__(self, in_channals, out_channels, **kwargs):
        super(CNNBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels=in_channals, out_channels=out_channels, bias=False, **kwargs)
        self.batch_norm = nn.BatchNorm2d(out_channels)
        self.leaky_relu = nn.LeakyReLU(0.1)

    def forward(self, x):
        return self.leaky_relu(self.batch_norm(self.conv(x)))

class YoloV1(nn.Module):
    def __init__(self, model_configuration, in_channels=3, **kwargs):
        super(YoloV1, self).__init__()
        self.in_channels = in_channels
        self.conv_layers = self._create_conv_layers(model_configuration)
        self.fc_layers = self._create_fc_layers(**kwargs)

    def forward(self, x):
        x = self.conv_layers(x)
        x = self.fc_layers(torch.flatten(x, start_dim=1))
        return x

    def  _create_conv_from_layer(self, layer):
        kernel_size, filters, stride, padding = layer
        return CNNBlock(
            self.in_channels,
            filters,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding
        ), filters

    def _create_conv_layers(self, model_configuration):
        layers = []
        for layer in model_configuration:
            if isinstance(layer, tuple):
                conv_layer, out_channels = self._create_conv_from_layer(layer)
                layers.append(conv_layer)
                self.in_channels = out_channels

            elif isinstance(layer, str):
                if layer == "M":
                    layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
                elif layer == "U":
                    layers.append(nn.Upsample(scale_factor=2))

            elif isinstance(layer, list):
                repeats = layer[2]
                for _ in range(repeats):
                    for conv_layer in layer[:2]:
                        conv_layer, out_channels = self._create_conv_from_layer(conv_layer)
                        layers.append(conv_layer)
                        self.in_channels = out_channels

        return nn.Sequential(*layers)

    def _create_fc_layers(self, split_size=7, num_classes=20, num_boxes=2):
        print(f"{self.in_channels}")
        return nn.Sequential(
            nn.Flatten(),
            nn.Linear(self.in_channels * split_size * split_size, 496),
            nn.Dropout(0.0),
            nn.LeakyReLU(0.1),
            nn.Linear(496, split_size**2 * (5 * num_boxes + num_classes)),
        )

### 4. Dataloader

In [4]:
class Compose(object):
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, x, boxes):
        for t in self.transforms:
            x, boxes = t(x), boxes
        return x, boxes

In [5]:
import pandas as pd
from PIL import Image
import os
from math import floor
        
class VOCDataset(torch.utils.data.Dataset):
    def __init__(
        self, 
        csv_file, 
        label_dir,
        image_dir, 
        transform=None, 
        split_size=7,
        num_boxes=2,
        num_classes=20,
    ):
        self.annotations = pd.read_csv(csv_file)
        self.label_dir = label_dir
        self.image_dir = image_dir
        self.transform = transform
        self.split_size = split_size
        self.num_bboxes = num_boxes
        self.num_classes = num_classes

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, index):
        # read labels
        label_path = os.path.join(self.label_dir, self.annotations.iloc[index, 1])
        boxes = []
        for label in open(label_path).read().splitlines():
            class_type, x, y, w, h = [
                # if errors then fix this
                float(value) for value in label.split()
            ]

            boxes.append([class_type, x, y, w, h])

        image_path = os.path.join(self.image_dir, self.annotations.iloc[index, 0])
        image = Image.open(image_path)
        boxes = torch.tensor(boxes)

        if self.transform is not None:
            image, boxes = self.transform(image, boxes)
        else:
            image = torch.tensor(np.array(image), dtype=torch.float32)

        label_tensor = torch.zeros(
            (self.split_size, 
            self.split_size, 
            self.num_classes + 5 * self.num_bboxes))

        for box in boxes:
            class_type, x, y, w, h = box.tolist()
            class_type = int(class_type)

            width = w * self.split_size
            height = h * self.split_size

            x_index = floor(x * self.split_size)
            y_index = floor(y * self.split_size)
            x_relative = (x * self.split_size) - x_index
            y_relative = (y * self.split_size) - y_index


            # restricting every cell to only have one bbox
            if label_tensor[y_index, x_index, 20] == 0:
                label_tensor[y_index, x_index, 20] = 1
                label_tensor[y_index, x_index, class_type] = 1

                box_property = torch.tensor([x_relative, y_relative, width, height])
                label_tensor[y_index, x_index, 21:25] = box_property

        return image, label_tensor

### 6. loading data and train

In [6]:
def train_fn(train_loader, model, loss_fn, optimizer, device):
    loader = tqdm.tqdm(train_loader, leave=True)
    model.train()

    mean_loss = []

    for batch_idx, (data, target) in enumerate(loader):
        data, target = data.to(device), target.to(device)
        output = model(data)
        loss = loss_fn(output, target)
        optimizer.zero_grad()
        loss.backward()

        optimizer.step()
        loader.set_description(f"Train loss: {loss.item():.4f}")
        #loader.postfix(loss=round(loss.item(), 4))

        mean_loss.append(loss.item())
    
    print(f"Mean train loss: {sum(mean_loss)/len(mean_loss):.4f}")
    

def test_fn(test_loader, model, device):
    model.eval()

    loader = tqdm.tqdm(test_loader, leave=True)
    loader.set_description("Testing")
    pred_boxes, target_boxes = get_bboxes(
        loader, model, iou_threshold=0.5, threshold=0.4, device=device
    )

    mean_avg_prec = mean_average_precision(
        pred_boxes, target_boxes, iou_threshold=0.5, box_format="midpoint"
    )
    print(f"Train mAP: {mean_avg_prec}")

def save_model(model, optimizer, model_path):
    state_dict = {
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
    }
    torch.save(state_dict, model_path)
    print("==> Saved model state to {}".format(model_path))

def load_model(model, optimizer, model_path):
    state_dict = torch.load(model_path)
    model.load_state_dict(state_dict["model_state_dict"])
    optimizer.load_state_dict(state_dict["optimizer_state_dict"])
    print("==> Loaded model state from {}".format(model_path))


### 7. train


In [7]:
hyperparams = {
    "learning_rate": 2e-5,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "batch_size": 64,
    "epochs": 100,
    "num_workers": 2,
    "weight_decay": 0,
    "img_dir": "./data/data/images",
    "label_dir": "./data/data/labels",
    "pin_memory": True,
    "load_model": True,
    "save_model": False,
    "save_per_epochs": 10,
}

MODEL_PATH = "./trained_models/yolov1.pt"

def main():
    torch.manual_seed(123)
    if hyperparams["device"] == "cuda":
        torch.cuda.empty_cache()

    model = YoloV1(model_architectures["yolov1"], in_channels=3, num_classes=20, num_boxes=2, split_size=7).to(hyperparams["device"])

    transform = Compose([
        transforms.Resize((448, 448)),
        transforms.ToTensor(),
    ])

    train_set = VOCDataset(
        csv_file="train.csv",
        label_dir=hyperparams["label_dir"],
        image_dir=hyperparams["img_dir"],
        transform=transform,
    )

    test_set = VOCDataset(
        csv_file="test.csv",
        label_dir=hyperparams["label_dir"],
        image_dir=hyperparams["img_dir"],
        transform=transform,
    )

    train_loader = torch.utils.data.DataLoader(
        train_set,
        batch_size=hyperparams["batch_size"],
        shuffle=True,
        drop_last=True,
        num_workers=hyperparams["num_workers"],
        pin_memory=hyperparams["pin_memory"],
    )

    test_loader = torch.utils.data.DataLoader(
        test_set,
        batch_size=hyperparams["batch_size"],
        shuffle=False,
        drop_last=True,
        num_workers=hyperparams["num_workers"],
        pin_memory=hyperparams["pin_memory"],
    )

    optimizer = torch.optim.Adam(
        model.parameters(), 
        lr=hyperparams["learning_rate"], 
        weight_decay=hyperparams["weight_decay"]
    )

    loss_fn = YoloLoss()

    if hyperparams["load_model"]:
        load_model(model, optimizer, MODEL_PATH)
    
    for epoch in range(hyperparams["epochs"]):
        print(f"Epoch: {epoch+1}")
        #train_fn(train_loader, model, loss_fn, optimizer, hyperparams["device"])

        if epoch % hyperparams["save_per_epochs"] == hyperparams["save_per_epochs"] - 1 and hyperparams["save_model"]:
            save_model(model, optimizer, MODEL_PATH)
        #test_fn(test_loader, model, hyperparams["device"])
        test_fn(test_loader, model, hyperparams["device"])

    if hyperparams["save_model"]:
        save_model(model, optimizer, MODEL_PATH)

main()


1024
==> Loaded model state from ./trained_models/yolov1.pt
Epoch: 1


Testing: 100%|██████████| 15/15 [00:09<00:00,  1.65it/s]


Train mAP: 0.11399567127227783
Epoch: 2


Testing: 100%|██████████| 15/15 [00:06<00:00,  2.41it/s]


Train mAP: 0.11399567127227783
Epoch: 3


Testing:   7%|▋         | 1/15 [00:02<00:30,  2.16s/it]


KeyboardInterrupt: 