
# YOLO v1

[paper](https://arxiv.org/pdf/1506.02640.pdf)

Key points:

- S x S grid. S = 7
- predicts B boxes for each cell. B = 2
- Responsible cell:
    - the cell that contains bbox midpoint.
    - Among B predicted boxes, only the one that has highest IoU will be responsible.
- predicts confidence each cell. confidence = IoU
- predicts x, y, w, h each cell:
    - x, y: they are midpoint coordinates relative to cell origin, h, w.
        Meaning, cell h, w are 1, 1, and x, y will be in [0, 1]
    - h, w: they are bbox height, width relative to whole image.
- predicts C classes each cell.
- All are trained only when the cell is responsible for a bbox.
- Each cell can only predict 1 object. although it tries to predict B bboxes
- Predicted tensor is of shape [S, S, (C + 5B)]
- Architecture is simply a CNN followed by a flatten and fully-connected layers.
- While inference, multiply C probabilities with predicted confidence.
- While inference, apply NMS
- All losses are MSE variations.

Hyperparams:

- leaky relu
- batch size 64
- epochs 135 (with pre-trained)
- momentum 0.9
- decay: 0.0005
- lr:
    - 10^-3 for few epochs.
    - 10^-2 for +75 epochs
    - 10^-3 for +30 epochs.
    - 10^-4 for +30 epochs.
- Extensive augmentation:
    - Random scaling and translation up to 20%
    - randomly adjust the exposure and saturation of the image by up to a factor of 1.5 in the HSV color space.
- dropout of 0.5 on last fully-connected

Losses:

- Object exists: lambda_coord * sum((x - xhat)^2 + (y - yhat)^2)
- Object exists: lambda_coord * sum((sqrt(w) - sqrt(w_hat))^2 + (sqrt(h) - sqrt(h_hat))^2)
- Object exists: 1 * sum((confidence - confidence_hat)^2)
- No-object exists: lambda_no_object * sum((confidence - confidence_hat)^2)
- Object exists: sum((probability(c) - probability(c_hat))^2)

confidence = IoU
lambda_coord = 5
lambda_no_object = 0.5


In [1]:
# ! pip install --upgrade pytorch-lightning albumentations wandb

In [2]:
import numpy as np
import wandb
from albumentations.pytorch import ToTensorV2
import cv2
from pytorch_lightning.loggers import WandbLogger
from torch.utils.data import DataLoader
from torchmetrics import AverageMeter, MetricCollection
from torchvision.datasets import VOCDetection
import albumentations as A
import torch
import torch.nn as nn
from torch.nn import functional as F
from typing import List, Union, Optional, Tuple, Dict, Any
import pytorch_lightning as pl
from torchsummary import summary

In [3]:
VOC_CLASSES = [
    "aeroplane",
    "bicycle",
    "bird",
    "boat",
    "bottle",
    "bus",
    "car",
    "cat",
    "chair",
    "cow",
    "diningtable",
    "dog",
    "horse",
    "motorbike",
    "person",
    "pottedplant",
    "sheep",
    "sofa",
    "train",
    "tvmonitor",
]

In [4]:
class YoloV1Transforms:
    def __init__(self, h: int, w: int, augment: bool, num_classes: int, grid_size: int):
        self.h = h
        self.w = w
        self.augment = augment
        self.num_classes = num_classes
        self.grid_size = grid_size

        self.albument_transforms = self._get_augmentations(self.h, self.w, self.augment)

    def __call__(self, image, targets: dict):
        """
        The transform function takes in pil image and a dict of target bboxes.
        It applies augmentations and returns an image and target tensor of shape (C+5, S, S)
        The transform will return image tensor and target tensor.

        The target is of the shape excluding unrelated info:
        ```
        annotation:
          object:
            - name: bicycle
              bndbox:
                xmax: 471
                xmin: 54
                ymax: 336
                ymin: 39
        ```
        The output target will be a tensor of shape: (C+5, S, S)
        :return: Callable function
        """
        boxes, classes = self._transform_pre_augmentation(targets)

        transformed = self.albument_transforms(
            image=np.array(image),
            bboxes=boxes,
            class_labels=classes,
        )

        image = transformed["image"]
        boxes = transformed["bboxes"]
        classes = transformed["class_labels"]

        targets = self.transform_targets_to_yolo(boxes, classes)
        return image, targets

    def transform_targets_to_yolo(self, boxes, classes) -> torch.Tensor:
        """
        Converts (xmin, ymin, xmax, ymax) format to yolo format.

        - Get responsible pairs:
            - Find midpoints of all bboxes.
            - For all cells, if there's a bbox midpoint in the cell,
              that cell and bbox will go in a responsible pair list.
        - Convert coordinates from (xmin, ymin, ...) to yolo style.
        - Put everything in a tensor.

        :param boxes: list of tuples of (xmin, ymin, xmax, ymax)
        :param classes: list of integers
        :return: torch.Tensor of shape (C+5, S, S)
        """
        pairs: List[Tuple[int, int, int]] = self._get_responsible_pairs(boxes)
        boxes_yolo = self._convert_boxes_to_yolo(boxes, pairs)

        tensor = torch.zeros((self.num_classes + 5, self.grid_size, self.grid_size))
        for i, (r, c, b) in enumerate(pairs):
            tensor[classes[b], r, c] = 1.0
            tensor[self.num_classes, r, c] = 1.0
            for j in range(4):
                tensor[self.num_classes + 1 + j, r, c] = boxes_yolo[i][j]
        return tensor

    def transform_targets_from_yolo(self):
        # TODO: Finish this
        pass

    def _convert_boxes_to_yolo(
        self,
        boxes: List[Tuple[int, int, int, int]],
        pairs: List[Tuple[int, int, int]],
    ) -> List[Tuple[float, float, float, float]]:
        """
        Returns a yolo style bbox coordinates for each responsible pair.
        """
        cell_h = self.h / self.grid_size
        cell_w = self.w / self.grid_size

        yolo_boxes = []
        for r, c, b in pairs:
            xmin, ymin, xmax, ymax = boxes[b]

            tw = (xmax - xmin) / self.w
            th = (ymax - ymin) / self.h

            mx = (xmax - xmin) / 2
            my = (ymax - ymin) / 2
            tx = mx / cell_w
            ty = my / cell_h

            yolo_boxes.append((tx, ty, tw, th))

        return yolo_boxes

    def _get_responsible_pairs(
        self,
        boxes: List[Tuple[int, int, int, int]],
    ) -> List[Tuple[int, int, int]]:
        """
        - Find midpoints of all bboxes.
        - For all cells, if there's a bbox midpoint in the cell,
          that cell and bbox will go in a responsible pair list.
        """
        midpoints = []
        for (xmin, ymin, xmax, ymax) in boxes:
            x = (xmin + xmax) / 2
            y = (ymin + ymax) / 2
            midpoints.append((x, y))

        cell_h = self.h / self.grid_size
        cell_w = self.w / self.grid_size

        pairs = []
        for r in range(self.grid_size):
            y1 = r * cell_h
            y2 = y1 + cell_h
            for c in range(self.grid_size):
                x1 = c * cell_w
                x2 = x1 + cell_w
                for b, (mx, my) in enumerate(midpoints):
                    if x1 < mx < x2 and y1 < my < y2:
                        pairs.append((r, c, b))
        return pairs

    @staticmethod
    def _get_augmentations(h, w, augment: bool):
        def normalize(x, **kwargs):
            return x / 255.0


        resizing: list = [
            # A.LongestMaxSize(max_size=WIDTH, always_apply=True),
            A.PadIfNeeded(min_height=h, min_width=w, border_mode=cv2.BORDER_CONSTANT),
            A.RandomCrop(h, w),
            # A.Resize(height=HEIGHT, width=WIDTH, always_apply=True),
        ]
        compatibility: list = [
            ToTensorV2(always_apply=True),
            A.Lambda(image=normalize),
        ]

        augmentations: list = []
        if augment:
            augmentations = [
                A.HorizontalFlip(p=0.5),
                A.RandomBrightnessContrast(p=0.2),
            ]

        return A.Compose(
            resizing + augmentations + compatibility,
            bbox_params=A.BboxParams(
                format="pascal_voc", min_visibility=0.05, label_fields=["class_labels"]
            ),
        )

    @staticmethod
    def _transform_pre_augmentation(targets: dict) -> Tuple[list, list]:
        """
        This converts the targets compatible with albumentations
        The target is of the shape excluding unrelated info:
        ```
        annotation:
          object:
            - name: bicycle
              bndbox:
                xmax: 471
                xmin: 54
                ymax: 336
                ymin: 39
        ```
        Output will be of the form:
        (
            [(xmin, ymin, xmax, ymax), ...],
            [3, ...]
        )
        """
        classes = []
        boxes = []
        for object in targets["annotation"]["object"]:
            class_index = VOC_CLASSES.index(object["name"])
            classes.append(class_index)

            box = object["bndbox"]
            box = tuple(int(box[key]) for key in ["xmin", "ymin", "xmax", "ymax"])
            boxes.append(box)

        return boxes, classes

In [5]:
class PartialVOCDetection(VOCDetection):
    def __init__(self, size: int, **kwargs):
        super().__init__(**kwargs)
        self.size = size

    def __len__(self):
        return self.size


class VocYoloDataModule(pl.LightningDataModule):
    def __init__(
        self,
        grid_size: int,
        batch_size: int,
        data_path: str,
        dataloader_num_workers: int = 0,
        data_augment=False,
        **_,
    ):
        super().__init__()
        self.grid_size = grid_size
        self.data_path = data_path
        self.batch_size = batch_size
        self.num_workers = dataloader_num_workers
        self.augment = data_augment

        self.h = 448
        self.w = 448
        self.dims = (3, self.h, self.w)
        self.num_classes = 20
        self.transforms = YoloV1Transforms(
            h=self.h,
            w=self.w,
            augment=self.augment,
            num_classes=self.num_classes,
            grid_size=self.grid_size,
        )

        self.dataset_train, self.dataset_val = None, None

    def prepare_data(self):
        VOCDetection(
            root=self.data_path,
            year="2012",
            image_set="trainval",
            download=True,  # TODO: Makke it True
        )

    def setup(self, stage: Optional[str] = None):
        self.dataset_train = PartialVOCDetection(
            root=self.data_path,
            year="2012",
            image_set="train",
            download=False,
            transforms=self.transforms,
            size=20
        )
        self.dataset_val = PartialVOCDetection(
            root=self.data_path,
            year="2012",
            image_set="val",
            download=False,
            transforms=self.transforms,
            size=20
        )

    def train_dataloader(self):
        return DataLoader(
            self.dataset_train,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
        )

    def val_dataloader(self):
        return DataLoader(
            self.dataset_val,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
        )

In [6]:
class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super(CNNBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.batchnorm = nn.BatchNorm2d(out_channels)
        self.leakyrelu = nn.LeakyReLU(0.1)

    def forward(self, x):
        return self.leakyrelu(self.batchnorm(self.conv(x)))


class SimpleCNN(nn.Module):
    def __init__(
        self,
        architecture: List[Union[tuple, str, list]],
        in_channels: int,
    ):
        super(SimpleCNN, self).__init__()
        layers = []
        for module in architecture:
            if type(module) is tuple:
                layers.append(self._get_cnn_block(module, in_channels))
                in_channels = module[1]
            elif module == "M":
                layers.append(
                    nn.MaxPool2d(
                        kernel_size=(2, 2),
                        stride=(2, 2),
                    )
                )
            elif type(module) is list:
                for i in range(module[-1]):
                    for j in range(len(module) - 1):
                        layers.append(self._get_cnn_block(module[j], in_channels))
                        in_channels = module[j][1]
        self.model = nn.Sequential(*layers)

    @staticmethod
    def _get_cnn_block(module: tuple, in_channels):
        kernel_size, filters, stride, padding = module
        return CNNBlock(
            in_channels,
            filters,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
        )

    def forward(self, x):
        return self.model(x)

In [19]:
"""
Information about architecture config:
- Tuple is structured by (kernel_size, filters, stride, padding)
- "M" is simply maxpooling with stride 2x2 and kernel 2x2
- List is structured by tuples and lastly int with number of repeats
"""

# original_yolo = [
#     (7, 64, 2, 3),
#     "M",
#     (3, 192, 1, 1),
#     "M",
#     (1, 128, 1, 0),
#     (3, 256, 1, 1),
#     (1, 256, 1, 0),
#     (3, 512, 1, 1),
#     "M",
#     [(1, 256, 1, 0), (3, 512, 1, 1), 4],
#     (1, 512, 1, 0),
#     (3, 1024, 1, 1),
#     "M",
#     [(1, 512, 1, 0), (3, 1024, 1, 1), 2],
#     (3, 1024, 1, 1),
#     (3, 1024, 2, 1),
#     (3, 1024, 1, 1),
#     (3, 1024, 1, 1),
# ]
architecture_config = [
    (7, 64, 2, 3),  # 224
    "M",  # 112
    (3, 194, 1, 1),
    "M",  # 56
    [(1, 128, 1, 0), (3, 128, 1, 1), 2],
    "M",  # 28
    [(1, 128, 1, 0), (3, 128, 1, 1), 4],
    "M", # 14
    [(1, 128, 1, 0), (3, 128, 1, 1), 4],
    (3, 64, 2, 1),  # 7
    (3, 32, 1, 1),
]

summary(SimpleCNN(architecture_config, in_channels=3), (3, 448, 448))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 224, 224]           9,408
       BatchNorm2d-2         [-1, 64, 224, 224]             128
         LeakyReLU-3         [-1, 64, 224, 224]               0
          CNNBlock-4         [-1, 64, 224, 224]               0
         MaxPool2d-5         [-1, 64, 112, 112]               0
            Conv2d-6        [-1, 194, 112, 112]         111,744
       BatchNorm2d-7        [-1, 194, 112, 112]             388
         LeakyReLU-8        [-1, 194, 112, 112]               0
          CNNBlock-9        [-1, 194, 112, 112]               0
        MaxPool2d-10          [-1, 194, 56, 56]               0
           Conv2d-11          [-1, 128, 56, 56]          24,832
      BatchNorm2d-12          [-1, 128, 56, 56]             256
        LeakyReLU-13          [-1, 128, 56, 56]               0
         CNNBlock-14          [-1, 128,

In [17]:
import torchvision
# mobilenetv2 = torchvision.models.MobileNetV2(num_classes=20)
# summary(mobilenetv2.features, (3, 448, 448))
cnn = torchvision.models.squeezenet1_0()
summary(cnn.features, (3, 448, 448))
cnn

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 96, 221, 221]          14,208
              ReLU-2         [-1, 96, 221, 221]               0
         MaxPool2d-3         [-1, 96, 110, 110]               0
            Conv2d-4         [-1, 16, 110, 110]           1,552
              ReLU-5         [-1, 16, 110, 110]               0
            Conv2d-6         [-1, 64, 110, 110]           1,088
              ReLU-7         [-1, 64, 110, 110]               0
            Conv2d-8         [-1, 64, 110, 110]           9,280
              ReLU-9         [-1, 64, 110, 110]               0
             Fire-10        [-1, 128, 110, 110]               0
           Conv2d-11         [-1, 16, 110, 110]           2,064
             ReLU-12         [-1, 16, 110, 110]               0
           Conv2d-13         [-1, 64, 110, 110]           1,088
             ReLU-14         [-1, 64, 1

SqueezeNet(
  (features): Sequential(
    (0): Conv2d(3, 96, kernel_size=(7, 7), stride=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
    (3): Fire(
      (squeeze): Conv2d(96, 16, kernel_size=(1, 1), stride=(1, 1))
      (squeeze_activation): ReLU(inplace=True)
      (expand1x1): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1))
      (expand1x1_activation): ReLU(inplace=True)
      (expand3x3): Conv2d(16, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (expand3x3_activation): ReLU(inplace=True)
    )
    (4): Fire(
      (squeeze): Conv2d(128, 16, kernel_size=(1, 1), stride=(1, 1))
      (squeeze_activation): ReLU(inplace=True)
      (expand1x1): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1))
      (expand1x1_activation): ReLU(inplace=True)
      (expand3x3): Conv2d(16, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (expand3x3_activation): ReLU(inplace=True)
    )
    (5): Fire(
   

In [20]:

class YoloV1(nn.Module):
    def __init__(self, in_channels, split_size, num_boxes, num_classes):
        super(YoloV1, self).__init__()
        self.backbone = SimpleCNN(architecture_config, in_channels)

        S, B, C = split_size, num_boxes, num_classes
        self.fcs = nn.Sequential(
            nn.Flatten(),
            nn.Linear(32 * S * S, 128),
            nn.Dropout(0.1),
            nn.LeakyReLU(0.1),
            nn.Linear(128, S * S * (C + B * 5)),
        )
        self.final_shape = (-1, (C + B * 5), S, S)

    def forward(self, x):
        x = self.backbone(x)
        out = self.fcs(torch.flatten(x, start_dim=1))
        out = out.view(self.final_shape)
        return out



In [21]:

class YoloV1Loss(nn.Module):
    """
    Losses:

    - Object exists: lambda_coord * sum((x - xhat)^2 + (y - yhat)^2)
    - Object exists: lambda_coord * sum((sqrt(w) - sqrt(w_hat))^2 + (sqrt(h) - sqrt(h_hat))^2)
    - Object exists: 1 * sum((confidence - confidence_hat)^2)
    - No-object exists: lambda_no_object * sum((confidence - confidence_hat)^2)
    - Object exists: sum((probability(c) - probability(c_hat))^2)

    confidence = IoU
    lambda_coord = 5
    lambda_no_object = 0.5
    """

    def __init__(
        self,
        num_boxes: int,
        num_classes: int,
        lambda_coord: float,
        lambda_object_exists: float,
        lambda_no_object: float,
        lambda_class: float,
    ):
        """
        Find the responsible cell-bbox pairs.

        :param num_boxes: (B)
        :param num_classes: (C)
        """
        super().__init__()

        self.num_boxes = num_boxes
        self.num_classes = num_classes
        self.lambda_coord = lambda_coord
        self.lambda_object_exists = lambda_object_exists
        self.lambda_no_object = lambda_no_object
        self.lambda_class = lambda_class

        self.mse = nn.MSELoss(reduction="none")

    def forward(
        self, preds: torch.Tensor, targets: torch.Tensor
    ) -> Dict[str, torch.Tensor]:
        """
        - Responsible box is the one that has the highest IoU.

        IoUs is a 0-1 tensor of shape (batch, B, S, S)
        Responsibility is an index tensor of shape (batch, S, S)
        object_exists is a 0,1 tensor of shape (batch, 1, S, S)

        :param preds: tensor of shape (batch, (C + B * 5), S, S)
        :param targets: tensor of shape (batch, C+5, S, S)
        :return: a dict of all losses.
        """

        ious = self._get_ious(preds.detach(), targets)  # shape: (batch, B, S, S)
        responsibility = F.one_hot(
            ious.argmax(dim=1), num_classes=self.num_boxes
        )  # shape (batch, S, S, B)
        object_exists = targets[:, self.num_classes]  # shape: (batch, S, S)
        object_not_exists = 1 - object_exists

        coords_loss = self._get_coords_loss(
            preds, targets, object_exists, responsibility
        )
        confidence_loss = self._get_confidence_loss(
            preds, ious, object_exists, responsibility
        )
        negative_confidence_loss = self._get_confidence_loss(
            preds, ious, object_not_exists, responsibility
        )
        class_loss = self._get_class_loss(preds, targets, object_exists)

        final_loss = (
            coords_loss * self.lambda_coord
            + confidence_loss * self.lambda_object_exists
            + negative_confidence_loss * self.lambda_no_object
            + class_loss * self.lambda_class
        )

        return {
            "loss": final_loss,
            "loss_coords": coords_loss.detach(),
            "loss_confidence": confidence_loss.detach(),
            "loss_confidence_negative": negative_confidence_loss.detach(),
            "loss_class": class_loss.detach(),
        }

    def _get_class_loss(self, preds, targets, object_exists):
        c = preds[:, : self.num_classes]  # shape (batch, C, S, S)
        c_hat = targets[:, : self.num_classes]  # shape (batch, C, S, S)

        c_loss = self.mse(c_hat, c)  # shape (batch, C, S, S)
        c_loss = c_loss.sum(dim=1)  # shape (batch, S, S)
        c_loss = object_exists * c_loss
        c_loss = c_loss.sum(dim=(1, 2)).mean(dim=0)
        return c_loss

    def _get_confidence_loss(self, preds, ious, object_exists, responsibility):
        c_losses = []
        for i in range(self.num_boxes):
            c = ious[:, i]  # shape (batch, S, S)
            c_hat = preds[:, self.num_classes + (i * 5)]  # shape (batch, S, S)

            c_loss = self.mse(c_hat, c)
            c_loss = object_exists * responsibility[..., i] * c_loss
            c_loss = c_loss.sum(dim=(1, 2)).mean(dim=0)
            c_losses.append(c_loss)
        c_loss = torch.stack(c_losses).sum(dim=0)

        return c_loss

    def _get_coords_loss(self, preds, targets, object_exists, responsibility):
        x = targets[:, self.num_classes + 1]  # shape (batch, S, S)
        y = targets[:, self.num_classes + 2]
        w = targets[:, self.num_classes + 3]
        h = targets[:, self.num_classes + 4]
        w_sqrt = torch.sqrt(torch.abs(w))
        h_sqrt = torch.sqrt(torch.abs(h))

        coords_losses = []  # shape (B,
        for i in range(self.num_boxes):
            start = self.num_classes + (i * 5)
            x_hat = preds[:, start + 1]  # shape (batch, S, S)
            y_hat = preds[:, start + 2]
            w_hat = preds[:, start + 3]
            h_hat = preds[:, start + 4]
            w_hat_sqrt = torch.sqrt(torch.abs(w_hat))
            h_hat_sqrt = torch.sqrt(torch.abs(h_hat))

            xy_loss = self.mse(x_hat, x) + self.mse(y_hat, y)
            wh_loss = self.mse(w_hat_sqrt, w_sqrt) + self.mse(h_hat_sqrt, h_sqrt)
            coords_loss = object_exists * responsibility[..., i] * (xy_loss + wh_loss)
            coords_loss = coords_loss.sum(dim=(1, 2)).mean(
                dim=0
            )  # average over batch, sum over rest.
            coords_losses.append(coords_loss)
        coords_loss = torch.stack(coords_losses).sum(dim=0)  # sum over B
        return coords_loss

    def _get_ious(self, preds, targets) -> torch.Tensor:
        """
        - When sum(target_[x,y,w,h]) is 0, iou is 0.
        - w_cell, h_cell = 1/S
        - w_image, h_image = 1

        - Get x1, y1, x2, y2 for predicted and target boxes.
            - x1 = midpoint_x - (width / 2)
        - find box iou

        :param preds: tensor of shape (batch, (C + B * 5), S, S)
        :param targets: tensor of shape (batch, C+5, S, S)
        :return: tensor of shape (batch, B, S, S)
        """

        all_coords = []
        for i in range(self.num_boxes):
            start = self.num_classes + (i * 5) + 1
            end = start + 4
            coords = preds[:, start:end]
            all_coords.append(coords)

        coords = targets[:, self.num_classes + 1 :]  # shape: (batch, 1, S, S)
        all_coords.append(coords)

        all_coords = torch.stack(all_coords)  # shape (B+1, batch, 4, S, S)
        all_coords = all_coords.moveaxis(2, 4)  # shape (B+1, batch, S, S, 4)

        x = all_coords[..., 0:1]
        y = all_coords[..., 1:2]
        w = all_coords[..., 2:3]
        h = all_coords[..., 3:4]

        w_half = w / 2
        h_half = h / 2

        x1 = x - w_half
        y1 = y - h_half
        x2 = x + w_half
        y2 = y + h_half

        # x1 is of shape (B+1, batch, S, S, 1)
        coords = torch.cat((x1, y1, x2, y2), dim=4)  # shape (B+1, batch, S, S, 4)

        ious = []
        for i in range(self.num_boxes):
            iou = self.custom_ious(coords[i], coords[-1])  # shape (batch, S, S)
            ious.append(iou)
        ious = torch.stack(ious)  # shape (B, batch, S, S)
        ious = ious.moveaxis(0, 1)  # shape (batch, B, S, S)

        return ious

    def custom_ious(self, boxes1, boxes2) -> torch.Tensor:
        """
        Performs 1 to 1 iou
        :param boxes1: tensor of shape (*N, 4)
        :param boxes2: tensor of shape (*N, 4)
        :return: tensor of shape *N
        """
        assert boxes1.shape == boxes2.shape

        ax1 = boxes1[..., 0]
        ay1 = boxes1[..., 1]
        ax2 = boxes1[..., 2]
        ay2 = boxes1[..., 3]

        bx1 = boxes2[..., 0]
        by1 = boxes2[..., 1]
        bx2 = boxes2[..., 2]
        by2 = boxes2[..., 3]

        x1 = self._max(ax1, bx1)
        y1 = self._max(ay1, by1)
        x2 = self._min(ax2, bx2)
        y2 = self._min(ay2, by2)

        zeros = torch.zeros_like(x1)
        ones = torch.ones_like(x1)

        side_x = self._max(zeros, x2 - x1)
        side_y = self._max(zeros, y2 - y1)

        intersection_area = side_x * side_y

        box1_area = (ax2 - ax1) * (ay2 - ay1)
        box2_area = (bx2 - bx1) * (by2 - by1)

        epsilon = 1e-7
        iou = intersection_area / (box1_area + box2_area - intersection_area + epsilon)
        iou = self._min(ones, iou)  # shape (*N)
        iou[bx2 - bx1 == 0] = 0.0  # Make IoU = 0 when width = 0

        return iou

    def _max(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """Simply finds the max off the two tensors.
        Shapes of the two tensors has to be same.
        """
        return torch.amax(torch.stack([x, y]), dim=0)

    def _min(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """Simply finds the max off the two tensors.
        Shapes of the two tensors has to be same.
        """
        return torch.amin(torch.stack([x, y]), dim=0)


In [22]:

class MyMetricCollection(MetricCollection):
    def update_each(self, params: dict, **kwargs: Any) -> None:
        """params is a dict where key is the metric key and the values are tuples of positional arguments.
        Keyword arguments (kwargs) will be filtered based on the signature of the individual metric.
        """
        for key, m in self.items(keep_base=True):
            if key in params:
                args = params[key]
                if type(args) is not tuple:
                    args = (args,)
                m_kwargs = m._filter_kwargs(**kwargs)
                m.update(*args, **m_kwargs)

    def compute(self):
        result = super().compute()
        self.reset()
        return result


In [23]:

class YoloV1PL(pl.LightningModule):
    def __init__(
        self,
        num_boxes: int,
        num_classes: int,
        in_channels: int,
        grid_size: int,
        lambda_coord: float,
        lambda_object_exists: float,
        lambda_no_object: float,
        lambda_class: float,
        **hp,
    ):
        super().__init__()
        self.hp = hp
        self.yolo_v1 = YoloV1(
            in_channels=in_channels,
            split_size=grid_size,
            num_boxes=num_boxes,
            num_classes=num_classes,
        )
        self.criterion = YoloV1Loss(
            num_boxes=num_boxes,
            num_classes=num_classes,
            lambda_coord=lambda_coord,
            lambda_object_exists=lambda_object_exists,
            lambda_no_object=lambda_no_object,
            lambda_class=lambda_class,
        )

        # --- metrics ---
        self.metrics_train = MyMetricCollection(
            {
                "loss": AverageMeter(),
                "loss_coords": AverageMeter(),
                "loss_confidence": AverageMeter(),
                "loss_confidence_negative": AverageMeter(),
                "loss_class": AverageMeter(),
            },
            prefix="train/",
        )
        self.metrics_val = self.metrics_train.clone(prefix="val/")

    def forward(self, x):
        return self.yolo_v1(x)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hp["lr_initial"])
        lr_scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer,
            step_size=self.hp["lr_decay_every"],
            gamma=self.hp["lr_decay_by"],
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": lr_scheduler,
                "interval": "epoch",
                "frequency": 1,
                "name": "learning_rate",
            },
        }

    def training_step(self, batch, batch_index):
        images, targets = batch
        preds = self(images)
        losses = self.criterion(preds, targets)

        # --- metrics and logging ----
        print(losses)
        self.metrics_train.update_each(losses)
        self.log("train/loss_step", losses["loss"], prog_bar=True)
        if batch_index == 0:
            images_to_log = images[:self.hp['num_log_images']]
            self.logger.experiment.log({'train/predictions': wandb.Image(images_to_log)})

        return losses['loss']

    def training_epoch_end(self, outputs):
        self.log_dict(self.metrics_train.compute())

    def validation_step(self, batch, _index):
        images, targets = batch
        preds = self(images)
        losses = self.criterion(preds, targets)
        return losses

    def validation_step_end(self, losses):
        self.metrics_val.update_each(losses)

    def on_validation_epoch_end(self):
        self.log_dict(self.metrics_val.compute())

In [24]:
# %load_ext tensorboard
# %tensorboard --logdir runs

In [25]:
hp = {
    "epochs": 5,
    "batch_size": 4,
    "lr_initial": 0.0001,
    "lr_decay_every": 20,
    "lr_decay_by": 0.99,
    "grid_size": 7,
    "data_augment": True,
    "num_boxes": 2,
    "lambda_coord": 5,
    "lambda_object_exists": 1,
    "lambda_no_object": 0.5,
    "lambda_class": 1,
}

config = {
    "output_path": "./output",
    "val_split": 0.1,
    "data_path": "./data",
    "num_classes": 20,
    "in_channels": 3,
    "num_log_images": 3,
    "dataloader_num_workers": 0,
    "num_gpus": 0
}

data_module = VocYoloDataModule(**config, **hp)
model = YoloV1PL(**hp, **config).float()
summary(model, (3, 448, 448))
wandb_logger = WandbLogger(project="yolo_test", log_model=False)
trainer = pl.Trainer(
    gpus=config["num_gpus"],
    max_epochs=hp["epochs"],
    default_root_dir=config["output_path"],
    logger=wandb_logger,
)
# wandb_logger.watch(model)

trainer.fit(model, datamodule=data_module)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 224, 224]           9,408
       BatchNorm2d-2         [-1, 64, 224, 224]             128
         LeakyReLU-3         [-1, 64, 224, 224]               0
          CNNBlock-4         [-1, 64, 224, 224]               0
         MaxPool2d-5         [-1, 64, 112, 112]               0
            Conv2d-6        [-1, 194, 112, 112]         111,744
       BatchNorm2d-7        [-1, 194, 112, 112]             388
         LeakyReLU-8        [-1, 194, 112, 112]               0
          CNNBlock-9        [-1, 194, 112, 112]               0
        MaxPool2d-10          [-1, 194, 56, 56]               0
           Conv2d-11          [-1, 128, 56, 56]          24,832
      BatchNorm2d-12          [-1, 128, 56, 56]             256
        LeakyReLU-13          [-1, 128, 56, 56]               0
         CNNBlock-14          [-1, 128,


  | Name          | Type               | Params
-----------------------------------------------------
0 | yolo_v1       | YoloV1             | 2.3 M 
1 | criterion     | YoloV1Loss         | 0     
2 | metrics_train | MyMetricCollection | 0     
3 | metrics_val   | MyMetricCollection | 0     
-----------------------------------------------------
2.3 M     Trainable params
0         Non-trainable params
2.3 M     Total params
9.026     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Training: -1it [00:00, ?it/s]

{'loss': tensor(50.6697, grad_fn=<AddBackward0>), 'loss_coords': tensor(9.1844), 'loss_confidence': tensor(0.0956), 'loss_confidence_negative': tensor(1.8735), 'loss_class': tensor(3.7152)}
{'loss': tensor(83.5870, grad_fn=<AddBackward0>), 'loss_coords': tensor(15.6762), 'loss_confidence': tensor(0.0452), 'loss_confidence_negative': tensor(1.5584), 'loss_class': tensor(4.3815)}
{'loss': tensor(104.3552, grad_fn=<AddBackward0>), 'loss_coords': tensor(19.7143), 'loss_confidence': tensor(0.0996), 'loss_confidence_negative': tensor(2.2591), 'loss_class': tensor(4.5548)}
{'loss': tensor(40.1079, grad_fn=<AddBackward0>), 'loss_coords': tensor(7.3364), 'loss_confidence': tensor(0.1248), 'loss_confidence_negative': tensor(1.2225), 'loss_class': tensor(2.6900)}
{'loss': tensor(100.5612, grad_fn=<AddBackward0>), 'loss_coords': tensor(19.2026), 'loss_confidence': tensor(0.0335), 'loss_confidence_negative': tensor(1.5894), 'loss_class': tensor(3.7200)}


Validating: 0it [00:00, ?it/s]

{'loss': tensor(80.0152, grad_fn=<AddBackward0>), 'loss_coords': tensor(15.1356), 'loss_confidence': tensor(0.0398), 'loss_confidence_negative': tensor(1.9380), 'loss_class': tensor(3.3284)}
{'loss': tensor(73.3488, grad_fn=<AddBackward0>), 'loss_coords': tensor(12.9016), 'loss_confidence': tensor(0.1014), 'loss_confidence_negative': tensor(1.2165), 'loss_class': tensor(8.1309)}
{'loss': tensor(56.9850, grad_fn=<AddBackward0>), 'loss_coords': tensor(10.6492), 'loss_confidence': tensor(0.0288), 'loss_confidence_negative': tensor(1.4065), 'loss_class': tensor(3.0069)}
{'loss': tensor(72.7366, grad_fn=<AddBackward0>), 'loss_coords': tensor(14.0065), 'loss_confidence': tensor(0.0153), 'loss_confidence_negative': tensor(1.3270), 'loss_class': tensor(2.0254)}
{'loss': tensor(70.2779, grad_fn=<AddBackward0>), 'loss_coords': tensor(12.9014), 'loss_confidence': tensor(0.1235), 'loss_confidence_negative': tensor(1.5520), 'loss_class': tensor(4.8713)}


Validating: 0it [00:00, ?it/s]

{'loss': tensor(nan, grad_fn=<AddBackward0>), 'loss_coords': tensor(nan), 'loss_confidence': tensor(nan), 'loss_confidence_negative': tensor(nan), 'loss_class': tensor(nan)}
{'loss': tensor(nan, grad_fn=<AddBackward0>), 'loss_coords': tensor(nan), 'loss_confidence': tensor(nan), 'loss_confidence_negative': tensor(nan), 'loss_class': tensor(nan)}
{'loss': tensor(nan, grad_fn=<AddBackward0>), 'loss_coords': tensor(nan), 'loss_confidence': tensor(nan), 'loss_confidence_negative': tensor(nan), 'loss_class': tensor(nan)}
{'loss': tensor(nan, grad_fn=<AddBackward0>), 'loss_coords': tensor(nan), 'loss_confidence': tensor(nan), 'loss_confidence_negative': tensor(nan), 'loss_class': tensor(nan)}
{'loss': tensor(nan, grad_fn=<AddBackward0>), 'loss_coords': tensor(nan), 'loss_confidence': tensor(nan), 'loss_confidence_negative': tensor(nan), 'loss_class': tensor(nan)}


Validating: 0it [00:00, ?it/s]

{'loss': tensor(nan, grad_fn=<AddBackward0>), 'loss_coords': tensor(nan), 'loss_confidence': tensor(nan), 'loss_confidence_negative': tensor(nan), 'loss_class': tensor(nan)}
{'loss': tensor(nan, grad_fn=<AddBackward0>), 'loss_coords': tensor(nan), 'loss_confidence': tensor(nan), 'loss_confidence_negative': tensor(nan), 'loss_class': tensor(nan)}
{'loss': tensor(nan, grad_fn=<AddBackward0>), 'loss_coords': tensor(nan), 'loss_confidence': tensor(nan), 'loss_confidence_negative': tensor(nan), 'loss_class': tensor(nan)}
{'loss': tensor(nan, grad_fn=<AddBackward0>), 'loss_coords': tensor(nan), 'loss_confidence': tensor(nan), 'loss_confidence_negative': tensor(nan), 'loss_class': tensor(nan)}
{'loss': tensor(nan, grad_fn=<AddBackward0>), 'loss_coords': tensor(nan), 'loss_confidence': tensor(nan), 'loss_confidence_negative': tensor(nan), 'loss_class': tensor(nan)}


Validating: 0it [00:00, ?it/s]

{'loss': tensor(nan, grad_fn=<AddBackward0>), 'loss_coords': tensor(nan), 'loss_confidence': tensor(nan), 'loss_confidence_negative': tensor(nan), 'loss_class': tensor(nan)}
{'loss': tensor(nan, grad_fn=<AddBackward0>), 'loss_coords': tensor(nan), 'loss_confidence': tensor(nan), 'loss_confidence_negative': tensor(nan), 'loss_class': tensor(nan)}
{'loss': tensor(nan, grad_fn=<AddBackward0>), 'loss_coords': tensor(nan), 'loss_confidence': tensor(nan), 'loss_confidence_negative': tensor(nan), 'loss_class': tensor(nan)}
{'loss': tensor(nan, grad_fn=<AddBackward0>), 'loss_coords': tensor(nan), 'loss_confidence': tensor(nan), 'loss_confidence_negative': tensor(nan), 'loss_class': tensor(nan)}
{'loss': tensor(nan, grad_fn=<AddBackward0>), 'loss_coords': tensor(nan), 'loss_confidence': tensor(nan), 'loss_confidence_negative': tensor(nan), 'loss_class': tensor(nan)}


Validating: 0it [00:00, ?it/s]

In [22]:
import torch
h, w = 448, 448
grid_rows, grid_cols = 7, 7
cell_h = h / grid_rows
cell_w = w / grid_cols

cell_h, cell_w

(64.0, 64.0)

In [23]:
cols = torch.arange(start=0, end=w, step=cell_w).view((1, grid_cols)).expand((grid_rows, grid_cols))
rows = torch.arange(start=0, end=h, step=cell_h).view((grid_rows, 1)).expand((grid_rows, grid_cols))
origins = torch.stack((rows, cols))
origins, origins.shape

(tensor([[[  0.,   0.,   0.,   0.,   0.,   0.,   0.],
          [ 64.,  64.,  64.,  64.,  64.,  64.,  64.],
          [128., 128., 128., 128., 128., 128., 128.],
          [192., 192., 192., 192., 192., 192., 192.],
          [256., 256., 256., 256., 256., 256., 256.],
          [320., 320., 320., 320., 320., 320., 320.],
          [384., 384., 384., 384., 384., 384., 384.]],
 
         [[  0.,  64., 128., 192., 256., 320., 384.],
          [  0.,  64., 128., 192., 256., 320., 384.],
          [  0.,  64., 128., 192., 256., 320., 384.],
          [  0.,  64., 128., 192., 256., 320., 384.],
          [  0.,  64., 128., 192., 256., 320., 384.],
          [  0.,  64., 128., 192., 256., 320., 384.],
          [  0.,  64., 128., 192., 256., 320., 384.]]]),
 torch.Size([2, 7, 7]))

In [28]:
cols[:, [1, 4]]

tensor([[ 64., 256.],
        [ 64., 256.],
        [ 64., 256.],
        [ 64., 256.],
        [ 64., 256.],
        [ 64., 256.],
        [ 64., 256.]])

In [33]:
num_boxes = 2
num_classes = 20

idx = [], [], [], []
idx_x, idx_y, idx_w, idx_h = idx

for i in range(num_boxes):
    for j in range(4):
        idx[j].append(num_classes + (i * 5) + j + 1)

idx

([21, 26], [22, 27], [23, 28], [24, 29])

In [31]:
x = 19 + 1 + 1
y = x + 1
w = y + 1
h = w + 1
x2 = h + 1 + 1
y2 = x2 + 1
w2 = y2 + 1
h2 = w2 + 1

([x, x2], [y, y2], [w, w2], [h, h2])

([21, 26], [22, 27], [23, 28], [24, 29])