In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from lightning_model import SegmentationModel 
from datasets import SACropTypeDataModule
from lightning.pytorch.utilities import model_summary
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from torchcam.utils import overlay_mask
from torchvision.transforms.functional import to_pil_image
from torchcam.methods import GradCAM
from comet_ml import Experiment
from torchmetrics import JaccardIndex, Precision, Recall, F1Score
from torchmetrics.wrappers import ClasswiseWrapper
from skimage.filters import threshold_otsu

In [None]:
experiment = Experiment(
    api_key="your_api_key",
    project_name="your_project_name",
    workspace="your_workspace",
)

import torch

In [None]:
CHECKPOINT_UKAN = "your_checkpoint_path"
CHECKPOINT_UNET = "your_checkpoint_path"

In [None]:
model_ukan = SegmentationModel.load_from_checkpoint(CHECKPOINT_UKAN, map_location="cpu")
model_summary.summarize(model_ukan)
model_ukan.eval()

In [None]:
model_unet = SegmentationModel.load_from_checkpoint(CHECKPOINT_UNET, map_location="cpu")
model_summary.summarize(model_unet)
model_unet.eval()

In [None]:
dm = SACropTypeDataModule(num_workers=0, binarize=True, batch_size=16, path="your_data_path")
dm.setup('test')

# UKAN and UNET test

In [None]:
def compute_Iou_cam(activation_map, y):
    # Ensure the tensors are boolean
    preds = activation_map.bool()
    targets = y.bool()
    
    # Compute the intersection and union
    intersection = (preds & targets).float().sum(dim=(1, 2))  # Summing over height and width
    union = (preds | targets).float().sum(dim=(1, 2))  # Summing over height and width
    
    # Compute the IoU for each image in the batch
    iou = intersection / union
    
    # Return the mean IoU for the batch
    return iou.mean().item()

def IoU_images(y_hat, y):
    metric = ClasswiseWrapper(
            JaccardIndex(task="multiclass", num_classes=2, average="none")
        )
    
    metric.update(y_hat, y)
    summary = metric.compute()
    return summary

def Precision_images(y_hat, y):
    metric = ClasswiseWrapper(
            Precision(task="multiclass", num_classes=2, average="none")
        )
    
    metric.update(y_hat, y)
    summary = metric.compute()
    return summary

def Recall_images(y_hat, y):
    metric = ClasswiseWrapper(
            Recall(task="multiclass", num_classes=2, average="none")
        )
    
    metric.update(y_hat, y)
    summary = metric.compute()
    return summary

def F1Score_images(y_hat, y):
    metric = ClasswiseWrapper(
            F1Score(task="multiclass", num_classes=2, average="none")
        )
    
    metric.update(y_hat, y)
    summary = metric.compute()
    return summary

def get_grad_cam(model, target_layer, x):

    # Index of the class to be explained 
    class_to_explain=1
    
    with GradCAM(model, target_layer) as cam_extractor:
        # Preprocess your data and feed it to the model
        out = model(x)
        # Retrieve the CAM by passing the class index and the model output
        activation_map = cam_extractor(class_idx=class_to_explain, scores=out)
    return activation_map


def plot_raw_cam(activation_map):
    # Visualize the raw CAM heatmap
    plt.imshow(activation_map[0][0].numpy()); plt.axis('off'); plt.tight_layout(); plt.show()


def image_transf(x, rgb=True, normalize=True):
    # RGB image 
    out = x
    if rgb:
        out = out[:, (3, 2, 1), :, :]
    if normalize:
        out = (out - out.min()) / (out.max() - out.min())

    return out

def plot_overlay_cam(activation_map, x, idx=0):
    image = image_transf(x)
    result = overlay_mask(to_pil_image(image[idx]), to_pil_image(activation_map[idx].squeeze(0), mode='F'), alpha=0.5)
    result = Image.fromarray(np.array(result))
    return result

def calculate_threshold(activation_map):

    masked_tensor = []

    for map in activation_map:
        # Calculate the percentile
        tmp = map.clone().detach().cpu().numpy()

        thr = threshold_otsu(tmp)  

        mask = tmp > thr

        # Plot the mask
        # plt.imshow(mask); plt.axis('off'); plt.tight_layout(); plt.show()

        masked_tensor.append(torch.tensor(mask))

    # Return the masked tensor as a torch tensorCos
    return torch.stack(masked_tensor)

def apply_mask_original_img(x, mask):
    masked_image = []
    # Apply the mask to each channel of the input image - and mask the original RGB image
    for img, m in zip(x, mask):
        masked_image.append(img * m)
    return torch.stack(masked_image)


def write_metric_results(file, results, type='default'):

    if type == 'default':

        file.write(f"IOU\n")
        for el in results['iou']:
            file.write(f"{el}\n")

        file.write(f"Precision\n")
        for el in results['precision']:
            file.write(f"{el}\n")

        file.write(f"Recall\n")
        for el in results['recall']:
            file.write(f"{el}\n")
        
        file.write(f"F1 Score\n")
        for el in results['f1_score']:
            file.write(f"{el}\n")

    else:
        file.write(f"Average Saliency per channel, per image, per batch\n")
        for el in results:
            file.write(f"{el.item()}\n")
    

