In [5]:
import torch
import torch.nn as nn

# 1- Model Architecture

In [6]:
architecture_config = [
    (7, 64, 2, 3), # kernel_size, filters, stride, padding
    "M", # Maxpooling
    (3, 192, 1, 1),
    "M",
    (1, 128, 1, 0),
    (3, 256, 1, 1),
    (1, 256, 1, 0),
    (3, 512, 1, 1),
    "M",
    [(1, 256, 1, 0), (3, 512, 1, 1), 4], # 4 is the number of repetitions
    (1, 512, 1, 0),
    (3, 1024, 1, 1),
    "M",
    [(1, 512, 1, 0), (3, 1024, 1, 1), 2],
    (3, 1024, 1, 1),
    (3, 1024, 2, 1),
    (3, 1024, 1, 1),
    (3, 1024, 1, 1),
]

In [7]:
class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super().__init__()
        
        self.conv = nn.Conv2d(in_channels, out_channels, bias = False, **kwargs)
        self.batchnorm = nn.BatchNorm2d(out_channels)
        self.leakyrelu = nn.LeakyReLU(0.1)
    
    def forward(self, x):
        return self.leakyrelu(self.batchnorm(self.conv(x)))

In [8]:
class YOLOv1(nn.Module):
    def __init__(self, in_channels = 3, **kwargs):
        super().__init__()

        # architecture_config is a list of tuples, strings, and lists
        self.architecture = architecture_config
        # in_channels is the number of channels of the input image
        self.in_channels = in_channels
        # device
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # create the layers
        self.darknet = self._create_conv_layers(self.architecture)
        # create the fully connected layers
        self.fcs = self._create_fcs(**kwargs)


    def forward(self, x):
        x = self.darknet(x.to(self.device))
        return self.fcs(torch.flatten(x, start_dim= 1))
    
    def _create_conv_layers(self, architecture):
        # empty list to store the layers
        layers = []
        # in_channels is the number of channels of the input image
        in_channels = self.in_channels

        for x in architecture:
            if type(x) == tuple:
                # add a convolutional layer
                layers += [CNNBlock(in_channels, out_channels=x[1], kernel_size = x[0], stride = x[2], padding = x[3])]
                # update in_channels
                in_channels = x[1]
            if type(x) == str:
                # add a maxpooling layer
                layers += [nn.MaxPool2d(kernel_size=2, stride = 2)]
            if type(x) == list:
                # This is a list contains 2 tuples and 1 int (number of repetitions)
                conv1 = x[0]
                conv2 = x[1]
                num_repeats = x[2]
                for _ in range(num_repeats):
                    layers += [CNNBlock(in_channels, conv1[1], kernel_size = conv1[0], stride = conv1[2], padding = conv1[3])]
                    layers += [CNNBlock(conv1[1], conv2[1], kernel_size = conv2[0], stride = conv2[2], padding = conv2[3])]
                    # update in_channels
                    in_channels = conv2[1]

        return nn.Sequential(*layers).to(self.device)

    
    def _create_fcs(self, split_size, num_boxes, num_classes):
        """
        split_size: The size of the image divided by the number of cells
        num_boxes: The number of bounding boxes per cell
        num_classes: The number of classes
        """
        S, B, C = split_size, num_boxes, num_classes
        return nn.Sequential(
            # Flattens the input into a vector
            nn.Flatten(), 
            nn.Linear(1024 * S * S, 496), # Originally 4096
            nn.Dropout(0.0),
            nn.LeakyReLU(0.1),
            nn.Linear(496, S * S * (C + B * 5)) # ()
        ).to(self.device)


In [12]:
def test(split_size = 7, num_boxes = 2, num_classes = 20):
    model = YOLOv1(split_size = split_size, num_boxes = num_boxes, num_classes = num_classes)
    x = torch.randn((2, 3, 448, 448)) # (batch_size, channels, height, width)
    print(model(x).shape)
test()

