<a href="https://colab.research.google.com/github/AchrafAsh/ml_projects/blob/main/image_detection_yolo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from collections import Counter
import torch
import torch.nn

In [None]:
def intersection_over_union(box_preds, box_labels, box_format="midpoint"):
    """
    Calculates the intersection over union

    Parameters:
        box_preds (tensor): Predictions of Bounding boxes (BATCH_SIZE, 4)
        box_labels (tensor): Correct labels of Bounding boxes (BATCH_SIZE, 4)
        box_format (str): midpoint/corners, if boxes (x, y, w, h) or (x1, y1, x2, y2)
    """

    if box_format == "midpoint":
        box1_x1 = box_preds[..., 0:1] - box_preds[..., 2:3] / 2
        box1_y1 = box_preds[..., 1:2] - box_preds[..., 3:4] / 2
        box1_x2 = box_preds[..., 2:3] + box_preds[..., 2:3] / 2
        box1_y2 = box_preds[..., 3:4] + box_preds[..., 3:4] / 2

        box2_x1 = box_labels[..., 0:1] - box_labels[..., 2:3] / 2
        box2_y1 = box_labels[..., 1:2] - box_labels[..., 3:4] / 2
        box2_x2 = box_labels[..., 2:3] + box_labels[..., 2:3] / 2
        box2_y2 = box_labels[..., 3:4] + box_labels[..., 3:4] / 2

    elif box_format == "corners":
        box1_x1 = box_preds[..., 0:1]
        box1_y1 = box_preds[..., 1:2]
        box1_x2 = box_preds[..., 2:3]
        box1_y2 = box_preds[..., 3:4]

        box2_x1 = box_labels[..., 0:1]
        box2_y1 = box_labels[..., 1:2]
        box2_x2 = box_labels[..., 2:3]
        box2_y2 = box_labels[..., 3:4]
    
    x1 = torch.max(box1_x1, box2_x1)
    y1 = torch.max(box1_y1, box2_y1)
    x2 = torch.min(box1_x2, box2_x2)
    y2 = torch.min(box1_y2, box2_y2)

    intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0) # clamp for when the intersection is empty

    box1_area = abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1))
    box2_area = abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1))
    union = box1_area + box2_area - intersection

    return intersection / (union + 1e-6) # 1e-6 for stability

In [None]:
def non_max_suppression(box_preds, iou_threshold, 
                        confidence_threshold, box_format="corners"):
    # box_preds = [[class, confidence, x1, y1, x2, y2], [], ...]
    assert type(box_preds) == list
    
    bboxes = [box for box in box_preds if box[1] > confidence_threshold]
    bboxes = sorted(bboxes, keys=lambda x: x[1], reverse=True)

    bboxes_after_nms = []
    while bboxes:
        chosen_box = bboxes.pop(0)
        bboxes = [box for box in bboxes 
                  if box[0] != chosen_box[0] 
                  or intersection_over_union(torch.tensor(chosen_box[2:]),
                                             torch.tensor(box[2:]),
                                             box_format=box_format)) 
                  < iou_threshold]
        bboxes_after_nms.append(chosen_box)
    
    return bboxes_after_nms

In [None]:
# Mean Average Precision mAP
def mean_average_precision(box_preds, box_labels, iou_threshold=0.5,
                           box_format="corners", num_classes=20):
    # box_preds = [[train_idx, class_pred, confidence, x1, y1, x2, y2], ...]
    average_precisions = []

    for c in range(num_classes):
        detections = []
        ground_truths = []

        for detection in box_preds:
            if detection[1] == c: detections.append(detection)
            
        for true_box in box_labels:
            if true_box[1] == c: ground_truths.append(true_box)
        
        amount_bboxes = Counter([gt[0] for gt in ground_truths])

        for key, val in amount_bboxes.items():
            amount_bboxes[key] = torch.zeros(val)
        
        detections.sort(key=lambda x: x[2], reverse=True)
        TP = torch.zeros((len(detections)))
        FP = torch.zeros((len(detections)))
        total_true_bboxes = len(ground_truths)

        for detection_idx, detection in enumerate(detections):
            ground_truth_img = [
                bbox for bbox in ground_truths if bbox[0] == detection[0]
            ]

            num_gts = len(ground_truth_img)
            best_iou = 0

            for idx, gt in enumerate(ground_truth_img):
                iou = intersection_over_union(
                    torch.tensor(detection[3:]),
                    torch.tensor(gt[3:]),
                    box_format=box_format
                )

                if iou > best_iou:
                    best_iou = iou
                    best_gt_idx = idx

            if best_iou > iou_threshold:
                if amount_bboxes[detection[0]][best_gt_idx] == 0:
                    TP[detection_idx] = 1
                    amount_bboxes[detection[0]][best_gt_idx] = 1
                else:
                    FP[detection_idx] = 1
                
            else:
                FP[detection_idx] = 1

        TP_cumsum = torch.cumsum(TP, dim=0)
        FP_cumsum = torch.cumsum(FP, dim=0)

        recalls = TP_cumsum / (total_true_bboxes + 1e-6)
        precisions = torch.divide(TP_cumsum, (TP_cumsum + FP_cumsum + 1e-6))
        # Add the origin to compute the area below the graph precisions = f(recalls)
        precisions = torch.cat(torch.tensor([1]), precisions)
        recalls = torch.cat(torch.tensor([0]), recalls)

        average_precisions.append(torch.trapz(precisions, recalls))
    
    return sum(average_precisions) / len(average_precisions)