In [1]:
import torch
from torch import nn

In [2]:
architecture_config = [
    (7 , 64 , 2 , 3),
    "M",
    (3 , 192 , 1 , 1),
    "M",
    (1 , 128 , 1 , 0),
    (3 , 256 , 1 , 1),
    (1 , 256 , 1 , 0),
    (3 , 512 , 1 , 1),
    "M",
    [(1 , 256 , 1 , 0),(3 , 512 ,1 ,1),4],
    (1 , 512 , 1 , 0),
    (3 , 1024 , 1 , 1),
    "M",
    [(1,512,1,0),(3,1024,1,1),2],
    (3,1024,1,1),
    (3,1024,2,1),
    (3,1024,1,1),
    (3,1024,1,1)
]

In [3]:
class CNNBlock(nn.Module):
    def __init__(self , in_channels , out_channels , **kwargs):
        super(CNNBlock , self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.conv = nn.Conv2d(in_channels , out_channels , bias = True , **kwargs)
        self.batchnorm = nn.BatchNorm2d(out_channels)
        self.leakyrelu = nn.LeakyReLU(0.1)
    def forward(self , x):
        return self.leakyrelu(self.batchnorm(self.conv(x)))
class Yolo(nn.Module):
    def __init__(self , in_channels = 3, **kwargs):
        super(Yolo , self).__init__()
        self.architecture = architecture_config
        self.in_channels = in_channels
        self.darknet = self._create_conv_layers(self.architecture)
        self.fcs = self._create_fcs(**kwargs)
    def forward(self , x):
        x = self.darknet(x)
        return self.fcs(torch.flatten(x , start_dim = 1))
    def _create_conv_layers(self , architecture):
        layers = []
        in_channels = self.in_channels
        for x in architecture:
            if type(x) == tuple:
                layers+=[CNNBlock(in_channels , out_channels = x[1] , kernel_size = x[0] , stride = x[2] , 
                                 padding = x[3])]
                in_channels = x[1]
            elif type(x) == str:
                layers+=[nn.MaxPool2d(kernel_size = 2 , stride = 2)]
            elif type(x) == list:
                conv1 = x[0]
                conv2 = x[1]
                num_repeats = x[2]
                for _ in range(num_repeats):
                    layers += [CNNBlock(in_channels , out_channels = conv1[1] , kernel_size = conv1[0] ,
                                       stride = conv1[2] , padding = conv1[3])]
                    layers += [CNNBlock(in_channels = conv1[1],out_channels = conv2[1] , kernel_size = conv2[0],
                                       stride = conv2[2] , padding = conv2[3])]
                    in_channels = conv2[1]
        return nn.Sequential(*layers)
    def _create_fcs(self , split_size , num_boxes , num_classes):
        s , b , c = split_size , num_boxes , num_classes
        return nn.Sequential(
            nn.Flatten(),
            nn.Linear(1024*s*s  , 4096),
            nn.Dropout(0),
            nn.LeakyReLU(0.1),
            nn.Linear(4096 , s*s*(c+b*5))
        )

In [4]:
def test(s = 7  , b = 2 , c = 20):
    model = Yolo(split_size = s , num_boxes = b , num_classes = c)
    x = torch.randn((2,3,448,448))
    print(model(x).shape)
test()

torch.Size([2, 1470])


In [5]:
def intersection_over_union(boxes_preds , boxes_labels , box_format = "midpoint"):
    """
    Calculates the intersection over union

    Parameters:
              boxes_preds(tensor) : Predictions of bounding boxes (batch_size , 4)
              boxes_labels(tensor): Correct labels of bouding boxes(batch_size , 4)
              box_format(str) : midpoint/corners , if boxes(x,y,w,h) or (x1,y1,x2,y2)
    Returns:
            tensor:Intersection over union for all examples
    """
    if box_format == "midpoint":
        box1_x1 = boxes_preds[...,0:1] - boxes_preds[...,2:3]/2
        box1_y1 = boxes_preds[...,1:2] - boxes_preds[...,3:4]/2
        box1_x2 = boxes_preds[...,0:1] + boxes_preds[...,2:3]/2
        box1_y2 = boxes_preds[...,1:2] + boxes_preds[...,3:4]/2
        box2_x1 = boxes_labels[...,0:1] - boxes_preds[...,2:3]/2
        box2_y1 = boxes_labels[...,1:2] - boxes_preds[...,3:4]/2
        box2_x2 = boxes_labels[...,0:1] + boxes_preds[...,2:3]/2
        box2_y2 = boxes_labels[...,1:2] + boxes_preds[...,3:4]/2

    elif box_format == "corners":
        box1_x1 = boxes_preds[...,0:1] #... = all previous dimension
        box1_y1 = boxes_preds[...,1:2]
        box1_x2 = boxes_preds[...,2:3]
        box1_y2 = boxes_preds[...,3:4]
        box2_x1 = boxes_labels[...,0:1]
        box2_y1 = boxes_labels[...,1:2]
        box2_x2 = boxes_labels[...,2:3]
        box2_y2 = boxes_labels[...,3:4]

    x1 = torch.max(box1_x1 , box2_x1)
    y1 = torch.max(box1_y1 , box2_y1)
    x2 = torch.min(box1_x2 , box2_x2)
    y2 = torch.min(box1_y2 , box2_y2)
    #.clamp(0) in case the two boxes don't intersect , if not then the intersection should be 0
    intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0) #Height * Width
    #Union
    box1_area = abs((box1_x2 - box1_x1)*(box1_y1 - box1_y2))
    box2_area = abs((box2_x2 - box2_x1)*(box2_y1 - box2_y2))
    return intersection / (box1_area + box2_area - intersection + 1e-6)

