# Faster R-CNN training

Training utilizes pytorch-lighning and Python COCO API.

A more detailed description for regular PyTorch training implementation can be found at https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html

The above mentioned guide is used as reference for this implementation.

# Setup

In [None]:
!pip install pytorch-lightning

In [None]:
import time
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
import os
import torch.utils.data
from torch.utils.data import random_split
import torchvision
from torchvision import transforms
from torchvision.datasets import CocoDetection
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
from IPython.display import display
from pycocotools.coco import COCO
import gdown, zipfile 
from contextlib import redirect_stdout
import math

# Utility methods

In [None]:
import json
import tempfile

import numpy as np
import copy
import time
import torch
import torch._six

from pycocotools.cocoeval import COCOeval
from pycocotools.coco import COCO
import pycocotools.mask as mask_util

from collections import defaultdict
import torch.distributed as dist

#import utils
####### UTILS #######
def is_dist_avail_and_initialized():
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True


def get_world_size():
    if not is_dist_avail_and_initialized():
        return 1
    return dist.get_world_size()

def all_gather(data):
    """
    Run all_gather on arbitrary picklable data (not necessarily tensors)
    Args:
        data: any picklable object
    Returns:
        list[data]: list of data gathered from each rank
    """
    world_size = get_world_size()
    if world_size == 1:
        return [data]

    # serialized to a Tensor
    buffer = pickle.dumps(data)
    storage = torch.ByteStorage.from_buffer(buffer)
    tensor = torch.ByteTensor(storage).to("cuda")

    # obtain Tensor size of each rank
    local_size = torch.tensor([tensor.numel()], device="cuda")
    size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
    dist.all_gather(size_list, local_size)
    size_list = [int(size.item()) for size in size_list]
    max_size = max(size_list)

    # receiving Tensor from all ranks
    tensor_list = []
    for _ in size_list:
        tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
    if local_size != max_size:
        padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
        tensor = torch.cat((tensor, padding), dim=0)
    dist.all_gather(tensor_list, tensor)

    data_list = []
    for size, tensor in zip(size_list, tensor_list):
        buffer = tensor.cpu().numpy().tobytes()[:size]
        data_list.append(pickle.loads(buffer))

    return data_list

def reduce_dict(input_dict, average=True):
    """
    Args:
        input_dict (dict): all the values will be reduced
        average (bool): whether to do average or sum
    Reduce the values in the dictionary from all processes so that all processes
    have the averaged results. Returns a dict with the same fields as
    input_dict, after reduction.
    """
    world_size = get_world_size()
    if world_size < 2:
        return input_dict
    with torch.no_grad():
        names = []
        values = []
        # sort the keys so that they are consistent across processes
        for k in sorted(input_dict.keys()):
            names.append(k)
            values.append(input_dict[k])
        values = torch.stack(values, dim=0)
        dist.all_reduce(values)
        if average:
            values /= world_size
        reduced_dict = {k: v for k, v in zip(names, values)}
    return reduced_dict
####### UTILS #######

####### COCO_UTILS ######
def convert_to_coco_api(ds):
    coco_ds = COCO()
    ann_id = 0
    dataset = {'images': [], 'categories': [], 'annotations': []}
    categories = set()
    for img_idx in range(len(ds)):
        # find better way to get target
        # targets = ds.get_annotations(img_idx)
        img, targets = ds[img_idx]
        image_id = targets["image_id"].item()
        img_dict = {}
        img_dict['id'] = image_id
        img_dict['height'] = img.shape[-2]
        img_dict['width'] = img.shape[-1]
        dataset['images'].append(img_dict)
        bboxes = targets["boxes"]
        bboxes[:, 2:] -= bboxes[:, :2]
        bboxes = bboxes.tolist()
        labels = targets['labels'].tolist()
        areas = targets['area'].tolist()
        iscrowd = targets['iscrowd'].tolist()
        if 'masks' in targets:
            masks = targets['masks']
            # make masks Fortran contiguous for coco_mask
            masks = masks.permute(0, 2, 1).contiguous().permute(0, 2, 1)
        if 'keypoints' in targets:
            keypoints = targets['keypoints']
            keypoints = keypoints.reshape(keypoints.shape[0], -1).tolist()
        num_objs = len(bboxes)
        for i in range(num_objs):
            ann = {}
            ann['image_id'] = image_id
            ann['bbox'] = bboxes[i]
            ann['category_id'] = labels[i]
            categories.add(labels[i])
            ann['area'] = areas[i]
            ann['iscrowd'] = iscrowd[i]
            ann['id'] = ann_id
            if 'masks' in targets:
                ann["segmentation"] = coco_mask.encode(masks[i].numpy())
            if 'keypoints' in targets:
                ann['keypoints'] = keypoints[i]
                ann['num_keypoints'] = sum(k != 0 for k in keypoints[i][2::3])
            dataset['annotations'].append(ann)
            ann_id += 1
    dataset['categories'] = [{'id': i} for i in sorted(categories)]
    coco_ds.dataset = dataset
    coco_ds.createIndex()
    return coco_ds


