In [16]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from collections import Counter

In [17]:
def intersection_over_union(boxes_preds, boxes_labels, box_format="midpoint"):
    """
    Calculates intersection over union
    Parameters:
        boxes_preds (tensor): Predictions of Bounding Boxes (BATCH_SIZE, 4)
        boxes_labels (tensor): Correct labels of Bounding Boxes (BATCH_SIZE, 4)
        box_format (str): midpoint/corners, if boxes (x,y,w,h) or (x1,y1,x2,y2)
    Returns:
        tensor: Intersection over union for all examples
    """

    if box_format == "midpoint":
        box1_x1 = boxes_preds[..., 0:1] - boxes_preds[..., 2:3] / 2
        box1_y1 = boxes_preds[..., 1:2] - boxes_preds[..., 3:4] / 2
        box1_x2 = boxes_preds[..., 0:1] + boxes_preds[..., 2:3] / 2
        box1_y2 = boxes_preds[..., 1:2] + boxes_preds[..., 3:4] / 2
        box2_x1 = boxes_labels[..., 0:1] - boxes_labels[..., 2:3] / 2
        box2_y1 = boxes_labels[..., 1:2] - boxes_labels[..., 3:4] / 2
        box2_x2 = boxes_labels[..., 0:1] + boxes_labels[..., 2:3] / 2
        box2_y2 = boxes_labels[..., 1:2] + boxes_labels[..., 3:4] / 2

    if box_format == "corners":
        box1_x1 = boxes_preds[..., 0:1]
        box1_y1 = boxes_preds[..., 1:2]
        box1_x2 = boxes_preds[..., 2:3]
        box1_y2 = boxes_preds[..., 3:4]  # (N, 1)
        box2_x1 = boxes_labels[..., 0:1]
        box2_y1 = boxes_labels[..., 1:2]
        box2_x2 = boxes_labels[..., 2:3]
        box2_y2 = boxes_labels[..., 3:4]

    x1 = torch.max(box1_x1, box2_x1)
    y1 = torch.max(box1_y1, box2_y1)
    x2 = torch.min(box1_x2, box2_x2)
    y2 = torch.min(box1_y2, box2_y2)

    # .clamp(0) is for the case when they do not intersect
    intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)

    box1_area = abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1))
    box2_area = abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1))

    return intersection / (box1_area + box2_area - intersection + 1e-6)

In [18]:
from PIL import Image, ImageDraw, ImageFont
from torchvision.transforms.functional import to_tensor, to_pil_image

In [19]:
def get_bboxes(
    loader,
    model,
    iou_threshold,
    threshold,
    pred_format="cells",
    box_format="midpoint",
    device="cuda",
):
  all_pred_boxes = []
  all_true_boxes = []
  
  model.eval()
  train_index = 0

  for batch_index, (x, labels) in enumerate(loader):
    x = x.to(device)
    labels = labels.to(device)

    with torch.no_grad():
      predictions = model(x)
    
    batch_size = x.shape[0]
    #true_bboxes = cellboxes_to_boxes(labels)
    true_bboxes = cellboxes_to_boxes(torch.concat([labels, labels[..., 20:25]], dim=-1))
    bboxes = cellboxes_to_boxes(predictions)

    for index in range(batch_size):
      nms_boxes = non_max_suppression(
          bboxes[index],
          iou_threshold=iou_threshold,
          threshold=threshold,
          box_format=box_format
      )

      for nms_box in nms_boxes:
        all_pred_boxes.append([train_index] + nms_box)
      for box in true_bboxes[index]:
        if box[1] > threshold:
          all_true_boxes.append([train_index] + box)
      train_index += 1
  model.train()
  return all_pred_boxes, all_true_boxes


