In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data.dataloader import DataLoader
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.anchor_utils import AnchorGenerator

import glob
import os
import random
import cv2
import numpy as np
from PIL import Image
from tqdm import tqdm
from torch.utils.data.dataset import Dataset
import xml.etree.ElementTree as ET

torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [2]:
dataset_params = {
  "im_train_path": "VOC2007/JPEGImages",
  "ann_train_path": "VOC2007/Annotations",
  "im_test_path": "VOC2007-test/JPEGImages",
  "ann_test_path": "VOC2007-test/Annotations",
  "num_classes": 21,
}

train_params = {
  "task_name": "voc",
  "seed": 1111,
  "acc_steps": 1, # Increase you want to get gradients from >1 steps(kind of mimicking >1 batch size)
  "num_epochs": 20,
  "lr_steps" : [12, 16],
  "lr": 0.001,
  "ckpt_name": "faster_rcnn_voc2007.pth",
}

In [3]:
from google.colab import drive
drive.mount("/content/drive")

drive_tar_dir = "/content/drive/MyDrive/Machine Learning/advanced-deep-learning/object-detection/R-CNN/"

Mounted at /content/drive


In [4]:
!mkdir -p /content/VOC2007
!mkdir -p /content/VOC2007-test

In [5]:
!tar -xvf "{drive_tar_dir}VOCtrainval_06-Nov-2007.tar" \
    -C /content/VOC2007 \
    --strip-components=2 \
    VOCdevkit/VOC2007/JPEGImages \
    VOCdevkit/VOC2007/Annotations

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
VOCdevkit/VOC2007/JPEGImages/000026.jpg
VOCdevkit/VOC2007/JPEGImages/000030.jpg
VOCdevkit/VOC2007/JPEGImages/000032.jpg
VOCdevkit/VOC2007/JPEGImages/000033.jpg
VOCdevkit/VOC2007/JPEGImages/000034.jpg
VOCdevkit/VOC2007/JPEGImages/000035.jpg
VOCdevkit/VOC2007/JPEGImages/000036.jpg
VOCdevkit/VOC2007/JPEGImages/000039.jpg
VOCdevkit/VOC2007/JPEGImages/000041.jpg
VOCdevkit/VOC2007/JPEGImages/000042.jpg
VOCdevkit/VOC2007/JPEGImages/000044.jpg
VOCdevkit/VOC2007/JPEGImages/000046.jpg
VOCdevkit/VOC2007/JPEGImages/000047.jpg
VOCdevkit/VOC2007/JPEGImages/000048.jpg
VOCdevkit/VOC2007/JPEGImages/000050.jpg
VOCdevkit/VOC2007/JPEGImages/000051.jpg
VOCdevkit/VOC2007/JPEGImages/000052.jpg
VOCdevkit/VOC2007/JPEGImages/000060.jpg
VOCdevkit/VOC2007/JPEGImages/000061.jpg
VOCdevkit/VOC2007/JPEGImages/000063.jpg
VOCdevkit/VOC2007/JPEGImages/000064.jpg
VOCdevkit/VOC2007/JPEGImages/000065.jpg
VOCdevkit/VOC2007/JPEGImages/000066.jpg
VOCdevkit/VOC20

In [6]:
!tar -xvf "{drive_tar_dir}VOCtest_06-Nov-2007.tar" \
    -C /content/VOC2007-test \
    --strip-components=2 \
    VOCdevkit/VOC2007/JPEGImages \
    VOCdevkit/VOC2007/Annotations

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
VOCdevkit/VOC2007/Annotations/009875.xml
VOCdevkit/VOC2007/Annotations/009876.xml
VOCdevkit/VOC2007/Annotations/009883.xml
VOCdevkit/VOC2007/Annotations/009885.xml
VOCdevkit/VOC2007/Annotations/009888.xml
VOCdevkit/VOC2007/Annotations/009889.xml
VOCdevkit/VOC2007/Annotations/009890.xml
VOCdevkit/VOC2007/Annotations/009891.xml
VOCdevkit/VOC2007/Annotations/009892.xml
VOCdevkit/VOC2007/Annotations/009893.xml
VOCdevkit/VOC2007/Annotations/009895.xml
VOCdevkit/VOC2007/Annotations/009899.xml
VOCdevkit/VOC2007/Annotations/009901.xml
VOCdevkit/VOC2007/Annotations/009903.xml
VOCdevkit/VOC2007/Annotations/009906.xml
VOCdevkit/VOC2007/Annotations/009907.xml
VOCdevkit/VOC2007/Annotations/009909.xml
VOCdevkit/VOC2007/Annotations/009910.xml
VOCdevkit/VOC2007/Annotations/009912.xml
VOCdevkit/VOC2007/Annotations/009914.xml
VOCdevkit/VOC2007/Annotations/009915.xml
VOCdevkit/VOC2007/Annotations/009916.xml
VOCdevkit/VOC2007/Annotations/009