def get_coco_api_from_dataset(dataset):
    for i in range(10):
        if isinstance(dataset, torchvision.datasets.CocoDetection):
            break
        if isinstance(dataset, torch.utils.data.Subset):
            dataset = dataset.dataset
    if isinstance(dataset, torchvision.datasets.CocoDetection):
        return dataset.coco
    return convert_to_coco_api(dataset)
####### COCO_UTILS ######

####### COCO_EVAL ######

class CocoEvaluator(object):
    def __init__(self, coco_gt, iou_types):
        assert isinstance(iou_types, (list, tuple))
        coco_gt = copy.deepcopy(coco_gt)
        self.coco_gt = coco_gt

        self.iou_types = iou_types
        self.coco_eval = {}
        for iou_type in iou_types:
            self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type)

        self.img_ids = []
        self.eval_imgs = {k: [] for k in iou_types}

    def update(self, predictions):
        img_ids = list(np.unique(list(predictions.keys())))
        self.img_ids.extend(img_ids)

        for iou_type in self.iou_types:
            results = self.prepare(predictions, iou_type)
            coco_dt = loadRes(self.coco_gt, results) if results else COCO()
            coco_eval = self.coco_eval[iou_type]

            coco_eval.cocoDt = coco_dt
            coco_eval.params.imgIds = list(img_ids)
            img_ids, eval_imgs = evaluate(coco_eval)

            self.eval_imgs[iou_type].append(eval_imgs)

    def synchronize_between_processes(self):
        for iou_type in self.iou_types:
            self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2)
            create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type])

    def accumulate(self):
        for coco_eval in self.coco_eval.values():
            coco_eval.accumulate()

    def summarize(self):
        for iou_type, coco_eval in self.coco_eval.items():
            print("IoU metric: {}".format(iou_type))
            coco_eval.summarize()

    def prepare(self, predictions, iou_type):
        if iou_type == "bbox":
            return self.prepare_for_coco_detection(predictions)
        elif iou_type == "segm":
            return self.prepare_for_coco_segmentation(predictions)
        elif iou_type == "keypoints":
            return self.prepare_for_coco_keypoint(predictions)
        else:
            raise ValueError("Unknown iou type {}".format(iou_type))

    def prepare_for_coco_detection(self, predictions):
        coco_results = []
        for original_id, prediction in predictions.items():
            if len(prediction) == 0:
                continue

            boxes = prediction["boxes"]
            boxes = convert_to_xywh(boxes).tolist()
            scores = prediction["scores"].tolist()
            labels = prediction["labels"].tolist()

            coco_results.extend(
                [
                    {
                        "image_id": original_id,
                        "category_id": labels[k],
                        "bbox": box,
                        "score": scores[k],
                    }
                    for k, box in enumerate(boxes)
                ]
            )
        return coco_results

    def prepare_for_coco_segmentation(self, predictions):
        coco_results = []
        for original_id, prediction in predictions.items():
            if len(prediction) == 0:
                continue

            scores = prediction["scores"]
            labels = prediction["labels"]
            masks = prediction["masks"]

            masks = masks > 0.5

            scores = prediction["scores"].tolist()
            labels = prediction["labels"].tolist()

            rles = [
                mask_util.encode(np.array(mask[0, :, :, np.newaxis], order="F"))[0]
                for mask in masks
            ]
            for rle in rles:
                rle["counts"] = rle["counts"].decode("utf-8")

            coco_results.extend(
                [
                    {
                        "image_id": original_id,
                        "category_id": labels[k],
                        "segmentation": rle,
                        "score": scores[k],
                    }
                    for k, rle in enumerate(rles)
                ]
            )
        return coco_results

    def prepare_for_coco_keypoint(self, predictions):
        coco_results = []
        for original_id, prediction in predictions.items():
            if len(prediction) == 0:
                continue

            boxes = prediction["boxes"]
            boxes = convert_to_xywh(boxes).tolist()
            scores = prediction["scores"].tolist()
            labels = prediction["labels"].tolist()
            keypoints = prediction["keypoints"]
            keypoints = keypoints.flatten(start_dim=1).tolist()

            coco_results.extend(
                [
                    {
                        "image_id": original_id,
                        "category_id": labels[k],
                        'keypoints': keypoint,
                        "score": scores[k],
                    }
                    for k, keypoint in enumerate(keypoints)
                ]
            )
        return coco_results


