 < read the notes with the numbering order for better understanding :D >

1. This is an instance segmentation based on a pretrained model on Mask-RCNN to segmentate "Chest-CT Scan " images to Normal=Background , GGO=Ground Glass Opacity=Red , and C=Consolidation=Blue , which are the most common lesions seen in Chest CT Scan images. The goal of this code is to Mask these Lesions over these images and show masks and the probabalities of them

2. In order to start, lets prepare the colab environment :

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')
%cd gdrive/MyDrive/Colab Notebooks/CVFOLDER/CVPROJECT     #path to the folder
import sys
path_to_module = '/content/gdrive/MyDrive/Colab Notebooks/CVFOLDER/CVPROJECT'  # path to folder to import libs.
sys.path.append(path_to_module)

3. then importing libraries needed for sure 🧑

In [None]:

import models.mask_net as mask_net
from models.mask_net.rpn_segmentation import AnchorGenerator
import time
import copy
import torch       
import torchvision
import numpy as np
import os
import cv2
from models.mask_net.faster_rcnn import FastRCNNPredictor, TwoMLPHead
from models.mask_net.covid_mask_net import MaskRCNNHeads, MaskRCNNPredictor
from torchvision import transforms
from PIL import Image as PILImage
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from tqdm import tqdm

4. Now its time to get the information of config dict and load the pretrained model, prepare Mask Rcnn  model and do segmentation task these information in config dict are actually hyper parameters which will be explained in last step

In [None]:
# main method
def main(config):
    if config["device"] == 'cuda' and torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    confidence_threshold = config["confidence_th"]
    mask_threshold = config["mask_logits_th"]
    save_dir = config["save_dir"]         # where saves segmented imgs 
    data_dir = config["test_data_dir"]    # path to test folder
    img_dir =  config["test_img_dir"]     #path to imgs in test folder
    mask_type = config["mask_type"]
    rpn_nms = config["rpn_nms_th"]
    roi_nms =  config["roi_nms_th"]
    truncation = config["truncation"]
    backbone_name='resnet50'
    if mask_type == "both":
        n_c = 3
    else:
        n_c = 2
    ckpt = torch.load(config["ckpt"], map_location=device)   #loads pretrained model( .pth) file

    sizes = ckpt['anchor_generator'].sizes
    aspect_ratios = ckpt['anchor_generator'].aspect_ratios
    anchor_generator = AnchorGenerator(sizes, aspect_ratios)
    print("Anchors: ", anchor_generator.sizes, anchor_generator.aspect_ratios)

    box_head = TwoMLPHead(in_channels=7 * 7 * 256, representation_size=128)
    box_predictor = FastRCNNPredictor(in_channels=128, num_classes=n_c)
    mask_predictor = MaskRCNNPredictor(in_channels=256, dim_reduced=256, num_classes=n_c)

    # keyword arguments
    maskrcnn_args = {'num_classes': None, 'min_size': 512, 'max_size': 1024, 'box_detections_per_img': 100,
                     'box_nms_thresh': roi_nms, 'box_score_thresh': confidence_threshold, 'rpn_nms_thresh': rpn_nms,
                     'box_head': box_head, 'rpn_anchor_generator': anchor_generator, 'mask_head':None,
                     'mask_predictor': mask_predictor, 'box_predictor': box_predictor}

    # Instantiate the segmentation model
    maskrcnn_model = mask_net.maskrcnn_resnet_fpn(backbone_name, truncation, pretrained_backbone=False, **maskrcnn_args)
    # Load weights
    maskrcnn_model.load_state_dict(ckpt['model_weights'])
    # Set to evaluation mode
    print(maskrcnn_model)
    maskrcnn_model.eval().to(device)

    start_time = time.time()
    # get the correct masks and mask colors
    if mask_type == "ggo":
       ct_classes = {0: '__bgr', 1: 'GGO'}
       ct_colors = {1: 'red', 'mask_cols': np.array([[255, 0, 0]])}
    elif mask_type == "merge":
       ct_classes = {0: '__bgr', 1: 'Lesion'}
       ct_colors = {1: 'red', 'mask_cols': np.array([[255, 0, 0]])}
    elif mask_type == "both":
       ct_classes = {0: '__bgr', 1: 'GGO', 2: 'CL'}
       ct_colors = {1: 'red', 2: 'blue', 'mask_cols': np.array([[255, 0, 0], [0, 0, 255]])} 

    # run the inference with provided hyperparameters for images in the provided directory
    test_ims = os.listdir(os.path.join(data_dir, img_dir))
    
    for j, ims in enumerate(tqdm(test_ims)):
        test_step(os.path.join(os.path.join(data_dir, img_dir), ims), device, maskrcnn_model,confidence_threshold, mask_threshold, save_dir, ct_classes, ct_colors, j)
        if True: #REMOVE THIS AND BREAK LINE TO DO IT ON ALL IMAGES IN DIR
          break
    end_time = time.time()
    print("Inference took {0:.1f} seconds".format(end_time - start_time))
    print("fig saved in ", save_dir)

5. Here the image is passed through a convolutional network.
The output of first conv net, is passed through to a Region Proposal network (RPN) which creates different achor boxes (Regions of Interest) based on the presence of any of the objects to be detected.
The Anchor boxes are sent to ROI Align stage (one of the key features of Mask RCNN for protecting spatial orientation), which converts ROI’s to the same size required for further processing
This output is sent to Fully connected layers which will generate the result of the class of the object in that specific region and the location of the bounding box for the object
The output of ROI Align stage is parallelly sent to Conv Nets in order to generate a mask of the pixels of the object but there are some hyperparameters constructing the model