In [7]:
def collate_function(data):
  return tuple(zip(*data))

In [8]:
def load_images_and_anns(im_dir, ann_dir, label2idx):
    """
    Method to get the xml files.
    For each file get all the objects and their ground truth detection information for the dataset

    im_dir: Path of the images
    ann_dir: Path of annotation xmlfiles
    label2idx: Class Name to index mapping for dataset
    """
    im_infos = []
    for ann_file in tqdm(glob.glob(os.path.join(ann_dir, "*.xml"))):
      im_info = {}

      im_info["img_id"] = os.path.basename(ann_file).split(".xml")[0]
      im_info["filename"] = os.path.join(im_dir, f"{im_info['img_id']}.jpg")

      ann_info = ET.parse(ann_file)
      root = ann_info.getroot()
      size = root.find("size")

      width = int(size.find("width").text)
      height = int(size.find("height").text)

      im_info["width"] = width
      im_info["height"] = height

      detections = []
      for obj in ann_info.findall("object"):
        det = {}
        label = label2idx[obj.find("name").text]
        bbox_info = obj.find("bndbox")
        bbox = [
            int(float(bbox_info.find("xmin").text))-1,
            int(float(bbox_info.find("ymin").text))-1,
            int(float(bbox_info.find("xmax").text))-1,
            int(float(bbox_info.find("ymax").text))-1
        ]
        det["label"] = label
        det["bbox"] = bbox
        detections.append(det)

      im_info["detections"] = detections
      im_infos.append(im_info)

    print(f"Total {len(im_infos)} images found")
    return im_infos

In [9]:
class VOCDataset(Dataset):
  def __init__(self, split, im_dir, ann_dir):
    self.split = split
    self.im_dir = im_dir
    self.ann_dir = ann_dir

    classes = [
        'person', 'bird', 'cat', 'cow', 'dog', 'horse', 'sheep',
        'aeroplane', 'bicycle', 'boat', 'bus', 'car', 'motorbike', 'train',
        'bottle', 'chair', 'diningtable', 'pottedplant', 'sofa', 'tvmonitor'
    ]

    classes = sorted(classes)
    classes = ["background"] + classes

    self.label2idx = {classes[idx]: idx for idx in range(len(classes))}
    self.idx2label = {idx: classes[idx] for idx in range(len(classes))}
    print(self.idx2label)

    self.images_info = load_images_and_anns(im_dir, ann_dir, self.label2idx)

  def __len__(self):
    return len(self.images_info)

  def __getitem__(self, index):
    im_info = self.images_info[index]
    im = Image.open(im_info["filename"])
    to_flip = False

    # Horizontal Flip Image Augmentation
    if self.split == "train" and random.random() < 0.5:
      to_flip = True
      im = im.transpose(Image.FLIP_LEFT_RIGHT)

    im_tensor = torchvision.transforms.ToTensor()(im)
    targets = {}

    targets["bboxes"] = torch.as_tensor([detection["bbox"] for detection in im_info["detections"]])
    targets["labels"] = torch.as_tensor([detection["label"] for detection in im_info["detections"]])

    if to_flip:
      for idx, box in enumerate(targets["bboxes"]):
        x1, y1, x2, y2 = box
        w = x2 - x1
        im_w = im_tensor.shape[-1]
        x1 = im_w - x1 - w
        x2 = x1 + w
        targets["bboxes"][idx] = torch.as_tensor([x1, y1, x2, y2])

    return im_tensor, targets, im_info["filename"]