def convert_to_xywh(boxes):
    xmin, ymin, xmax, ymax = boxes.unbind(1)
    return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1)


def merge(img_ids, eval_imgs):
    all_img_ids = all_gather(img_ids)
    all_eval_imgs = all_gather(eval_imgs)

    merged_img_ids = []
    for p in all_img_ids:
        merged_img_ids.extend(p)

    merged_eval_imgs = []
    for p in all_eval_imgs:
        merged_eval_imgs.append(p)

    merged_img_ids = np.array(merged_img_ids)
    merged_eval_imgs = np.concatenate(merged_eval_imgs, 2)

    # keep only unique (and in sorted order) images
    merged_img_ids, idx = np.unique(merged_img_ids, return_index=True)
    merged_eval_imgs = merged_eval_imgs[..., idx]

    return merged_img_ids, merged_eval_imgs


def create_common_coco_eval(coco_eval, img_ids, eval_imgs):
    img_ids, eval_imgs = merge(img_ids, eval_imgs)
    img_ids = list(img_ids)
    eval_imgs = list(eval_imgs.flatten())

    coco_eval.evalImgs = eval_imgs
    coco_eval.params.imgIds = img_ids
    coco_eval._paramsEval = copy.deepcopy(coco_eval.params)


#################################################################
# From pycocotools, just removed the prints
#################################################################

def createIndex(self):
    anns, cats, imgs = {}, {}, {}
    imgToAnns, catToImgs = defaultdict(list), defaultdict(list)
    if 'annotations' in self.dataset:
        for ann in self.dataset['annotations']:
            imgToAnns[ann['image_id']].append(ann)
            anns[ann['id']] = ann

    if 'images' in self.dataset:
        for img in self.dataset['images']:
            imgs[img['id']] = img

    if 'categories' in self.dataset:
        for cat in self.dataset['categories']:
            cats[cat['id']] = cat

    if 'annotations' in self.dataset and 'categories' in self.dataset:
        for ann in self.dataset['annotations']:
            catToImgs[ann['category_id']].append(ann['image_id'])

    # create class members
    self.anns = anns
    self.imgToAnns = imgToAnns
    self.catToImgs = catToImgs
    self.imgs = imgs
    self.cats = cats


maskUtils = mask_util


def loadRes(self, resFile):
    """
    Load result file and return a result api object.
    :param   resFile (str)     : file name of result file
    :return: res (obj)         : result api object
    """
    res = COCO()
    res.dataset['images'] = [img for img in self.dataset['images']]

    if isinstance(resFile, torch._six.string_classes):
        anns = json.load(open(resFile))
    elif type(resFile) == np.ndarray:
        anns = self.loadNumpyAnnotations(resFile)
    else:
        anns = resFile
    assert type(anns) == list, 'results in not an array of objects'
    annsImgIds = [ann['image_id'] for ann in anns]
    assert set(annsImgIds) == (set(annsImgIds) & set(self.getImgIds())), \
        'Results do not correspond to current coco set'
    if 'caption' in anns[0]:
        imgIds = set([img['id'] for img in res.dataset['images']]) & set([ann['image_id'] for ann in anns])
        res.dataset['images'] = [img for img in res.dataset['images'] if img['id'] in imgIds]
        for id, ann in enumerate(anns):
            ann['id'] = id + 1
    elif 'bbox' in anns[0] and not anns[0]['bbox'] == []:
        res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
        for id, ann in enumerate(anns):
            bb = ann['bbox']
            x1, x2, y1, y2 = [bb[0], bb[0] + bb[2], bb[1], bb[1] + bb[3]]
            if 'segmentation' not in ann:
                ann['segmentation'] = [[x1, y1, x1, y2, x2, y2, x2, y1]]
            ann['area'] = bb[2] * bb[3]
            ann['id'] = id + 1
            ann['iscrowd'] = 0
    elif 'segmentation' in anns[0]:
        res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
        for id, ann in enumerate(anns):
            ann['area'] = maskUtils.area(ann['segmentation'])
            if 'bbox' not in ann:
                ann['bbox'] = maskUtils.toBbox(ann['segmentation'])
            ann['id'] = id + 1
            ann['iscrowd'] = 0
    elif 'keypoints' in anns[0]:
        res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
        for id, ann in enumerate(anns):
            s = ann['keypoints']
            x = s[0::3]
            y = s[1::3]
            x0, x1, y0, y1 = np.min(x), np.max(x), np.min(y), np.max(y)
            ann['area'] = (x1 - x0) * (y1 - y0)
            ann['id'] = id + 1
            ann['bbox'] = [x0, y0, x1 - x0, y1 - y0]

    res.dataset['annotations'] = anns
    createIndex(res)
    return res