def calculate_sufficiency(model, target_layer, x, y):

    activation_map_unet = get_grad_cam(model, [target_layer], x)

    # Calculate the percentile and threshold the activation maps
    thresholded_tensor = calculate_threshold(activation_map_unet[0])

    # Apply the mask to the original image
    masked_original_imgs = apply_mask_original_img(x, thresholded_tensor)
    
    # Do a forward pass with the masked image
    y_hat_masked = model_ukan(masked_original_imgs)

    # Compute the IoU of the prediction
    iou_sufficiency = IoU_images(y_hat_masked, y)

    # Compute the Precision of the prediction
    precision = Precision_images(y_hat_masked, y)

    # Compute the Recall of the prediction
    recall = Recall_images(y_hat_masked, y)

    # Compute the F1 Score of the prediction
    f1_score = F1Score_images(y_hat_masked, y)

    return {'iou': iou_sufficiency, 'precision':precision, 'recall': recall, 'f1_score': f1_score}

def calculate_plausibility(model, target_layer, x, y):

    activation_map = get_grad_cam(model, [target_layer], x)

    # Calculate the percentile and threshold the activation maps
    thresholded_tensor = calculate_threshold(activation_map[0])

    # Compute the IoU of the prediction
    iou_plausibility = IoU_images(thresholded_tensor, y)

    # Compute the Precision of the prediction
    precision = Precision_images(thresholded_tensor, y)

    # Compute the Recall of the prediction
    recall = Recall_images(thresholded_tensor, y)

    # Compute the F1 Score of the prediction
    f1_score = F1Score_images(thresholded_tensor, y)

    return {'iou': iou_plausibility, 'precision':precision, 'recall': recall, 'f1_score': f1_score}

    
def plot_predictions (x, idx, y, y_hat, i):
    fig, _ = dm.dataset_test.plot({"image": x[idx], "mask": y[idx], "prediction": y_hat[idx, 1].detach().numpy()})
    experiment.log_figure(figure=fig, figure_name=f"ukan_{i}")

def save_overlay_cam(name, model, target_layer, activation_map, x, idx=0):
    activation_map = get_grad_cam(model, [target_layer], x)
    cam = plot_overlay_cam(activation_map, x, idx=idx)
    experiment.log_image(cam, name=f"{name}_cam_{idx}")

def compute_saliency_per_channel(x, activation_map_original, n):

    running_saliency = torch.empty(12, len(x))
    total_running_saliency = []

    for num_img, img in enumerate(x):
        act_map = activation_map_original[num_img]
        for c in range(12):
            image_to_mask = img.clone()
            image_to_mask[c, :, :] = 0 # Mask one channel at time
            
            # Compute the new saliency map with the masked image
            activan_map_masked = get_grad_cam(model_ukan, [model_ukan.model.final],image_to_mask.unsqueeze(0))
            overlay_cam_ukan = plot_overlay_cam(activan_map_masked[0], x, idx=0)


            res = compute_Iou_cam(activan_map_masked[0], act_map)
            
            if res == float('nan'):
                res = 0.0
            running_saliency[c, num_img] = res

            # Plot the overlay CAM or store in the experiment
            #plt.imshow(overlay_cam_ukan, vmin=0, vmax=1); plt.axis('off'); plt.tight_layout(); plt.show()
            experiment.log_image(overlay_cam_ukan, name=f"img_{num_img}_unet_channel_masked_{c}_batch_{n}")       

    return running_saliency

In [None]:
# Open files to write the results
model_name_ukan = 'UKAN'
task = 'channel_relevance'
f_ukan = open(f"results_{model_name_ukan}_{task}.csv", "w")
model_name_unet = 'UNET'
f_unet = open(f"results_{model_name_unet}_{task}.csv", "w")


In [None]:
idx = 0 # Index of the image to be explained
target_layer_ukan = model_ukan.model.final # Layer to be explained
target_layer_unet = model_unet.model.up4 # Layer to be explained

running_IoU_ukan = 0
running_IoU_unet = 0

# Make the empty lists to store the results

iou_ukan_suff, precision_ukan_suff, recall_ukan_suff, f1_score_ukan_suff = [], [], [], []
iou_unet_suff, precision_unet_suff, recall_unet_suff, f1_score_unet_suff = [], [], [], []
iou_ukan_plaus, precision_ukan_plaus, recall_ukan_plaus, f1_score_ukan_plaus = [], [], [], []
iou_unet_plaus, precision_unet_plaus, recall_unet_plaus, f1_score_unet_plaus = [], [], [], []


saliency_ukan = []
saliency_unet = []