In [10]:
def train(use_resnet50_fpn=True):
  dataset_config = dataset_params
  train_config = train_params

  # Set the random seed
  seed = train_config["seed"]
  torch.manual_seed(seed)
  np.random.seed(seed)
  random.seed(seed)
  if device.type == "cuda":
    torch.cuda.manual_seed_all(seed)

  voc = VOCDataset("train",
                   im_dir=dataset_config["im_train_path"],
                   ann_dir=dataset_config["ann_train_path"])

  train_dataset = DataLoader(voc,
                             batch_size=4,
                             shuffle=True,
                             num_workers=os.cpu_count(),
                             pin_memory=True,
                             collate_fn=collate_function)

  if use_resnet50_fpn:
    faster_rcnn_model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True,
                                                                             min_size=600,
                                                                             max_size=1000)

    # Initialize the final classificaiton and bounding box regression layers in the ROI Head

    # Keep the input features the same
    # Change num_classes=21 because trained on COCO dataset with different number of classes than VOC
    faster_rcnn_model.roi_heads.box_predictor = FastRCNNPredictor(faster_rcnn_model.roi_heads.box_predictor.cls_score.in_features, num_classes=21)
  else:
    # Pre-trained ResNet34 backbone with new RPN and ROI
    backbone = torchvision.models.resnet34(pretrained=True, norm_layer=torchvision.ops.FrozenBatchNorm2d)

    # Remove last 3 layers to have the final stride of 16 similar to VGG16
    backbone = nn.Sequential(*list(backbone.children())[:-3])
    backbone.out_channels = 256

    roi_align = torchvision.ops.MultiScaleRoIAlign(featmap_names=["0"], output_size=7, sampling_ratio=2)
    rpn_anchor_generator = AnchorGenerator() # By default 3 sizes=(128, 256, 512) and 3 aspect_ratios=(0.5, 1.0, 2.0)
    faster_rcnn_model = torchvision.models.detection.FasterRCNN(backbone,
                                                                num_classes=21,
                                                                min_size=600,
                                                                max_size=1000,
                                                                rpn_anchor_generator=rpn_anchor_generator,
                                                                rpn_pre_nms_top_n_train=12000,
                                                                rpn_pre_nms_top_n_test=6000,
                                                                box_batch_size_per_image=128,
                                                                rpn_post_nms_top_n_test=300)

  faster_rcnn_model.train()
  faster_rcnn_model.to(device)

  if not os.path.exists(train_config["task_name"]):
    os.mkdir(train_config["task_name"])

  optimizer = torch.optim.SGD(lr=1E-4,
                              params=filter(lambda p: p.requires_grad, faster_rcnn_model.parameters()),
                              weight_decay=5E-5, momentum=0.9)

  num_epochs = train_config["num_epochs"]
  step_count = 0

  # Resume training from a saved checkpoint if it exists
  path_base = "tv_frcnn_r50fpn_" if use_resnet50_fpn else "tv_frcnn_"
  checkpoint_path = os.path.join(train_config["task_name"], path_base + train_config["ckpt_name"])
  start_epoch = 0

  if os.path.exists(checkpoint_path):
    print("Resuming training from checkpoint...")

    checkpoint = torch.load(checkpoint_path, map_location=device)

    faster_rcnn_model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    start_epoch = checkpoint["epoch"]
    step_count = checkpoint["step_count"]

  optimizer.zero_grad()

  for i in range(start_epoch, num_epochs):
    rpn_classification_losses = []
    rpn_localization_losses = []
    frcnn_classification_losses = []
    frcnn_localization_losses = []

    for ims, targets, _ in tqdm(train_dataset):
      optimizer.zero_grad()

      for target in targets:
        target["boxes"] = target["bboxes"].float().to(device)
        del target["bboxes"]
        target["labels"] = target["labels"].long().to(device)

      images = [im.float().to(device) for im in ims]
      batch_losses = faster_rcnn_model(images, targets)

      loss = batch_losses["loss_classifier"]
      loss += batch_losses["loss_box_reg"]
      loss += batch_losses["loss_rpn_box_reg"]
      loss += batch_losses["loss_objectness"]

      rpn_classification_losses.append(batch_losses["loss_objectness"].item())
      rpn_localization_losses.append(batch_losses["loss_rpn_box_reg"].item())
      frcnn_classification_losses.append(batch_losses["loss_classifier"].item())
      frcnn_localization_losses.append(batch_losses["loss_box_reg"].item())

      loss.backward()
      optimizer.step()
      step_count += 1

    print(f"Finished epoch {i}")

    checkpoint = {
        "epoch": i,
        "step_count": step_count,
        "model_state_dict": faster_rcnn_model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
    }
    path_base = "tv_frcnn_r50fpn_" if use_resnet50_fpn else "tv_frcnn_"

    torch.save(checkpoint, os.path.join(train_config["task_name"], path_base + train_config["ckpt_name"]))

    loss_output = ""
    loss_output += f"RPN Classification Loss : {np.mean(rpn_classification_losses):.4f}"
    loss_output += f" | RPN Localization Loss : {np.mean(rpn_localization_losses):.4f}"
    loss_output += f" | FRCNN Classification Loss : {np.mean(frcnn_classification_losses):.4f}"
    loss_output += f" | FRCNN Localization Loss : {np.mean(frcnn_localization_losses):.4f}"
    print(loss_output)

  print("Done Training...")


