<a href="https://colab.research.google.com/github/AchrafAsh/ml_projects/blob/main/image_detection_yolo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [58]:
from collections import Counter
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
def intersection_over_union(box_preds, box_labels, box_format="midpoint"):
    """
    Calculates the intersection over union

    Parameters:
        box_preds (tensor): Predictions of Bounding boxes (BATCH_SIZE, 4)
        box_labels (tensor): Correct labels of Bounding boxes (BATCH_SIZE, 4)
        box_format (str): midpoint/corners, if boxes (x, y, w, h) or (x1, y1, x2, y2)
    """

    if box_format == "midpoint":
        box1_x1 = box_preds[..., 0:1] - box_preds[..., 2:3] / 2
        box1_y1 = box_preds[..., 1:2] - box_preds[..., 3:4] / 2
        box1_x2 = box_preds[..., 2:3] + box_preds[..., 2:3] / 2
        box1_y2 = box_preds[..., 3:4] + box_preds[..., 3:4] / 2

        box2_x1 = box_labels[..., 0:1] - box_labels[..., 2:3] / 2
        box2_y1 = box_labels[..., 1:2] - box_labels[..., 3:4] / 2
        box2_x2 = box_labels[..., 2:3] + box_labels[..., 2:3] / 2
        box2_y2 = box_labels[..., 3:4] + box_labels[..., 3:4] / 2

    elif box_format == "corners":
        box1_x1 = box_preds[..., 0:1]
        box1_y1 = box_preds[..., 1:2]
        box1_x2 = box_preds[..., 2:3]
        box1_y2 = box_preds[..., 3:4]

        box2_x1 = box_labels[..., 0:1]
        box2_y1 = box_labels[..., 1:2]
        box2_x2 = box_labels[..., 2:3]
        box2_y2 = box_labels[..., 3:4]
    
    x1 = torch.max(box1_x1, box2_x1)
    y1 = torch.max(box1_y1, box2_y1)
    x2 = torch.min(box1_x2, box2_x2)
    y2 = torch.min(box1_y2, box2_y2)

    intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0) # clamp for when the intersection is empty

    box1_area = abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1))
    box2_area = abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1))
    union = box1_area + box2_area - intersection

    return intersection / (union + 1e-6) # 1e-6 for stability

In [5]:
def non_max_suppression(box_preds, iou_threshold, 
                        confidence_threshold, box_format="corners"):
    # box_preds = [[class, confidence, x1, y1, x2, y2], [], ...]
    assert type(box_preds) == list
    
    bboxes = [box for box in box_preds if box[1] > confidence_threshold]
    bboxes = sorted(bboxes, keys=lambda x: x[1], reverse=True)

    bboxes_after_nms = []
    while bboxes:
        chosen_box = bboxes.pop(0)
        bboxes = [box for box in bboxes 
                  if box[0] != chosen_box[0] 
                  or intersection_over_union(torch.tensor(chosen_box[2:]),
                                             torch.tensor(box[2:]),
                                             box_format=box_format)
                  < iou_threshold]
        bboxes_after_nms.append(chosen_box)
    
    return bboxes_after_nms

In [6]:
# Mean Average Precision mAP
def mean_average_precision(box_preds, box_labels, iou_threshold=0.5,
                           box_format="corners", num_classes=20):
    # box_preds = [[train_idx, class_pred, confidence, x1, y1, x2, y2], ...]
    average_precisions = []

    for c in range(num_classes):
        detections = []
        ground_truths = []

        for detection in box_preds:
            if detection[1] == c: detections.append(detection)
            
        for true_box in box_labels:
            if true_box[1] == c: ground_truths.append(true_box)
        
        amount_bboxes = Counter([gt[0] for gt in ground_truths])

        for key, val in amount_bboxes.items():
            amount_bboxes[key] = torch.zeros(val)
        
        detections.sort(key=lambda x: x[2], reverse=True)
        TP = torch.zeros((len(detections)))
        FP = torch.zeros((len(detections)))
        total_true_bboxes = len(ground_truths)

        for detection_idx, detection in enumerate(detections):
            ground_truth_img = [
                bbox for bbox in ground_truths if bbox[0] == detection[0]
            ]

            num_gts = len(ground_truth_img)
            best_iou = 0

            for idx, gt in enumerate(ground_truth_img):
                iou = intersection_over_union(
                    torch.tensor(detection[3:]),
                    torch.tensor(gt[3:]),
                    box_format=box_format
                )

                if iou > best_iou:
                    best_iou = iou
                    best_gt_idx = idx

            if best_iou > iou_threshold:
                if amount_bboxes[detection[0]][best_gt_idx] == 0:
                    TP[detection_idx] = 1
                    amount_bboxes[detection[0]][best_gt_idx] = 1
                else:
                    FP[detection_idx] = 1
                
            else:
                FP[detection_idx] = 1

        TP_cumsum = torch.cumsum(TP, dim=0)
        FP_cumsum = torch.cumsum(FP, dim=0)

        recalls = TP_cumsum / (total_true_bboxes + 1e-6)
        precisions = torch.divide(TP_cumsum, (TP_cumsum + FP_cumsum + 1e-6))
        # Add the origin to compute the area below the graph precisions = f(recalls)
        precisions = torch.cat(torch.tensor([1]), precisions)
        recalls = torch.cat(torch.tensor([0]), recalls)

        average_precisions.append(torch.trapz(precisions, recalls))
    
    return sum(average_precisions) / len(average_precisions)