for i, batch in enumerate(dm.test_dataloader()):
   
    x, y = batch
    y_hat = model_ukan(x)
    y_hat_unet = model_unet(x)

    # ================ COMPUTE GRAD CAM =================
    activation_map_ukan = get_grad_cam(model_ukan, [target_layer_ukan], x)
    activation_map_unet = get_grad_cam(model_unet, [target_layer_unet], x)

    # ================ COMPUTE THE SALIENCY PER CHANNEL =================
    saliency_ukan.append(compute_saliency_per_channel(x, activation_map_ukan[0], i))
    saliency_unet.append(compute_saliency_per_channel(x, activation_map_unet[0], i))
    
    # ================ COMPUTE THE IOU UKAN  =================
    iou_ukan = IoU_images(y_hat, y)
    # experiment.log_metric("multiclassjaccardindex_1_ukan", iou_ukan['multiclassjaccardindex_1'].detach(), step=i)
    f_ukan.write(f"{iou_ukan['multiclassjaccardindex_1'].detach()}\n")

    # ================ COMPUTE THE IOU UNET =================
    iou_unet = IoU_images(y_hat_unet, y)
    # experiment.log_metric("multiclassjaccardindex_1_unet", iou_unet['multiclassjaccardindex_1'].detach(), step=i)
    f_unet.write(f"{iou_unet['multiclassjaccardindex_1'].detach()}\n")
    
    # ================ PLOT PREDICTIONS =================
    plot_predictions(x, idx, y, y_hat, i)
    plot_predictions(x, idx, y, y_hat_unet, i)


    # ================ SUFFICIENCY =================
    dict_sufficiency_metrics = calculate_sufficiency(model_ukan, model_ukan.model.final, x, y, threshold=0.1)
    iou_ukan_suff.append(dict_sufficiency_metrics['iou']['multiclassjaccardindex_1'].item())
    precision_ukan_suff.append(dict_sufficiency_metrics['precision']['multiclassprecision_1'].item())
    recall_ukan_suff.append(dict_sufficiency_metrics['recall']['multiclassrecall_1'].item())
    f1_score_ukan_suff.append(dict_sufficiency_metrics['f1_score']['multiclassf1score_1'].item())

    dict_sufficiency_metrics = calculate_sufficiency(model_unet, model_unet.model.up4, x, y, threshold=0.1)
    iou_unet_suff.append(dict_sufficiency_metrics['iou']['multiclassjaccardindex_1'].item())
    precision_unet_suff.append(dict_sufficiency_metrics['precision']['multiclassprecision_1'].item())
    recall_unet_suff.append(dict_sufficiency_metrics['recall']['multiclassrecall_1'].item())
    f1_score_unet_suff.append(dict_sufficiency_metrics['f1_score']['multiclassf1score_1'].item())

    
    # ================ PLAUSIBILITY =================

    dict_plausibility_metrics = calculate_plausibility(model_ukan, model_ukan.model.final, x, y)
    iou_ukan_plaus.append(dict_plausibility_metrics['iou']['multiclassjaccardindex_1'].item())
    precision_ukan_plaus.append(dict_plausibility_metrics['precision']['multiclassprecision_1'].item())
    recall_ukan_plaus.append(dict_plausibility_metrics['recall']['multiclassrecall_1'].item())
    f1_score_ukan_plaus.append(dict_plausibility_metrics['f1_score']['multiclassf1score_1'].item())

    dict_plausibility_metrics = calculate_plausibility(model_unet, model_unet.model.up4, x, y)
    iou_unet_plaus.append(dict_plausibility_metrics['iou']['multiclassjaccardindex_1'].item())
    precision_unet_plaus.append(dict_plausibility_metrics['precision']['multiclassprecision_1'].item())
    recall_unet_plaus.append(dict_plausibility_metrics['recall']['multiclassrecall_1'].item())
    f1_score_unet_plaus.append(dict_plausibility_metrics['f1_score']['multiclassf1score_1'].item())

    # ================ SAVE THE OVERLAY CAM ================   
    save_overlay_cam(model_name_ukan, model_name_ukan, model_ukan.model.final, x, idx=idx) 
    save_overlay_cam(model_name_unet, model_name_unet, model_unet.model.up4, x, idx=idx)

# Concatenate the results of per channel saliency
final_salience_ukan = torch.cat(saliency_ukan, dim=-1)
final_salience_unet = torch.cat(saliency_unet, dim=-1)

# Write results for sufficiency and plausibility
write_metric_results(f_ukan, results={'iou': iou_ukan_suff, 'precision': precision_ukan_suff, 'recall': recall_ukan_suff, 'f1_score': f1_score_ukan_suff})
write_metric_results(f_unet, results={'iou': iou_unet_suff, 'precision': precision_unet_suff, 'recall': recall_unet_suff, 'f1_score': f1_score_unet_suff})

write_metric_results(f_ukan, results={'iou': iou_ukan_plaus, 'precision': precision_ukan_plaus, 'recall': recall_ukan_plaus, 'f1_score': f1_score_ukan_plaus})
write_metric_results(f_unet, results={'iou': iou_unet_plaus, 'precision': precision_unet_plaus, 'recall': recall_unet_plaus, 'f1_score': f1_score_unet_plaus})


# Apply the mean 
mean_salience_ukan = final_salience_ukan.mean(dim=-1)
mean_salience_unet = final_salience_unet.mean(dim=-1)

# Write the results of the saliency
write_metric_results(f_ukan, results=mean_salience_ukan, type='saliency')
write_metric_results(f_unet, results=mean_salience_unet, type='saliency')
