# Importing Necessary Libraries

In [None]:
import detectron2
import contextlib
import datetime
import io
import os
import json
import logging
import cv2
import random
import numpy as np
import copy,torch,torchvision
import PIL
from PIL import Image
import xml.etree.ElementTree as X
import math
from itertools import repeat
import re
import shutil
import io
import ast

from fvcore.common.file_io import PathManager
from fvcore.common.timer import Timer

from detectron2.structures import Boxes, BoxMode, PolygonMasks
from detectron2.config import *
from detectron2.modeling import build_model
from detectron2 import model_zoo
from detectron2.data import transforms as T
from detectron2.data import detection_utils as utils
from detectron2.data import DatasetCatalog, MetadataCatalog, build_detection_test_loader, build_detection_train_loader
from detectron2.data.datasets import register_coco_instances
from detectron2.engine import DefaultTrainer, DefaultPredictor
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.evaluation import RotatedCOCOEvaluator,DatasetEvaluators, inference_on_dataset, coco_evaluation,DatasetEvaluator
from detectron2.utils.logger import setup_logger
from detectron2.utils.visualizer import Visualizer

import matplotlib.pyplot as plt
from platform import python_version

import glob
import time
import shutil
from multiprocessing.pool import ThreadPool
import concurrent.futures

import torch
torch.cuda.set_device(0)

from torch.utils.cpp_extension import CUDA_HOME
print(torch.cuda.is_available(), CUDA_HOME)

setup_logger()

# Custom Function for Preparing Training Set

In [None]:
def get_rbbox(mask):
    import cv2
    cnts, hierarchy = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    rbbox = cv2.minAreaRect(cnts[0])
    return rbbox



def make_rbbox_cotton_dicts(Train_data_path, image_id = 1):

    padded_seg_dicts = make_seg_cotton_dicts(Train_data_path)

    dataset_list = []
    for file in padded_seg_dicts:

        img_height = file['height']
        img_width = file['width']
        img_path = file['file_name']
        frame_name = file['fr_name']

        dict_holder = {}
        dict_holder["file_name"] = img_path
        dict_holder["height"] =  img_height
        dict_holder["width"] = img_width
        dict_holder["image_id"] = image_id
        dict_holder["fr_name"] = frame_name

        # loop over each instance in current image and save annotations dictionary in a list
        annotations = []
        for index,variable in enumerate(file['annotations']):
            category = variable['category_id']
            segment = variable['segmentation']
            mymask = detectron2.structures.polygons_to_bitmask(segment, img_height,img_width)
            mymask = 255*mymask
            rbbox = get_rbbox((mymask).astype('uint8'))
            cent_x = rbbox[0][0]
            cent_y = rbbox[0][1]
            w = rbbox[1][0]
            h = rbbox[1][1]
            angle = rbbox[2]
#             if h > w:
#                 angle = 90-angle
#             else:
            angle = -angle # -angle works best (for now)
            bbox = [cent_x, cent_y, w, h, angle]
            bbox_mode = detectron2.structures.BoxMode(4) # box_mode = 4 --> (x_cent,y_cent,w,h,a)
            dict_annot = {
                            "bbox": bbox,
                            "bbox_mode": bbox_mode,
                            "category_id": category,
                        }
            annotations.append(dict_annot)

        dict_holder["annotations"] = annotations

        if 'train' in Train_data_path:
                    dataset_list.append(dict_holder)
                    image_id += 1
        else:
            if 'aug' in frame_name:
                dataset_list.append(dict_holder)
                image_id += 1
                
    return dataset_list

In [None]:
Train_data_path = 'train_average'
Base_path = 'Cotton Fiber Project'
rbbox_train_dicts = make_rbbox_cotton_dicts(Train_data_path)

In [None]:
for d in ["train_average"]: #,,"val","test" (enter inside list for val data creation)
    DatasetCatalog.register("CFH_" + d,lambda d=d: make_seg_cotton_dicts(os.path.join(Base_path,d)))
    MetadataCatalog.get("CFH_" + d).thing_classes=["fiber"]

In [None]:
metadata_train = MetadataCatalog.get("CFH_train_average")

# Custom Dataset Mapper