train(use_resnet50_fpn=True)

{0: 'background', 1: 'aeroplane', 2: 'bicycle', 3: 'bird', 4: 'boat', 5: 'bottle', 6: 'bus', 7: 'car', 8: 'cat', 9: 'chair', 10: 'cow', 11: 'diningtable', 12: 'dog', 13: 'horse', 14: 'motorbike', 15: 'person', 16: 'pottedplant', 17: 'sheep', 18: 'sofa', 19: 'train', 20: 'tvmonitor'}


100%|██████████| 5011/5011 [00:00<00:00, 12322.38it/s]


Total 5011 images found


Downloading: "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth" to /root/.cache/torch/hub/checkpoints/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth
100%|██████████| 160M/160M [00:00<00:00, 216MB/s]
100%|██████████| 1253/1253 [03:41<00:00,  5.65it/s]


Finished epoch 0
RPN Classification Loss : 0.0173 | RPN Localization Loss : 0.0096 | FRCNN Classification Loss : 0.7306 | FRCNN Localization Loss : 0.3567


100%|██████████| 1253/1253 [02:47<00:00,  7.46it/s]


Finished epoch 1
RPN Classification Loss : 0.0154 | RPN Localization Loss : 0.0099 | FRCNN Classification Loss : 0.5018 | FRCNN Localization Loss : 0.2930


100%|██████████| 1253/1253 [02:42<00:00,  7.71it/s]


Finished epoch 2
RPN Classification Loss : 0.0139 | RPN Localization Loss : 0.0097 | FRCNN Classification Loss : 0.4107 | FRCNN Localization Loss : 0.2486


100%|██████████| 1253/1253 [02:38<00:00,  7.92it/s]


Finished epoch 3
RPN Classification Loss : 0.0120 | RPN Localization Loss : 0.0093 | FRCNN Classification Loss : 0.3668 | FRCNN Localization Loss : 0.2206


100%|██████████| 1253/1253 [02:39<00:00,  7.85it/s]


Finished epoch 4
RPN Classification Loss : 0.0111 | RPN Localization Loss : 0.0091 | FRCNN Classification Loss : 0.3416 | FRCNN Localization Loss : 0.2030


100%|██████████| 1253/1253 [02:37<00:00,  7.97it/s]


Finished epoch 5
RPN Classification Loss : 0.0106 | RPN Localization Loss : 0.0089 | FRCNN Classification Loss : 0.3254 | FRCNN Localization Loss : 0.1919


100%|██████████| 1253/1253 [02:37<00:00,  7.97it/s]


Finished epoch 6
RPN Classification Loss : 0.0099 | RPN Localization Loss : 0.0087 | FRCNN Classification Loss : 0.3131 | FRCNN Localization Loss : 0.1840


100%|██████████| 1253/1253 [02:36<00:00,  7.98it/s]


Finished epoch 7
RPN Classification Loss : 0.0094 | RPN Localization Loss : 0.0086 | FRCNN Classification Loss : 0.3033 | FRCNN Localization Loss : 0.1779


100%|██████████| 1253/1253 [02:41<00:00,  7.78it/s]


Finished epoch 8
RPN Classification Loss : 0.0090 | RPN Localization Loss : 0.0085 | FRCNN Classification Loss : 0.2954 | FRCNN Localization Loss : 0.1732


100%|██████████| 1253/1253 [02:33<00:00,  8.15it/s]


Finished epoch 9
RPN Classification Loss : 0.0085 | RPN Localization Loss : 0.0084 | FRCNN Classification Loss : 0.2886 | FRCNN Localization Loss : 0.1691


100%|██████████| 1253/1253 [02:37<00:00,  7.98it/s]


Finished epoch 10
RPN Classification Loss : 0.0081 | RPN Localization Loss : 0.0083 | FRCNN Classification Loss : 0.2817 | FRCNN Localization Loss : 0.1652


100%|██████████| 1253/1253 [02:38<00:00,  7.91it/s]


Finished epoch 11
RPN Classification Loss : 0.0078 | RPN Localization Loss : 0.0082 | FRCNN Classification Loss : 0.2765 | FRCNN Localization Loss : 0.1625


100%|██████████| 1253/1253 [02:35<00:00,  8.06it/s]


Finished epoch 12
RPN Classification Loss : 0.0074 | RPN Localization Loss : 0.0082 | FRCNN Classification Loss : 0.2714 | FRCNN Localization Loss : 0.1595


100%|██████████| 1253/1253 [02:35<00:00,  8.07it/s]