In [20]:
def convert_cellboxes(predictions, S=7):
  predictions = predictions.to("cpu")
  batch_size = predictions.shape[0]
  predictions = predictions.reshape(batch_size, S, S, 30)
  bboxes1 = predictions[..., 21:25]
  bboxes2 = predictions[..., 26:30]
  scores = torch.cat(
      (predictions[..., 20].unsqueeze(0), predictions[..., 25].unsqueeze(0)), dim=0
  )
  best_box = scores.argmax(0).unsqueeze(-1)
  best_boxes = bboxes1 * (1 - best_box) + best_box * bboxes2            # (Batch, 7, 7, 4)
  cell_indices = torch.arange(7).repeat(batch_size, 7, 1).unsqueeze(-1) # (Batch, 7, 7, 1)
  x = 1 / S * (best_boxes[..., :1] + cell_indices)                      # x, y : 0. ~ 1.
  y = 1 / S * (best_boxes[..., 1:2] + cell_indices.permute(0, 2, 1, 3)) 
  w_h = 1 / S * best_boxes[..., 2:4]
  converted_bboxes = torch.cat((x, y, w_h), dim=-1) # (Batch, 7, 7, 4)
  predicted_class = predictions[..., :20].argmax(-1).unsqueeze(-1)  # (Batch, 7, 7, 1)
  best_confidence = torch.max(predictions[..., 20], predictions[..., 25]).unsqueeze(-1)

  converted_preds = torch.cat((predicted_class, best_confidence, converted_bboxes), dim=-1)

  return converted_preds  # (Batch, 7, 7, 6)


In [21]:
def cellboxes_to_boxes(out, S=7):
  converted_pred = convert_cellboxes(out).reshape(out.shape[0], S * S, -1)  # (Batch, 7 * 7, 6)
  converted_pred[..., 0] = converted_pred[..., 0].long()
  all_bboxes = []

  for batch_index in range(out.shape[0]):
    bboxes = []
    for bbox_index in range(S * S):
      bboxes.append([x.item() for x in converted_pred[batch_index, bbox_index, :]])
    all_bboxes.append(bboxes)
  return all_bboxes # list: (Batch, S * S, 6)

In [22]:
test_input = torch.randn(16, 7, 7, 30)
ret = cellboxes_to_boxes(test_input)

In [23]:
def non_max_suppression(bboxes, iou_threshold, threshold, box_format="corners"):
  # iou_threshold = 0.5, threshold = 0.4
  assert type(bboxes) == list

  bboxes = [box for box in bboxes if box[1] > threshold]
  bboxes = sorted(bboxes, key=lambda x: x[1], reverse=True)
  bboxes_after_nms = []  # result

  while bboxes:
    cur_box = bboxes.pop(0)
    bboxes = [
              box
              for box in bboxes
              if box[0] != cur_box[0]
              or intersection_over_union(
                  torch.tensor(cur_box[2:]),
                  torch.tensor(box[2:]),
                  box_format=box_format
              )
              < iou_threshold
    ]
    bboxes_after_nms.append(cur_box)
  return bboxes_after_nms

In [24]:
from collections import Counter