In [None]:
def test_step(image, device, model, theta_conf, theta_mask, save_dir, cls, cols, num):
    im = PILImage.open(image)
    # converts image to rgb
    if im.mode != 'RGB':
        im = im.convert(mode='RGB')
    img = np.array(im)
    # copy image to make background for plotting
    bgr_img = copy.deepcopy(img)

    t_ = transforms.Compose([transforms.ToPILImage(), transforms.ToTensor()])
    img = t_(img).to(device)
    out = model([img])
    # scores + bounding boxes + labels + masks
    scores = out[0]['scores']
    bboxes = out[0]['boxes']
    classes = out[0]['labels']
    mask = out[0]['masks']
    # this is the array for all masks
    best_scores = scores[scores > theta_conf]
    # Are there any detections with confidence above the threshold?
    if len(best_scores):
        best_idx = np.where(scores > theta_conf)
        best_bboxes = bboxes[best_idx]
        best_classes = classes[best_idx]
        best_masks = mask[best_idx]
        #print('bm', best_masks.shape)
        mask_array = np.zeros([best_masks[0].shape[1], best_masks[0].shape[2], 3], dtype=np.uint8)
        fig, ax = plt.subplots(1, 1)
        fig.set_size_inches(12, 6)
        ax.axis("off")
        # plot predictions
        for idx, dets in enumerate(best_bboxes):
            found_masks = best_masks[idx][0].detach().clone().to(device).numpy()
            pred_class = best_classes[idx].item()
            pred_col_n = cols[pred_class]
            pred_class_txt = cls[pred_class]
            pred_col = cols['mask_cols'][pred_class - 1]
            mask_array[found_masks > theta_mask] = pred_col
            rect = Rectangle((dets[0], dets[1]), dets[2] - dets[0], dets[3] - dets[1], linewidth=1,edgecolor=pred_col_n, facecolor='none', linestyle="--")
            ax.text(dets[0] + 40, dets[1], '{0:}'.format(pred_class_txt), fontsize=10, color=pred_col_n)
            ax.text(dets[0], dets[1], '{0:.2f}'.format(best_scores[idx]), fontsize=10, color=pred_col_n)
            ax.add_patch(rect)

        added_image = cv2.addWeighted(bgr_img, 0.5, mask_array, 0.75, gamma=0)
        ax.imshow(added_image)
        #fig.savefig(os.path.join(save_dir, str(num) + ".png"),bbox_inches='tight', pad_inches=0.0)
        fig.savefig(os.path.join(save_dir, "__MASKED__"+image.split('/')[-1] ),bbox_inches='tight', pad_inches=0.0)        

    else:
        print(image, " : No detections")

6. (last comment but the everything starts from here :D
 this is where we also need to define(config) the hyperparameters, which the most importatnt ones are:

**Backbone**: The Backbone is the Conv Net architecture that is to be used in the first step of Mask R-CNN. one of the famous available options for choice of Backbones is  ResNet50. the choice of backbone should be based on the trade off between training time and accuracy. ResNet50 would take relatively lesser time than the later ones, and has several open source pre-trained weights for huge data sets like coco, which can considerably reduce the training time for different instance segmentation projects. ResNet 101 and ResNext 101 will take more time for training (because of the number of layers), but they tend to be more accurate if there are no pre-trained weights involved and basic parameters like **learning rate** and **number of epochs** are well tuned.
An ideal approach would be to start with pre-trained weights available like coco with ResNet 50 and evaluate the performance of the model. This would work faster and better on models which involve detection of real world objects which were trained in the coco dataset. If accuracy is of utmost importance and high computation power is available, the options other backbones can be explored.

**Train_ROIs_Per_Image**
This is the maximum number of ROI’s, the Region Proposal Network will generate for the image, which will further be processed for classification and masking in the next stage. The ideal way is to start with default values if number of instances in the image are unknown. If the number of instances are limited, it can be reduced to reduce the training time.

**Max_GT_Instances:**
This is the maximum number of instances that can be detected in one image. If the number of instances in the images are limited, this can be set to maximum number of instances that can occur in the image. This helps in reduction of false positives and reduces the training time.

**Detection_Min_Confidence:**
This is the confidence level threshold, beyond which the classification of an instance will happen. Initialization can be at default and reduced or increased based on the number of instances that are detected in the model. If detection of everything is important and false positives are fine(which is true in our case), reduce the threshold to identify every possible instance. If accuracy of detection is important, increase the threshold to ensure that there are minimal false positive by guaranteeing that the model predicts only the instances with very high confidence.

In [None]:
# run the inference
if __name__ == '__main__':
    config_dict= {
        "backbone_name": "reset50",
        "ckpt": "segmentation_folder/pretrained_model_for_instance_segmentation/segmentation_model_both_classes.pth",
        "confidence_th": 0.05,   # theta conf
        "device": "cpu",
        "gt_dir": "masks",
        "mask_logits_th": 0.5,
        "mask_type": "both",
        "model_name": None,
        "roi_nms_th": 0.5,
        "rpn_nms_th": 0.75,
        "test_data_dir": "segmentation_folder", 
        "test_img_dir": "images",
        "save_dir": "segmentation_folder/results",
        "model_args": None,
        "truncation": '0'
    }
    main(config_dict)