In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data.dataloader import DataLoader
from torch.optim.lr_scheduler import MultiStepLR

import math
import glob
import os
import random
import cv2
import numpy as np
from PIL import Image
from tqdm import tqdm
from torch.utils.data.dataset import Dataset
import xml.etree.ElementTree as ET
import albumentations as albu

torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

  check_for_updates()


device(type='cuda')

In [2]:
dataset_params = {
  "train_im_sets": ["data/VOC2007", "data/VOC2012"],
  "test_im_sets": ["data/VOC2007-test"],
  "num_classes": 20, # 20 foreground classes
  "im_size": 448
}

model_params = {
  "im_channels": 3,
  "backbone_channels": 512,
  "conv_spatial_size": 7, # Size after all conv layers
  "yolo_conv_channels": 1024,
  "leaky_relu_slope": 0.1,
  "fc_dim": 4096,
  "fc_dropout": 0.5,
  "S": 7, # 49 grid cells
  "B": 2, # 2 boxes per grid cell
  "use_sigmoid": True,
  "use_conv": True
}

train_params = {
  "task_name": "voc",
  "seed": 1111,
  "acc_steps": 1, # Increase if you want to get gradients from >1 steps (kind of mimicking >1 batch size)
  "log_steps": 100,
  "num_epochs": 135,
  "batch_size": 64,
  "lr_steps": [50, 75, 100, 125],
  "lr": 0.001,
  "infer_conf_threshold": 0.2,
  "eval_conf_threshold": 0.001,
  "nms_threshold": 0.5,
  "ckpt_name": "yolo_voc2007.pth"
}

In [3]:
from google.colab import drive
drive.mount("/content/drive")

Mounted at /content/drive


In [4]:
!unzip -q "/content/drive/MyDrive/Machine Learning/advanced-deep-learning/object-detection/YOLO/data.zip" -d "/content"

In [5]:
def load_images_and_anns(im_sets, label2idx, ann_fname, split):
  """
  Method to get the xml files and for each file
  get all the objects and their ground truth detection
  information for the dataset

  im_sets: Sets of images to consider
  label2idx: Class Name to index mapping for dataset
  ann_fname: txt file containing image names{trainval.txt/test.txt}
  split: train/test
  """
  im_infos = []
  ims = []

  for im_set in im_sets:
    im_names = []

    # Fetch all image names in txt file for this imageset
    for line in open(os.path.join(im_set, "ImageSets", "Main", f"{ann_fname}.txt")):
      im_names.append(line.strip())

    # Set annotation and image path
    ann_dir = os.path.join(im_set, "Annotations")
    im_dir = os.path.join(im_set, "JPEGImages")

    for im_name in im_names:
      ann_file = os.path.join(ann_dir, f"{im_name}.xml")
      im_info = {}
      ann_info = ET.parse(ann_file)
      root = ann_info.getroot()
      size = root.find("size")

      width = int(size.find("width").text)
      height = int(size.find("height").text)

      im_info["img_id"] = os.path.basename(ann_file).split(".xml")[0]
      im_info["filename"] = os.path.join(im_dir, f"{im_info['img_id']}.jpg")

      im_info["width"] = width
      im_info["height"] = height
      detections = []

      # We will keep an image only if there are valid rois in it
      any_valid_object = False
      for obj in ann_info.findall("object"):
        det = {}

        label = label2idx[obj.find("name").text]
        difficult = int(obj.find("difficult").text)
        bbox_info = obj.find("bndbox")

        # Boxes are in x_1, y_1, x_2, y_2 format
        bbox = [
            int(float(bbox_info.find("xmin").text))-1,
            int(float(bbox_info.find("ymin").text))-1,
            int(float(bbox_info.find("xmax").text))-1,
            int(float(bbox_info.find("ymax").text))-1
        ]

        det["label"] = label
        det["bbox"] = bbox
        det["difficult"] = difficult

        # Ignore difficult rois during training
        # At test time eval does the job of ignoring difficult examples.
        if difficult == 0 or split == "test":
          detections.append(det)
          any_valid_object = True

      if any_valid_object:
        im_info["detections"] = detections
        im_infos.append(im_info)

  print(f"Total {len(im_infos)} images found")
  return im_infos

In [12]:
class VOCDataset(Dataset):
  def __init__(self, split, im_sets, im_size=448, S=7, B=2, C=20):
    self.split = split
    # ImageSets for this dataset instance (VOC2007/VOC2007+VOC2012/VOC2007-test)
    self.im_sets = im_sets
    self.fname = "trainval" if self.split == "train" else "test"
    self.im_size = im_size
    # Grid size, B and C parameter for target setting
    self.S = S
    self.B = B
    self.C = C

    # Train and test augmentations
    self.transforms = {
        # Training Augmentations: Horizontal Flip, Affine, ColorJitter
        "train": albu.Compose([
            albu.HorizontalFlip(p=0.5),
            albu.Affine(
                scale=(0.8, 1.2),
                translate_percent=(-0.2, 0.2),
                always_apply=True
            ),
            albu.ColorJitter(
                brightness=(0.8, 1.2),
                contrast=(0.8, 1.2),
                saturation=(0.8, 1.2),
                hue=(-0.2, 0.2),
                always_apply=False,
                p=0.5,
            ),
            albu.Resize(self.im_size, self.im_size)],
            bbox_params=albu.BboxParams(format="pascal_voc",
                                        label_fields=["labels"])),
        "test": albu.Compose([
            albu.Resize(self.im_size, self.im_size),
            ],
            bbox_params=albu.BboxParams(format="pascal_voc",
                                        label_fields=["labels"]))
    }

    classes = [
        "person", "bird", "cat", "cow", "dog", "horse", "sheep",
        "aeroplane", "bicycle", "boat", "bus", "car", "motorbike", "train",
        "bottle", "chair", "diningtable", "pottedplant", "sofa", "tvmonitor"
    ]
    classes = sorted(classes)

    self.label2idx = {classes[idx]: idx for idx in range(len(classes))}
    self.idx2label = {idx: classes[idx] for idx in range(len(classes))}

    print(self.idx2label)

    self.images_info = load_images_and_anns(self.im_sets,
                                            self.label2idx,
                                            self.fname,
                                            self.split)

  def __len__(self):
    return len(self.images_info)

  def __getitem__(self, index):
    im_info = self.images_info[index]
    im = cv2.imread(im_info["filename"])
    im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)

    # Get annotations for this image
    bboxes = [detection["bbox"] for detection in im_info["detections"]]
    labels = [detection["label"] for detection in im_info["detections"]]
    difficult = [detection["difficult"] for detection in im_info["detections"]]

    # Transform Image and ann according to augmentations list
    transformed_info = self.transforms[self.split](image=im,
                                                    bboxes=bboxes,
                                                    labels=labels)
    im = transformed_info["image"]
    bboxes = torch.as_tensor(transformed_info["bboxes"])
    labels = torch.as_tensor(transformed_info["labels"])
    difficult = torch.as_tensor(difficult)

    # Convert image to tensor and normalize for ResNet pretrained on ImageNet
    im_tensor = torch.from_numpy(im / 255.).permute((2, 0, 1)).float()
    im_tensor_channel_0 = (torch.unsqueeze(im_tensor[0], 0) - 0.485) / 0.229
    im_tensor_channel_1 = (torch.unsqueeze(im_tensor[1], 0) - 0.456) / 0.224
    im_tensor_channel_2 = (torch.unsqueeze(im_tensor[2], 0) - 0.406) / 0.225

    im_tensor = torch.cat((im_tensor_channel_0,
                            im_tensor_channel_1,
                            im_tensor_channel_2), 0)

    bboxes_tensor = torch.as_tensor(bboxes)
    labels_tensor = torch.as_tensor(labels)

    # Build Target for Yolo
    # S x S x (5B +C)
    target_dim = 5 * self.B + self.C
    h, w = im.shape[:2]
    yolo_targets = torch.zeros(self.S, self.S, target_dim)

    # Height and width of grid cells is H // S
    cell_pixels = h // self.S

    if len(bboxes) > 0:
      # Convert x_1, y_1, x_2, y_2 to center_x, center_y, width, height format
      box_widths = bboxes_tensor[:, 2] - bboxes_tensor[:, 0]
      box_heights = bboxes_tensor[:, 3] - bboxes_tensor[:, 1]
      box_center_x = bboxes_tensor[:, 0] + 0.5 * box_widths
      box_center_y = bboxes_tensor[:, 1] + 0.5 * box_heights

      # Get the grid cell i, j from x_center, y_center
      box_i = torch.floor(box_center_x / cell_pixels).long()
      box_j = torch.floor(box_center_y / cell_pixels).long()

      # x_center and y_center offsets from the top left of grid cell and normalized to [0, 1]
      box_xc_cell_offset = (box_center_x - box_i*cell_pixels) / cell_pixels
      box_yc_cell_offset = (box_center_y - box_j*cell_pixels) / cell_pixels

      # width and height targets normalized to [0, 1]
      box_w_label = box_widths / w
      box_h_label = box_heights / h

      # Update the target array for all bboxes
      for idx, b in enumerate(range(bboxes_tensor.size(0))):
        # Make target of the exact same shape as prediction
        for k in range(self.B):
          s = 5 * k

          # target_ij = [xc_offset,yc_offset,sqrt(w),sqrt(h), conf, cls_label]
          yolo_targets[box_j[idx], box_i[idx], s] = box_xc_cell_offset[idx]
          yolo_targets[box_j[idx], box_i[idx], s+1] = box_yc_cell_offset[idx]

          # Use the square root for localization loss function
          yolo_targets[box_j[idx], box_i[idx], s+2] = box_w_label[idx].sqrt()
          yolo_targets[box_j[idx], box_i[idx], s+3] = box_h_label[idx].sqrt()
          yolo_targets[box_j[idx], box_i[idx], s+4] = 1.0

        label = int(labels[b])
        cls_target = torch.zeros((self.C,))
        cls_target[label] = 1.
        yolo_targets[box_j[idx], box_i[idx], 5 * self.B:] = cls_target

    # For training, we use yolo_targets(xoffset, yoffset, sqrt(w), sqrt(h))
    # For evaluation we use bboxes_tensor (x1, y1, x2, y2)
    # Normalize bboxes tensor to be between [0, 1] as thats what evaluation script expects so (x1/w, y1/h, x2/w, y2/h)
    if len(bboxes) > 0:
      bboxes_tensor /= torch.Tensor([[w, h, w, h]]).expand_as(bboxes_tensor)

    targets = {
        "bboxes": bboxes_tensor,
        "labels": labels_tensor,
        "yolo_targets": yolo_targets,
        "difficult": difficult,
    }

    return im_tensor, targets, im_info["filename"]

In [13]:
class YOLOV1(nn.Module):
  """
  Model with three components:
  1. Backbone of ResnNt34 pretrained on 224x224 images from Imagenet
  2. 4 Conv, BatchNorm, and LeakyReLU layers for Yolo Detection Head
  3. Linear layers with final layer having S * S * (5B + C) output dimensions

  The final layer predicts the x_offset_boxB, y_offset_boxB, sqrt_w_boxB, sqrt_h_boxB, conf_boxB,
  and class conditional probabilities for each S * S grid cell
  """
  def __init__(self, im_size, num_classes, model_config):
    super().__init__()

    self.im_size = im_size # 448 x 448
    self.im_channels = model_config["im_channels"] # 3
    self.backbone_channels = model_config["backbone_channels"] # ResNet34 gives feature map with 512 channels
    self.yolo_conv_channels = model_config["yolo_conv_channels"] # 1024 channels
    self.conv_spatial_size = model_config["conv_spatial_size"] # size of final feature map after conv layers
    self.leaky_relu_slope = model_config["leaky_relu_slope"]
    self.yolo_fc_hidden_dim = model_config["fc_dim"]
    self.yolo_fc_dropout_prob = model_config["fc_dropout"]
    self.use_conv = model_config["use_conv"]
    self.S = model_config["S"] # 7
    self.B = model_config["B"] # 2
    self.C = num_classes

    backbone = torchvision.models.resnet34(weights=torchvision.models.ResNet34_Weights.IMAGENET1K_V1)

    # Backbone Layers
    self.features = nn.Sequential(
        backbone.conv1,
        backbone.bn1,
        backbone.relu,
        backbone.maxpool,
        backbone.layer1,
        backbone.layer2,
        backbone.layer3,
        backbone.layer4,
    ) # ResNet34 layers without last linear layers

    # Detection Convolutional Layers
    self.conv_yolo_layers = nn.Sequential(
        nn.Conv2d(in_channels=self.backbone_channels,
                  out_channels=self.yolo_conv_channels,
                  kernel_size=3,
                  stride=1,
                  padding=1,
                  bias=False),
        nn.BatchNorm2d(self.yolo_conv_channels),
        nn.LeakyReLU(self.leaky_relu_slope),

        nn.Conv2d(in_channels=self.yolo_conv_channels,
                  out_channels=self.yolo_conv_channels,
                  kernel_size=3,
                  stride=2,
                  padding=1,
                  bias=False),
        nn.BatchNorm2d(self.yolo_conv_channels),
        nn.LeakyReLU(self.leaky_relu_slope),

        nn.Conv2d(in_channels=self.yolo_conv_channels,
                  out_channels=self.yolo_conv_channels,
                  kernel_size=3,
                  stride=1,
                  padding=1,
                  bias=False),
        nn.BatchNorm2d(self.yolo_conv_channels),
        nn.LeakyReLU(self.leaky_relu_slope),

        nn.Conv2d(in_channels=self.yolo_conv_channels,
                  out_channels=self.yolo_conv_channels,
                  kernel_size=3,
                  stride=1,
                  padding=1,
                  bias=False),
        nn.BatchNorm2d(self.yolo_conv_channels),
        nn.LeakyReLU(self.leaky_relu_slope),
    )

    # Detection Layers
    if self.use_conv:
      self.linear_yolo_layers = nn.Sequential(
          nn.Conv2d(in_channels=self.yolo_conv_channels,
                    out_channels=(5 * self.B + self.C),
                    kernel_size=1,
                    stride=1,
                    padding=0)
      )
    else:
      self.linear_yolo_layers = nn.Sequential(
          nn.Flatten(),
          nn.Linear(in_features=(self.conv_spatial_size * self.conv_spatial_size * self.yolo_conv_channels),
                    out_features=self.yolo_fc_hidden_dim),
          nn.LeakyReLU(self.leaky_relu_slope),
          nn.Dropout(self.yolo_fc_dropout_prob),
          nn.Linear(in_features=self.yolo_fc_hidden_dim,
                    out_features=(self.S * self.S * (5 * self.B + self.C)))
      )

  def forward(self, x):
    # Backbone
    out = self.features(x)

    out = self.conv_yolo_layers(out)
    out = self.linear_yolo_layers(out)

    if self.use_conv:
      # Reshape to (batch_size, S, S, 5 * B + C)
      out = out.permute(0, 2, 3, 1)

    return out

