In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

S x S grid

predicitons = S * S * ( B * 5 + C )

24 convs + 2 fc

pretrain conv on ImageNet

input image = 448 * 448

output = 7 * 7 * 30

In [2]:
class ConvLayers(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=192, kernel_size=7, stride=2, padding=3)  
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv2 = nn.Conv2d(in_channels=192, out_channels=256, kernel_size=3, padding=1)  
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv3 = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=128, kernel_size=1),
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),  
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=1),
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1)  
        )
        self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv3_5 = nn.Sequential(
            *[nn.Sequential(
                nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1),
                nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1)  
            ) for _ in range(4)]
        )
        
        self.conv4 = nn.Sequential(
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=1),
            nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, padding=1)  
        )
        self.maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv4_5 = nn.Sequential(
            *[nn.Sequential(
                nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=1),
                nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, padding=1)  
            ) for _ in range(2)]
        )
        
        # The upper layers are used for pretraining on ImageNet
        
        self.conv5 = nn.Sequential(
            nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, padding=1),  
            nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, stride=2, padding=1)  
        )
        
        self.conv6 = nn.Sequential(
            nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, padding=1),  
            nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, padding=1)  
        )
    
    def preTrain(self, x):
        x = self.maxpool1(F.leaky_relu(self.conv1(x)))
        x = self.maxpool2(F.leaky_relu(self.conv2(x)))
        x = self.maxpool3(F.leaky_relu(self.conv3(x)))
        x = self.maxpool4(F.leaky_relu(self.conv4(x)))
        
        return x
    
    def forward(self, x):
        x = self.maxpool1(F.leaky_relu(self.conv1(x)))
        x = self.maxpool2(F.leaky_relu(self.conv2(x)))
        x = self.maxpool3(F.leaky_relu(self.conv3(x)))
        x = self.maxpool4(F.leaky_relu(self.conv4(x)))
        x = F.leaky_relu(self.conv5(x))
        x = F.leaky_relu(self.conv6(x))
        
        return x

In [3]:
class FCLayers(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.fc1 = nn.Linear(7*7*1024, 4096)
        self.dropout1 = nn.Dropout(0.5)
        self.fc2 = nn.Linear(4096, 7*7*30)
        
    def forward(self, x):
        x = x.flatten(start_dim=1)
        x = F.gelu(self.fc1(x))
        x = self.dropout1(x)
        x = self.fc2(x)
        return x

In [40]:
class YOLOv1(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.convLayers = ConvLayers()
        self.fcLayers = FCLayers()
        
        self.S = 7
        self.B = 2
        self.C = 20
        self.coord_scale = 5
        self.noobj_scale = 0.5
        
    def forward(self, x):
        x = self.convLayers(x)
        x = self.fcLayers(x)
        x = x.view(-1, self.S, self.S, self.B * 5 + self.C)
        return x
    
    def loss_fn(self, predictions, targets):
        pred_bboxes, pred_confs, pred_classes = self.destruct(predictions)
        true_bboxes, true_confs, true_classes = self.destruct(targets)

        iou = self.findIOU(pred_bboxes, true_bboxes)
        best_iou_mask = self.get_best_iou_mask(iou)

        obj_mask = true_confs > 0  # Mask for cells containing objects
        no_obj_mask = ~obj_mask   # Mask for cells without objects

        # Responsible bounding box predictions
        responsible_mask = torch.zeros_like(pred_confs, dtype=torch.bool)
        responsible_mask.scatter_(-1, best_iou_mask, 1)

        # Coordinate loss (only for responsible predictors and object cells)
        xy_loss = F.mse_loss(
            pred_bboxes[..., :2][obj_mask & responsible_mask],
            true_bboxes[..., :2][obj_mask & responsible_mask],
            reduction='sum'
        )
        wh_loss = F.mse_loss(
            torch.sqrt(pred_bboxes[..., 2:4][obj_mask & responsible_mask] + 1e-6),
            torch.sqrt(true_bboxes[..., 2:4][obj_mask & responsible_mask] + 1e-6),
            reduction='sum'
        )
        bbox_loss = self.coord_scale * (xy_loss + wh_loss)

        obj_conf_loss = F.mse_loss(
            pred_confs[obj_mask & responsible_mask],
            true_confs[obj_mask & responsible_mask],
            reduction='sum'
        )
        
        no_obj_conf_loss = F.mse_loss(
            pred_confs[no_obj_mask],
            true_confs[no_obj_mask],
            reduction='sum'
        )
        conf_loss = obj_conf_loss + self.noobj_scale * no_obj_conf_loss

        
        grid_obj_mask = obj_mask.any(dim=-1)

        class_loss = F.mse_loss(
            pred_classes[grid_obj_mask],
            true_classes[grid_obj_mask],
            reduction='sum'
        )

        total_loss = bbox_loss + conf_loss + class_loss
        return total_loss / predictions.shape[0] 
    
    def destruct(self, x, B=None):
        if B is None:
            B = self.B 
        bboxes_and_confs = x[..., :B * 5].view(-1, self.S, self.S, B, 5)
        bboxes = bboxes_and_confs[..., :4]
        confs = bboxes_and_confs[..., 4]
        classes = x[..., -self.C:]
        
        return bboxes, confs, classes
    
    def findIOU(self, pred_bboxes, true_bboxes):
        pred_tl, pred_br = self.bbox_to_coords(pred_bboxes)
        true_tl, true_br = self.bbox_to_coords(true_bboxes)
        
        inter_tl = torch.max(pred_tl, true_tl)
        inter_br = torch.min(pred_br, true_br)
        inter_wh = torch.clamp(inter_br - inter_tl, min=0)
        intersection = inter_wh[..., 0] * inter_wh[..., 1]

        pred_area = (pred_br[..., 0] - pred_tl[..., 0]) * (pred_br[..., 1] - pred_tl[..., 1])
        true_area = (true_br[..., 0] - true_tl[..., 0]) * (true_br[..., 1] - true_tl[..., 1])
        union = pred_area + true_area - intersection

        return intersection / (union + 1e-6)

    def bbox_to_coords(self, bboxes):
        x, y, w, h = bboxes[..., 0], bboxes[..., 1], bboxes[..., 2], bboxes[..., 3]
        tl = torch.stack((x - w / 2, y - h / 2), dim=-1)
        br = torch.stack((x + w / 2, y + h / 2), dim=-1)
        return tl, br

    def get_best_iou_mask(self, iou):
        best_iou_idx = torch.argmax(iou, dim=-1)
        return best_iou_idx.unsqueeze(-1)

In [41]:
model = YOLOv1()
# x = torch.randn(1, 3, 448, 448)
# x = model(x)
# print(x.shape)

y = torch.randn(1, 7, 7, 25)
# loss = model.loss_fn(x, y)
bb, c, cc = model.destruct(y, 1)
print(c.shape)
print(cc.shape)

torch.Size([1, 7, 7, 1])
torch.Size([1, 7, 7, 20])
