In [None]:
import os, json

import cv2
import torch
import numpy as np

In [None]:
def cv2_imshow(img):
    import matplotlib.pyplot as plt
    plt.figure(figsize=(10,10))
    plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))

In [None]:
# original_dict : {
#     img_name  : {
#         'filename'
#         'regions' : {
#             'id'  : {
#                 'shape_attributes' : {
#                     'all_points_x'
#                     'all_points_y'
                    
#                 }
#             }
#         }
#     }
# }

# Required = [
#     {
#         'file_name'
#         'image_id'
#         'height'
#         'width'
#         'annotations' = [
#             {
#                 'bbox'
#                 'bbox_mode'
#                 'segmentation'
#                 'category_id'
#             }
#         ]
#     }
# ]

In [None]:
from detectron2.structures import BoxMode

In [None]:
def get_balloon_dicts(IMG_DIR):
    with open(os.path.join(IMG_DIR, 'via_region_data.json')) as file:
        original_dict = json.load(file)
    
    dataset = []
    # loop over images
    for idx, blob in enumerate(original_dict.values()):

        filename      = os.path.join(IMG_DIR, blob['filename'])
        height, width = cv2.imread(filename).shape[:2]

        regions = blob['regions']
        objects = []
        # loop over objects in an image
        for _, anno in regions.items():
            anno = anno["shape_attributes"]
            px   = anno["all_points_x"]
            py   = anno["all_points_y"]
            poly = [(x + 0.5, y + 0.5) for x, y in zip(px, py)]
            poly = [p for li in poly for p in li]

            obj = {
                    "bbox": [np.min(px), np.min(py), np.max(px), np.max(py)],
                    "bbox_mode": BoxMode.XYXY_ABS,
                    "segmentation": [poly],
                    "category_id": 0,
                }
            objects.append(obj)

        record = dict(
            file_name   = filename,
            image_id    = idx,
            height      = height,
            width       = width,
            annotations = objects

        )
        dataset.append(record)
    return dataset

In [None]:
DIR = '/home/l-ashwin/Datasets/balloon_dataset/balloon/'

In [None]:
from detectron2.data import DatasetCatalog, MetadataCatalog

In [None]:
for cat in ['train', 'val']:
    func = lambda x=cat:get_balloon_dicts(os.path.join(DIR, x))
    DatasetCatalog.register(f'balloon_{cat}', func)
    MetadataCatalog.get(f'balloon_{cat}').set(thing_classes=["balloon"])

In [None]:
balloon_metadata = MetadataCatalog.get('balloon_train')
balloon_metadata

In [None]:
DatasetCatalog

In [None]:
# for cat in ['train', 'val']:
#     DatasetCatalog.remove(f'balloon_{cat}')
#     MetadataCatalog.remove(f'balloon_{cat}')

In [None]:
dataset_train = get_balloon_dicts(os.path.join(DIR, 'train'))

In [None]:
record = np.random.choice(dataset_train)
img = cv2.imread(record['file_name'])

In [None]:
cv2_imshow(img)

In [None]:
from detectron2.utils.visualizer import Visualizer

In [None]:
visualizer = Visualizer(img[:,:,::-1], metadata=balloon_metadata)
out        = visualizer.draw_dataset_dict(record).get_image()[:,:,::-1]

In [None]:
cv2_imshow(out)

In [None]:
from detectron2.engine import DefaultTrainer
from detectron2.config import get_cfg
from detectron2 import model_zoo

In [None]:
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file(
    "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))

In [None]:
cfg.DATASETS.TRAIN = ("balloon_train",)
cfg.DATASETS.TEST = ()
cfg.DATALOADER.NUM_WORKERS = 2

In [None]:
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(
    "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")

In [None]:
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.SOLVER.BASE_LR = 0.00025 
cfg.SOLVER.MAX_ITER = 300    
cfg.SOLVER.STEPS = [] 

In [None]:
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128   
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1  

In [None]:
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = DefaultTrainer(cfg) 
trainer.resume_or_load(resume=False)
trainer.train()