# Importing Necessary Libraries

In [None]:
import detectron2
import contextlib
import datetime
import io
import os
import json
import logging
import cv2
import random
import numpy as np
import copy,torch,torchvision
import PIL
from PIL import Image
import xml.etree.ElementTree as X
import math
from itertools import repeat
import re
import shutil
import io
import ast

from fvcore.common.file_io import PathManager
from fvcore.common.timer import Timer

from detectron2.structures import Boxes, BoxMode, PolygonMasks
from detectron2.config import *
from detectron2.modeling import build_model
from detectron2 import model_zoo
from detectron2.data import transforms as T
from detectron2.data import detection_utils as utils
from detectron2.data import DatasetCatalog, MetadataCatalog, build_detection_test_loader, build_detection_train_loader
from detectron2.data.datasets import register_coco_instances
from detectron2.engine import DefaultTrainer, DefaultPredictor
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.evaluation import RotatedCOCOEvaluator,DatasetEvaluators, inference_on_dataset, coco_evaluation,DatasetEvaluator
from detectron2.utils.logger import setup_logger
from detectron2.utils.visualizer import Visualizer

import matplotlib.pyplot as plt
from platform import python_version

import glob
import time
import shutil
from multiprocessing.pool import ThreadPool
import concurrent.futures

import torch
torch.cuda.set_device(0)

from torch.utils.cpp_extension import CUDA_HOME
print(torch.cuda.is_available(), CUDA_HOME)

setup_logger()

# Custom Function for Preparing Training Set

In [None]:
def get_rbbox(mask):
    import cv2
    cnts, hierarchy = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    rbbox = cv2.minAreaRect(cnts[0])
    return rbbox



def make_rbbox_cotton_dicts(Train_data_path, image_id = 1):

    padded_seg_dicts = make_seg_cotton_dicts(Train_data_path)

    dataset_list = []
    for file in padded_seg_dicts:

        img_height = file['height']
        img_width = file['width']
        img_path = file['file_name']
        frame_name = file['fr_name']

        dict_holder = {}
        dict_holder["file_name"] = img_path
        dict_holder["height"] =  img_height
        dict_holder["width"] = img_width
        dict_holder["image_id"] = image_id
        dict_holder["fr_name"] = frame_name

        # loop over each instance in current image and save annotations dictionary in a list
        annotations = []
        for index,variable in enumerate(file['annotations']):
            category = variable['category_id']
            segment = variable['segmentation']
            mymask = detectron2.structures.polygons_to_bitmask(segment, img_height,img_width)
            mymask = 255*mymask
            rbbox = get_rbbox((mymask).astype('uint8'))
            cent_x = rbbox[0][0]
            cent_y = rbbox[0][1]
            w = rbbox[1][0]
            h = rbbox[1][1]
            angle = rbbox[2]
#             if h > w:
#                 angle = 90-angle
#             else:
            angle = -angle # -angle works best (for now)
            bbox = [cent_x, cent_y, w, h, angle]
            bbox_mode = detectron2.structures.BoxMode(4) # box_mode = 4 --> (x_cent,y_cent,w,h,a)
            dict_annot = {
                            "bbox": bbox,
                            "bbox_mode": bbox_mode,
                            "category_id": category,
                        }
            annotations.append(dict_annot)

        dict_holder["annotations"] = annotations

        if 'train' in Train_data_path:
                    dataset_list.append(dict_holder)
                    image_id += 1
        else:
            if 'aug' in frame_name:
                dataset_list.append(dict_holder)
                image_id += 1
                
    return dataset_list

In [None]:
Train_data_path = 'train_average'
Base_path = 'Cotton Fiber Project'
rbbox_train_dicts = make_rbbox_cotton_dicts(Train_data_path)

In [None]:
for d in ["train_average"]: #,,"val","test" (enter inside list for val data creation)
    DatasetCatalog.register("CFH_" + d,lambda d=d: make_seg_cotton_dicts(os.path.join(Base_path,d)))
    MetadataCatalog.get("CFH_" + d).thing_classes=["fiber"]

In [None]:
metadata_train = MetadataCatalog.get("CFH_train_average")

# Custom Dataset Mapper

In [None]:
def my_transform_instance_annotations(annotation, transforms, image_size, *, keypoint_hflip_indices=None):
    if annotation["bbox_mode"] == BoxMode.XYWHA_ABS:
        annotation["bbox"] = transforms.apply_rotated_box(np.asarray([annotation["bbox"]]))[0]
    else:
        bbox = BoxMode.convert(annotation["bbox"], annotation["bbox_mode"], BoxMode.XYXY_ABS)
        # Note that bbox is 1d (per-instance bounding box)
        annotation["bbox"] = transforms.apply_box([bbox])[0]
        annotation["bbox_mode"] = BoxMode.XYXY_ABS

    return annotation

def mapper(dataset_dict):
    # Implement a mapper, similar to the default DatasetMapper, but with our own customizations
    dataset_dict = copy.deepcopy(dataset_dict)  # it will be modified by code below
    image = utils.read_image(dataset_dict["file_name"], format="BGR")
    image, transforms = T.apply_transform_gens([T.Resize((800, 800))], image)
    dataset_dict["image"] = torch.as_tensor(image.transpose(2, 0, 1).astype("float32"))

    annos = [
      my_transform_instance_annotations(obj, transforms, image.shape[:2]) 
      for obj in dataset_dict.pop("annotations")
      if obj.get("iscrowd", 0) == 0
    ]
    instances = utils.annotations_to_instances_rotated(annos, image.shape[:2])
    dataset_dict["instances"] = utils.filter_empty_instances(instances)
    return dataset_dict