In [49]:
architecture_config = [
    (7, 64, 2, 3),
    "M",
    (3, 192, 1, 1),
    "M",
    (1, 128, 1, 0),
    (3, 256, 1, 1),
    (1, 256, 1, 0),
    (3, 512, 1, 1),
    "M",
    [(1, 256, 1, 0), (3, 512, 1, 1), 4],
    (1, 512, 1, 0),
    (3, 1024, 1, 1),
    "M",
    [(1, 512, 1, 0), (3, 1024, 1, 1), 2],
    (3, 1024, 1, 1),
    (3, 1024, 2, 1),
    (3, 1024, 1, 1),
    (3, 1024, 1, 1),
]

In [50]:
class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super(CNNBlock, self).__init__()
        self.cnn = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.batch_norm = nn.BatchNorm2d(out_channels)
        self.leaky_relu = nn.LeakyReLU(0.1)
    
    def forward(self, x):
        return self.leaky_relu(self.batch_norm(self.cnn(x)))

In [54]:
class Yolo(nn.Module):
    def __init__(self, in_channels, split_size, num_boxes, num_classes):
        super(Yolo, self).__init__()
        self.architecture = architecture_config
        self.in_channels = in_channels
        self.darknet = self._create_conv_layers(self.architecture)
        self.fcs = self._create_fcs(split_size, num_boxes, num_classes)
    
    def _create_conv_layers(self, architecture):
        layers = []
        in_channels = self.in_channels

        for x in architecture:
            if type(x) == tuple: 
                layers += [CNNBlock(in_channels, 
                                    out_channels=x[1], 
                                    kernel_size=x[0], 
                                    stride=x[2], 
                                    padding=x[3])]
                in_channels = x[1]

            elif type(x) == str: 
                layers+= [nn.MaxPool2d(kernel_size=(2,2), stride=(2,2))]

            elif type(x) == list:
                conv1 = x[0]
                conv2 = x[1]
                num_repeats = x[2]

                for _ in range(num_repeats):
                    layers += [CNNBlock(in_channels, 
                                        conv1[1],
                                        kernel_size=conv1[0],
                                        stride=conv1[2],
                                        padding=conv1[3])]
                    in_channels = conv1[1]
                    
                    layers += [CNNBlock(in_channels, 
                                        conv2[1],
                                        kernel_size=conv2[0],
                                        stride=conv2[2],
                                        padding=conv2[3])]
                    
                    in_channels = conv2[1]
        
        return nn.Sequential(*layers)

    def _create_fcs(self, split_size, num_boxes, num_classes):
        S, B, C = split_size, num_boxes, num_classes
        return nn.Sequential(nn.Flatten(), 
                             nn.Linear(1024 * S * S, 496),
                             nn.Dropout(0.0),
                             nn.LeakyReLU(0.1),
                             nn.Linear(496, S*S*(C+B * 5)))

    def forward(self, x):
        x = self.darknet(x)
        return self.fcs(torch.flatten(x, start_dim=1))

In [55]:
def test(in_channels=3, split_size=7, num_boxes=2, num_classes=20):
    model = Yolo(in_channels, split_size, num_boxes, num_classes)
    x = torch.randn((2, 3, 448, 448))
    print(model(x).shape)

In [56]:
class YoloLoss(nn.Module):
    def __init__(self, split_size=7, num_boxes=2, num_classes=20):
        super(YoloLoss, self).__init__()
        self.mse = nn.MSELoss(reduction="sum")
        self.split_size = split_size
        self.num_boxes = num_boxes
        self.num_classes = num_classes
        self.lambda_noobj = .5
        self.lambda_coord = 5
    
    def forward(self, preds, target):
        preds = preds.reshape(-1, self.split_size, self.split_size,
                              self.num_classes + self.num_boxes*5)
        iou_b1 = intersection_over_union(preds[..., 21:25], target[..., 21:25])
        iou_b2 = intersection_over_union(preds[..., 26:30], target[..., 21:25])
        ious = torch.cat([iou_b1.unsqueeze(0), iou_b2.unsqueeze(0)], dim=0)

        iou_maxes, best_box = torch.max(ious, dim=0)
        exists_box = target[..., 20].unsqueeze(3)

torch.Size([2, 1470])