Finished epoch 13
RPN Classification Loss : 0.0072 | RPN Localization Loss : 0.0080 | FRCNN Classification Loss : 0.2658 | FRCNN Localization Loss : 0.1565


100%|██████████| 1253/1253 [02:36<00:00,  8.02it/s]


Finished epoch 14
RPN Classification Loss : 0.0069 | RPN Localization Loss : 0.0080 | FRCNN Classification Loss : 0.2618 | FRCNN Localization Loss : 0.1545


100%|██████████| 1253/1253 [02:36<00:00,  8.01it/s]


Finished epoch 15
RPN Classification Loss : 0.0065 | RPN Localization Loss : 0.0080 | FRCNN Classification Loss : 0.2582 | FRCNN Localization Loss : 0.1522


100%|██████████| 1253/1253 [02:34<00:00,  8.10it/s]


Finished epoch 16
RPN Classification Loss : 0.0066 | RPN Localization Loss : 0.0079 | FRCNN Classification Loss : 0.2549 | FRCNN Localization Loss : 0.1503


100%|██████████| 1253/1253 [02:33<00:00,  8.14it/s]


Finished epoch 17
RPN Classification Loss : 0.0065 | RPN Localization Loss : 0.0078 | FRCNN Classification Loss : 0.2515 | FRCNN Localization Loss : 0.1487


100%|██████████| 1253/1253 [02:34<00:00,  8.09it/s]


Finished epoch 18
RPN Classification Loss : 0.0060 | RPN Localization Loss : 0.0077 | FRCNN Classification Loss : 0.2476 | FRCNN Localization Loss : 0.1468


100%|██████████| 1253/1253 [02:34<00:00,  8.10it/s]


Finished epoch 19
RPN Classification Loss : 0.0061 | RPN Localization Loss : 0.0076 | FRCNN Classification Loss : 0.2452 | FRCNN Localization Loss : 0.1453
Done Training...


In [11]:
def get_iou(det, gt):
  det_x1, det_y1, det_x2, det_y2 = det
  gt_x1, gt_y1, gt_x2, gt_y2 = gt

  x_left = max(det_x1, gt_x1)
  y_top = max(det_y1, gt_y1)
  x_right = min(det_x2, gt_x2)
  y_bottom = min(det_y2, gt_y2)

  if x_right < x_left or y_bottom < y_top:
    return 0.0

  area_intersection = (x_right - x_left) * (y_bottom - y_top)
  det_area = (det_x2 - det_x1) * (det_y2 - det_y1)
  gt_area = (gt_x2 - gt_x1) * (gt_y2 - gt_y1)

  area_union = float(det_area + gt_area - area_intersection + 1E-6)
  iou = area_intersection / area_union

  return iou

In [12]:
def compute_map(det_boxes, gt_boxes, iou_threshold=0.5, method="area"):
  gt_labels = {cls_key for im_gt in gt_boxes for cls_key in im_gt.keys()}
  gt_labels = sorted(gt_labels)
  all_aps = {}

  # Average precisions for ALL classes
  aps = []
  for idx, label in enumerate(gt_labels):
    # Get detection predictions of this class
    cls_dets = [
        [im_idx, im_dets_label] for im_idx, im_dets in enumerate(det_boxes)
        if label in im_dets for im_dets_label in im_dets[label]
    ]

    # Sort them by confidence score
    cls_dets = sorted(cls_dets, key=lambda k: -k[1][-1])

    # For tracking which gt boxes of this class have already been matched
    gt_matched = [[False for _ in im_gts[label]] for im_gts in gt_boxes]
    # Number of gt boxes for this class for recall calculation
    num_gts = sum([len(im_gts[label]) for im_gts in gt_boxes])
    tp = [0] * len(cls_dets)
    fp = [0] * len(cls_dets)

    # For each prediction
    for det_idx, (im_idx, det_pred) in enumerate(cls_dets):
      # Get gt boxes for this image and this label
      im_gts = gt_boxes[im_idx][label]
      max_iou_found = -1
      max_iou_gt_idx = -1

      # Get best matching gt box
      for gt_box_idx, gt_box in enumerate(im_gts):
        gt_box_iou = get_iou(det_pred[:-1], gt_box)
        if gt_box_iou > max_iou_found:
          max_iou_found = gt_box_iou
          max_iou_gt_idx = gt_box_idx

      # TP only if iou >= threshold and this gt has not yet been matched
      if max_iou_found < iou_threshold or gt_matched[im_idx][max_iou_gt_idx]:
        fp[det_idx] = 1
      else:
        tp[det_idx] = 1
        # If tp then we set this gt box as matched
        gt_matched[im_idx][max_iou_gt_idx] = True

    # Cumulative tp and fp
    tp = np.cumsum(tp)
    fp = np.cumsum(fp)

    eps = np.finfo(np.float32).eps
    recalls = tp / np.maximum(num_gts, eps)
    precisions = tp / np.maximum((tp + fp), eps)

    if method == "area":
      recalls = np.concatenate(([0.0], recalls, [1.0]))
      precisions = np.concatenate(([0.0], precisions, [0.0]))

      # Replace precision values with recall r with maximum precision value
      # of any recall value >= r
      # This computes the precision envelope
      for i in range(precisions.size - 1, 0, -1):
        precisions[i - 1] = np.maximum(precisions[i - 1], precisions[i])

      # For computing area, get points where recall changes value
      i = np.where(recalls[1:] != recalls[:-1])[0]

      # Add the rectangular areas to get ap
      ap = np.sum((recalls[i + 1] - recalls[i]) * precisions[i + 1])
    elif method == "interp":
      ap = 0.0
      for interp_pt in np.arange(0, 1 + 1E-3, 0.1):
        # Get precision values for recall values >= interp_pt
        prec_interp_pt = precisions[recalls >= interp_pt]

        # Get max of those precision values
        prec_interp_pt = prec_interp_pt.max() if prec_interp_pt.size > 0.0 else 0.0
        ap += prec_interp_pt
      ap = ap / 11.0
    else:
      raise ValueError("Method can only be area or interp")
    if num_gts > 0:
      aps.append(ap)
      all_aps[label] = ap
    else:
      all_aps[label] = np.nan

  # compute mAP at provided iou threshold
  mean_ap = sum(aps) / len(aps)

  return mean_ap, all_aps