def evaluate(self):
    '''
    Run per image evaluation on given images and store results (a list of dict) in self.evalImgs
    :return: None
    '''
    p = self.params
    if p.useSegm is not None:
        p.iouType = 'segm' if p.useSegm == 1 else 'bbox'
        print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType))
    p.imgIds = list(np.unique(p.imgIds))
    if p.useCats:
        p.catIds = list(np.unique(p.catIds))
    p.maxDets = sorted(p.maxDets)
    self.params = p

    self._prepare()
    catIds = p.catIds if p.useCats else [-1]

    if p.iouType == 'segm' or p.iouType == 'bbox':
        computeIoU = self.computeIoU
    elif p.iouType == 'keypoints':
        computeIoU = self.computeOks
    self.ious = {
        (imgId, catId): computeIoU(imgId, catId)
        for imgId in p.imgIds
        for catId in catIds}

    evaluateImg = self.evaluateImg
    maxDet = p.maxDets[-1]
    evalImgs = [
        evaluateImg(imgId, catId, areaRng, maxDet)
        for catId in catIds
        for areaRng in p.areaRng
        for imgId in p.imgIds
    ]
    # this is NOT in the pycocotools code, but could be done outside
    evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds))
    self._paramsEval = copy.deepcopy(self.params)
    return p.imgIds, evalImgs

#################################################################
# end of straight copy from pycocotools
#################################################################

# Set parameters

In [None]:
### Colab compatibility ###
import sys
sys.argv=['']
del sys
### Colab compatibility ###

import argparse

parser = argparse.ArgumentParser(description='VCC Arguments',add_help=False)
parser.add_argument('--num_workers', type=int, default=2, help="Number of data loader workers")
parser.add_argument('--lr', type=float, default=0.005, help="Learning rate")
parser.add_argument('--momentum', type=float, default=0.9, help="Optimizer momentum")
parser.add_argument('--weight_decay', type=float, default=0.0005, help="Optimizer weight decay")
parser.add_argument('--lr_schedule_step', type=int, default=3, help="Learning rate scheduler step size")
parser.add_argument('--lr_schedule_gamma', type=float, default=0.1, help="Learning rate scheduler gamma value")
parser.add_argument('--lr_warmup', type=int, default=1, help="Learning rate warmup")
parser.add_argument('--batch_size', type=int, default=2, help="Batch Size of Train loaders")
parser.add_argument('--num_epochs', type=int, default=30, help="Training Epochs")
parser.add_argument('--eval_epoch', type=int, default=1, help="Validation frequency")
parser.add_argument('--log_path', type=str, default='log/', help="Logging directory")
parser.add_argument('--model_path', type=str, default='models/', help="Checkpoint directory")
parser.add_argument('--model', type=str, default='fastrcnn', help="[fastrcnn|yolov5|yolov4]")
parser.add_argument('--dataset', type=str, default='rvk', help="")
parser.add_argument('--dataset_dir', type=str, default='data/', help="Dataset directory")
parser.add_argument('--annotation_dir', type=str, default='annotation/', help="Annotation directory")
parser.add_argument('--num_classes', type=int, default=2, help="Number of object detection classes")
parser.add_argument('--val_split', type=float, default=.3, help="Validation split")


args = parser.parse_args()