In [14]:
def get_iou(boxes1, boxes2):
  """IOU between two sets of boxes"""

  # Area of boxes (x_2 - x_1) * (y_2 - y_1)
  area1 = (boxes1[..., 2] - boxes1[..., 0]) * (boxes1[..., 3] - boxes1[..., 1])
  area2 = (boxes2[..., 2] - boxes2[..., 0]) * (boxes2[..., 3] - boxes2[..., 1])

  # Get top left coordinates: (x_1, y_1)
  x_left = torch.max(boxes1[..., 0], boxes2[..., 0])
  y_top = torch.max(boxes1[..., 1], boxes2[..., 1])

  # Get bottom right coordinates: (x_2, y_2)
  x_right = torch.min(boxes1[..., 2], boxes2[..., 2])
  y_bottom = torch.min(boxes1[..., 3], boxes2[..., 3])

  intersection_area = (x_right - x_left).clamp(min=0) * (y_bottom - y_top).clamp(min=0)
  union = area1.clamp(min=0) + area2.clamp(min=0) - intersection_area

  iou = intersection_area / (union + 1E-6)
  return iou

In [15]:
class YOLOV1Loss(nn.Module):
  """
  Loss module for YoloV1 which caters to the following components:
  1. Localization Loss for responsible predictor boxes
  2. Confidence Loss for responsible predictor boxes
  2. Confidence Loss for non-responsible predictor boxes of cells assigned with objects
  2. Confidence Loss for ALL predictor boxes of cells not assigned with objects
  3. Classification Loss
  """
  def __init__(self, S=7, B=2, C=20):
    super().__init__()

    self.S = S
    self.B = B
    self.C = C
    self.lambda_coord = 5
    self.lambda_noobj = 0.5

  def forward(self, preds, targets, use_sigmoid=False):
    """
    Compute the YOLO Loss
    The target element for each cell has been duplicated 5B times(done in VOCDataset)

    preds: (batch_size, S * S * (5B+C)) tensor
    targets: (batch_size, S, S, (5B + C)) tensor
    use_sigmoid: Whether to use sigmoid activation for box predicitons or not
    """
    batch_size = preds.size(0)

    # Reshape preds to same shape as targets
    preds = preds.reshape(batch_size, self.S, self.S, 5 * self.B + self.C)
    # preds shape: (batch_size, S, S, 5B + C)

    # Sigmoid leads to quicker convergence
    if use_sigmoid:
      preds[..., :5 * self.B] = torch.sigmoid(preds[..., :5 * self.B])

    # For localization and confidence loss, we need to get the responsible predictor box for each grid cell

    # Define shifts for all grid cell locations and normalize between 0 and 1
    # Will use these for converting x_center_offset and y_center_offset
    # Values for x_1, y_2, x_2, and y_2 are normalized between [0, 1]
    # S cells = 1 => each cell adds 1/S pixels of shift
    shifts_x = torch.arange(0,
                            self.S,
                            dtype=torch.int32,
                            device=preds.device) * 1 / float(self.S)
    shifts_y = torch.arange(0,
                            self.S,
                            dtype=torch.int32,
                            device=preds.device) * 1 / float(self.S)

    # Create a grid using these shifts
    shifts_y, shifts_x = torch.meshgrid(shifts_y, shifts_x, indexing="ij")

    # shifts shape: (1, S, S, B)
    shifts_x = shifts_x.reshape((1, self.S, self.S, 1)).repeat(1, 1, 1, self.B)
    shifts_y = shifts_y.reshape((1, self.S, self.S, 1)).repeat(1, 1, 1, self.B)

    # pred_boxes shape: (batch_size, S, S, B, 5)
    pred_boxes = preds[..., :5 * self.B].reshape(batch_size, self.S, self.S, self.B, -1)

    # xc_offset, yc_offset, width, height -> x_1, y_1, x_2, y_2 normalized betwen [0, 1]
    # x_center = (xc_offset / S + shift_x)
    # x_1 = x_center - 0.5 * w
    # x_2 = x_center + 0.5 * w

    # Use squared value because the model predicts the square root of width and height
    pred_boxes_x1 = ((pred_boxes[..., 0] / self.S + shifts_x) - 0.5 * torch.square(pred_boxes[..., 2]))
    pred_boxes_x1 = pred_boxes_x1.unsqueeze(dim=-1)

    pred_boxes_y1 = ((pred_boxes[..., 1] / self.S + shifts_y) - 0.5 * torch.square(pred_boxes[..., 3]))
    pred_boxes_y1 = pred_boxes_y1.unsqueeze(dim=-1)

    pred_boxes_x2 = ((pred_boxes[..., 0] / self.S + shifts_x) + 0.5 * torch.square(pred_boxes[..., 2]))
    pred_boxes_x2 = pred_boxes_x2.unsqueeze(dim=-1)

    pred_boxes_y2 = ((pred_boxes[..., 1] / self.S + shifts_y) + 0.5 * torch.square(pred_boxes[..., 3]))
    pred_boxes_y2 = pred_boxes_y2.unsqueeze(dim=-1)

    pred_boxes_x1y1x2y2 = torch.cat([
        pred_boxes_x1,
        pred_boxes_y1,
        pred_boxes_x2,
        pred_boxes_y2], dim=-1)

    # Do the same for the target boxes
    # target_boxes shape: (batch_size, S, S, B, 5)
    target_boxes = targets[..., :5*self.B].reshape(batch_size, self.S, self.S, self.B, -1)

    target_boxes_x1 = ((target_boxes[..., 0] / self.S + shifts_x) - 0.5 * torch.square(target_boxes[..., 2]))
    target_boxes_x1 = target_boxes_x1.unsqueeze(dim=-1)

    target_boxes_y1 = ((target_boxes[..., 1] / self.S + shifts_y) - 0.5 * torch.square(target_boxes[..., 3]))
    target_boxes_y1 = target_boxes_y1.unsqueeze(dim=-1)

    target_boxes_x2 = ((target_boxes[..., 0] / self.S + shifts_x) + 0.5 * torch.square(target_boxes[..., 2]))
    target_boxes_x2 = target_boxes_x2.unsqueeze(dim=-1)

    target_boxes_y2 = ((target_boxes[..., 1] / self.S + shifts_y) + 0.5 * torch.square(target_boxes[..., 3]))
    target_boxes_y2 = target_boxes_y2.unsqueeze(dim=-1)

    target_boxes_x1y1x2y2 = torch.cat([
        target_boxes_x1,
        target_boxes_y1,
        target_boxes_x2,
        target_boxes_y2], dim=-1)

    # pred_boxes_x1y1x2y2 shape: (batch_size, S, S, B, 4)
    # target_boxes_x1y1x2y2 shape: (batch_size, S, S, B, 4)

    # iou shape: (batch_size, S, S, B)
    iou = get_iou(pred_boxes_x1y1x2y2, target_boxes_x1y1x2y2)

    # Get the max along the last dimension which is the box index that has the maximum IOU at every grid cell
    # max_iou_val and max_iou_idx shape: (batch_size, S, S, 1)
    max_iou_val, max_iou_idx = iou.max(dim=-1, keepdim=True)

    # For localization and confidence loss, we need to find the cells that have objects assigned and which are the responsible predictor boxes

    # Indicator Definitions
    max_iou_idx = max_iou_idx.repeat(1, 1, 1, self.B) # (batch_size, S, S, 1) -> (batch_size, S, S, B)

    bb_idxs = (torch.arange(self.B).reshape(1, 1, 1, self.B).expand_as(max_iou_idx).to(preds.device)) # (batch_size, S, S, B)

    # Only the index which of the box that has the highest IOU will be 1, the others will be 0
    is_max_iou_box = (max_iou_idx == bb_idxs).long() # (batch_size, S, S, B)

    obj_indicator = targets[..., 4:5] # (batch_size, S, S, 1)

    # Classification Loss
    cls_target = targets[..., 5 * self.B:]
    cls_preds = preds[..., 5 * self.B:]
    cls_mse = (cls_preds - cls_target)**2

    # Only keep losses from cells with objects assigned
    cls_mse = (obj_indicator * cls_mse).sum()

    # Confidence Loss for the responsible predictor boxes

    # Filters out the cells which were not assigned an object and boxes that are not the responsible predictor
    is_max_box_obj_indicator = is_max_iou_box * obj_indicator
    obj_mse = (pred_boxes[..., 4] - max_iou_val)**2

    # Only keep losses from boxes of cells with object assigned and that box which is the responsible predictor
    obj_mse = (is_max_box_obj_indicator * obj_mse).sum()

    # Localization Loss for x, y, w, h
    x_mse = (pred_boxes[..., 0] - target_boxes[..., 0])**2
    y_mse = (pred_boxes[..., 1] - target_boxes[..., 1])**2
    w_sqrt_mse = (pred_boxes[..., 2] - target_boxes[..., 2])**2
    h_sqrt_mse = (pred_boxes[..., 3] - target_boxes[..., 3])**2

    # Only keep losses from boxes of cells with object assigned and that box which is the responsible predictor
    x_mse = (is_max_box_obj_indicator * x_mse).sum()
    y_mse = (is_max_box_obj_indicator * y_mse).sum()
    w_sqrt_mse = (is_max_box_obj_indicator * w_sqrt_mse).sum()
    h_sqrt_mse = (is_max_box_obj_indicator * h_sqrt_mse).sum()

    # Confidence Loss for background boxes and cells
    # Boxes of cells assigned with object that aren't responsible predictor boxes and for boxes of cell not assigned with object
    no_object_indicator = 1 - is_max_box_obj_indicator
    no_obj_mse = (pred_boxes[..., 4] - torch.zeros_like(pred_boxes[..., 4])) ** 2
    no_obj_mse = (no_object_indicator * no_obj_mse).sum()

    # Total Loss
    loss = self.lambda_coord * (x_mse + y_mse + w_sqrt_mse + h_sqrt_mse)
    loss += cls_mse + obj_mse
    loss += self.lambda_noobj * no_obj_mse
    loss = loss / batch_size
    return loss

In [16]:
def collate_function(data):
  return list(zip(*data))