In [None]:
def my_transform_instance_annotations(annotation, transforms, image_size, *, keypoint_hflip_indices=None):
    if annotation["bbox_mode"] == BoxMode.XYWHA_ABS:
        annotation["bbox"] = transforms.apply_rotated_box(np.asarray([annotation["bbox"]]))[0]
    else:
        bbox = BoxMode.convert(annotation["bbox"], annotation["bbox_mode"], BoxMode.XYXY_ABS)
        # Note that bbox is 1d (per-instance bounding box)
        annotation["bbox"] = transforms.apply_box([bbox])[0]
        annotation["bbox_mode"] = BoxMode.XYXY_ABS

    return annotation

def mapper(dataset_dict):
    # Implement a mapper, similar to the default DatasetMapper, but with our own customizations
    dataset_dict = copy.deepcopy(dataset_dict)  # it will be modified by code below
    image = utils.read_image(dataset_dict["file_name"], format="BGR")
    image, transforms = T.apply_transform_gens([T.Resize((800, 800))], image)
    dataset_dict["image"] = torch.as_tensor(image.transpose(2, 0, 1).astype("float32"))

    annos = [
      my_transform_instance_annotations(obj, transforms, image.shape[:2]) 
      for obj in dataset_dict.pop("annotations")
      if obj.get("iscrowd", 0) == 0
    ]
    instances = utils.annotations_to_instances_rotated(annos, image.shape[:2])
    dataset_dict["instances"] = utils.filter_empty_instances(instances)
    return dataset_dict

# Custom Trainer and Visualizer Class

In [None]:
class MyTrainer(DefaultTrainer):
    @classmethod
    def build_evaluator(cls, cfg, dataset_name):
        output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
        evaluators = [RotatedCOCOEvaluator(dataset_name, cfg, True, output_folder)]
        return DatasetEvaluators(evaluators)
      
    @classmethod
    def build_train_loader(cls, cfg):
        return build_detection_train_loader(cfg, mapper=mapper)

class RotatedPredictor(DefaultPredictor):
    def __init__(self, cfg):
        
        self.cfg = cfg.clone()  # cfg can be modified by model
        self.model = trainer.model
        self.model.eval()

        self.transform_gen = T.ResizeShortestEdge(
            [cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST
        )

        self.input_format = cfg.INPUT.FORMAT
        assert self.input_format in ["RGB", "BGR"], self.input_format

    def __call__(self, original_image):
        """
        Args:
            original_image (np.ndarray): an image of shape (H, W, C) (in BGR order).
        Returns:
            predictions (dict):
                the output of the model for one image only.
                See :doc:`/tutorials/models` for details about the format.
        """
        with torch.no_grad():  # https://github.com/sphinx-doc/sphinx/issues/4258
            # Apply pre-processing to image.
            if self.input_format == "RGB":
                # whether the model expects BGR inputs or RGB
                original_image = original_image[:, :, ::-1]
            height, width = original_image.shape[:2]
            image = self.transform_gen.get_transform(original_image).apply_image(original_image)
            image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))

            inputs = {"image": image, "height": height, "width": width}
            predictions = self.model([inputs])[0]
            return predictions

# As of 0.3 the XYWHA_ABS box is not supported in the visualizer, this is fixed in master branch atm (19/11/20)
class myVisualizer(Visualizer):
  
    def draw_dataset_dict(self, dic):
        annos = dic.get("annotations", None)
        if annos:
            if "segmentation" in annos[0]:
                masks = [x["segmentation"] for x in annos]
            else:
                masks = None
            if "keypoints" in annos[0]:
                keypts = [x["keypoints"] for x in annos]
                keypts = np.array(keypts).reshape(len(annos), -1, 3)
            else:
                keypts = None

            boxes = [BoxMode.convert(x["bbox"], x["bbox_mode"], BoxMode.XYWHA_ABS) for x in annos]

            labels = [x["category_id"] for x in annos]
            names = self.metadata.get("thing_classes", None)
            if names:
                labels = [names[i] for i in labels]
            labels = [
                "{}".format(i) + ("|crowd" if a.get("iscrowd", 0) else "")
                for i, a in zip(labels, annos)
            ]
            self.overlay_instances(labels=labels, boxes=boxes, masks=masks, keypoints=keypts)

        sem_seg = dic.get("sem_seg", None)
        if sem_seg is None and "sem_seg_file_name" in dic:
            sem_seg = cv2.imread(dic["sem_seg_file_name"], cv2.IMREAD_GRAYSCALE)
        if sem_seg is not None:
            self.draw_sem_seg(sem_seg, area_threshold=0, alpha=0.5)
        return self.output

