#### Author: Madhusudhanan Balasubramanian (MB), Ph.D., The University of Memphis
#### Visualize model outputs

#### July 12, 2024:
    * Code segment copied from axon_net-Copy22-3_4.ipynb
#### April 6, 2025:
    * Final Centermask2 example outputs
    * Display detection and ground truth on the same image
    * Save the image into an output file
    * Run for all images in the dataset

In [3]:
# Work around to enhance error display involving unicode characters
%colors nocolor
%xmode plain

# import various libraries
import logging
import os, re
import numpy as np
import cv2
import matplotlib.pyplot as plt
#
import torch
#
from detectron2.data import MetadataCatalog, DatasetCatalog
from detectron2.data.datasets import register_coco_instances, load_coco_json
#
from detectron2.engine import DefaultTrainer, DefaultPredictor
from detectron2.engine import default_argument_parser, default_setup, hooks, launch
#
#MB: 03/31/2024
#from detectron2.evaluation import COCOEvaluator, inference_on_dataset
from centermask.evaluation import COCOEvaluator, inference_on_dataset
#
#from detectron2.utils.visualizer import Visualizer, ColorMode
from centermask.utils.axon_visualizer import Visualizer, ColorMode
#
from centermask.config import get_cfg

Exception reporting mode: Plain


## Config

In [4]:
cfg = get_cfg()
cfg.merge_from_file("configs/AxonClass/axon_V_39_eSE_FPN_ms_3x.yaml")
model_file = "model_0039999.pth"
cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, model_file)
cfg.num_gpus = 1
cfg.resume = True
#
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7   # Not used by centermask2
cfg.MODEL.ROI_HEADS.IOU_THRESHOLDS = [0.5] #IoU threshold for classing training proposals as true positive or false positive
cfg.MODEL.ROI_HEADS.IOU_LABELS = [0, 1] #Label 0 - false positive proposals, Label 1 - true positive proposals
cfg.freeze()
default_setup(cfg, '')

[04/07 11:20:50 detectron2]: Rank of current process: 0. World size: 1
[04/07 11:20:51 detectron2]: Environment info:
----------------------  ---------------------------------------------------------------------------------------------------
sys.platform            linux
Python                  3.7.12 | packaged by conda-forge | (default, Oct 26 2021, 06:08:21) [GCC 9.4.0]
numpy                   1.21.6
detectron2              0.6 @/opt/anaconda3/envs/venv_mamba/envs/venv_py37_D2_6a/lib/python3.7/site-packages/detectron2
Compiler                GCC 7.3
CUDA compiler           CUDA 11.3
detectron2 arch flags   3.7, 5.0, 5.2, 6.0, 6.1, 7.0, 7.5, 8.0, 8.6
DETECTRON2_ENV_MODULE   <not set>
PyTorch                 1.10.0 @/opt/anaconda3/envs/venv_mamba/envs/venv_py37_D2_6a/lib/python3.7/site-packages/torch
PyTorch debug build     False
GPU available           Yes
GPU 0,1                 Quadro RTX 8000 (arch=7.5)
Driver version          460.32.03
CUDA_HOME               /usr/local/cuda
Pill

## Register datasets

In [5]:
def register_datasets(dataset_names):
    """Registers multiple COCO datasets with Detectron2 if not already registered.

    Args:
        dataset_names (list): List of dataset names to register.
    """
    for dataset_name in dataset_names:
        json_file = f"./datasets/{dataset_name}/{dataset_name}.json"
        image_root = f"./datasets/{dataset_name}"

        if dataset_name not in DatasetCatalog.list():
            try:
                register_coco_instances(dataset_name, {}, json_file, image_root)
                print(f"Registered dataset '{dataset_name}'")
            except FileNotFoundError:
                print(f"Warning: Could not find annotation file or image directory for '{dataset_name}'")
        else:
            print(f"Dataset '{dataset_name}' already registered, skipping.")

#Register training axon datasets
dataset_names= ["Phase3_Training", "Phase3_Validation", "Phase3_Testing", "Quigley_Eval"]
register_datasets(dataset_names)

Registered dataset 'Phase3_Training'
Registered dataset 'Phase3_Validation'
Registered dataset 'Phase3_Testing'
Registered dataset 'Quigley_Eval'


## Visualize model outputs

In [12]:
def save_GT_annotations(meta_data, image_mat, data_dict, output_image_file):
    """
        Save visualization of GT object annotations
    """
    
    gt_outline_linewidth=0.8
    gt_outline_linestyle='-'

    #Instantiate Visualizer object "v" with the image and the metadata
    viz_obj = Visualizer(image_mat[:, :, ::-1],
                   metadata = meta_data, 
                   #MB: March 31, 2024
                   #instance_mode=ColorMode.IMAGE_BW,   # remove the colors of unsegmented pixels. This option is only available for segmentation models
                   instance_mode=ColorMode.SEGMENTATION
    )
    #
    #Draw the ground truth first
    visImage_obj = viz_obj.draw_dataset_dict(data_dict, jittering=False, outline_linewidth=gt_outline_linewidth, outline_linestyle=gt_outline_linestyle)
    output_image = visImage_obj.get_image()[:, :, ::-1]
    cv2.imwrite(output_image_file, output_image[:, :, ::-1])
    #
    #visImage_obj.save(output_image_file)