In [6]:
class YoloLoss(nn.Module):
    def __init__(self , s = 7 , b = 2 , c = 20):
        super(YoloLoss , self).__init__()
        self.mse = nn.MSELoss(reduction = 'sum')
        self.s = s
        self.b = b 
        self.c = c
        self.lambda_noobj = 0.5
        self.lambda_coord = 5
    def forward(self , predictions , targets):
        predictions = predictions.reshape(-1 , self.s , self.s , self.c + self.b*5)
        iou_b1 = intersection_over_union(predictions[...,21:25] , target[...,21:25])
        iou_b2 = intersection_over_union(predictions[...,26:30] , target[...,21:25])
        ious = torch.cat([iou_b1.unsqueeze(0) , iou_b2.unsqueeze(0)], dim = 0)
        iou_maxes , bestbox = torch.max(ious , dim = 0)
        exists_box = target[...,20].unsqueeze(3)
    
        # Box Coordinates #
        #=================#
        box_predictions = exists_box*((
            bestbox*predictions[...,26:30]+(1-bestbox) * predictions[...,21:25]
        ))
        box_targets = exists_box*target[...,21:25]
        box_predictions[...,2:4] = torch.sign(box_predictions[...,2:4])*torch.sqrt(
            torch.abs(box_predictions[...,2:4]+1e-6)
        )
        box_targets[...,2:4] = torch.sqrt(box_targets[...,2:4])
        box_loss = self.mse(
            torch.flatten(box_predictions , end_dim = -2),
            torch.flatten(box_targets , end_dim = -2)
        )
        
        # For Object Loss if there is an object detected
        #==============================================#
        pred_box = (bestbox*predictions[...,25:26]+(1-bestbox)*predictions[...,20:21])
        object_loss = self.mse(torch.flatten(exists_box*pred_box),
                              torch.flatten(exists_box*target[...,20:21]))
        
        # For No Object Loss if there is no object detected
        #==================================================#
        no_object_loss = self.mse(torch.flatten((1-exists_box)*predictions[...,20:21], start_dim = 1),
                                 torch.flatten((1-exists_box)*target[...,20:21] , start_dim = 1))
        no_object_loss += self.mse(torch.flatten((1-exists_box)*predictions[...,25:26], start_dim = 1),
                                 torch.flatten((1-exists_box)*target[...,20:21] , start_dim = 1))
       
        # For Class Loss #
        #================#
        class_loss = self.mse(torch.flatten(exists_box*predictions[...,:20] , end_dim = -2),
                             torch.flatten(exists_box*target[...,:20], end_dim = -2))
        loss = (self.lambda_coord*box_loss   #First two rows of loss in paper
               +object_loss
               +self.lambda_noobj * no_object_loss
               +class_loss) 
        return loss