# GradCam/GradCam++ for detectron2

This code can be run in "cpu"

Refer: https://github.com/alexriedel1/detectron2-GradCAM

## Register datasets

In [None]:
import matplotlib
import matplotlib.pyplot as plt
import os
import gc  # Import garbage collection module

from Gradcam.detectron2_gradcam import Detectron2GradCAM
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog
from detectron2 import model_zoo
from Gradcam.gradcam import GradCAM, GradCamPlusPlus
from detectron2.data.datasets import register_coco_instances
from detectron2.data import MetadataCatalog

register_coco_instances("Pollutant_train_80%", {}, "/scratch/tjian/Data/Pollutant_SSL/labeled/SL_train_80%/annotations/train.json", "/scratch/tjian/Data/Pollutant_SSL/labeled/SL_train_80%/train/")
register_coco_instances("Pollutant_val_80%", {}, "/scratch/tjian/Data/Pollutant_SSL/labeled/SL_train_80%/annotations/val.json", "/scratch/tjian/Data/Pollutant_SSL/labeled/SL_train_80%/val/")


# Get and modify metadata
metadata = MetadataCatalog.get("Pollutant_train_80%")
metadata.thing_classes = ["entrapped particle", "free particle"]  # Replace with actual class names

# print("thing_classes:", getattr(metadata, 'thing_classes', 'No thing_classes attribute'))

## Load model

In [None]:
# model path
# SC_FTAL_F2_80%_best
model_checkpoint_path = "/scratch/tjian/PythonProject/deep_pollutant_SSL/checkpoints/train_weights/SSL_SC_FTAL_F2/Train_80%/model_best_89.1765.pth"
model_config_path = "/scratch/tjian/PythonProject/deep_pollutant_SSL/checkpoints/train_weights/SSL_SC_FTAL_F2/Train_80%/config.yaml"

# baseline_FTAL_F2_80%_best
# model_checkpoint_path = "/scratch/tjian/PythonProject/deep_pollutant_SSL/checkpoints/train_weights/SL_F2_2/Train_80%/model_best_91.1003.pth"
# model_config_path = "/scratch/tjian/PythonProject/deep_pollutant_SSL/checkpoints/train_weights/SL_F2_2/Train_80%/config.yaml"

# SW_FTAL_F2_80%_best
# model_checkpoint_path = "/scratch/tjian/PythonProject/deep_pollutant_SSL/checkpoints/train_weights/SSL_SwAV_FTAL_F2_1/Train_80%/model_best_89.6902.pth"
# model_config_path = "/scratch/tjian/PythonProject/deep_pollutant_SSL/checkpoints/train_weights/SSL_SwAV_FTAL_F2_1/Train_80%/config.yaml"


config_list = [
"MODEL.ROI_HEADS.SCORE_THRESH_TEST", "0.5",
"MODEL.ROI_HEADS.NUM_CLASSES", "2",
"MODEL.WEIGHTS", model_checkpoint_path
]


# define the layer for CAM
# layer_name = "backbone.res4.5.conv3"
layer_name = "roi_heads.res5.2.conv3"
# layer_name = "backbone.res2.0.conv1"

## Select GradCAM type

In [None]:
# (1) "GradCAM"
# grad_cam_type="GradCAM"

# (2) "GradCAM++"
grad_cam_type="GradCAM++"

## Run CAM
Note: CAM is generated per object instance, not per class! Thus, if "num_instances" predicted by models is 2, then CAM will output two images for each instance, e.g., "XXX_instance_0.jpg" and "XXX_instance_1.jpg"

### (1) On one image

In [None]:
# define the input image and output folder
img = r"/scratch/tjian/Data/Peng_SSL/test/A10-P-27.jpg"
output_path = r"/scratch/tjian/Data/Peng_SSL/"

cam_extractor = Detectron2GradCAM(config_file=model_config_path, model_file=model_checkpoint_path)