In [None]:
class DatasetCOCO(torch.utils.data.Dataset):
    def __init__(self, root, annotation, transforms=None):
        self.root = root
        self.transforms = transforms
        self.coco = COCO(annotation)
        self.ids = list(sorted(self.coco.imgs.keys()))

    def __getitem__(self, index):
        # Own coco file
        coco = self.coco
        # Image ID
        img_id = self.ids[index]
        # List: get annotation id from coco
        ann_ids = coco.getAnnIds(imgIds=img_id)
        # Dictionary: target coco_annotation file for an image
        coco_annotation = coco.loadAnns(ann_ids)
        # path for input image
        path = coco.loadImgs(img_id)[0]["file_name"]
        # open the input image
        img = Image.open(os.path.join(self.root, path)).convert('RGB')
        # number of objects in the image
        num_objs = len(coco_annotation)
        # get object categories ids
        cat_ids = coco.getCatIds()
        # get object categories names
        cats = coco.loadCats(coco.getCatIds())
        boxes = []
        for i in range(num_objs):
            xmin = coco_annotation[i]["bbox"][0]
            ymin = coco_annotation[i]["bbox"][1]
            xmax = xmin + coco_annotation[i]["bbox"][2]
            ymax = ymin + coco_annotation[i]["bbox"][3]
            boxes.append([xmin, ymin, xmax, ymax])
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        boxes = boxes.reshape(-1, 4)

        labels = []
        for i in range(num_objs):
          label = int(coco_annotation[i]['category_id'])
          labels.append(label)

        labels = torch.as_tensor(labels, dtype=torch.int64)

        # Tensorise img_id
        img_id = torch.tensor([img_id])
        # Size of bbox (Rectangular)
        areas = []
        for i in range(num_objs):
            areas.append(coco_annotation[i]["area"])
        areas = torch.as_tensor(areas, dtype=torch.float32)
        # Iscrowd
        iscrowd = torch.zeros((num_objs,), dtype=torch.int64)

        # Annotation is in dictionary format
        my_annotation = {}
        my_annotation["boxes"] = boxes
        my_annotation["labels"] = labels
        my_annotation["image_id"] = img_id
        my_annotation["area"] = areas
        my_annotation["iscrowd"] = iscrowd
        my_annotation["categories"] = cats


        if self.transforms is not None:
            img = self.transforms(img)
        return img, my_annotation

    def __len__(self):
        return len(self.ids)

    def get_labels(self):
        return list(map(lambda l: l['name'], self.coco.cats.values()))



# Initialize Dataset

In [None]:
class FCCDataModule(pl.LightningDataModule):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.data_dir = args.dataset_dir
        self.train_transform = transforms.Compose([transforms.ToTensor()])
        self.val_transform = transforms.Compose([transforms.ToTensor()])

    def prepare_data(self):
      print("Dataset download")

      if (args.dataset == "rvk"):
        # download train images
        !curl -L "https://app.roboflow.com/ds/lmjqJlEhnK?key=C5kBjzDuqk" > roboflow.zip; 
        os.system("unzip roboflow.zip -d {}".format(args.dataset_dir)) 
        self.dir = "data/train"
        !mkdir data/annotations/
        !mv data/train/_annotations.coco.json data/annotations/annotation.json
        self.annotation = "data/annotations/annotation.json"

      else:
        raise ValueError('No valid dataset')

    def setup(self, stage=None):
      print("Dataset setup"
      data_set = DatasetCOCO(root=self.dir,annotation=self.annotation,transforms=self.train_transform)
      train_split_idx = int(len(data_set)*(1-args.val_split))
      val_split_idx = len(data_set)-train_split_idx
      self.coco_train, self.coco_val = random_split(data_set,[train_split_idx,val_split_idx])

    def collate_fn(self, batch):
      if args.model == "fastrcnn": 
        return tuple(zip(*batch))
      elif args.model == "yolov5":
        return list(zip(*batch))[0], list(zip(*batch))[1]

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.coco_train,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
            collate_fn=self.collate_fn,
        )

    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            self.coco_val,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=args.num_workers,
            collate_fn=self.collate_fn,
        )

# Init DataModule
dm = FCCDataModule(args)
dm.prepare_data()
dm.setup()

# Define and configure model

In [None]:
model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained=True)
# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, args.num_classes)