def save_detected_annotations(meta_data, image_mat, detected_instances, output_image_file):
    """
        Save visualization of annotations of objects detected by the model
    """
    
    dt_outline_linewidth = 0.8
    dt_outline_linestyle = '-'

    #Instantiate Visualizer object "v" with the image and the metadata
    viz_obj = Visualizer(image_mat[:, :, ::-1],
                   metadata = meta_data, 
                   #MB: March 31, 2024
                   #instance_mode=ColorMode.IMAGE_BW,   # remove the colors of unsegmented pixels. This option is only available for segmentation models
                   instance_mode=ColorMode.SEGMENTATION
    )
    #
    #Draw the detected annotations on top of the input image
    visImage_obj = viz_obj.draw_instance_predictions(detected_instances, jittering=False, outline_linewidth=dt_outline_linewidth)
    output_image = visImage_obj.get_image()[:, :, ::-1]
    cv2.imwrite(output_image_file, output_image[:, :, ::-1])
    #
    #visImage_obj.save(output_image_file)

def save_GT_DT_combined_annotations(meta_data, image_mat, data_dict, detected_instances, output_image_file):
    """
        Save combined annotation of GT objects and objects detected by the model
    """

    gt_outline_linewidth = 0.8
    gt_outline_linestyle = ':'
    dt_outline_linewidth = 0.4

    #GT image
    #
    #Instantiate Visualizer object "v" with the image and the metadata
    viz_obj = Visualizer(image_mat[:, :, ::-1],
                   metadata = meta_data, 
                   #MB: March 31, 2024
                   #instance_mode=ColorMode.IMAGE_BW,   # remove the colors of unsegmented pixels. This option is only available for segmentation models
                   instance_mode=ColorMode.SEGMENTATION
    )
    #
    #Draw the ground truth first
    visImage_obj = viz_obj.draw_dataset_dict(data_dict, jittering=False, outline_linewidth=gt_outline_linewidth, outline_linestyle=gt_outline_linestyle)

    #Draw model detections on top of the ground truth
    visImage_obj = viz_obj.draw_instance_predictions(detected_instances, jittering=False, outline_linewidth=dt_outline_linewidth)

    # Get the drawing as image
    out_image_combined = visImage_obj.get_image()[:, :, ::-1]
    cv2.imwrite(output_image_file, out_image_combined[:, :, ::-1])
    #
    #visImage_obj.save(output_image_file)

In [22]:
import random
import importlib
import centermask.utils.axon_visualizer
importlib.reload(centermask.utils.axon_visualizer)
from centermask.utils.axon_visualizer import Visualizer, ColorMode

dataset_name = 'Quigley_Eval' #'Phase3_Validation' #'Phase3_Testing' #
cfg.defrost()

predictor = DefaultPredictor(cfg)
dataset_dicts = DatasetCatalog.get(dataset_name)

#Output director
output_dir = os.path.join(cfg.OUTPUT_DIR, 'Annotated_Examples', dataset_name)
os.makedirs(output_dir, exist_ok=True)

#for d in random.sample(dataset_dicts, 2):
for d in dataset_dicts:
    #Form input / output files
    input_image_file = d["file_name"]
    base_filename = os.path.basename(input_image_file)
    base_name, extension = os.path.splitext(base_filename)

    image_file = os.path.join(output_dir, f"{base_name}{extension}")
    GT_image_file = os.path.join(output_dir, f"{base_name}_GT{extension}")
    DT_image_file = os.path.join(output_dir, f"{base_name}_DT{extension}")

    #Read the image
    im = cv2.imread(input_image_file)

    # Use the model to predict instances in the image "im"
    outputs = predictor(im)  # format is documented at https://detectron2.readthedocs.io/tutorials/models.html#model-output-format
    detected_instances = outputs["instances"].to("cpu")
    
    # Debugging
    #eval_dict_key_value(outputs)
    #tmp = outputs["instances"]
    #print(f"type(tmp): {type(tmp)}")
    #print(f"dir(tmp): {dir(tmp)}")
    #print(f"tmp.get_fields(): {tmp.get_fields()}")
    #print(f"outputs[instances].pred_classes.shape: {outputs['instances'].pred_classes.shape}")
    
    # Organize metadata associated with the dataset (such as class name, color)
    curr_meta_data = MetadataCatalog.get('test') # (dataset_name)
    curr_meta_data.thing_colors = [(0, 0, 255), (0, 255, 0), (255, 0, 0)]

    # Save the original image
    cv2.imwrite(image_file, im)

    # Save GT and detected annotations
    save_GT_annotations(curr_meta_data, im, d, GT_image_file)
    save_detected_annotations(curr_meta_data, im, detected_instances, DT_image_file)

FPN -> in_features: ['stage2', 'stage3', 'stage4']
FPN -> strides: [4, 8, 16]
FPN -> top_block: None
[04/07 12:04:41 fvcore.common.checkpoint]: [Checkpointer] Loading from output/AxonClass/AxonClass-V-39-ms-3x/model_0039999.pth ...
[04/07 12:04:41 d2.data.datasets.coco]: Loaded 51 images in COCO format from ./datasets/Quigley_Eval/Quigley_Eval.json
GeneralizedRCNN.inference()->do_postprocess: True
GeneralizedRCNN.inference()->do_postprocess: True
GeneralizedRCNN.inference()->do_postprocess: True
GeneralizedRCNN.inference()->do_postprocess: True
GeneralizedRCNN.inference()->do_postprocess: True
GeneralizedRCNN.inference()->do_postprocess: True
GeneralizedRCNN.inference()->do_postprocess: True
GeneralizedRCNN.inference()->do_postprocess: True
GeneralizedRCNN.inference()->do_postprocess: True
GeneralizedRCNN.inference()->do_postprocess: True
GeneralizedRCNN.inference()->do_postprocess: True
GeneralizedRCNN.inference()->do_postprocess: True
GeneralizedRCNN.inference()->do_postprocess: True