## Function to Save the Detectron2 Config into Disk

In [None]:
def cfg2yaml(cfg):
    
    with open(cfg.OUTPUT_DIR + "/Config.txt", 'w') as file:
        file.write(cfg.dump())
    
    os.rename(cfg.OUTPUT_DIR + "/Config.txt", cfg.OUTPUT_DIR + "/Config.yaml")

# Setup Detectron2's Config

In [None]:
cfg = get_cfg()

cfg.OUTPUT_DIR = 'FasterRCNN Test'

cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"))
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml") # Let training initialize from model zoo
cfg.DATASETS.TRAIN = ("CFH_train_average_padded_rotated",)
cfg.DATASETS.TEST = ()

cfg.MODEL.MASK_ON=False
cfg.MODEL.PROPOSAL_GENERATOR.NAME = "RRPN"
cfg.MODEL.RPN.HEAD_NAME = "StandardRPNHead"
cfg.MODEL.RPN.BBOX_REG_WEIGHTS = (10,10,5,5,1)
cfg.MODEL.ANCHOR_GENERATOR.NAME = "RotatedAnchorGenerator"
cfg.MODEL.ANCHOR_GENERATOR.ANGLES = [[-90,-72,-54,-36,-18,0,18,36,54,72,90]]
# cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.9 
cfg.MODEL.ROI_HEADS.NAME = "RROIHeads"
# cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512   #this is far lower than usual.  
cfg.MODEL.ROI_HEADS.NUM_CLASSES =1
cfg.MODEL.ROI_BOX_HEAD.POOLER_TYPE = "ROIAlignRotated"
cfg.MODEL.ROI_BOX_HEAD.BBOX_REG_WEIGHTS = (10,10,5,5,1)
# cfg.MODEL.ROI_BOX_HEAD.NUM_CONV=4
# cfg.MODEL.ROI_MASK_HEAD.NUM_CONV=8
cfg.SOLVER.IMS_PER_BATCH = 14 #can be up to  24 for a p100 (6 default)
cfg.SOLVER.CHECKPOINT_PERIOD=2500
cfg.SOLVER.BASE_LR = 0.00015
# cfg.SOLVER.GAMMA=0.5
cfg.SOLVER.STEPS=(17500, 19000)
cfg.SOLVER.MAX_ITER=20000


cfg.DATALOADER.NUM_WORKERS = 1
cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS = True 
# cfg.DATALOADER.SAMPLER_TRAIN= "RepeatFactorTrainingSampler"
# cfg.DATALOADER.REPEAT_THRESHOLD=0.01
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)#lets just check our output dir exists
# cfg.MODEL.BACKBONE.FREEZE_AT=6
cfg2yaml(cfg)

In [None]:
trainer = MyTrainer(cfg) 
trainer.resume_or_load(resume=True)
trainer.train()

# Custom COCO Evaluator Class

In [None]:
# Bug in RotatedCOCOEvaluator where it gets passed img_ids
class MyRotatedCOCOEvaluator(RotatedCOCOEvaluator):
    def _eval_predictions(self, tasks, predictions, img_ids=None):
        super()._eval_predictions(tasks, predictions)

## Evaluate Model

In [None]:
# Create coco evaluator, but use the default detectron2 data format for generation, make sure ids overlap
evaluator = MyRotatedCOCOEvaluator("CFH_train_average_padded_rotated", cfg, False, output_dir=cfg.OUTPUT_DIR)
val_loader = build_detection_test_loader(cfg, "CFH_train_average_padded_rotated", mapper=mapper) 
outputs = inference_on_dataset(trainer.model, val_loader, evaluator)

# Visualize Model Output and Performance

In [None]:
for d in random.sample(test_dict, 3):
    im = cv2.imread(d["file_name"])
    outputs = predictor(im)  # format is documented at https://detectron2.readthedocs.io/tutorials/models.html#model-output-format
    v = myVisualizer(im[:, :, ::-1],
                  metadata=MetadataCatalog.get("Test"), 
                  scale=0.5)
                  # instance_mode=ColorMode.IMAGE_BW   # remove the colors of unsegmented pixels. This option is only available for segmentation models
    # )
    out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
    plt.imshow(out.get_image()[:, :, ::-1])
    plt.show()