In [13]:
def load_model_and_dataset(use_resnet50_fpn=True):
  dataset_config = dataset_params
  train_config = train_params

  # Set random seed
  seed = train_config["seed"]
  torch.manual_seed(seed)
  np.random.seed(seed)
  random.seed(seed)
  if device.type == "cuda":
    torch.cuda.manual_seed_all(seed)

  voc = VOCDataset("test",
                   im_dir=dataset_config["im_test_path"],
                   ann_dir=dataset_config["ann_test_path"])

  # Use a Batch Size of 1 for now
  test_dataset = DataLoader(voc,
                            batch_size=1,
                            shuffle=False,
                            num_workers=os.cpu_count(),
                            pin_memory=True)

  if use_resnet50_fpn:
    faster_rcnn_model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True,
                                                                             min_size=600,
                                                                             max_size=1000)
    faster_rcnn_model.roi_heads.box_predictor = FastRCNNPredictor(faster_rcnn_model.roi_heads.box_predictor.cls_score.in_features, num_classes=21)
  else:
    backbone = torchvision.models.resnet34(pretrained=True, norm_layer=torchvision.ops.FrozenBatchNorm2d)
    backbone = nn.Sequential(*list(backbone.children())[:-3])
    backbone.out_channels = 256
    roi_align = torchvision.ops.MultiScaleRoIAlign(featmap_names=["0"], output_size=7, sampling_ratio=2)
    rpn_anchor_generator = AnchorGenerator()
    faster_rcnn_model = torchvision.models.detection.FasterRCNN(backbone,
                                                                num_classes=21,
                                                                min_size=600,
                                                                max_size=1000,
                                                                rpn_anchor_generator=rpn_anchor_generator,
                                                                box_roi_pool=roi_align,
                                                                rpn_pre_nms_top_n_train=12000,
                                                                rpn_pre_nms_top_n_test=6000,
                                                                box_batch_size_per_image=128,
                                                                box_score_thresh=0.7,
                                                                rpn_post_nms_top_n_test=300)

  faster_rcnn_model.eval()
  faster_rcnn_model.to(device)

  path_base = "tv_frcnn_r50fpn_" if use_resnet50_fpn else "tv_frcnn_"

  checkpoint = torch.load(os.path.join(train_config["task_name"], path_base + train_config["ckpt_name"]),
                          map_location=device)

  faster_rcnn_model.load_state_dict(checkpoint["model_state_dict"])

  return faster_rcnn_model, voc, test_dataset