torch.Size([2, 1470])


# 2- YOLO Loss

In [None]:
from utils import intersection_over_union

In [None]:
class YOLOLoss(nn.Module):
    def __init__(self, S = 7, B = 2, C = 20):
        super().__init__()

        self.S = S
        self.B = B
        self.C = C

        # mean squared error loss (sum reduction)
        self.mse = nn.MSELoss(reduction="sum")

        # set the weights
        self.lambda_noobj = 0.5
        self.lambda_coord = 5
    
    def forward(self, predictions, target):
        # predictions: (batch_size, S * S * (C + B * 5))
        predictions = predictions.reshape(-1, self.S, self.S, self.C + self.B * 5)

        # calculate IoU for the two bounding boxes
        iou_b1 = intersection_over_union(predictions[..., 21:25], target[..., 21:25])
        iou_b2 = intersection_over_union(predictions[..., 26:30], target[..., 21:25])

        # concatenate the IoU values
        ious = torch.cat([iou_b1.unsqueeze(0), iou_b2.unsqueeze(0)], dim = 0)
        # get the max IoU values along the first dimension (the two bounding boxes)
        iou_maxes, best_box = torch.max(ious, dim = 0)
        # I_obj_i (1 if the object exists in the cell, 0 otherwise)
        exists_box = target[..., 20].unsqueeze(3) # (batch_size, S, S, 1)


        # ======== For Box Coordinates ======== #
        # (best_box * predictions[..., 26:30] + (1 - best_box) * predictions[..., 21:25]) selects best box
        box_predictions = exists_box * (
            best_box * predictions[..., 26:30] + (1 - best_box) * predictions[..., 21:25]
        )
        box_targets = exists_box * target[..., 21:25]
        
        # take the square root of the width and height
        box_predictions[..., 2:4] = torch.sign(box_predictions[..., 2:4] * torch.sqrt(torch.abs(box_predictions[..., 2:4]) + 1e-6))
        box_targets[..., 2:4] = torch.sqrt(box_targets[..., 2:4])

        # (N, S, S, 4) -> (N * S * S, 4)
        box_loss = self.mse(
            torch.flatten(box_predictions, end_dim = -2),
            torch.flatten(box_targets, end_dim = -2)
        )

        # ======== For Object Loss ======== # 
        # predictions[..., 25:26] or predictions[..., 20:21] is the confidence score of the best bbox 
        pred_box = (
            best_box * predictions[..., 25:26] + (1 - best_box) * predictions[..., 20:21]
        )
        # (N * S * S, 1) -> (N, S, S, 1)
        object_loss = self.mse(
            torch.flatten(exists_box * pred_box),
            torch.flatten(exists_box * target[..., 20:21]))

        # ======== For No Object Loss ======== #
        # take the loss for both boxes -> both of them should know there's no object
        # (N, S, S, 1) -> (N, S * S)
        # box 1
        no_object_loss = self.mse(
            torch.flatten((1 - exists_box) * predictions[..., 20:21], start_dim = 1),
            torch.flatten((1 - exists_box) * target[..., 20:21], start_dim = 1)
        )
        # box 2
        no_object_loss += self.mse(
            torch.flatten((1 - exists_box) * predictions[..., 25:26], start_dim = 1),
            torch.flatten((1 - exists_box) * target[..., 20:21], start_dim = 1)
        )

        # ======== For Class Loss ======== #
        # (N, S, S, 20) -> (N * S * S, 20)
        class_loss = self.mse(
            torch.flatten(exists_box * predictions[..., :20], end_dim = -2),
            torch.flatten(exists_box * target[..., :20], end_dim = -2)
        )

        # ======== Final Loss ======== #
        loss = (
            self.lambda_coord * box_loss # first two rows in the paper
            + object_loss # third row in the paper
            + self.lambda_noobj * no_object_loss # fourth row in the paper
            + class_loss # fifth row in the paper
        )

        return loss

# 3- Util functions