In [None]:
class FCC(pl.LightningModule):
    def __init__(self, dataset, args):
        super().__init__()
        self.save_hyperparameters(args)
        
        self.model = model

        self.coco = get_coco_api_from_dataset(dataset)
        self.iou_types = ['bbox'] #coco_utils: _get_iou_types(model)

    def forward(self,imgs,annotations):
        # Overloaded fw pass
        if self.training: # For train - returns loss dict
          return self.model(imgs,annotations)
        else: # For test - return predictions and processed annotations for val metric calculations
          return self.model(imgs), annotations

    def configure_optimizers(self):
        params = [p for p in model.parameters() if p.requires_grad]
        optimizer = torch.optim.SGD(params, lr=self.hparams.lr,
                            momentum=self.hparams.momentum, weight_decay=self.hparams.weight_decay)
        
        lf = lambda x: ((1 + math.cos(x * math.pi / self.hparams.num_epochs)) / 2) * (1 - 0.15) + 0.15  # cosine
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)

        return [optimizer], [scheduler]

    def training_step(self,batch,batchidx):
        # Get batch
        imgs, annotations = batch
        # Get model train loss
        loss_dict = self(imgs, annotations)
        losses = sum(loss for loss in loss_dict.values())
        # reduce losses over all GPUs in case of multiple gpus
        loss_dict_reduced = reduce_dict(loss_dict)
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())
        # Log metrics
        for key in loss_dict_reduced:
          self.log(key, loss_dict_reduced[key], prog_bar=True,on_step=True,on_epoch=True)
        self.log("lr", self.optimizers(use_pl_optimizer=False).param_groups[0]['lr'], prog_bar=True)
        return {'loss': losses}

    def on_validation_epoch_start(self):
        # Init new evaluator for each val epoch
        self.coco_evaluator = CocoEvaluator(self.coco, self.iou_types)

    def validation_step(self,batch,batchidx):
        # Get batch
        imgs, annotations = batch
        # Predict
        pred_time = time.time()
        outputs, targets = self(imgs, annotations)
        # Log
        self.log('pred_time_sec/img', (time.time()-pred_time)/len(batch), prog_bar=True)

        outputs = [{k: v.to("cpu") for k, v in t.items()} for t in outputs] 
        res = {target["image_id"].item(): output for target, output in zip(targets, outputs)}
        self.coco_evaluator.update(res)

    def validation_epoch_end(self, training_step_outputs):  
        # Accumulate and log all metrics at the end of the val epoch
        with redirect_stdout(None): # workaround to mute std output
          self.coco_evaluator.synchronize_between_processes()
          self.coco_evaluator.accumulate()
          self.coco_evaluator.summarize() 
        
        self.log('Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ]', self.coco_evaluator.coco_eval['bbox'].stats[0], prog_bar=True)
        self.log('Average Precision  (AP) @[ IoU=0.5       | area=   all | maxDets=100 ]', self.coco_evaluator.coco_eval['bbox'].stats[1], prog_bar=True)
        self.log('Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ]', self.coco_evaluator.coco_eval['bbox'].stats[2], prog_bar=True)
        self.log('Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ]', self.coco_evaluator.coco_eval['bbox'].stats[3], prog_bar=False)
        self.log('Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ]', self.coco_evaluator.coco_eval['bbox'].stats[4], prog_bar=False)
        self.log('Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ]', self.coco_evaluator.coco_eval['bbox'].stats[5], prog_bar=False)
        self.log('Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ]', self.coco_evaluator.coco_eval['bbox'].stats[6], prog_bar=False)
        self.log('Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ]', self.coco_evaluator.coco_eval['bbox'].stats[7], prog_bar=False)
        self.log('Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ]', self.coco_evaluator.coco_eval['bbox'].stats[8], prog_bar=False)
        self.log('Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ]', self.coco_evaluator.coco_eval['bbox'].stats[9], prog_bar=False)
        self.log('Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ]', self.coco_evaluator.coco_eval['bbox'].stats[10], prog_bar=False)
        self.log('Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ]', self.coco_evaluator.coco_eval['bbox'].stats[11], prog_bar=False)

#pl.seed_everything(42, workers=True)
model = FCC(dm.train_dataloader().dataset, args)

# Training

In [None]:
# Model checkpoint
checkpoint_callback = ModelCheckpoint(
    monitor='Average Precision  (AP) @[ IoU=0.5       | area=   all | maxDets=100 ]',
    dirpath=args.model_path,
    filename='fasterrcnn-{epoch:02d}-{AP_IoU_50:.2f}',
    #save_top_k=3,
    mode='max',
)
#early_stopping_callback = EarlyStopping(monitor="AP_IoU_50")
# Init trainer
trainer = pl.Trainer(#min_epochs=args.num_epochs,
                     max_epochs=args.num_epochs,
                     gpus=[0],
                     #tpu_cores=8,
                     deterministic=True,    
                     check_val_every_n_epoch=args.eval_epoch,
                     logger=True,
                     progress_bar_refresh_rate=5,
                     callbacks=[checkpoint_callback],
                     auto_lr_find=True
                     #fast_dev_run = True
                     #auto_lr_find = "hparams.lr"
                     )


# Fit model
trainer.fit(model, datamodule=dm)