In [14]:
def infer(use_resnet50_fpn=True):
  if use_resnet50_fpn:
    output_dir = "samples_tv_r50fpn"
  else:
    output_dir = "samples_tv"

  if not os.path.exists(output_dir):
      os.mkdir(output_dir)
  faster_rcnn_model, voc, test_dataset = load_model_and_dataset(use_resnet50_fpn)

  for sample_count in tqdm(range(10)):
    random_idx = random.randint(0, len(voc) - 1)
    im, target, fname = voc[random_idx]
    im = im.unsqueeze(0).float().to(device)

    gt_im = cv2.imread(fname)
    gt_im_copy = gt_im.copy()

    # Saving images with ground truth boxes
    for idx, box in enumerate(target["bboxes"]):
      x1, y1, x2, y2 = box.detach().cpu().numpy()
      x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)

      cv2.rectangle(gt_im, (x1, y1), (x2, y2), thickness=2, color=[0, 255, 0])
      cv2.rectangle(gt_im_copy, (x1, y1), (x2, y2), thickness=2, color=[0, 255, 0])

      text = voc.idx2label[target["labels"][idx].detach().cpu().item()]
      text_size, _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_PLAIN, 1, 1)
      text_w, text_h = text_size

      cv2.rectangle(gt_im_copy , (x1, y1), (x1 + 10 + text_w, y1 + 10 + text_h), [255, 255, 255], -1)
      cv2.putText(gt_im, text=voc.idx2label[target["labels"][idx].detach().cpu().item()],
                  org=(x1+5, y1+15),
                  thickness=1,
                  fontScale=1,
                  color=[0, 0, 0],
                  fontFace=cv2.FONT_HERSHEY_PLAIN)

      cv2.putText(gt_im_copy, text=text,
                  org=(x1 + 5, y1 + 15),
                  thickness=1,
                  fontScale=1,
                  color=[0, 0, 0],
                  fontFace=cv2.FONT_HERSHEY_PLAIN)

    cv2.addWeighted(gt_im_copy, 0.7, gt_im, 0.3, 0, gt_im)
    cv2.imwrite(f"{output_dir}/output_frcnn_gt_{sample_count}.png", gt_im)

    # Getting predictions from trained model
    frcnn_output = faster_rcnn_model(im, None)[0]
    boxes = frcnn_output["boxes"]
    labels = frcnn_output["labels"]
    scores = frcnn_output["scores"]

    im = cv2.imread(fname)
    im_copy = im.copy()

    # Saving images with predicted boxes
    for idx, box in enumerate(boxes):
      x1, y1, x2, y2 = box.detach().cpu().numpy()
      x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)

      cv2.rectangle(im, (x1, y1), (x2, y2), thickness=2, color=[0, 0, 255])
      cv2.rectangle(im_copy, (x1, y1), (x2, y2), thickness=2, color=[0, 0, 255])
      text = f"{voc.idx2label[labels[idx].detach().cpu().item()]} : {scores[idx].detach().cpu().item():.2f}"

      text_size, _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_PLAIN, 1, 1)
      text_w, text_h = text_size

      cv2.rectangle(im_copy , (x1, y1), (x1 + 10+text_w, y1 + 10+text_h), [255, 255, 255], -1)
      cv2.putText(im, text=text,
                  org=(x1+5, y1+15),
                  thickness=1,
                  fontScale=1,
                  color=[0, 0, 0],

                  fontFace=cv2.FONT_HERSHEY_PLAIN)
      cv2.putText(im_copy, text=text,
                  org=(x1 + 5, y1 + 15),
                  thickness=1,
                  fontScale=1,
                  color=[0, 0, 0],
                  fontFace=cv2.FONT_HERSHEY_PLAIN)

    cv2.addWeighted(im_copy, 0.7, im, 0.3, 0, im)
    cv2.imwrite(f"{output_dir}/output_frcnn_{sample_count}.jpg", im)

infer(use_resnet50_fpn=True)

{0: 'background', 1: 'aeroplane', 2: 'bicycle', 3: 'bird', 4: 'boat', 5: 'bottle', 6: 'bus', 7: 'car', 8: 'cat', 9: 'chair', 10: 'cow', 11: 'diningtable', 12: 'dog', 13: 'horse', 14: 'motorbike', 15: 'person', 16: 'pottedplant', 17: 'sheep', 18: 'sofa', 19: 'train', 20: 'tvmonitor'}


100%|██████████| 4952/4952 [00:00<00:00, 12460.14it/s]


Total 4952 images found


100%|██████████| 10/10 [00:01<00:00,  6.82it/s]