In [18]:
def train():
  dataset_config = dataset_params
  model_config = model_params
  train_config = train_params

  # Set the random seed
  seed = train_config["seed"]
  torch.manual_seed(seed)
  np.random.seed(seed)
  random.seed(seed)
  if device.type == "cuda":
    torch.cuda.manual_seed_all(seed)

  voc = VOCDataset("train", im_sets=dataset_config["train_im_sets"])
  train_dataset = DataLoader(voc,
                             batch_size=train_config["batch_size"],
                             shuffle=True,
                             collate_fn=collate_function,
                             num_workers=os.cpu_count(),
                             pin_memory=True)

  yolo_model = YOLOV1(im_size=dataset_config["im_size"],
                      num_classes=dataset_config["num_classes"],
                      model_config=model_config)
  yolo_model.train()
  yolo_model.to(device)

  if not os.path.exists(train_config["task_name"]):
    os.mkdir(train_config["task_name"])

  optimizer = torch.optim.SGD(lr=train_config["lr"],
                              params=filter(lambda p: p.requires_grad,
                                            yolo_model.parameters()),
                              weight_decay=5E-4,
                              momentum=0.9)

  scheduler = MultiStepLR(optimizer, milestones=train_config["lr_steps"], gamma=0.5)
  criterion = YOLOV1Loss()
  acc_steps = train_config["acc_steps"]
  num_epochs = train_config["num_epochs"]
  steps = 0

  # Resume training from a saved checkpoint if it exists
  checkpoint_path = os.path.join(train_config["task_name"], train_config["ckpt_name"])
  start_epoch = 0

  if os.path.exists(checkpoint_path):
    print("Resuming training from checkpoint...")
    checkpoint = torch.load(checkpoint_path, map_location=device)

    yolo_model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
    start_epoch = checkpoint["epoch"]
    steps = checkpoint["steps"]

  optimizer.zero_grad()

  for epoch_idx in range(start_epoch, num_epochs):
    losses = []
    optimizer.zero_grad()
    for idx, (ims, targets, _) in enumerate(tqdm(train_dataset)):
      yolo_targets = torch.cat([
          target["yolo_targets"].unsqueeze(0).float().to(device)
          for target in targets
      ], dim=0)

      im = torch.cat([im.unsqueeze(0).float().to(device) for im in ims], dim=0)

      yolo_preds = yolo_model(im)

      loss = criterion(yolo_preds, yolo_targets, use_sigmoid=model_config["use_sigmoid"])
      loss = loss / acc_steps
      loss.backward()
      losses.append(loss.item())

      if (idx + 1) % acc_steps == 0:
        optimizer.step()
        optimizer.zero_grad()
      if steps % train_config["log_steps"] == 0:
        print(f"Loss: {np.mean(losses):.4f}")
      if torch.isnan(loss):
        print("Loss is becoming nan. Exiting")
        exit(0)

      steps += 1
    print(f"Finished epoch {epoch_idx + 1} | Loss: {np.mean(losses):.4f}")
    optimizer.step()
    optimizer.zero_grad()
    scheduler.step()

    checkpoint = {
        "epoch": epoch_idx,
        "steps": steps,
        "model_state_dict": yolo_model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "scheduler_state_dict": scheduler.state_dict(),
    }
    torch.save(checkpoint, os.path.join(train_config["task_name"], train_config["ckpt_name"]))

  print("Done Training...")

