In [None]:
import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()


import numpy as np
import os, json, cv2, random
import matplotlib.pyplot as plt


from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor, DefaultTrainer
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer, ColorMode
from detectron2.data import MetadataCatalog, DatasetCatalog

# run this only if setup is not complete 
import os
os.chdir('../lib')
import importlib, setup_coco_json, process_coco_json

importlib.reload(setup_coco_json)
importlib.reload(process_coco_json)

from setup_coco_json import setup_rgb, setup_grayscale, setup_rgbd
from process_coco_json import get_coco_rgb, get_coco_grayscale, get_coco_rgbd

os.chdir('../data')
setup_rgb('./useable_data', coco_json_dir='./coco_json', per_train=70, per_val=15, per_test=15)
get_coco_rgb("./coco_json/rgb/")

from detectron2.data.datasets import register_coco_instances

register_coco_instances("my_dataset_train", {}, "../data/coco_json/rgb/train/images/train.json", "../data/coco_json/rgb/train/images/")
register_coco_instances("my_dataset_val", {}, "../data/coco_json/rgb/val/images/val.json", "../data/coco_json/rgb/val/images/")
register_coco_instances("my_dataset_test", {}, "../data/coco_json/rgb/test/images/test.json", "../data/coco_json/rgb/test/images/")

train_metadata = MetadataCatalog.get("my_dataset_train")
train_dataset_dicts = DatasetCatalog.get("my_dataset_train")

val_metadata = MetadataCatalog.get("my_dataset_val")
val_dataset_dicts = DatasetCatalog.get("my_dataset_val")

test_metadata = MetadataCatalog.get("my_dataset_test")
test_dataset_dicts = DatasetCatalog.get("my_dataset_test")

mask_folder = "../data/coco_json/rgb/train/masks/Tumor"  # Update with your binary mask folder path

for d in random.sample(train_dataset_dicts, 1):
   
    img = cv2.imread(d["file_name"])

    print(f"Image file: {d['file_name']}")
    

    if img is None:
        print(f"Failed to load image: {d['file_name']}")
        continue

    #image annotations
    visualizer = Visualizer(img[:, :, ::-1], metadata=train_metadata, scale=0.5)
    vis = visualizer.draw_dataset_dict(d)

    # Display the annotated image
    plt.subplot(1, 2, 1)  # 1 row, 2 columns, first subplot
    plt.imshow(vis.get_image()[:, :, ::-1])
    plt.title("Image with Annotations")

    # Load the corresponding binary mask
    image_base_name = os.path.basename(d["file_name"]).split(".")[0]  # Get base name of the image
    mask_path = os.path.join(mask_folder, f"{image_base_name}.png")  # Assuming masks have .png extension

    print(f"Mask file: {mask_path}")
    
    # Load the binary mask
    mask_img = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

    # Check if the mask was successfully loaded
    if mask_img is None:
        print(f"Failed to load mask: {mask_path}")
        continue

    # Display the binary mask
    plt.subplot(1, 2, 2)  # 1 row, 2 columns, second subplot
    plt.imshow(mask_img, cmap="gray")
    plt.title("Binary Mask")

    plt.show()  # Display the images side by side

# Custom segmentation overlap loss function
def compute_segmentation_overlap_loss(pred_masks, target_masks):
    intersection = np.logical_and(pred_masks, target_masks)
    union = np.logical_or(pred_masks, target_masks)
    
    # Compute IoU
    iou = np.sum(intersection) / (np.sum(union) + 1e-6)  # Adding a small constant to avoid division by zero
    return 1 - iou  # Return the overlap loss (1 - IoU)

# Custom trainer with overlap loss
class CustomTrainer(DefaultTrainer):
    def build_loss(self, outputs, targets):
        losses = super().build_loss(outputs, targets)
        
        # Assuming `outputs` contains predicted masks and `targets` contains ground truth masks
        pred_masks = outputs['instances'].pred_masks.cpu().numpy()
        target_masks = targets['instances'].gt_masks.cpu().numpy()
        
        # Compute segmentation overlap loss
        overlap_loss = compute_segmentation_overlap_loss(pred_masks, target_masks)
        
        # Add the overlap loss to the existing losses
        losses['segmentation_overlap_loss'] = overlap_loss
        return losses

cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.DATASETS.TRAIN = ("my_dataset_train",)
cfg.DATASETS.TEST = ()
cfg.DATALOADER.NUM_WORKERS = 2
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")  # Let training initialize from model zoo
cfg.SOLVER.IMS_PER_BATCH = 2  # This is the real "batch size" commonly known to deep learning people
cfg.SOLVER.BASE_LR = 0.00025  # pick a good LR
cfg.SOLVER.MAX_ITER = 300    # 300 iterations seems good enough for this toy dataset; you will need to train longer for a practical dataset
cfg.SOLVER.STEPS = []        # do not decay learning rate
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128   # The "RoIHead batch size". 128 is faster, and good enough for this toy dataset (default: 512)
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1  # only has one class (tumor)

os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)

# Use CustomTrainer
trainer = CustomTrainer(cfg) 
trainer.resume_or_load(resume=False)
trainer.train()

cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")  # path to the model we just trained
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5   # set a custom testing threshold

# Function to display the original image with annotations and the predicted image
def display_original_and_prediction_with_annotations(val_dataset_dicts, predictor, val_metadata):
    for d in random.sample(val_dataset_dicts, 1):  # Select number of images for display
        im = cv2.imread(d["file_name"])
        outputs = predictor(im)
        
        # Create a visualizer object for the original image with annotations
        v_gt = Visualizer(im[:, :, ::-1],
                          metadata=val_metadata,
                          scale=0.5
        )
        out_gt = v_gt.draw_dataset_dict(d)
        
        # Create a visualizer object for the predicted image
        v_pred = Visualizer(im[:, :, ::-1],
                            metadata=val_metadata,
                            scale=0.5,
                            instance_mode=ColorMode.IMAGE_BW  # Remove the colors of unsegmented pixels
        )
        out_pred = v_pred.draw_instance_predictions(outputs["instances"].to("cpu"))
        
        # Set up subplots
        fig, ax = plt.subplots(1, 2, figsize=(30, 30))
        
        # Display the original image with annotations
        ax[0].imshow(out_gt.get_image()[:, :, ::-1])
        ax[0].set_title('Original Image with Annotations', fontsize=24)
        
        # Display the predicted image
        ax[1].imshow(out_pred.get_image()[:, :, ::-1])
        ax[1].set_title('Predicted Image', fontsize=24)

        for a in ax:
            a.axis("off")
        
        plt.show()

# test usage
predictor = DefaultPredictor(cfg)
display_original_and_prediction_with_annotations(test_dataset_dicts, predictor, test_metadata)

def save_all_images(val_dataset_dicts, predictor, val_metadata, output_dir):
    
    # Create image path if it doesn't exist 
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    for i, d in enumerate(val_dataset_dicts):

        im = cv2.imread(d["file_name"])
        outputs = predictor(im)
        
        v_gt = Visualizer(im[:, :, ::-1], metadata=val_metadata, scale=0.5)
        out_gt = v_gt.draw_dataset_dict(d)
        
        v_pred = Visualizer(im[:, :, ::-1], metadata=val_metadata, scale=0.5, instance_mode=ColorMode.IMAGE_BW)
        out_pred = v_pred.draw_instance_predictions(outputs["instances"].to("cpu"))
        
        fig, ax = plt.subplots(1, 2, figsize=(30, 30))
        
        ax[0].imshow(out_gt.get_image()[:, :, ::-1])
        ax[0].set_title('Original Image with Annotations', fontsize=24)
        
        ax[1].imshow(out_pred.get_image()[:, :, ::-1])
        ax[1].imshow(out_pred.get_image()[:, :, ::-1])
        ax[1].set_title('Predicted Image', fontsize=24)

        for a in ax:
            a.axis("off")

        # Save the plot to the specified output directory
        save_path = os.path.join(output_dir, f"comparison_{i}.png")
        plt.savefig(save_path)
        plt.close()

        print(f"Saved comparison image to {save_path}")

 save all validation set images
output_dir = "./output_images"
save_all_images(test_dataset_dicts, predictor, test_metadata, output_dir)