def mean_average_precision(
    pred_boxes, true_boxes, iou_threshold=0.5, box_format="midpoint", num_classes=20
):
  """
  Calculates mean average precision 
  Parameters:
      pred_boxes (list): list of lists containing all bboxes with each bboxes
      specified as [train_idx, class_prediction, prob_score, x1, y1, x2, y2]
      true_boxes (list): Similar as pred_boxes except all the correct ones 
      iou_threshold (float): threshold where predicted bboxes is correct
      box_format (str): "midpoint" or "corners" used to specify bboxes
      num_classes (int): number of classes
  Returns:
      float: mAP value across all classes given a specific IoU threshold 
  """

  # list storing all AP for respective classes
  average_precisions = []

  # used for numerical stability later on
  epsilon = 1e-6

  for c in range(num_classes):
      detections = []
      ground_truths = []

      # Go through all predictions and targets,
      # and only add the ones that belong to the
      # current class c
      for detection in pred_boxes:
          if detection[1] == c:
              detections.append(detection)

      for true_box in true_boxes:
          if true_box[1] == c:
              ground_truths.append(true_box)

      # find the amount of bboxes for each training example
      # Counter here finds how many ground truth bboxes we get
      # for each training example, so let's say img 0 has 3,
      # img 1 has 5 then we will obtain a dictionary with:
      # amount_bboxes = {0:3, 1:5}
      amount_bboxes = Counter([gt[0] for gt in ground_truths])

      # We then go through each key, val in this dictionary
      # and convert to the following (w.r.t same example):
      # ammount_bboxes = {0:torch.tensor[0,0,0], 1:torch.tensor[0,0,0,0,0]}
      for key, val in amount_bboxes.items():
          amount_bboxes[key] = torch.zeros(val)

      # sort by box probabilities which is index 2
      detections.sort(key=lambda x: x[2], reverse=True)
      TP = torch.zeros((len(detections)))
      FP = torch.zeros((len(detections)))
      total_true_bboxes = len(ground_truths)
      
      # If none exists for this class then we can safely skip
      if total_true_bboxes == 0:
          continue

      for detection_idx, detection in enumerate(detections):
          # Only take out the ground_truths that have the same
          # training idx as detection
          ground_truth_img = [
              bbox for bbox in ground_truths if bbox[0] == detection[0]
          ]

          num_gts = len(ground_truth_img)
          best_iou = 0

          for idx, gt in enumerate(ground_truth_img):
              iou = intersection_over_union(
                  torch.tensor(detection[3:]),
                  torch.tensor(gt[3:]),
                  box_format=box_format,
              )

              if iou > best_iou:
                  best_iou = iou
                  best_gt_idx = idx

          if best_iou > iou_threshold:
              # only detect ground truth detection once
              if amount_bboxes[detection[0]][best_gt_idx] == 0:
                  # true positive and add this bounding box to seen
                  TP[detection_idx] = 1
                  amount_bboxes[detection[0]][best_gt_idx] = 1
              else:
                  FP[detection_idx] = 1

          # if IOU is lower then the detection is a false positive
          else:
              FP[detection_idx] = 1

      TP_cumsum = torch.cumsum(TP, dim=0)
      FP_cumsum = torch.cumsum(FP, dim=0)
      recalls = TP_cumsum / (total_true_bboxes + epsilon)
      precisions = torch.divide(TP_cumsum, (TP_cumsum + FP_cumsum + epsilon))
      precisions = torch.cat((torch.tensor([1]), precisions))
      recalls = torch.cat((torch.tensor([0]), recalls))
      # torch.trapz for numerical integration
      average_precisions.append(torch.trapz(precisions, recalls))

  return sum(average_precisions) / len(average_precisions)


In [63]:
def plot_image(image, boxes):
  img = to_pil_image(image)
  draw = ImageDraw.Draw(img)
  W, H = img.size

  for box in boxes: # box : tensor(6,) [class, prob_score, center_x, center_y, w, h]
    color = np.random.randint(0, 255, size=(3,), dtype="uint8").tolist()
    x_converted, y_converted, width_converted, height_converted = (
        box[2] * W,
        box[3] * H,
        box[4] * W,
        box[5] * H
    )
    left_top, right_bot = (
        ((x_converted - width_converted / 2), (y_converted + height_converted / 2)),
        ((x_converted + width_converted / 2), (y_converted - height_converted / 2))
    )
    draw.rectangle((left_top, right_bot), outline=tuple(color), width=3)
    draw.text((x_converted - width_converted/2, y_converted - height_converted/2),\
              classes[int(box[0])], fill=(255, 255, 255, 0))
  plt.figure(figsize=(8, 8))
  plt.imshow(img)
  plt.show()

In [26]:
def showImage(img, label_matrix, C=7):
  # img:tensor (7,7,3)
  # label_matrix:tensor (7, 7, 25) or (7, 7, 30)
  img = to_pil_image(img)
  draw = ImageDraw.Draw(img)
  W, H = img.size
  
  cell_size_i, cell_size_j = img.size[1] / C, img.size[0] / C
  for i in range(C):
    for j in range(C):
      if label_matrix.size(-1) == 25 and label_matrix[20] == 1:
        color = np.random.randint(0, 255, size=(3,), dtype="uint8").tolist()
        x_converted, y_converted, width_converted, height_converted = (
          cell_size_j * (j + label_matrix[21]),
          cell_size_i * (i + label_matrix[22]),
          cell_size_j * label_matrix[23],
          cell_size_i * label_matrix[24],
        )

        left_top = (x_converted - width_converted / 2), (y_converted + height_converted / 2)
        right_bot = (x_converted + width_converted / 2), (y_converted - height_converted / 2)
        draw.rectangle((left_top, right_bot), outline=tuple(color), width=3)
  plt.figure(figsize=(15, 15))
  plt.imshow(img)