In [15]:
def evaluate_map(use_resnet50_fpn=True):
  faster_rcnn_model, voc, test_dataset = load_model_and_dataset(use_resnet50_fpn)
  gts = []
  preds = []

  for im, target, fname in tqdm(test_dataset):
    im_name = fname
    im = im.float().to(device)

    target_boxes = target["bboxes"].float().to(device)[0]
    target_labels = target["labels"].long().to(device)[0]
    frcnn_output = faster_rcnn_model(im, None)[0]

    boxes = frcnn_output["boxes"]
    labels = frcnn_output["labels"]
    scores = frcnn_output["scores"]

    pred_boxes = {}
    gt_boxes = {}
    for label_name in voc.label2idx:
      pred_boxes[label_name] = []
      gt_boxes[label_name] = []

    for idx, box in enumerate(boxes):
      x1, y1, x2, y2 = box.detach().cpu().numpy()
      label = labels[idx].detach().cpu().item()
      score = scores[idx].detach().cpu().item()
      label_name = voc.idx2label[label]
      pred_boxes[label_name].append([x1, y1, x2, y2, score])

    for idx, box in enumerate(target_boxes):
      x1, y1, x2, y2 = box.detach().cpu().numpy()
      label = target_labels[idx].detach().cpu().item()
      label_name = voc.idx2label[label]
      gt_boxes[label_name].append([x1, y1, x2, y2])

    gts.append(gt_boxes)
    preds.append(pred_boxes)

  mean_ap, all_aps = compute_map(preds, gts, method="interp")
  print("Class Wise Average Precisions")

  for idx in range(len(voc.idx2label)):
    print(f"AP for class {voc.idx2label[idx]} = {all_aps[voc.idx2label[idx]]:.4f}")

  print(f"Mean Average Precision: {mean_ap:.4f}")

evaluate_map(use_resnet50_fpn=True)

{0: 'background', 1: 'aeroplane', 2: 'bicycle', 3: 'bird', 4: 'boat', 5: 'bottle', 6: 'bus', 7: 'car', 8: 'cat', 9: 'chair', 10: 'cow', 11: 'diningtable', 12: 'dog', 13: 'horse', 14: 'motorbike', 15: 'person', 16: 'pottedplant', 17: 'sheep', 18: 'sofa', 19: 'train', 20: 'tvmonitor'}


100%|██████████| 4952/4952 [00:00<00:00, 13028.70it/s]


Total 4952 images found


100%|██████████| 4952/4952 [02:41<00:00, 30.60it/s]


Class Wise Average Precisions
AP for class background = nan
AP for class aeroplane = 0.8470
AP for class bicycle = 0.8192
AP for class bird = 0.7963
AP for class boat = 0.6439
AP for class bottle = 0.6775
AP for class bus = 0.8220
AP for class car = 0.8580
AP for class cat = 0.8565
AP for class chair = 0.6200
AP for class cow = 0.8694
AP for class diningtable = 0.6332
AP for class dog = 0.8515
AP for class horse = 0.8647
AP for class motorbike = 0.8279
AP for class person = 0.8531
AP for class pottedplant = 0.5479
AP for class sheep = 0.8204
AP for class sofa = 0.6962
AP for class train = 0.8408
AP for class tvmonitor = 0.7832
Mean Average Precision: 0.7764


In [16]:
!zip -r samples_tv_r50fpn.zip samples_tv_r50fpn

  adding: samples_tv_r50fpn/ (stored 0%)
  adding: samples_tv_r50fpn/output_frcnn_gt_8.png (deflated 3%)
  adding: samples_tv_r50fpn/output_frcnn_gt_0.png (deflated 8%)
  adding: samples_tv_r50fpn/output_frcnn_2.jpg (deflated 0%)
  adding: samples_tv_r50fpn/output_frcnn_gt_7.png (deflated 1%)
  adding: samples_tv_r50fpn/output_frcnn_7.jpg (deflated 1%)
  adding: samples_tv_r50fpn/output_frcnn_gt_2.png (deflated 2%)
  adding: samples_tv_r50fpn/output_frcnn_gt_1.png (deflated 2%)
  adding: samples_tv_r50fpn/output_frcnn_gt_4.png (deflated 0%)
  adding: samples_tv_r50fpn/output_frcnn_5.jpg (deflated 0%)
  adding: samples_tv_r50fpn/output_frcnn_3.jpg (deflated 4%)
  adding: samples_tv_r50fpn/output_frcnn_8.jpg (deflated 0%)
  adding: samples_tv_r50fpn/output_frcnn_gt_9.png (deflated 2%)
  adding: samples_tv_r50fpn/output_frcnn_9.jpg (deflated 1%)
  adding: samples_tv_r50fpn/output_frcnn_gt_3.png (deflated 5%)
  adding: samples_tv_r50fpn/output_frcnn_0.jpg (deflated 0%)
  adding: samples_tv

In [17]:
from google.colab import files

files.download("samples_tv_r50fpn.zip")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>