# Check the number of detected instances
num_instances = cam_extractor.check_num_instances(img)
print(f"Number of detected instances: {num_instances}")

if num_instances > 0:
    for i in range(num_instances):
        image_dict, cam_orig = cam_extractor.get_cam(img=img, target_instance=i, layer_name=layer_name, grad_cam_type=grad_cam_type)
        v = Visualizer(image_dict["image"], MetadataCatalog.get(cam_extractor.cfg.DATASETS.TRAIN[0]), scale=1.0)
        out = v.draw_instance_predictions(image_dict["output"][0]["instances"][i].to("cpu"))
        plt.figure(figsize=(image_dict["image"].shape[1]/100, image_dict["image"].shape[0]/100))
        plt.imshow(out.get_image(), interpolation='none')
        plt.imshow(image_dict["cam"], cmap='jet', alpha=0.5)
        plt.axis('off')  # Hide axes
        plt.title(f"CAM for Instance {i} (class {image_dict['label']})")
        # define the output image name
        output_img_name = "A10-P-27" + "_instance_" + str(i) + ".jpg"  # Correctly convert integer to string
        output_img_path = os.path.join(output_path, output_img_name)
        # print(output_img_path)
        # Save the figure
        plt.savefig(output_img_path, dpi=100, bbox_inches='tight', pad_inches=0)
        plt.close() 
else:
    print("No instances detected.")


### (2) On images in a folder

In [None]:
# define the input image path and output folder
input_path = r"/scratch/tjian/Data/Peng_SSL/test/"
# output_path = r"/scratch/tjian/Data/Peng_SSL/SL_CAM/Gradcam++/backbone.res4.5.conv3/"
# output_path = r"/scratch/tjian/Data/Peng_SSL/SL_CAM/Gradcam++/roi_heads.res5.2.conv3/"
output_path = r"/scratch/tjian/Data/Peng_SSL/SSL_SC/GradCAM++/backbone.res2.0.conv1/"


cam_extractor = Detectron2GradCAM(config_file=model_config_path, model_file=model_checkpoint_path)

file_names = os.listdir(input_path)
for filename in file_names:
        if filename.endswith('.jpg'):
            image_path = os.path.join(input_path, filename)
            # Check the number of detected instances
            num_instances = cam_extractor.check_num_instances(image_path)
            # print(f"Number of detected instances: {num_instances}")
            if num_instances > 0:
                for i in range(num_instances):
                    image_dict, cam_orig = cam_extractor.get_cam(img=image_path, target_instance=i, layer_name=layer_name, grad_cam_type=grad_cam_type)
                    v = Visualizer(image_dict["image"], MetadataCatalog.get(cam_extractor.cfg.DATASETS.TRAIN[0]), scale=1.0)
                    out = v.draw_instance_predictions(image_dict["output"][0]["instances"][i].to("cpu"))
                    plt.figure(figsize=(image_dict["image"].shape[1]/100, image_dict["image"].shape[0]/100))
                    plt.imshow(out.get_image(), interpolation='none')
                    plt.imshow(image_dict["cam"], cmap='jet', alpha=0.5)
                    plt.axis('off')  # Hide axes
                    plt.title(f"CAM for Instance {i} (class {image_dict['label']})")
                    # define the output image name
                    # head, sep, tail = filename.partition('.')
                    # output_img_name = str(head) + "_instance_" + str(i) + ".jpg" 
                    output_img_name = f"{filename.split('.')[0]}_instance_{i}.jpg"
                    output_img_path = os.path.join(output_path, output_img_name)
                    # print(output_img_path)
                    # Save the figure
                    plt.savefig(output_img_path, dpi=100, bbox_inches='tight', pad_inches=0)
                    plt.close()
                    # # Free up memory by deleting large variables
                    # del image_dict, cam_orig, v, out
                    # gc.collect()
            else:
                print(f"No instances detected in {filename}.")

