In [None]:
# ===== Imports =====
from functools import partial
import os
import cv2
import numpy as np
import torch
from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.config import get_cfg
from detectron2.engine import DefaultTrainer, DefaultPredictor
from detectron2.utils.visualizer import Visualizer, ColorMode
from detectron2.structures import BoxMode
import pycocotools.mask as mask_util
from detectron2 import model_zoo
import matplotlib.pyplot as plt

# ===== Object Classes =====
OBJ_CLASSES = {
    0: '00_person',
    1: '01_table',
    2: '02_leg',
    3: '03_tool1',
    4: '04_tool2',
}

# ===== Helper Functions =====
def create_sub_masks(mask_image):
    """
    Convert a colored mask image into class-wise binary masks.
    """
    sub_masks = {}
    height, width, _ = mask_image.shape
    # Define RGB colors for each class
    color_map = [(0,128,128), (128,0,0), (128,128,0), (0,128,0), (0,0,128)]
    
    for class_id, color in enumerate(color_map):
        mask = cv2.inRange(mask_image, np.array(color), np.array(color))
        if mask.sum() > 0:
            sub_masks[class_id] = mask
    return sub_masks

def dexycb_hand_seg_func_mivos(num_samples=-1, dir='dataset'):
    lst = []
    images_dir = os.path.join(dir, 'images')
    masks_dir = os.path.join(dir, 'masks')
    image_files = sorted([f for f in os.listdir(images_dir) if f.endswith('.jpg') or f.endswith('.png')])
    if num_samples > 0:
        image_files = image_files[:num_samples]

    for i, img_file in enumerate(image_files):
        color_file = os.path.join(images_dir, img_file)
        mask_file = os.path.join(masks_dir, os.path.splitext(img_file)[0] + '_mask.png')

        assert os.path.exists(color_file), f"{color_file} missing"
        assert os.path.exists(mask_file), f"{mask_file} missing"

        # Read original image for size
        img = cv2.imread(color_file)
        height, width = img.shape[:2]

        # Read mask for segmentation
        seg_img = cv2.imread(mask_file)
        sub_masks = create_sub_masks(seg_img)

        annotations = []
        for class_id, mask in sub_masks.items():
            contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            for contour in contours:
                contour = contour.flatten().tolist()
                if len(contour) > 4:
                    x, y, w, h = cv2.boundingRect(np.array(contour).reshape(-1,2))
                    annotations.append({
                        "bbox": [x, y, x+w, y+h],
                        "bbox_mode": BoxMode.XYXY_ABS,
                        "category_id": class_id,
                        "segmentation": [contour]
                    })

        lst.append({
            "file_name": color_file,
            "height": height,
            "width": width,
            "image_id": i,
            "annotations": annotations
        })
    return lst
 

# ===== Register Dataset =====
DatasetCatalog.register("Mivos", partial(dexycb_hand_seg_func_mivos, num_samples=-1, dir='dataset'))
meta = MetadataCatalog.get("Mivos")
meta.thing_classes = [OBJ_CLASSES[i] for i in range(len(OBJ_CLASSES))]

# ===== Configure Mask R-CNN =====
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
cfg.DATASETS.TRAIN = ("Mivos",)
cfg.DATASETS.TEST = ()
cfg.DATALOADER.NUM_WORKERS = 2
cfg.MODEL.ROI_HEADS.NUM_CLASSES = len(OBJ_CLASSES)
cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
cfg.SOLVER.BASE_LR = 0.00025
cfg.SOLVER.MAX_ITER = 500
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.OUTPUT_DIR = "./output_mivos"
cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS = False  # allow empty images
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)

# ===== Train Mask R-CNN =====
trainer = DefaultTrainer(cfg)
trainer.resume_or_load(resume=False)
trainer.train()

# ===== Inference and Save Masks =====
cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
predictor = DefaultPredictor(cfg)

output_masks_dir = os.path.join(cfg.OUTPUT_DIR, "predicted_masks")
os.makedirs(output_masks_dir, exist_ok=True)

dataset_dicts = dexycb_hand_seg_func_mivos(dir='dataset')

for d in dataset_dicts:
    img = cv2.imread(d['file_name'])
    outputs = predictor(img)
    masks = outputs["instances"].pred_masks.cpu().numpy()
    classes = outputs["instances"].pred_classes.cpu().numpy()

    for idx, mask in enumerate(masks):
        mask_img = (mask * 255).astype(np.uint8)
        save_path = os.path.join(output_masks_dir, f"{os.path.splitext(os.path.basename(d['file_name']))[0]}_mask{idx}.png")
        cv2.imwrite(save_path, mask_img)

    # Visualize
    v = Visualizer(img[:, :, ::-1], metadata=meta, scale=0.8, instance_mode=ColorMode.IMAGE_BW)
    v = v.draw_instance_predictions(outputs["instances"].to("cpu"))
    plt.figure(figsize=(8,6))
    plt.imshow(v.get_image()[:, :, ::-1])
    plt.axis('off')
    plt.show()