train()

  albu.Affine(
  albu.ColorJitter(


{0: 'aeroplane', 1: 'bicycle', 2: 'bird', 3: 'boat', 4: 'bottle', 5: 'bus', 6: 'car', 7: 'cat', 8: 'chair', 9: 'cow', 10: 'diningtable', 11: 'dog', 12: 'horse', 13: 'motorbike', 14: 'person', 15: 'pottedplant', 16: 'sheep', 17: 'sofa', 18: 'train', 19: 'tvmonitor'}
Total 16551 images found


  0%|          | 1/259 [00:04<19:30,  4.54s/it]

Loss: 30.8514


 39%|███▉      | 101/259 [00:27<00:45,  3.48it/s]

Loss: 4.6826


 78%|███████▊  | 201/259 [00:50<00:12,  4.72it/s]

Loss: 3.7976


100%|██████████| 259/259 [01:02<00:00,  4.14it/s]


Finished epoch 1 | Loss: 3.5693


 17%|█▋        | 43/259 [00:13<00:42,  5.06it/s]

Loss: 2.5888


 55%|█████▌    | 143/259 [00:36<00:23,  4.88it/s]

Loss: 2.5194


 94%|█████████▍| 243/259 [00:59<00:03,  4.97it/s]

Loss: 2.4678


100%|██████████| 259/259 [01:02<00:00,  4.17it/s]


Finished epoch 2 | Loss: 2.4588


 32%|███▏      | 84/259 [00:23<00:34,  5.03it/s]

Loss: 2.2776


 71%|███████   | 184/259 [00:46<00:17,  4.35it/s]

Loss: 2.2142


100%|██████████| 259/259 [01:03<00:00,  4.07it/s]


Finished epoch 3 | Loss: 2.1929


  9%|▉         | 24/259 [00:10<00:47,  4.90it/s]

Loss: 2.0826


 48%|████▊     | 125/259 [00:33<00:32,  4.08it/s]

Loss: 2.0517


 87%|████████▋ | 225/259 [00:58<00:07,  4.74it/s]

Loss: 2.0335


100%|██████████| 259/259 [01:05<00:00,  3.98it/s]


Finished epoch 4 | Loss: 2.0207


 25%|██▌       | 66/259 [00:19<00:49,  3.92it/s]

Loss: 1.9756


 64%|██████▍   | 166/259 [00:43<00:19,  4.77it/s]

Loss: 1.9255


100%|██████████| 259/259 [01:04<00:00,  4.02it/s]


Finished epoch 5 | Loss: 1.8892


  3%|▎         | 7/259 [00:05<01:35,  2.64it/s]

Loss: 1.8110


 41%|████▏     | 107/259 [00:28<00:31,  4.78it/s]

Loss: 1.7882


 80%|███████▉  | 207/259 [00:51<00:10,  4.93it/s]

Loss: 1.7952


100%|██████████| 259/259 [01:02<00:00,  4.14it/s]


Finished epoch 6 | Loss: 1.7901


 19%|█▊        | 48/259 [00:14<00:41,  5.07it/s]

Loss: 1.6731


 57%|█████▋    | 148/259 [00:37<00:26,  4.21it/s]

Loss: 1.7197


 96%|█████████▌| 248/259 [00:59<00:02,  5.30it/s]

Loss: 1.7021


100%|██████████| 259/259 [01:01<00:00,  4.20it/s]


Finished epoch 7 | Loss: 1.7080


 34%|███▍      | 88/259 [00:24<00:49,  3.46it/s]

Loss: 1.6849


 73%|███████▎  | 189/259 [00:47<00:14,  4.87it/s]

Loss: 1.6408


100%|██████████| 259/259 [01:02<00:00,  4.13it/s]


Finished epoch 8 | Loss: 1.6400


 11%|█         | 29/259 [00:10<00:46,  4.94it/s]

Loss: 1.5817


 50%|████▉     | 129/259 [00:33<00:26,  4.90it/s]

Loss: 1.5925


 88%|████████▊ | 229/259 [00:55<00:06,  4.60it/s]

Loss: 1.5636


100%|██████████| 259/259 [01:01<00:00,  4.19it/s]


Finished epoch 9 | Loss: 1.5714


 27%|██▋       | 71/259 [00:19<00:38,  4.94it/s]

Loss: 1.5023


 66%|██████▌   | 170/259 [00:42<00:17,  4.97it/s]

Loss: 1.4936


100%|██████████| 259/259 [01:01<00:00,  4.19it/s]


Finished epoch 10 | Loss: 1.5056


  5%|▍         | 12/259 [00:06<00:55,  4.45it/s]

Loss: 1.5027


 43%|████▎     | 112/259 [00:29<00:29,  4.91it/s]

Loss: 1.4343


 82%|████████▏ | 212/259 [00:52<00:11,  4.12it/s]

Loss: 1.4548


100%|██████████| 259/259 [01:02<00:00,  4.16it/s]


Finished epoch 11 | Loss: 1.4544


 20%|██        | 53/259 [00:16<00:45,  4.48it/s]

Loss: 1.3442


 59%|█████▉    | 153/259 [00:39<00:22,  4.74it/s]

Loss: 1.3947


 98%|█████████▊| 253/259 [01:01<00:01,  5.44it/s]

Loss: 1.4020


100%|██████████| 259/259 [01:02<00:00,  4.14it/s]


Finished epoch 12 | Loss: 1.4006


 36%|███▋      | 94/259 [00:25<00:36,  4.51it/s]

Loss: 1.3638


 75%|███████▍  | 194/259 [00:47<00:15,  4.17it/s]

Loss: 1.3670


100%|██████████| 259/259 [01:01<00:00,  4.19it/s]


Finished epoch 13 | Loss: 1.3576


 14%|█▎        | 35/259 [00:12<00:45,  4.88it/s]

Loss: 1.3627


 52%|█████▏    | 134/259 [00:35<00:24,  5.00it/s]

Loss: 1.3212


 90%|█████████ | 234/259 [00:58<00:06,  4.15it/s]

Loss: 1.3203


100%|██████████| 259/259 [01:02<00:00,  4.12it/s]


Finished epoch 14 | Loss: 1.3208


 29%|██▉       | 76/259 [00:21<00:52,  3.49it/s]

Loss: 1.2537


 68%|██████▊   | 176/259 [00:44<00:18,  4.47it/s]

Loss: 1.2686


100%|██████████| 259/259 [01:02<00:00,  4.16it/s]


Finished epoch 15 | Loss: 1.2757


  7%|▋         | 17/259 [00:08<00:58,  4.17it/s]

Loss: 1.2642


 45%|████▌     | 117/259 [00:31<00:33,  4.30it/s]

Loss: 1.2358


 83%|████████▎ | 216/259 [00:53<00:09,  4.76it/s]

Loss: 1.2322


100%|██████████| 259/259 [01:02<00:00,  4.12it/s]


Finished epoch 16 | Loss: 1.2448


 22%|██▏       | 57/259 [00:17<00:42,  4.79it/s]

Loss: 1.1990


 61%|██████    | 158/259 [00:40<00:20,  4.90it/s]

Loss: 1.1925


100%|█████████▉| 258/259 [01:03<00:00,  5.39it/s]

Loss: 1.2115


100%|██████████| 259/259 [01:03<00:00,  4.06it/s]


Finished epoch 17 | Loss: 1.2123


 38%|███▊      | 99/259 [00:26<00:43,  3.66it/s]

Loss: 1.1803


 77%|███████▋  | 199/259 [00:49<00:12,  4.62it/s]

Loss: 1.1754


100%|██████████| 259/259 [01:02<00:00,  4.17it/s]


Finished epoch 18 | Loss: 1.1784


 15%|█▌        | 39/259 [00:12<00:43,  5.00it/s]

Loss: 1.1623


 54%|█████▍    | 140/259 [00:35<00:27,  4.31it/s]

Loss: 1.1297


 93%|█████████▎| 240/259 [00:57<00:03,  5.13it/s]

Loss: 1.1398


100%|██████████| 259/259 [01:01<00:00,  4.21it/s]


Finished epoch 19 | Loss: 1.1458


 31%|███▏      | 81/259 [00:22<00:37,  4.79it/s]

Loss: 1.1073


 70%|██████▉   | 181/259 [00:46<00:15,  5.02it/s]

Loss: 1.1123


100%|██████████| 259/259 [01:02<00:00,  4.11it/s]


Finished epoch 20 | Loss: 1.1150


  8%|▊         | 21/259 [00:09<00:50,  4.74it/s]

Loss: 1.1202


 47%|████▋     | 122/259 [00:33<00:30,  4.49it/s]

Loss: 1.1146


 86%|████████▌ | 222/259 [00:55<00:08,  4.24it/s]

Loss: 1.1037


100%|██████████| 259/259 [01:03<00:00,  4.10it/s]


Finished epoch 21 | Loss: 1.0978


 24%|██▍       | 63/259 [00:18<00:55,  3.56it/s]

Loss: 1.0334


 63%|██████▎   | 163/259 [00:41<00:20,  4.80it/s]

Loss: 1.0564


100%|██████████| 259/259 [01:02<00:00,  4.17it/s]


Finished epoch 22 | Loss: 1.0713


  2%|▏         | 4/259 [00:05<03:49,  1.11it/s]

Loss: 1.1072


 40%|████      | 104/259 [00:27<00:32,  4.74it/s]

Loss: 1.0690


 79%|███████▉  | 204/259 [00:50<00:11,  4.99it/s]

Loss: 1.0436


100%|██████████| 259/259 [01:01<00:00,  4.18it/s]


Finished epoch 23 | Loss: 1.0540


 17%|█▋        | 45/259 [00:14<00:50,  4.26it/s]

Loss: 0.9899


 56%|█████▌    | 144/259 [00:36<00:24,  4.61it/s]

Loss: 1.0372


 95%|█████████▍| 245/259 [00:59<00:02,  5.24it/s]

Loss: 1.0281


100%|██████████| 259/259 [01:01<00:00,  4.18it/s]


Finished epoch 24 | Loss: 1.0282


 33%|███▎      | 86/259 [00:23<00:35,  4.92it/s]

Loss: 0.9795


 72%|███████▏  | 186/259 [00:46<00:15,  4.60it/s]

Loss: 0.9935


100%|██████████| 259/259 [01:02<00:00,  4.18it/s]


Finished epoch 25 | Loss: 1.0034


 10%|█         | 27/259 [00:10<00:46,  5.01it/s]

Loss: 0.9977


 49%|████▊     | 126/259 [00:32<00:27,  4.75it/s]

Loss: 0.9839


 88%|████████▊ | 227/259 [00:55<00:06,  4.91it/s]

Loss: 0.9822


100%|██████████| 259/259 [01:01<00:00,  4.19it/s]


Finished epoch 26 | Loss: 0.9827


 26%|██▌       | 67/259 [00:18<00:40,  4.79it/s]

Loss: 0.9259


 64%|██████▍   | 167/259 [00:41<00:19,  4.63it/s]

Loss: 0.9583


100%|██████████| 259/259 [01:01<00:00,  4.18it/s]


Finished epoch 27 | Loss: 0.9652


  3%|▎         | 9/259 [00:06<01:10,  3.53it/s]

Loss: 0.9667


 42%|████▏     | 108/259 [00:28<00:34,  4.36it/s]

Loss: 0.9136


 81%|████████  | 209/259 [00:51<00:10,  4.66it/s]

Loss: 0.9238


100%|██████████| 259/259 [01:02<00:00,  4.16it/s]


Finished epoch 28 | Loss: 0.9419


 19%|█▉        | 50/259 [00:15<00:42,  4.94it/s]

Loss: 0.8950


 58%|█████▊    | 149/259 [00:37<00:22,  4.96it/s]

Loss: 0.9142


 97%|█████████▋| 250/259 [01:00<00:01,  5.32it/s]

Loss: 0.9233


100%|██████████| 259/259 [01:02<00:00,  4.16it/s]


Finished epoch 29 | Loss: 0.9255


 35%|███▌      | 91/259 [00:24<00:37,  4.50it/s]

Loss: 0.9153


 74%|███████▎  | 191/259 [00:47<00:13,  4.86it/s]

Loss: 0.9085


100%|██████████| 259/259 [01:02<00:00,  4.16it/s]


Finished epoch 30 | Loss: 0.9170


 12%|█▏        | 31/259 [00:11<00:46,  4.85it/s]

Loss: 0.8310


 51%|█████     | 132/259 [00:34<00:28,  4.45it/s]

Loss: 0.9061


 89%|████████▉ | 231/259 [00:56<00:05,  4.86it/s]

Loss: 0.8941


100%|██████████| 259/259 [01:02<00:00,  4.17it/s]


Finished epoch 31 | Loss: 0.8997


 28%|██▊       | 73/259 [00:20<00:39,  4.67it/s]

Loss: 0.8475


 67%|██████▋   | 173/259 [00:43<00:17,  4.99it/s]

Loss: 0.8740


100%|██████████| 259/259 [01:02<00:00,  4.13it/s]


Finished epoch 32 | Loss: 0.8794


  5%|▌         | 14/259 [00:08<01:23,  2.94it/s]

Loss: 0.8668


 44%|████▎     | 113/259 [00:30<00:33,  4.42it/s]

Loss: 0.8560


 82%|████████▏ | 213/259 [00:53<00:10,  4.52it/s]

Loss: 0.8520


100%|██████████| 259/259 [01:02<00:00,  4.12it/s]


Finished epoch 33 | Loss: 0.8576


 21%|██        | 55/259 [00:16<00:45,  4.46it/s]

Loss: 0.9153


 60%|█████▉    | 155/259 [00:39<00:21,  4.91it/s]

Loss: 0.8562


 98%|█████████▊| 255/259 [01:01<00:00,  5.42it/s]

Loss: 0.8449


100%|██████████| 259/259 [01:02<00:00,  4.17it/s]


Finished epoch 34 | Loss: 0.8446


 37%|███▋      | 96/259 [00:25<00:33,  4.90it/s]

Loss: 0.7860


 75%|███████▌  | 195/259 [00:49<00:14,  4.32it/s]

Loss: 0.8190


100%|██████████| 259/259 [01:02<00:00,  4.12it/s]


Finished epoch 35 | Loss: 0.8276


 14%|█▍        | 37/259 [00:12<00:48,  4.58it/s]

Loss: 0.8127


 53%|█████▎    | 136/259 [00:35<00:37,  3.29it/s]

Loss: 0.8336


 92%|█████████▏| 237/259 [00:58<00:04,  4.46it/s]

Loss: 0.8290


100%|██████████| 259/259 [01:02<00:00,  4.12it/s]


Finished epoch 36 | Loss: 0.8343


 30%|███       | 78/259 [00:22<00:41,  4.36it/s]

Loss: 0.8089


 68%|██████▊   | 177/259 [00:45<00:16,  4.86it/s]

Loss: 0.8073


100%|██████████| 259/259 [01:03<00:00,  4.11it/s]


Finished epoch 37 | Loss: 0.8159


  7%|▋         | 18/259 [00:09<01:02,  3.89it/s]

Loss: 0.8364


 46%|████▌     | 118/259 [00:31<00:28,  4.88it/s]

Loss: 0.8020


 84%|████████▍ | 218/259 [00:54<00:11,  3.57it/s]

Loss: 0.7907


100%|██████████| 259/259 [01:03<00:00,  4.09it/s]


Finished epoch 38 | Loss: 0.7958


 23%|██▎       | 60/259 [00:18<00:40,  4.96it/s]

Loss: 0.7699


 61%|██████▏   | 159/259 [00:41<00:20,  4.99it/s]

Loss: 0.7784


100%|██████████| 259/259 [01:03<00:00,  4.05it/s]

Loss: 0.7818
Finished epoch 39 | Loss: 0.7818



 39%|███▉      | 101/259 [00:27<00:37,  4.21it/s]

Loss: 0.7574


 77%|███████▋  | 200/259 [00:49<00:12,  4.88it/s]

Loss: 0.7820


100%|██████████| 259/259 [01:02<00:00,  4.16it/s]


Finished epoch 40 | Loss: 0.7857


 16%|█▌        | 42/259 [00:14<01:02,  3.48it/s]

Loss: 0.7452


 55%|█████▍    | 142/259 [00:36<00:25,  4.66it/s]

Loss: 0.7770


 93%|█████████▎| 242/259 [00:59<00:03,  5.13it/s]

Loss: 0.7779


100%|██████████| 259/259 [01:02<00:00,  4.13it/s]


Finished epoch 41 | Loss: 0.7786


 32%|███▏      | 82/259 [00:22<00:40,  4.37it/s]

Loss: 0.7504


 71%|███████   | 183/259 [00:45<00:19,  3.88it/s]

Loss: 0.7751


100%|██████████| 259/259 [01:01<00:00,  4.18it/s]


Finished epoch 42 | Loss: 0.7717


  9%|▉         | 23/259 [00:09<00:48,  4.91it/s]

Loss: 0.7817


 48%|████▊     | 124/259 [00:32<00:30,  4.47it/s]

Loss: 0.7514


 86%|████████▋ | 224/259 [00:55<00:07,  4.80it/s]

Loss: 0.7479


100%|██████████| 259/259 [01:02<00:00,  4.16it/s]


Finished epoch 43 | Loss: 0.7493


 25%|██▍       | 64/259 [00:19<00:57,  3.41it/s]

Loss: 0.7475


 64%|██████▎   | 165/259 [00:42<00:20,  4.70it/s]

Loss: 0.7475


100%|██████████| 259/259 [01:03<00:00,  4.10it/s]


Finished epoch 44 | Loss: 0.7487


  2%|▏         | 6/259 [00:05<01:50,  2.28it/s]

Loss: 0.6996


 41%|████      | 105/259 [00:28<00:32,  4.73it/s]

Loss: 0.7566


 80%|███████▉  | 206/259 [00:51<00:10,  5.03it/s]

Loss: 0.7414


100%|██████████| 259/259 [01:02<00:00,  4.16it/s]


Finished epoch 45 | Loss: 0.7469


 18%|█▊        | 47/259 [00:14<00:41,  5.06it/s]

Loss: 0.7188


 56%|█████▋    | 146/259 [00:37<00:25,  4.35it/s]

Loss: 0.7215


 95%|█████████▌| 247/259 [01:00<00:02,  5.28it/s]

Loss: 0.7250


100%|██████████| 259/259 [01:02<00:00,  4.15it/s]


Finished epoch 46 | Loss: 0.7257


 34%|███▍      | 88/259 [00:24<00:34,  4.94it/s]

Loss: 0.6896


 73%|███████▎  | 188/259 [00:47<00:18,  3.94it/s]

Loss: 0.7110


100%|██████████| 259/259 [01:02<00:00,  4.13it/s]


Finished epoch 47 | Loss: 0.7156


 11%|█         | 29/259 [00:10<00:54,  4.22it/s]

Loss: 0.7303


 49%|████▉     | 128/259 [00:33<00:26,  4.87it/s]

Loss: 0.7171


 88%|████████▊ | 229/259 [00:56<00:06,  4.97it/s]

Loss: 0.7132


100%|██████████| 259/259 [01:02<00:00,  4.17it/s]


Finished epoch 48 | Loss: 0.7192


 27%|██▋       | 70/259 [00:19<00:45,  4.14it/s]

Loss: 0.6878


 66%|██████▌   | 170/259 [00:42<00:18,  4.88it/s]

Loss: 0.6920


100%|██████████| 259/259 [01:02<00:00,  4.17it/s]


Finished epoch 49 | Loss: 0.7037


  4%|▍         | 10/259 [00:06<01:05,  3.78it/s]

Loss: 0.7970


 42%|████▏     | 110/259 [00:28<00:30,  4.88it/s]

Loss: 0.7054


 81%|████████  | 210/259 [00:51<00:14,  3.47it/s]

Loss: 0.6982


100%|██████████| 259/259 [01:01<00:00,  4.20it/s]


Finished epoch 50 | Loss: 0.7012


 20%|█▉        | 51/259 [00:15<00:41,  4.95it/s]

Loss: 0.6578


 59%|█████▊    | 152/259 [00:39<00:31,  3.38it/s]

Loss: 0.6617


 97%|█████████▋| 252/259 [01:01<00:01,  5.41it/s]

Loss: 0.6563


100%|██████████| 259/259 [01:02<00:00,  4.15it/s]


Finished epoch 51 | Loss: 0.6545


 36%|███▌      | 93/259 [00:25<00:35,  4.66it/s]

Loss: 0.6193


 75%|███████▍  | 193/259 [00:47<00:13,  4.83it/s]

Loss: 0.6276


100%|██████████| 259/259 [01:01<00:00,  4.19it/s]


Finished epoch 52 | Loss: 0.6341


 13%|█▎        | 34/259 [00:12<00:50,  4.50it/s]

Loss: 0.6250


 52%|█████▏    | 134/259 [00:34<00:25,  4.97it/s]

Loss: 0.6253


 90%|████████▉ | 233/259 [00:57<00:05,  5.04it/s]

Loss: 0.6192


100%|██████████| 259/259 [01:02<00:00,  4.14it/s]


Finished epoch 53 | Loss: 0.6219


 29%|██▊       | 74/259 [00:21<00:38,  4.83it/s]

Loss: 0.5968


 68%|██████▊   | 175/259 [00:44<00:21,  3.94it/s]

Loss: 0.6148


100%|██████████| 259/259 [01:02<00:00,  4.12it/s]


Finished epoch 54 | Loss: 0.6118


  6%|▌         | 15/259 [00:07<00:52,  4.67it/s]

Loss: 0.6951


 45%|████▍     | 116/259 [00:30<00:31,  4.50it/s]

Loss: 0.6052


 83%|████████▎ | 216/259 [00:53<00:08,  4.94it/s]

Loss: 0.6045


100%|██████████| 259/259 [01:02<00:00,  4.15it/s]


Finished epoch 55 | Loss: 0.6103


 22%|██▏       | 56/259 [00:16<00:39,  5.08it/s]

Loss: 0.5903


 60%|██████    | 156/259 [00:39<00:20,  5.01it/s]

Loss: 0.5957


 99%|█████████▉| 257/259 [01:02<00:00,  5.41it/s]

Loss: 0.6046


100%|██████████| 259/259 [01:03<00:00,  4.10it/s]


Finished epoch 56 | Loss: 0.6046


 37%|███▋      | 97/259 [00:26<00:34,  4.76it/s]

Loss: 0.5833


 76%|███████▌  | 197/259 [00:48<00:13,  4.76it/s]

Loss: 0.6033


100%|██████████| 259/259 [01:02<00:00,  4.18it/s]


Finished epoch 57 | Loss: 0.6015


 15%|█▍        | 38/259 [00:13<01:10,  3.12it/s]

Loss: 0.5735


 54%|█████▎    | 139/259 [00:36<00:26,  4.60it/s]

Loss: 0.5826


 92%|█████████▏| 239/259 [00:59<00:05,  3.88it/s]

Loss: 0.5839


100%|██████████| 259/259 [01:03<00:00,  4.09it/s]


Finished epoch 58 | Loss: 0.5835


 31%|███       | 79/259 [00:21<00:41,  4.29it/s]

Loss: 0.5655


 69%|██████▉   | 179/259 [00:44<00:17,  4.50it/s]

Loss: 0.5789


100%|██████████| 259/259 [01:02<00:00,  4.16it/s]


Finished epoch 59 | Loss: 0.5887


  8%|▊         | 21/259 [00:09<00:48,  4.87it/s]

Loss: 0.5892


 46%|████▋     | 120/259 [00:31<00:27,  5.02it/s]

Loss: 0.5921


 85%|████████▍ | 220/259 [00:54<00:10,  3.89it/s]

Loss: 0.5885


100%|██████████| 259/259 [01:02<00:00,  4.15it/s]


Finished epoch 60 | Loss: 0.5905


 24%|██▍       | 62/259 [00:17<00:39,  4.96it/s]

Loss: 0.5867


 63%|██████▎   | 162/259 [00:41<00:21,  4.60it/s]

Loss: 0.5794


100%|██████████| 259/259 [01:02<00:00,  4.16it/s]


Finished epoch 61 | Loss: 0.5764


  1%|          | 2/259 [00:04<07:18,  1.71s/it]

Loss: 0.8925


 40%|███▉      | 103/259 [00:27<00:44,  3.49it/s]

Loss: 0.5713


 78%|███████▊  | 203/259 [00:50<00:11,  4.67it/s]

Loss: 0.5730


100%|██████████| 259/259 [01:02<00:00,  4.16it/s]


Finished epoch 62 | Loss: 0.5697


 17%|█▋        | 43/259 [00:13<00:57,  3.73it/s]

Loss: 0.5794


 55%|█████▌    | 143/259 [00:36<00:25,  4.46it/s]

Loss: 0.5745


 94%|█████████▍| 244/259 [00:59<00:02,  5.22it/s]

Loss: 0.5635


100%|██████████| 259/259 [01:01<00:00,  4.18it/s]


Finished epoch 63 | Loss: 0.5622


 33%|███▎      | 85/259 [00:23<00:35,  4.91it/s]

Loss: 0.5561


 71%|███████▏  | 185/259 [00:46<00:17,  4.30it/s]

Loss: 0.5664


100%|██████████| 259/259 [01:02<00:00,  4.12it/s]


Finished epoch 64 | Loss: 0.5681


 10%|▉         | 25/259 [00:10<00:48,  4.87it/s]

Loss: 0.5498


 49%|████▊     | 126/259 [00:32<00:28,  4.70it/s]

Loss: 0.5516


 87%|████████▋ | 226/259 [00:55<00:06,  5.02it/s]

Loss: 0.5549


100%|██████████| 259/259 [01:02<00:00,  4.16it/s]


Finished epoch 65 | Loss: 0.5603


 26%|██▌       | 67/259 [00:19<00:44,  4.27it/s]

Loss: 0.5483


 64%|██████▍   | 167/259 [00:42<00:18,  4.92it/s]

Loss: 0.5369


100%|██████████| 259/259 [01:02<00:00,  4.14it/s]


Finished epoch 66 | Loss: 0.5482


  3%|▎         | 8/259 [00:06<01:21,  3.08it/s]

Loss: 0.5750


 42%|████▏     | 108/259 [00:28<00:30,  4.95it/s]

Loss: 0.5456


 80%|████████  | 208/259 [00:51<00:14,  3.52it/s]

Loss: 0.5434


100%|██████████| 259/259 [01:02<00:00,  4.13it/s]


Finished epoch 67 | Loss: 0.5460


 19%|█▉        | 49/259 [00:15<00:41,  5.06it/s]

Loss: 0.5807


 57%|█████▋    | 148/259 [00:37<00:26,  4.22it/s]

Loss: 0.5551


 96%|█████████▌| 249/259 [01:00<00:01,  5.25it/s]

Loss: 0.5496


100%|██████████| 259/259 [01:02<00:00,  4.12it/s]


Finished epoch 68 | Loss: 0.5495


 35%|███▍      | 90/259 [00:24<00:47,  3.54it/s]

Loss: 0.5255


 73%|███████▎  | 190/259 [00:47<00:14,  4.61it/s]

Loss: 0.5413


100%|██████████| 259/259 [01:02<00:00,  4.16it/s]


Finished epoch 69 | Loss: 0.5492


 12%|█▏        | 30/259 [00:11<00:47,  4.86it/s]

Loss: 0.5492


 51%|█████     | 131/259 [00:33<00:25,  5.05it/s]

Loss: 0.5338


 89%|████████▉ | 231/259 [00:56<00:06,  4.07it/s]

Loss: 0.5416


100%|██████████| 259/259 [01:02<00:00,  4.14it/s]


Finished epoch 70 | Loss: 0.5472


 27%|██▋       | 71/259 [00:19<00:43,  4.31it/s]

Loss: 0.5431


 66%|██████▌   | 171/259 [00:42<00:19,  4.42it/s]

Loss: 0.5407


100%|██████████| 259/259 [01:02<00:00,  4.12it/s]


Finished epoch 71 | Loss: 0.5453


  5%|▍         | 12/259 [00:07<00:55,  4.43it/s]

Loss: 0.6282


 43%|████▎     | 112/259 [00:29<00:32,  4.52it/s]

Loss: 0.5294


 82%|████████▏ | 212/259 [00:52<00:10,  4.40it/s]

Loss: 0.5317


100%|██████████| 259/259 [01:02<00:00,  4.14it/s]


Finished epoch 72 | Loss: 0.5340


 21%|██        | 54/259 [00:16<00:40,  5.04it/s]

Loss: 0.5517


 59%|█████▉    | 153/259 [00:38<00:25,  4.16it/s]

Loss: 0.5492


 98%|█████████▊| 254/259 [01:01<00:00,  5.36it/s]

Loss: 0.5415


100%|██████████| 259/259 [01:02<00:00,  4.17it/s]


Finished epoch 73 | Loss: 0.5424


 37%|███▋      | 95/259 [00:26<00:35,  4.68it/s]

Loss: 0.5355


 75%|███████▌  | 195/259 [00:49<00:13,  4.82it/s]

Loss: 0.5363


100%|██████████| 259/259 [01:03<00:00,  4.10it/s]


Finished epoch 74 | Loss: 0.5356


 14%|█▎        | 35/259 [00:12<00:45,  4.97it/s]

Loss: 0.5037


 53%|█████▎    | 136/259 [00:35<00:29,  4.15it/s]

Loss: 0.5356


 91%|█████████ | 236/259 [00:58<00:04,  4.68it/s]

Loss: 0.5274


100%|██████████| 259/259 [01:02<00:00,  4.14it/s]


Finished epoch 75 | Loss: 0.5273


 29%|██▉       | 76/259 [00:22<00:36,  4.97it/s]

Loss: 0.5196


 68%|██████▊   | 177/259 [00:45<00:18,  4.54it/s]

Loss: 0.5159


100%|██████████| 259/259 [01:03<00:00,  4.08it/s]


Finished epoch 76 | Loss: 0.5146


  7%|▋         | 18/259 [00:08<01:00,  3.99it/s]

Loss: 0.5722


 45%|████▌     | 117/259 [00:31<00:30,  4.59it/s]

Loss: 0.5203


 84%|████████▍ | 218/259 [00:53<00:08,  5.01it/s]

Loss: 0.5132


100%|██████████| 259/259 [01:02<00:00,  4.14it/s]


Finished epoch 77 | Loss: 0.5123


 23%|██▎       | 59/259 [00:17<00:40,  4.88it/s]

Loss: 0.5043


 61%|██████    | 158/259 [00:40<00:20,  5.01it/s]

Loss: 0.4977


100%|██████████| 259/259 [01:03<00:00,  6.09it/s]

Loss: 0.5008


100%|██████████| 259/259 [01:03<00:00,  4.07it/s]


Finished epoch 78 | Loss: 0.5008


 38%|███▊      | 99/259 [00:26<00:31,  5.01it/s]

Loss: 0.4922


 77%|███████▋  | 200/259 [00:51<00:12,  4.69it/s]

Loss: 0.5007


100%|██████████| 259/259 [01:04<00:00,  4.03it/s]


Finished epoch 79 | Loss: 0.5009


 16%|█▌        | 41/259 [00:13<00:43,  5.04it/s]

Loss: 0.5126


 54%|█████▍    | 141/259 [00:37<00:25,  4.57it/s]

Loss: 0.5182


 93%|█████████▎| 241/259 [00:59<00:03,  5.08it/s]

Loss: 0.5082


100%|██████████| 259/259 [01:03<00:00,  4.08it/s]


Finished epoch 80 | Loss: 0.5065


 31%|███▏      | 81/259 [00:22<00:40,  4.43it/s]

Loss: 0.4864


 70%|██████▉   | 181/259 [00:46<00:21,  3.67it/s]

Loss: 0.4983


100%|██████████| 259/259 [01:04<00:00,  4.04it/s]


Finished epoch 81 | Loss: 0.4974


  9%|▉         | 23/259 [00:10<00:49,  4.73it/s]

Loss: 0.5280


 47%|████▋     | 122/259 [00:32<00:31,  4.34it/s]

Loss: 0.4982


 86%|████████▌ | 222/259 [00:56<00:08,  4.42it/s]

Loss: 0.4919


100%|██████████| 259/259 [01:03<00:00,  4.05it/s]


Finished epoch 82 | Loss: 0.4930


 25%|██▍       | 64/259 [00:18<00:39,  4.91it/s]

Loss: 0.4723


 63%|██████▎   | 164/259 [00:42<00:23,  4.02it/s]

Loss: 0.4921


100%|██████████| 259/259 [01:03<00:00,  4.08it/s]


Finished epoch 83 | Loss: 0.4826


  2%|▏         | 5/259 [00:05<02:30,  1.69it/s]

Loss: 0.5144


 40%|████      | 104/259 [00:28<00:33,  4.66it/s]

Loss: 0.4845


 79%|███████▉  | 204/259 [00:51<00:11,  4.95it/s]

Loss: 0.4858


100%|██████████| 259/259 [01:02<00:00,  4.11it/s]


Finished epoch 84 | Loss: 0.4913


 18%|█▊        | 46/259 [00:14<00:42,  5.01it/s]

Loss: 0.4895


 56%|█████▋    | 146/259 [00:38<00:30,  3.67it/s]

Loss: 0.4957


 95%|█████████▍| 246/259 [01:00<00:02,  5.30it/s]

Loss: 0.4866


100%|██████████| 259/259 [01:03<00:00,  4.09it/s]


Finished epoch 85 | Loss: 0.4859


 34%|███▎      | 87/259 [00:24<00:57,  2.97it/s]

Loss: 0.4690


 72%|███████▏  | 186/259 [00:47<00:15,  4.70it/s]

Loss: 0.4784


100%|██████████| 259/259 [01:03<00:00,  4.08it/s]


Finished epoch 86 | Loss: 0.4850


 11%|█         | 28/259 [00:10<00:54,  4.25it/s]

Loss: 0.4790


 49%|████▉     | 128/259 [00:34<00:30,  4.34it/s]

Loss: 0.4793


 88%|████████▊ | 228/259 [00:57<00:06,  4.78it/s]

Loss: 0.4821


100%|██████████| 259/259 [01:03<00:00,  4.08it/s]


Finished epoch 87 | Loss: 0.4811


 27%|██▋       | 69/259 [00:21<00:51,  3.67it/s]

Loss: 0.4502


 65%|██████▌   | 169/259 [00:44<00:21,  4.25it/s]

Loss: 0.4753


100%|██████████| 259/259 [01:04<00:00,  4.04it/s]


Finished epoch 88 | Loss: 0.4849


  4%|▍         | 10/259 [00:06<01:06,  3.72it/s]

Loss: 0.5076


 42%|████▏     | 109/259 [00:29<00:29,  5.02it/s]

Loss: 0.4811


 81%|████████  | 210/259 [00:53<00:10,  4.73it/s]

Loss: 0.4772


100%|██████████| 259/259 [01:03<00:00,  4.08it/s]


Finished epoch 89 | Loss: 0.4795


 19%|█▉        | 50/259 [00:15<00:41,  4.99it/s]

Loss: 0.4946


 58%|█████▊    | 150/259 [00:39<00:28,  3.77it/s]

Loss: 0.4784


 97%|█████████▋| 251/259 [01:01<00:01,  5.32it/s]

Loss: 0.4753


100%|██████████| 259/259 [01:03<00:00,  4.10it/s]


Finished epoch 90 | Loss: 0.4758


 35%|███▌      | 91/259 [00:25<00:41,  4.02it/s]

Loss: 0.4828


 74%|███████▎  | 191/259 [00:48<00:18,  3.73it/s]

Loss: 0.4819


100%|██████████| 259/259 [01:03<00:00,  4.09it/s]


Finished epoch 91 | Loss: 0.4829


 13%|█▎        | 33/259 [00:11<00:59,  3.83it/s]

Loss: 0.4680


 51%|█████▏    | 133/259 [00:34<00:25,  4.87it/s]

Loss: 0.4684


 90%|████████▉ | 233/259 [00:59<00:05,  5.00it/s]

Loss: 0.4665


100%|██████████| 259/259 [01:04<00:00,  4.03it/s]


Finished epoch 92 | Loss: 0.4705


 29%|██▊       | 74/259 [00:21<00:37,  4.91it/s]

Loss: 0.4448


 67%|██████▋   | 174/259 [00:45<00:16,  5.02it/s]

Loss: 0.4673


100%|██████████| 259/259 [01:04<00:00,  4.02it/s]


Finished epoch 93 | Loss: 0.4761


  6%|▌         | 15/259 [00:08<01:01,  3.94it/s]

Loss: 0.5500


 44%|████▍     | 115/259 [00:31<00:33,  4.36it/s]

Loss: 0.4611


 83%|████████▎ | 214/259 [00:53<00:09,  4.58it/s]

Loss: 0.4604


100%|██████████| 259/259 [01:03<00:00,  4.08it/s]


Finished epoch 94 | Loss: 0.4643


 21%|██        | 55/259 [00:16<00:41,  4.93it/s]

Loss: 0.4454


 60%|█████▉    | 155/259 [00:39<00:20,  4.96it/s]

Loss: 0.4640


 99%|█████████▉| 256/259 [01:02<00:00,  5.41it/s]

Loss: 0.4691


100%|██████████| 259/259 [01:03<00:00,  4.07it/s]


Finished epoch 95 | Loss: 0.4686


 37%|███▋      | 97/259 [00:26<00:32,  4.94it/s]

Loss: 0.4780


 76%|███████▌  | 196/259 [00:49<00:12,  4.97it/s]

Loss: 0.4754


100%|██████████| 259/259 [01:02<00:00,  4.14it/s]


Finished epoch 96 | Loss: 0.4759


 14%|█▍        | 37/259 [00:12<00:44,  4.99it/s]

Loss: 0.4731


 53%|█████▎    | 137/259 [00:35<00:27,  4.52it/s]

Loss: 0.4666


 92%|█████████▏| 238/259 [00:58<00:04,  4.79it/s]

Loss: 0.4650


100%|██████████| 259/259 [01:02<00:00,  4.15it/s]


Finished epoch 97 | Loss: 0.4642


 31%|███       | 79/259 [00:22<00:41,  4.31it/s]

Loss: 0.4476


 69%|██████▉   | 179/259 [00:45<00:16,  4.75it/s]

Loss: 0.4600


100%|██████████| 259/259 [01:02<00:00,  4.13it/s]


Finished epoch 98 | Loss: 0.4584


  7%|▋         | 19/259 [00:08<00:52,  4.61it/s]

Loss: 0.4694


 46%|████▌     | 119/259 [00:31<00:28,  5.00it/s]

Loss: 0.4516


 85%|████████▍ | 220/259 [00:54<00:09,  3.96it/s]

Loss: 0.4532


100%|██████████| 259/259 [01:02<00:00,  4.13it/s]


Finished epoch 99 | Loss: 0.4584


 24%|██▎       | 61/259 [00:17<00:40,  4.88it/s]

Loss: 0.4636


 62%|██████▏   | 160/259 [00:40<00:19,  4.98it/s]

Loss: 0.4663


100%|██████████| 259/259 [01:02<00:00,  4.16it/s]


Finished epoch 100 | Loss: 0.4638


  0%|          | 1/259 [00:05<21:47,  5.07s/it]

Loss: 0.4218


 39%|███▉      | 101/259 [00:27<00:37,  4.23it/s]

Loss: 0.4504


 78%|███████▊  | 201/259 [00:49<00:13,  4.30it/s]

Loss: 0.4533


100%|██████████| 259/259 [01:02<00:00,  4.15it/s]


Finished epoch 101 | Loss: 0.4570


 16%|█▌        | 42/259 [00:13<01:00,  3.62it/s]

Loss: 0.4233


 55%|█████▌    | 143/259 [00:36<00:24,  4.67it/s]

Loss: 0.4549


 94%|█████████▍| 243/259 [00:59<00:03,  5.13it/s]

Loss: 0.4547


100%|██████████| 259/259 [01:02<00:00,  4.17it/s]


Finished epoch 102 | Loss: 0.4536


 32%|███▏      | 84/259 [00:23<00:37,  4.70it/s]

Loss: 0.4260


 71%|███████   | 184/259 [00:46<00:14,  5.03it/s]

Loss: 0.4404


100%|██████████| 259/259 [01:02<00:00,  4.14it/s]


Finished epoch 103 | Loss: 0.4488


  9%|▉         | 24/259 [00:09<00:48,  4.89it/s]

Loss: 0.4678


 48%|████▊     | 125/259 [00:32<00:32,  4.17it/s]

Loss: 0.4528


 86%|████████▋ | 224/259 [00:55<00:07,  4.80it/s]

Loss: 0.4571


100%|██████████| 259/259 [01:02<00:00,  4.14it/s]


Finished epoch 104 | Loss: 0.4591


 25%|██▌       | 66/259 [00:18<00:41,  4.61it/s]

Loss: 0.4279


 64%|██████▍   | 166/259 [00:42<00:18,  5.00it/s]

Loss: 0.4346


100%|██████████| 259/259 [01:02<00:00,  4.15it/s]


Finished epoch 105 | Loss: 0.4422


  2%|▏         | 6/259 [00:05<01:57,  2.15it/s]

Loss: 0.4088


 41%|████      | 106/259 [00:28<00:34,  4.41it/s]

Loss: 0.4567


 80%|███████▉  | 207/259 [00:51<00:12,  4.30it/s]

Loss: 0.4517


100%|██████████| 259/259 [01:02<00:00,  4.17it/s]


Finished epoch 106 | Loss: 0.4520


 19%|█▊        | 48/259 [00:15<00:42,  5.00it/s]

Loss: 0.4436


 57%|█████▋    | 148/259 [00:38<00:31,  3.48it/s]

Loss: 0.4406


 96%|█████████▌| 248/259 [01:00<00:02,  5.26it/s]

Loss: 0.4438


100%|██████████| 259/259 [01:02<00:00,  4.14it/s]


Finished epoch 107 | Loss: 0.4453


 34%|███▍      | 89/259 [00:24<00:42,  4.04it/s]

Loss: 0.4220


 73%|███████▎  | 189/259 [00:46<00:14,  4.81it/s]

Loss: 0.4370


100%|██████████| 259/259 [01:02<00:00,  4.17it/s]


Finished epoch 108 | Loss: 0.4389


 11%|█         | 29/259 [00:10<00:59,  3.90it/s]

Loss: 0.4310


 50%|████▉     | 129/259 [00:33<00:27,  4.73it/s]

Loss: 0.4318


 88%|████████▊ | 229/259 [00:56<00:06,  5.00it/s]

Loss: 0.4313


100%|██████████| 259/259 [01:02<00:00,  4.13it/s]


Finished epoch 109 | Loss: 0.4350


 27%|██▋       | 71/259 [00:20<00:39,  4.71it/s]

Loss: 0.4281


 66%|██████▌   | 171/259 [00:43<00:17,  4.95it/s]

Loss: 0.4496


100%|██████████| 259/259 [01:02<00:00,  4.16it/s]


Finished epoch 110 | Loss: 0.4483


  5%|▍         | 12/259 [00:07<00:56,  4.37it/s]

Loss: 0.4987


 43%|████▎     | 112/259 [00:30<00:32,  4.48it/s]

Loss: 0.4378


 82%|████████▏ | 212/259 [00:52<00:09,  4.77it/s]

Loss: 0.4411


100%|██████████| 259/259 [01:02<00:00,  4.14it/s]


Finished epoch 111 | Loss: 0.4407


 20%|██        | 52/259 [00:15<00:41,  5.02it/s]

Loss: 0.4428


 59%|█████▊    | 152/259 [00:38<00:25,  4.19it/s]

Loss: 0.4439


 98%|█████████▊| 253/259 [01:01<00:01,  5.29it/s]

Loss: 0.4446


100%|██████████| 259/259 [01:02<00:00,  4.14it/s]


Finished epoch 112 | Loss: 0.4472


 36%|███▋      | 94/259 [00:25<00:37,  4.36it/s]

Loss: 0.4386


 75%|███████▍  | 194/259 [00:48<00:13,  4.95it/s]

Loss: 0.4397


100%|██████████| 259/259 [01:02<00:00,  4.15it/s]


Finished epoch 113 | Loss: 0.4450


 14%|█▎        | 35/259 [00:12<00:46,  4.84it/s]

Loss: 0.4443


 52%|█████▏    | 135/259 [00:35<00:24,  5.02it/s]

Loss: 0.4474


 90%|█████████ | 234/259 [00:58<00:05,  4.48it/s]

Loss: 0.4390


100%|██████████| 259/259 [01:03<00:00,  4.10it/s]


Finished epoch 114 | Loss: 0.4425


 29%|██▉       | 76/259 [00:21<00:53,  3.45it/s]

Loss: 0.4371


 68%|██████▊   | 176/259 [00:44<00:18,  4.42it/s]

Loss: 0.4357


100%|██████████| 259/259 [01:02<00:00,  4.13it/s]


Finished epoch 115 | Loss: 0.4400


  7%|▋         | 17/259 [00:08<00:56,  4.27it/s]

Loss: 0.4948


 45%|████▍     | 116/259 [00:30<00:34,  4.18it/s]

Loss: 0.4415


 83%|████████▎ | 216/259 [00:53<00:09,  4.42it/s]

Loss: 0.4350


100%|██████████| 259/259 [01:02<00:00,  4.15it/s]


Finished epoch 116 | Loss: 0.4401


 22%|██▏       | 58/259 [00:16<00:40,  5.00it/s]

Loss: 0.4358


 61%|██████    | 157/259 [00:40<00:22,  4.58it/s]

Loss: 0.4347


100%|█████████▉| 258/259 [01:02<00:00,  5.39it/s]

Loss: 0.4312


100%|██████████| 259/259 [01:02<00:00,  4.14it/s]


Finished epoch 117 | Loss: 0.4312


 38%|███▊      | 98/259 [00:26<00:32,  4.98it/s]

Loss: 0.4416


 77%|███████▋  | 199/259 [00:49<00:13,  4.30it/s]

Loss: 0.4379


100%|██████████| 259/259 [01:02<00:00,  4.12it/s]


Finished epoch 118 | Loss: 0.4337


 15%|█▌        | 40/259 [00:12<00:46,  4.73it/s]

Loss: 0.4151


 54%|█████▎    | 139/259 [00:35<00:26,  4.58it/s]

Loss: 0.4247


 93%|█████████▎| 240/259 [00:58<00:03,  5.14it/s]

Loss: 0.4358


100%|██████████| 259/259 [01:02<00:00,  4.16it/s]


Finished epoch 119 | Loss: 0.4365


 31%|███       | 80/259 [00:22<00:42,  4.26it/s]

Loss: 0.4330


 69%|██████▉   | 180/259 [00:45<00:18,  4.24it/s]

Loss: 0.4333


100%|██████████| 259/259 [01:02<00:00,  4.16it/s]


Finished epoch 120 | Loss: 0.4334


  8%|▊         | 22/259 [00:09<00:48,  4.88it/s]

Loss: 0.4415


 47%|████▋     | 122/259 [00:32<00:28,  4.80it/s]

Loss: 0.4243


 85%|████████▌ | 221/259 [00:54<00:07,  5.07it/s]

Loss: 0.4381


100%|██████████| 259/259 [01:03<00:00,  4.11it/s]


Finished epoch 121 | Loss: 0.4400


 24%|██▍       | 63/259 [00:18<00:38,  5.09it/s]

Loss: 0.4244


 63%|██████▎   | 162/259 [00:41<00:23,  4.07it/s]

Loss: 0.4346


100%|██████████| 259/259 [01:03<00:00,  4.11it/s]


Finished epoch 122 | Loss: 0.4345


  1%|          | 3/259 [00:04<05:06,  1.20s/it]

Loss: 0.6214


 40%|████      | 104/259 [00:27<00:34,  4.51it/s]

Loss: 0.4350


 78%|███████▊  | 203/259 [00:50<00:11,  4.87it/s]

Loss: 0.4219


100%|██████████| 259/259 [01:02<00:00,  4.16it/s]


Finished epoch 123 | Loss: 0.4316


 17%|█▋        | 45/259 [00:14<00:49,  4.35it/s]

Loss: 0.4402


 56%|█████▌    | 144/259 [00:36<00:23,  4.81it/s]

Loss: 0.4350


 95%|█████████▍| 245/259 [00:59<00:03,  4.23it/s]

Loss: 0.4349


100%|██████████| 259/259 [01:02<00:00,  4.16it/s]


Finished epoch 124 | Loss: 0.4361


 33%|███▎      | 86/259 [00:23<00:43,  4.01it/s]

Loss: 0.4354


 71%|███████▏  | 185/259 [00:46<00:16,  4.57it/s]

Loss: 0.4327


100%|██████████| 259/259 [01:02<00:00,  4.17it/s]


Finished epoch 125 | Loss: 0.4303


 10%|█         | 27/259 [00:10<00:46,  4.94it/s]

Loss: 0.4490


 49%|████▊     | 126/259 [00:34<00:39,  3.33it/s]

Loss: 0.4259


 88%|████████▊ | 227/259 [00:57<00:06,  4.88it/s]

Loss: 0.4220


100%|██████████| 259/259 [01:04<00:00,  4.02it/s]


Finished epoch 126 | Loss: 0.4319


 26%|██▋       | 68/259 [00:19<00:42,  4.52it/s]

Loss: 0.4410


 65%|██████▍   | 168/259 [00:42<00:18,  4.94it/s]

Loss: 0.4441


100%|██████████| 259/259 [01:02<00:00,  4.13it/s]


Finished epoch 127 | Loss: 0.4380


  3%|▎         | 8/259 [00:06<01:19,  3.14it/s]

Loss: 0.4212


 42%|████▏     | 109/259 [00:28<00:35,  4.24it/s]

Loss: 0.4222


 80%|████████  | 208/259 [00:51<00:10,  5.02it/s]

Loss: 0.4230


100%|██████████| 259/259 [01:02<00:00,  4.17it/s]


Finished epoch 128 | Loss: 0.4219


 19%|█▉        | 49/259 [00:14<00:41,  5.05it/s]

Loss: 0.4301


 58%|█████▊    | 150/259 [00:38<00:27,  3.92it/s]

Loss: 0.4108


 97%|█████████▋| 250/259 [01:00<00:01,  5.27it/s]

Loss: 0.4120


100%|██████████| 259/259 [01:01<00:00,  4.18it/s]


Finished epoch 129 | Loss: 0.4142


 35%|███▌      | 91/259 [00:24<00:37,  4.42it/s]

Loss: 0.4147


 73%|███████▎  | 190/259 [00:47<00:16,  4.06it/s]

Loss: 0.4193


100%|██████████| 259/259 [01:02<00:00,  4.15it/s]


Finished epoch 130 | Loss: 0.4257


 12%|█▏        | 32/259 [00:11<00:45,  4.98it/s]

Loss: 0.4488


 51%|█████     | 131/259 [00:34<00:30,  4.22it/s]

Loss: 0.4373


 89%|████████▉ | 231/259 [00:57<00:06,  4.55it/s]

Loss: 0.4212


100%|██████████| 259/259 [01:02<00:00,  4.13it/s]


Finished epoch 131 | Loss: 0.4253


 28%|██▊       | 72/259 [00:20<00:38,  4.83it/s]

Loss: 0.4072


 66%|██████▋   | 172/259 [00:44<00:25,  3.41it/s]

Loss: 0.4143


100%|██████████| 259/259 [01:03<00:00,  4.10it/s]


Finished epoch 132 | Loss: 0.4149


  5%|▌         | 14/259 [00:08<01:13,  3.34it/s]

Loss: 0.4793


 44%|████▎     | 113/259 [00:30<00:33,  4.42it/s]

Loss: 0.4200


 82%|████████▏ | 213/259 [00:52<00:10,  4.43it/s]

Loss: 0.4196


100%|██████████| 259/259 [01:02<00:00,  4.14it/s]


Finished epoch 133 | Loss: 0.4250


 21%|██        | 55/259 [00:16<00:46,  4.43it/s]

Loss: 0.4039


 59%|█████▉    | 154/259 [00:39<00:22,  4.77it/s]

Loss: 0.4130


 98%|█████████▊| 255/259 [01:01<00:00,  5.35it/s]

Loss: 0.4149


100%|██████████| 259/259 [01:02<00:00,  4.15it/s]


Finished epoch 134 | Loss: 0.4160


 37%|███▋      | 95/259 [00:25<00:37,  4.32it/s]

Loss: 0.4057


 75%|███████▌  | 195/259 [00:48<00:14,  4.33it/s]

Loss: 0.4176


100%|██████████| 259/259 [01:02<00:00,  4.16it/s]


Finished epoch 135 | Loss: 0.4206
Done Training...


In [19]:
COLORS = [
    [0, 0, 0],
    [200, 0, 0],
    [0, 200, 0],
    [50, 128, 0],
    [0, 0, 128],
    [128, 0, 128],
    [0, 128, 128],
    [128, 128, 128],
    [50, 0, 0],
    [192, 0, 0],
    [64, 128, 0],
    [192, 128, 0],
    [64, 0, 128],
    [192, 0, 128],
    [50, 128, 200],
    [192, 128, 200],
    [0, 64, 0],
    [128, 64, 0],
    [0, 192, 0],
    [128, 192, 0],
    [0, 64, 128]
]


def visualize_bbox(img, bbox, class_name, score=None, color=(255, 0, 0), thickness=2):
    """
    Draws a single bounding box on image
    """
    x_min, y_min, x_max, y_max = bbox
    x_min, y_min, x_max, y_max = int(x_min), int(y_min), int(x_max), int(y_max )

    cv2.rectangle(img, (x_min, y_min), (x_max, y_max), color=color, thickness=thickness)
    box_text = class_name + ' {:.2f}'.format(score) if score is not None else class_name
    ((text_width, text_height), _) = cv2.getTextSize(box_text, cv2.FONT_HERSHEY_SIMPLEX, 0.45, 1)
    cv2.rectangle(img, (x_min, y_min - int(1.3 * text_height)), (x_min + text_width, y_min), color, -1)

    cv2.putText(
        img,
        text=box_text,
        org=(x_min, y_min - int(0.3 * text_height)),
        fontFace=cv2.FONT_HERSHEY_SIMPLEX,
        fontScale=0.45,
        color=(255, 255, 255),
        lineType=cv2.LINE_AA,
    )

    return img


def visualize(image, bboxes, category_ids, category_id_to_name, scores=None):
    img = image.copy()
    for idx, (bbox, category_id) in enumerate(zip(bboxes, category_ids)):
      class_name = category_id_to_name[category_id]
      img = visualize_bbox(img, bbox, class_name, scores[idx] if scores is not None else None)

    return img


def draw_grid(img, grid_shape, color=(0, 0, 0), thickness=2):
    """
    Draws a grid on image
    """
    grid_im = np.copy(img)
    h, w, _ = grid_im.shape
    rows, cols = grid_shape
    dy, dx = h / rows, w / cols

    # draw vertical lines
    for x in np.linspace(start=dx, stop=w-dx, num=cols-1):
      x = int(round(x))
      cv2.line(grid_im, (x, 0), (x, h), color=color, thickness=thickness)

    # draw horizontal lines
    for y in np.linspace(start=dy, stop=h-dy, num=rows-1):
      y = int(round(y))
      cv2.line(grid_im, (0, y), (w, y), color=color, thickness=thickness)

    return grid_im


def draw_cls_grid(img, cls_idx, grid_shape):
  """
  Draws color coded grid for the entire image
  coded based on the class label
  """
  rect_im = np.copy(img)
  h, w, _ = rect_im.shape
  rows, cols = grid_shape
  dy, dx = h / rows, w / cols
  for i in range(rows):
    for j in range(cols):
          cv2.rectangle(rect_im, (int(i*dx), int(j*dy)), (int((i+1)*dx), int((j+1)*dy)),
                        thickness=-1,
                        color=COLORS[cls_idx[j, i].item()])

  return rect_im


def draw_cls_text(img, cls_idx, cls_idx_label, grid_shape):
  """
  Writes class text name in grid center locations
  """
  rect_im = np.copy(img)
  h, w, _ = rect_im.shape
  rows, cols = grid_shape
  dy, dx = h / rows, w / cols
  for i in range(rows):
    for j in range(cols):
      cls_label = cls_idx_label[cls_idx[j, i].item()]
      cv2.putText(rect_im,
                  cls_label[:6],
                  (int((i+0.1)*dx), int((j+0.5)*dy)),
                  fontFace=cv2.FONT_HERSHEY_SIMPLEX,
                  fontScale=0.45,
                  color=(255, 255, 255),
                  lineType=cv2.LINE_AA)
  return rect_im

In [20]:
def get_iou_map(det, gt):
  det_x1, det_y1, det_x2, det_y2 = det
  gt_x1, gt_y1, gt_x2, gt_y2 = gt

  x_left = max(det_x1, gt_x1)
  y_top = max(det_y1, gt_y1)
  x_right = min(det_x2, gt_x2)
  y_bottom = min(det_y2, gt_y2)

  if x_right < x_left or y_bottom < y_top:
    return 0.0

  area_intersection = (x_right - x_left) * (y_bottom - y_top)
  det_area = (det_x2 - det_x1) * (det_y2 - det_y1)
  gt_area = (gt_x2 - gt_x1) * (gt_y2 - gt_y1)

  area_union = float(det_area + gt_area - area_intersection + 1E-6)
  iou = area_intersection / area_union

  return iou

In [33]:
def compute_map(det_boxes, gt_boxes, iou_threshold=0.5, method="area", difficult=None):
  gt_labels = {cls_key for im_gt in gt_boxes for cls_key in im_gt.keys()}
  gt_labels = sorted(gt_labels)

  all_aps = {}

  # Average precisions for ALL classes
  aps = []
  for idx, label in enumerate(gt_labels):
    # Get detection predictions of this class
    cls_dets = [
        [im_idx, im_dets_label] for im_idx, im_dets in enumerate(det_boxes)
        if label in im_dets for im_dets_label in im_dets[label]
    ]

    # Sort them by confidence score
    cls_dets = sorted(cls_dets, key=lambda k: -k[1][-1])

    # For tracking which gt boxes of this class have already been matched
    gt_matched = [[False for _ in im_gts[label]] for im_gts in gt_boxes]
    # Number of gt boxes for this class for recall calculation
    num_gts = sum([len(im_gts[label]) for im_gts in gt_boxes])
    num_difficults = sum([sum(difficults_label[label]) for difficults_label in difficult])

    tp = [0] * len(cls_dets)
    fp = [0] * len(cls_dets)

    # For each prediction
    for det_idx, (im_idx, det_pred) in enumerate(cls_dets):
      # Get gt boxes for this image and this label
      im_gts = gt_boxes[im_idx][label]
      im_gt_difficults = difficult[im_idx][label]

      max_iou_found = -1
      max_iou_gt_idx = -1

      # Get best matching gt box
      for gt_box_idx, gt_box in enumerate(im_gts):
        gt_box_iou = get_iou_map(det_pred[:-1], gt_box)
        if gt_box_iou > max_iou_found:
          max_iou_found = gt_box_iou
          max_iou_gt_idx = gt_box_idx
      # TP only if iou >= threshold and this gt has not yet been matched
      if max_iou_found >= iou_threshold:
        if not im_gt_difficults[max_iou_gt_idx]:
          if not gt_matched[im_idx][max_iou_gt_idx]:
            # If tp then we set this gt box as matched
            gt_matched[im_idx][max_iou_gt_idx] = True
            tp[det_idx] = 1
          else:
            fp[det_idx] = 1
      else:
        fp[det_idx] = 1

    # Cumulative tp and fp
    tp = np.cumsum(tp)
    fp = np.cumsum(fp)

    eps = np.finfo(np.float32).eps
    # recalls = tp / np.maximum(num_gts, eps)
    recalls = tp / np.maximum(num_gts - num_difficults, eps)
    precisions = tp / np.maximum((tp + fp), eps)

    if method == "area":
      recalls = np.concatenate(([0.0], recalls, [1.0]))
      precisions = np.concatenate(([0.0], precisions, [0.0]))

      # Replace precision values with recall r with maximum precision value
      # of any recall value >= r
      # This computes the precision envelope
      for i in range(precisions.size - 1, 0, -1):
          precisions[i - 1] = np.maximum(precisions[i - 1], precisions[i])
      # For computing area, get points where recall changes value
      i = np.where(recalls[1:] != recalls[:-1])[0]
      # Add the rectangular areas to get ap
      ap = np.sum((recalls[i + 1] - recalls[i]) * precisions[i + 1])
    elif method == "interp":
      ap = 0.0
      for interp_pt in np.arange(0, 1 + 1E-3, 0.1):
        # Get precision values for recall values >= interp_pt
        prec_interp_pt = precisions[recalls >= interp_pt]

        # Get max of those precision values
        prec_interp_pt= prec_interp_pt.max() if prec_interp_pt.size>0.0 else 0.0
        ap += prec_interp_pt
      ap = ap / 11.0
    else:
      raise ValueError("Method can only be area or interp")
    if num_gts > 0:
      aps.append(ap)
      all_aps[label] = ap
    else:
      all_aps[label] = np.nan

  # compute mAP at provided iou threshold
  mean_ap = sum(aps) / len(aps)
  return mean_ap, all_aps

In [30]:
def load_model_and_dataset():
  dataset_config = dataset_params
  model_config = model_params
  train_config = train_params

  voc = VOCDataset("test",
                   im_sets=dataset_config["test_im_sets"],
                   im_size=dataset_config["im_size"],
                   S=model_config["S"],
                   B=model_config["B"],
                   C=dataset_config["num_classes"])

  test_dataset = DataLoader(voc,
                            batch_size=1,
                            shuffle=False,
                            num_workers=os.cpu_count(),
                            pin_memory=True)

  yolo_model = YOLOV1(im_size=dataset_config["im_size"],
                      num_classes=dataset_config["num_classes"],
                      model_config=model_config)
  yolo_model.eval()
  yolo_model.to(device)

  checkpoint = torch.load(os.path.join(train_config["task_name"], train_config["ckpt_name"]),
                          map_location=device)

  yolo_model.load_state_dict(checkpoint["model_state_dict"])

  return yolo_model, voc, test_dataset

In [31]:
def convert_yolo_pred_x1y1x2y2(yolo_pred, S, B, C, use_sigmoid=False):
  """
  Method converts yolo predictions to x_1, y_1, x_2, y_2 format
  """
  out = yolo_pred.reshape((S, S, 5 * B + C))
  if use_sigmoid:
    out[..., :5 * B] = torch.nn.functional.sigmoid(out[..., :5 * B])

  out = torch.clamp(out, min=0., max=1.)
  class_score, class_idx = torch.max(out[..., 5 * B:], dim=-1)

  # Create a grid using these shifts
  # Will use these for converting x_center_offset and y_center_offset
  # values to x_1, y_1, x_2, y_2 that are normalized between [0,1]
  # S cells = 1 => each cell adds 1/S pixels of shift
  shifts_x = torch.arange(0, S, dtype=torch.int32, device=out.device) * 1 / float(S)
  shifts_y = torch.arange(0, S, dtype=torch.int32, device=out.device) * 1 / float(S)
  shifts_y, shifts_x = torch.meshgrid(shifts_y, shifts_x, indexing="ij")

  boxes = []
  confidences = []
  labels = []
  for box_idx in range(B):
    # xc_offset, yc_offset, width, height -> x_1, y_1, x_2, y_2
    boxes_x1 = ((out[..., box_idx * 5] * 1 / float(S) + shifts_x) - 0.5 * torch.square(out[..., 2 + box_idx * 5])).reshape(-1, 1)
    boxes_y1 = ((out[..., 1 + box_idx * 5] * 1 / float(S) + shifts_y) - 0.5 * torch.square(out[..., 3 + box_idx * 5])).reshape(-1, 1)
    boxes_x2 = ((out[..., box_idx * 5] * 1 / float(S) + shifts_x) + 0.5 * torch.square(out[..., 2 + box_idx * 5])).reshape(-1, 1)
    boxes_y2 = ((out[..., box_idx * 5] * 1 / float(S) + shifts_y) + 0.5 * torch.square(out[..., 3 + box_idx * 5])).reshape(-1, 1)

    boxes.append(torch.cat([boxes_x1, boxes_y1, boxes_x2, boxes_y2], dim=-1))
    confidences.append((out[..., 4 + box_idx * 5] * class_score).reshape(-1))
    labels.append(class_idx.reshape(-1))

  boxes = torch.cat(boxes, dim=0)
  scores = torch.cat(confidences, dim=0)
  labels = torch.cat(labels, dim=0)

  return boxes, scores, labels

In [24]:
def infer():
  if not os.path.exists("samples"):
      os.mkdir("samples")

  yolo_model, voc, test_dataset = load_model_and_dataset()
  conf_threshold = train_params["infer_conf_threshold"]
  nms_threshold = train_params["nms_threshold"]

  num_samples = 5

  for i in tqdm(range(num_samples)):
    dataset_idx = random.randint(0, len(voc) - 1)
    im_tensor, targets, fname = voc[dataset_idx]

    out = yolo_model(im_tensor.unsqueeze(0).to(device))

    # Have to convert the output for YOLO format to x_1, y_1, x_2, y_2
    boxes, scores, labels = convert_yolo_pred_x1y1x2y2(out,
                                                       S=yolo_model.S,
                                                       B=yolo_model.B,
                                                       C=yolo_model.C,
                                                       use_sigmoid=model_params["use_sigmoid"])

    # Confidence Score Thresholding
    keep = torch.where(scores > conf_threshold)[0]
    boxes = boxes[keep]
    scores = scores[keep]
    labels = labels[keep]

    # NMS
    keep_mask = torch.zeros_like(scores, dtype=torch.bool)
    for class_id in torch.unique(labels):
      curr_indices = torch.where(labels == class_id)[0]
      curr_keep_indices = torch.ops.torchvision.nms(boxes[curr_indices],
                                                    scores[curr_indices],
                                                    nms_threshold)
      keep_mask[curr_indices[curr_keep_indices]] = True

    keep = torch.where(keep_mask)[0]
    boxes = boxes[keep]
    scores = scores[keep]
    labels = labels[keep]

    # Visualization

    if not os.path.exists("samples/preds"):
      os.mkdir("samples/preds")
    if not os.path.exists("samples/grid_cls"):
      os.mkdir("samples/grid_cls")

    im = cv2.imread(fname)
    h, w = im.shape[:2]
    # Scale prediction boxes x_1, y_1, x_2, y_2 from [0, 1] to 0-w and 0-h
    boxes[..., 0::2] = (w * boxes[..., 0::2])
    boxes[..., 1::2] = (h * boxes[..., 1::2])

    out_img = visualize(image=im,
                        bboxes=boxes.detach().cpu().numpy(),
                        category_ids=labels.detach().cpu().numpy(),
                        category_id_to_name=voc.idx2label,
                        scores=scores.detach().cpu().numpy())

    cv2.imwrite(f"samples/preds/{i}_pred.jpeg", out_img)

    # Below lines of code are only for drawing class prob map
    im = cv2.resize(im, (yolo_model.im_size, yolo_model.im_size))

    # Draw a SxS grid on image
    grid_im = draw_grid(im, (yolo_model.S, yolo_model.S))

    out = out.reshape((yolo_model.S, yolo_model.S, 5 * yolo_model.B + yolo_model.C))
    cls_val, cls_idx = torch.max(out[..., 5 * yolo_model.B:], dim=-1)

    # Draw colored squares for probability mappings on image
    rect_im = draw_cls_grid(im, cls_idx, (yolo_model.S, yolo_model.S))
    # Draw grid again on top of this image
    rect_im = draw_grid(rect_im, (yolo_model.S, yolo_model.S))

    # Overlay image with grid and cls mappings with grid on top of each other
    res = cv2.addWeighted(rect_im, 0.5, grid_im, 0.5, 1.0)

    # Write class labels on grid on this image
    res = draw_cls_text(res, cls_idx, voc.idx2label, (yolo_model.S, yolo_model.S))
    cv2.imwrite(f"samples/grid_cls/{i}_grid_map.jpeg", res)

  print("Done Detecting..")

infer()

  albu.Affine(
  albu.ColorJitter(


{0: 'aeroplane', 1: 'bicycle', 2: 'bird', 3: 'boat', 4: 'bottle', 5: 'bus', 6: 'car', 7: 'cat', 8: 'chair', 9: 'cow', 10: 'diningtable', 11: 'dog', 12: 'horse', 13: 'motorbike', 14: 'person', 15: 'pottedplant', 16: 'sheep', 17: 'sofa', 18: 'train', 19: 'tvmonitor'}
Total 4952 images found


100%|██████████| 5/5 [00:00<00:00, 10.38it/s]

Done Detecting..





In [34]:
def evaluate_map():
  yolo_model, voc, test_dataset = load_model_and_dataset()

  conf_threshold = train_params["eval_conf_threshold"]
  nms_threshold = train_params["nms_threshold"]

  gts = []
  preds = []
  difficults = []
  for im_tensor, target, fname in tqdm(test_dataset):
    im_tensor = im_tensor.float().to(device)
    target_bboxes = target["bboxes"].float().to(device)[0]
    target_labels = target["labels"].long().to(device)[0]
    difficult = target["difficult"].long().to(device)[0]

    out = yolo_model(im_tensor)

    boxes, scores, labels = convert_yolo_pred_x1y1x2y2(out,
                                                        S=yolo_model.S,
                                                        B=yolo_model.B,
                                                        C=yolo_model.C,
                                                        use_sigmoid=model_params["use_sigmoid"])

    # Confidence Score Thresholding
    keep = torch.where(scores > conf_threshold)[0]
    boxes = boxes[keep]
    scores = scores[keep]
    labels = labels[keep]

    # NMS
    keep_mask = torch.zeros_like(scores, dtype=torch.bool)
    for class_id in torch.unique(labels):
      curr_indices = torch.where(labels == class_id)[0]
      curr_keep_indices = torch.ops.torchvision.nms(boxes[curr_indices],
                                                    scores[curr_indices],
                                                    nms_threshold)

      keep_mask[curr_indices[curr_keep_indices]] = True

    keep = torch.where(keep_mask)[0]
    boxes = boxes[keep]
    scores = scores[keep]
    labels = labels[keep]

    pred_boxes = {}
    gt_boxes = {}
    difficult_boxes = {}

    for label_name in voc.label2idx:
      pred_boxes[label_name] = []
      gt_boxes[label_name] = []
      difficult_boxes[label_name] = []

    for idx, box in enumerate(boxes):
      x1, y1, x2, y2 = box.detach().cpu().numpy()
      label = labels[idx].detach().cpu().item()
      score = scores[idx].detach().cpu().item()
      label_name = voc.idx2label[label]
      pred_boxes[label_name].append([x1, y1, x2, y2, score])

    for idx, box in enumerate(target_bboxes):
      x1, y1, x2, y2 = box.detach().cpu().numpy()
      label = target_labels[idx].detach().cpu().item()
      label_name = voc.idx2label[label]
      gt_boxes[label_name].append([x1, y1, x2, y2])
      difficult_boxes[label_name].append(difficult[idx].detach().cpu().item())

      gts.append(gt_boxes)
      preds.append(pred_boxes)
      difficults.append(difficult_boxes)

  mean_ap, all_aps = compute_map(preds, gts, method="area", difficult=difficults)

  print("Class Wise Average Precisions")
  for idx in range(len(voc.idx2label)):
    print(f"AP for class {voc.idx2label[idx]} = {all_aps[voc.idx2label[idx]]:.4f}")

  print(f"Mean Average Precision: {mean_ap:.4f}")

evaluate_map()

  albu.Affine(
  albu.ColorJitter(


{0: 'aeroplane', 1: 'bicycle', 2: 'bird', 3: 'boat', 4: 'bottle', 5: 'bus', 6: 'car', 7: 'cat', 8: 'chair', 9: 'cow', 10: 'diningtable', 11: 'dog', 12: 'horse', 13: 'motorbike', 14: 'person', 15: 'pottedplant', 16: 'sheep', 17: 'sofa', 18: 'train', 19: 'tvmonitor'}
Total 4952 images found


100%|██████████| 4952/4952 [01:14<00:00, 66.66it/s]


Class Wise Average Precisions
AP for class aeroplane = 0.3184
AP for class bicycle = 0.4868
AP for class bird = 0.2774
AP for class boat = 0.2727
AP for class bottle = 0.1549
AP for class bus = 0.5753
AP for class car = 0.4371
AP for class cat = 0.7185
AP for class chair = 0.2702
AP for class cow = 0.4358
AP for class diningtable = 0.4733
AP for class dog = 0.6572
AP for class horse = 0.6561
AP for class motorbike = 0.4671
AP for class person = 0.4352
AP for class pottedplant = 0.1607
AP for class sheep = 0.4112
AP for class sofa = 0.4401
AP for class train = 0.7036
AP for class tvmonitor = 0.3888
Mean Average Precision: 0.4370


In [27]:
!zip -r samples.zip samples

  adding: samples/ (stored 0%)
  adding: samples/grid_cls/ (stored 0%)
  adding: samples/grid_cls/2_grid_map.jpeg (deflated 5%)
  adding: samples/grid_cls/3_grid_map.jpeg (deflated 1%)
  adding: samples/grid_cls/1_grid_map.jpeg (deflated 1%)
  adding: samples/grid_cls/4_grid_map.jpeg (deflated 3%)
  adding: samples/grid_cls/0_grid_map.jpeg (deflated 1%)
  adding: samples/preds/ (stored 0%)
  adding: samples/preds/0_pred.jpeg (deflated 0%)
  adding: samples/preds/2_pred.jpeg (deflated 0%)
  adding: samples/preds/1_pred.jpeg (deflated 0%)
  adding: samples/preds/3_pred.jpeg (deflated 0%)
  adding: samples/preds/4_pred.jpeg (deflated 1%)


In [28]:
from google.colab import files

files.download("samples.zip")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>