In [None]:

import numpy as np
from typing import Callable, List, Optional, Tuple
import torch
import ttach as tta
import math
from typing import Dict, List
import cv2
import matplotlib
from matplotlib import pyplot as plt
from matplotlib.lines import Line2D
from scipy.ndimage import zoom
from torchvision.transforms import Compose, Normalize, ToTensor
import torchvision
from sklearn.decomposition import KernelPCA
import os
import shutil
from typing import List, Optional, Tuple, Union
from PIL import Image
from ultralytics.nn.tasks import attempt_load_weights
from ultralytics.utils.ops import non_max_suppression, xywh2xyxy
from torchvision.ops import box_iou

class BaseCAM:
    def __init__(
        self,
        model: torch.nn.Module,
        target_layers: List[torch.nn.Module],
        reshape_transform: Callable = None,
        compute_input_gradient: bool = False,
        uses_gradients: bool = True,
        tta_transforms: Optional[tta.Compose] = None,
        detach: bool = True,
    ) -> None:
        #print('BaseCAM_init')
        self.model = model.eval()
        self.target_layers = target_layers

        # Use the same device as the model.
        self.device = next(self.model.parameters()).device
        self.reshape_transform = reshape_transform
        self.compute_input_gradient = compute_input_gradient
        self.uses_gradients = uses_gradients
        if tta_transforms is None:
            self.tta_transforms = tta.Compose(
                [
                    tta.HorizontalFlip(),
                    tta.Multiply(factors=[0.9, 1, 1.1]),
                ]
            )
        else:
            self.tta_transforms = tta_transforms

        self.detach = detach
        self.activations_and_grads = ActivationsAndGradients(self.model, target_layers, reshape_transform)
        
    """ Get a vector of weights for every channel in the target layer.
        Methods that return weights channels,
        will typically need to only implement this function. """

    def get_cam_weights(
        self,
        input_tensor: torch.Tensor,
        target_layers: List[torch.nn.Module],
        targets: List[torch.nn.Module],
        activations: torch.Tensor,
        grads: torch.Tensor,
    ) -> np.ndarray:
        raise Exception("Not Implemented")

    def get_cam_image(
        self,
        input_tensor: torch.Tensor,
        target_layer: torch.nn.Module,
        targets: List[torch.nn.Module],
        activations: torch.Tensor,
        grads: torch.Tensor,
        eigen_smooth: bool = False,
    ) -> np.ndarray:
        weights = self.get_cam_weights(input_tensor, target_layer, targets, activations, grads)
        if isinstance(activations, torch.Tensor):
            activations = activations.cpu().detach().numpy()
        # 2D conv
        if len(activations.shape) == 4:
            weighted_activations = weights[:, :, None, None] * activations
        # 3D conv
        elif len(activations.shape) == 5:
            weighted_activations = weights[:, :, None, None, None] * activations
        else:
            raise ValueError(f"Invalid activation shape. Get {len(activations.shape)}.")

        if eigen_smooth:
            cam = get_2d_projection(weighted_activations)
        else:
            cam = weighted_activations.sum(axis=1)
        return cam

    def forward(
        self, input_tensor: torch.Tensor, targets: List[torch.nn.Module], eigen_smooth: bool = False
    ) -> np.ndarray:
        input_tensor = input_tensor.to(self.device)
        
        if self.compute_input_gradient:
            input_tensor = torch.autograd.Variable(input_tensor, requires_grad=True)

        self.outputs = outputs = self.activations_and_grads(input_tensor)
        #print('BaseCAM_forward_after_activations_and_grads')
        if targets is None:
            target_categories = np.argmax(outputs.cpu().data.numpy(), axis=-1)
            targets = [ClassifierOutputTarget(category) for category in target_categories]
        if self.uses_gradients:
            self.model.zero_grad()
            loss = sum([target(output) for target, output in zip(targets, outputs)])
            
            if self.detach:
                loss.backward(retain_graph=True)
            else:
                # keep the computational graph, create_graph = True is needed for hvp
                torch.autograd.grad(loss, input_tensor, retain_graph = True, create_graph = True)
            if 'hpu' in str(self.device):
                self.__htcore.mark_step()
                
        cam_per_layer = self.compute_cam_per_layer(input_tensor, targets, eigen_smooth)
        return self.aggregate_multi_layers(cam_per_layer)

    def get_target_width_height(self, input_tensor: torch.Tensor) -> Tuple[int, int]:
        if len(input_tensor.shape) == 4:
            width, height = input_tensor.size(-1), input_tensor.size(-2)
            return width, height
        elif len(input_tensor.shape) == 5:
            depth, width, height = input_tensor.size(-1), input_tensor.size(-2), input_tensor.size(-3)
            return depth, width, height
        else:
            raise ValueError("Invalid input_tensor shape. Only 2D or 3D images are supported.")

    def compute_cam_per_layer(
        self, input_tensor: torch.Tensor, targets: List[torch.nn.Module], eigen_smooth: bool
    ) -> np.ndarray:
        if self.detach:
            activations_list = [a.cpu().data.numpy() for a in self.activations_and_grads.activations]
            grads_list = [g.cpu().data.numpy() for g in self.activations_and_grads.gradients]
        else:
            activations_list = [a for a in self.activations_and_grads.activations]
            grads_list = [g for g in self.activations_and_grads.gradients]
        target_size = self.get_target_width_height(input_tensor)

        cam_per_target_layer = []
        # Loop over the saliency image from every layer
        for i in range(len(self.target_layers)):
            target_layer = self.target_layers[i]
            layer_activations = None
            layer_grads = None
            if i < len(activations_list):
                layer_activations = activations_list[i]
            if i < len(grads_list):
                layer_grads = grads_list[i]

            cam = self.get_cam_image(input_tensor, target_layer, targets, layer_activations, layer_grads, eigen_smooth)
            cam = np.maximum(cam, 0)
            scaled = scale_cam_image(cam, target_size)
            cam_per_target_layer.append(scaled[:, None, :])

        return cam_per_target_layer

    def aggregate_multi_layers(self, cam_per_target_layer: np.ndarray) -> np.ndarray:
        cam_per_target_layer = np.concatenate(cam_per_target_layer, axis=1)
        cam_per_target_layer = np.maximum(cam_per_target_layer, 0)
        result = np.mean(cam_per_target_layer, axis=1)
        return scale_cam_image(result)

    def __call__(
        self,
        input_tensor: torch.Tensor,
        targets: List[torch.nn.Module] = None,
        aug_smooth: bool = False,
        eigen_smooth: bool = False,
    ) -> np.ndarray:
        # Smooth the CAM result with test time augmentation
        # if aug_smooth is True:
        #     return self.forward_augmentation_smoothing(input_tensor, targets, eigen_smooth)

        return self.forward(input_tensor, targets, eigen_smooth)

    def __del__(self):
        self.activations_and_grads.release()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, exc_tb):
        self.activations_and_grads.release()
        if isinstance(exc_value, IndexError):
            # Handle IndexError here...
            print(f"An exception occurred in CAM with block: {exc_type}. Message: {exc_value}")
            return True

class ActivationsAndGradients:
    """ Class for extracting activations and
    registering gradients from targetted intermediate layers """

    def __init__(self, model: torch.nn.Module,
                 target_layers: List[torch.nn.Module],
                 reshape_transform: Optional[callable]) -> None:  # type: ignore
        """
        Initializes the ActivationsAndGradients object.

        Args:
            model (torch.nn.Module): The neural network model.
            target_layers (List[torch.nn.Module]): List of target layers from which to extract activations and gradients.
            reshape_transform (Optional[callable]): A function to transform the shape of the activations and gradients if needed.
        """
        self.model = model
        self.gradients = []
        self.activations = []
        self.reshape_transform = reshape_transform
        self.handles = []
        #print('ActivationsAndGradients_init')
        for target_layer in target_layers:
            self.handles.append(
                target_layer.register_forward_hook(self.save_activation))
            # Because of https://github.com/pytorch/pytorch/issues/61519,
            # we don't use backward hook to record gradients.
            self.handles.append(
                target_layer.register_forward_hook(self.save_gradient))

    def save_activation(self, module: torch.nn.Module,
                        input: Union[torch.Tensor, Tuple[torch.Tensor, ...]],
                        output: torch.Tensor) -> None:
        """
        Saves the activation of the targeted layer.

        Args:
            module (torch.nn.Module): The targeted layer module.
            input (Union[torch.Tensor, Tuple[torch.Tensor, ...]]): The input to the targeted layer.
            output (torch.Tensor): The output activation of the targeted layer.
        """
        activation = output

        if self.reshape_transform is not None:
            activation = self.reshape_transform(activation)
        self.activations.append(activation.cpu().detach())

    def save_gradient(self, module: torch.nn.Module,
                      input: Union[torch.Tensor, Tuple[torch.Tensor, ...]],
                      output: torch.Tensor) -> None:
        """
        Saves the gradient of the targeted layer.

        Args:
            module (torch.nn.Module): The targeted layer module.
            input (Union[torch.Tensor, Tuple[torch.Tensor, ...]]): The input to the targeted layer.
            output (torch.Tensor): The output activation of the targeted layer.
        """
        if not hasattr(output, "requires_grad") or not output.requires_grad:
            # You can only register hooks on tensor requires grad.
            return

        # Gradients are computed in reverse order
        def _store_grad(grad: torch.Tensor) -> None:
            if self.reshape_transform is not None:
                grad = self.reshape_transform(grad)
            self.gradients = [grad.cpu().detach()] + self.gradients

        output.register_hook(_store_grad)

    def post_process(self, result: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, np.ndarray]:
        """
        Post-processes the result.

        Args:
            result (torch.Tensor): The result tensor.

        Returns:
            Tuple[torch.Tensor, torch.Tensor, np.ndarray]: A tuple containing the post-processed result.
        """
        logits_ = result[:, 4:]
        boxes_ = result[:, :4]
        sorted, indices = torch.sort(logits_.max(1)[0], descending=True)
        return torch.transpose(logits_[0], dim0=0, dim1=1)[indices[0]], torch.transpose(boxes_[0], dim0=0, dim1=1)[
            indices[0]], xywh2xyxy(torch.transpose(boxes_[0], dim0=0, dim1=1)[indices[0]]).cpu().detach().numpy()

    def __call__(self, x: torch.Tensor) -> List[List[Union[torch.Tensor, np.ndarray]]]:
        """
        Calls the ActivationsAndGradients object.

        Args:
            x (torch.Tensor): The input tensor.

        Returns:
            List[List[Union[torch.Tensor, np.ndarray]]]: A list containing activations and gradients.
        """
        #print('ActivationsAndGradients_call')
        self.gradients = []
        self.activations = []
        model_output = self.model(x)
        post_result, pre_post_boxes, post_boxes = self.post_process(
            model_output[0])
        return [[post_result, pre_post_boxes]]

    def release(self) -> None:
        """Removes hooks."""
        for handle in self.handles:
            handle.remove()

class ScoreCAM(BaseCAM):
    def __init__(
            self,
            model,
            target_layers,
            reshape_transform=None):
        super(ScoreCAM, self).__init__(model,
                                       target_layers,
                                       reshape_transform=reshape_transform,
                                       uses_gradients=False)

    def get_cam_weights(self,
                        input_tensor,
                        target_layer,
                        targets,
                        activations,
                        grads):
        with torch.no_grad():
            upsample = torch.nn.UpsamplingBilinear2d(
                size=input_tensor.shape[-2:])
            activation_tensor = torch.from_numpy(activations)
            activation_tensor = activation_tensor.to(self.device)

            upsampled = upsample(activation_tensor)

            maxs = upsampled.view(upsampled.size(0),
                                  upsampled.size(1), -1).max(dim=-1)[0]
            mins = upsampled.view(upsampled.size(0),
                                  upsampled.size(1), -1).min(dim=-1)[0]

            maxs, mins = maxs[:, :, None, None], mins[:, :, None, None]
            upsampled = (upsampled - mins) / (maxs - mins + 1e-8)

            input_tensors = input_tensor[:, None,
                                         :, :] * upsampled[:, :, None, :, :]

            if hasattr(self, "batch_size"):
                BATCH_SIZE = self.batch_size
            else:
                BATCH_SIZE = 16

            scores = []
            for target, tensor in zip(targets, input_tensors):
                for i in tqdm.tqdm(range(0, tensor.size(0), BATCH_SIZE)):
                    batch = tensor[i: i + BATCH_SIZE, :]
                    # outputs = [target(o).cpu().item()
                    #            for o in self.model(batch)]
                    # model_out = self.model(batch)
                    # print(f"Model output type: {type(model_out)}")
                    # print(f"Model output: {model_out}")
                    # outputs = [target((model_out, None)).cpu().item()]
                    # model_output = self.model(batch)
                    # outputs = [target(model_output).cpu().item()]

                    model_output = self.model(batch)
                    post_result, pre_post_boxes = model_output
                    
                    # Create a dummy tensor that supports [i,j] indexing for the target function
                    # Since we just need it to work with the target function's indexing
                    batch_size = post_result.shape[0]
                    dummy_boxes = torch.zeros(batch_size, 4, device=post_result.device, dtype=post_result.dtype)
                    
                    # Reconstruct the tuple with compatible format
                    formatted_output = (post_result, dummy_boxes)
                    outputs = [target(formatted_output).cpu().item()]                    
                    scores.extend(outputs)
            scores = torch.Tensor(scores)
            scores = scores.view(activations.shape[0], activations.shape[1])
            weights = torch.nn.Softmax(dim=-1)(scores).numpy()
            return weights

class GradCAM(BaseCAM):
    def __init__(self, model, target_layers,
                 reshape_transform=None):
        super(
            GradCAM,
            self).__init__(
            model,
            target_layers,
            reshape_transform)
        #print('GradCAM_init')
    def get_cam_weights(self,
                        input_tensor,
                        target_layer,
                        target_category,
                        activations,
                        grads):
        # 2D image
        if len(grads.shape) == 4:
            return np.mean(grads, axis=(2, 3))
        
        # 3D image
        elif len(grads.shape) == 5:
            return np.mean(grads, axis=(2, 3, 4))
        
        else:
            raise ValueError("Invalid grads shape." 
                             "Shape of grads should be 4 (2D image) or 5 (3D image).")

class EigenCAM(BaseCAM):
    def __init__(self, model, target_layers, 
                 reshape_transform=None):
        super(EigenCAM, self).__init__(model,
                                       target_layers,
                                       reshape_transform,
                                       uses_gradients=False)

    def get_cam_image(self,
                      input_tensor,
                      target_layer,
                      target_category,
                      activations,
                      grads,
                      eigen_smooth):
        return get_2d_projection(activations)

class EigenGradCAM(BaseCAM):
    def __init__(self, model, target_layers, 
                 reshape_transform=None):
        super(EigenGradCAM, self).__init__(model, target_layers,
                                           reshape_transform)

    def get_cam_image(self,
                      input_tensor,
                      target_layer,
                      target_category,
                      activations,
                      grads,
                      eigen_smooth):
        return get_2d_projection(grads * activations)


class XGradCAM(BaseCAM):
    def __init__(
            self,
            model,
            target_layers,
            reshape_transform=None):
        super(
            XGradCAM,
            self).__init__(
            model,
            target_layers,
            reshape_transform)

    def get_cam_weights(self,
                        input_tensor,
                        target_layer,
                        target_category,
                        activations,
                        grads):
        sum_activations = np.sum(activations, axis=(2, 3))
        eps = 1e-7
        weights = grads * activations / \
            (sum_activations[:, :, None, None] + eps)
        weights = weights.sum(axis=(2, 3))
        return weights


class RandomCAM(BaseCAM):
    def __init__(self, model, target_layers, 
                 reshape_transform=None):
        super(
            RandomCAM,
            self).__init__(
            model,
            target_layers,
            reshape_transform)

    def get_cam_weights(self,
                        input_tensor,
                        target_layer,
                        target_category,
                        activations,
                        grads):
        return np.random.uniform(-1, 1, size=(grads.shape[0], grads.shape[1]))
        
class LayerCAM(BaseCAM):
    def __init__(
            self,
            model,
            target_layers,
            reshape_transform=None):
        super(
            LayerCAM,
            self).__init__(
            model,
            target_layers,
            reshape_transform)

    def get_cam_image(self,
                      input_tensor,
                      target_layer,
                      target_category,
                      activations,
                      grads,
                      eigen_smooth):
        spatial_weighted_activations = np.maximum(grads, 0) * activations

        if eigen_smooth:
            cam = get_2d_projection(spatial_weighted_activations)
        else:
            cam = spatial_weighted_activations.sum(axis=1)
        return cam

class KPCA_CAM(BaseCAM):
    def __init__(self, model, target_layers, 
                 reshape_transform=None, kernel='sigmoid', gamma=None):
        super(KPCA_CAM, self).__init__(model,
                                       target_layers,
                                       reshape_transform,
                                       uses_gradients=False)
        self.kernel=kernel
        self.gamma=gamma

    def get_cam_image(self,
                      input_tensor,
                      target_layer,
                      target_category,
                      activations,
                      grads,
                      eigen_smooth):
        return get_2d_projection_kernel(activations, self.kernel, self.gamma)

class HiResCAM(BaseCAM):
    def __init__(self, model, target_layers, 
                 reshape_transform=None):
        super(
            HiResCAM,
            self).__init__(
            model,
            target_layers,
            reshape_transform)

    def get_cam_image(self,
                      input_tensor,
                      target_layer,
                      target_category,
                      activations,
                      grads,
                      eigen_smooth):
        elementwise_activations = grads * activations

        if eigen_smooth:
            print(
                "Warning: HiResCAM's faithfulness guarantees do not hold if smoothing is applied")
            cam = get_2d_projection(elementwise_activations)
        else:
            cam = elementwise_activations.sum(axis=1)
        return cam

class GradCAMPlusPlus(BaseCAM):
    def __init__(self, model, target_layers,
                 reshape_transform=None):
        super(GradCAMPlusPlus, self).__init__(model, target_layers,
                                              reshape_transform)

    def get_cam_weights(self,
                        input_tensor,
                        target_layers,
                        target_category,
                        activations,
                        grads):
        grads_power_2 = grads**2
        grads_power_3 = grads_power_2 * grads
        # Equation 19 in https://arxiv.org/abs/1710.11063
        sum_activations = np.sum(activations, axis=(2, 3))
        eps = 0.000001
        aij = grads_power_2 / (2 * grads_power_2 +
                               sum_activations[:, :, None, None] * grads_power_3 + eps)
        # Now bring back the ReLU from eq.7 in the paper,
        # And zero out aijs where the activations are 0
        aij = np.where(grads != 0, aij, 0)

        weights = np.maximum(grads, 0) * aij
        weights = np.sum(weights, axis=(2, 3))
        return weights

class GradCAMElementWise(BaseCAM):
    def __init__(self, model, target_layers, 
                 reshape_transform=None):
        super(
            GradCAMElementWise,
            self).__init__(
            model,
            target_layers,
            reshape_transform)

    def get_cam_image(self,
                      input_tensor,
                      target_layer,
                      target_category,
                      activations,
                      grads,
                      eigen_smooth):
        elementwise_activations = np.maximum(grads * activations, 0)

        if eigen_smooth:
            cam = get_2d_projection(elementwise_activations)
        else:
            cam = elementwise_activations.sum(axis=1)
        return cam
        
class FEM(BaseCAM):
    def __init__(self, model, target_layers, 
                 reshape_transform=None, k=2):
        super(FEM, self).__init__(model,
                                       target_layers,
                                       reshape_transform,
                                       uses_gradients=False)
        self.k = k

    def get_cam_image(self,
                      input_tensor,
                      target_layer,
                      target_category,
                      activations,
                      grads,
                      eigen_smooth):
        
        
        # 2D image
        if len(activations.shape) == 4:
            axis = (2, 3)
        # 3D image
        elif len(activations.shape) == 5:
            axis = (2, 3, 4)
        else:
            raise ValueError("Invalid activations shape." 
                             "Shape of activations should be 4 (2D image) or 5 (3D image).")
        means = np.mean(activations, axis=axis)
        stds = np.std(activations, axis=axis)
        # k sigma rule:
        # Add extra dimensions to match activations shape
        th = means + self.k * stds
        weights_shape = list(means.shape) + [1] * len(axis)
        th = th.reshape(weights_shape)
        binary_mask = activations > th
        weights = binary_mask.mean(axis=axis)
        return (weights.reshape(weights_shape) * activations).sum(axis=1)

class FinerCAM:
    def __init__(self, model: torch.nn.Module, target_layers: List[torch.nn.Module], reshape_transform: Callable = None, base_method=GradCAM):
        self.base_cam = base_method(model, target_layers, reshape_transform)
        self.compute_input_gradient = self.base_cam.compute_input_gradient
        self.uses_gradients = self.base_cam.uses_gradients

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

    def forward(self, input_tensor: torch.Tensor, targets: List[torch.nn.Module] = None, eigen_smooth: bool = False,
                alpha: float = 1, comparison_categories: List[int] = [1, 2, 3], target_idx: int = None
                ) -> np.ndarray:
        input_tensor = input_tensor.to(self.base_cam.device)

        if self.compute_input_gradient:
            input_tensor = torch.autograd.Variable(input_tensor, requires_grad=True)

        outputs = self.base_cam.activations_and_grads(input_tensor)

        if targets is None:
            output_data = outputs.detach().cpu().numpy()
            target_logits = np.max(output_data, axis=-1) if target_idx is None else output_data[:, target_idx]
            # Sort class indices for each sample based on the absolute difference 
            # between the class scores and the target logit, in ascending order.
            # The most similar classes (smallest difference) appear first.
            sorted_indices = np.argsort(np.abs(output_data - target_logits[:, None]), axis=-1)
            targets = [FinerWeightedTarget(int(sorted_indices[i, 0]), 
                                           [int(sorted_indices[i, idx]) for idx in comparison_categories], 
                                           alpha) 
                       for i in range(output_data.shape[0])]

        if self.uses_gradients:
            self.base_cam.model.zero_grad()
            loss = sum([target(output) for target, output in zip(targets, outputs)])
            if self.base_cam.detach:
                loss.backward(retain_graph=True)
            else:
                # keep the computational graph, create_graph = True is needed for hvp
                torch.autograd.grad(loss, input_tensor, retain_graph = True, create_graph = True)
                # When using the following loss.backward() method, a warning is raised: "UserWarning: Using backward() with create_graph=True will create a reference cycle"
                # loss.backward(retain_graph=True, create_graph=True)
            if 'hpu' in str(self.base_cam.device):
                self.base_cam.__htcore.mark_step()

        cam_per_layer = self.base_cam.compute_cam_per_layer(input_tensor, targets, eigen_smooth)
        return self.base_cam.aggregate_multi_layers(cam_per_layer)


def get_2d_projection(activation_batch):
    # TBD: use pytorch batch svd implementation
    activation_batch[np.isnan(activation_batch)] = 0
    projections = []
    for activations in activation_batch:
        reshaped_activations = (activations).reshape(
            activations.shape[0], -1).transpose()
        # Centering before the SVD seems to be important here,
        # Otherwise the image returned is negative
        reshaped_activations = reshaped_activations - \
            reshaped_activations.mean(axis=0)
        U, S, VT = np.linalg.svd(reshaped_activations, full_matrices=True)
        projection = reshaped_activations @ VT[0, :]
        projection = projection.reshape(activations.shape[1:])
        projections.append(projection)
    return np.float32(projections)



def get_2d_projection_kernel(activation_batch, kernel='sigmoid', gamma=None):
    activation_batch[np.isnan(activation_batch)] = 0
    projections = []
    for activations in activation_batch:
        reshaped_activations = activations.reshape(activations.shape[0], -1).transpose()
        reshaped_activations = reshaped_activations - reshaped_activations.mean(axis=0)
        # Apply Kernel PCA
        kpca = KernelPCA(n_components=1, kernel=kernel, gamma=gamma)
        projection = kpca.fit_transform(reshaped_activations)
        projection = projection.reshape(activations.shape[1:])
        projections.append(projection)
    return np.float32(projections)
    
class ClassifierOutputTarget:
    def __init__(self, category):
        self.category = category

    def __call__(self, model_output):
        if len(model_output.shape) == 1:
            return model_output[self.category]
        return model_output[:, self.category]


class ClassifierOutputSoftmaxTarget:
    def __init__(self, category):
        self.category = category

    def __call__(self, model_output):
        if len(model_output.shape) == 1:
            return torch.softmax(model_output, dim=-1)[self.category]
        return torch.softmax(model_output, dim=-1)[:, self.category]


class ClassifierOutputReST:
    """
    Using both pre-softmax and post-softmax, proposed in https://arxiv.org/abs/2501.06261
    """
    def __init__(self, category):
        self.category = category
    def __call__(self, model_output): 
        if len(model_output.shape) == 1:
            target = torch.tensor([self.category], device=model_output.device)
            model_output = model_output.unsqueeze(0)
            return model_output[0][self.category] - torch.nn.functional.cross_entropy(model_output, target)
        else:
            target = torch.tensor([self.category] * model_output.shape[0], device=model_output.device)
            return model_output[:,self.category] - torch.nn.functional.cross_entropy(model_output, target)


class BinaryClassifierOutputTarget:
    def __init__(self, category):
        self.category = category

    def __call__(self, model_output):
        if self.category == 1:
            sign = 1
        else:
            sign = -1
        return model_output * sign


class SoftmaxOutputTarget:
    def __init__(self):
        pass

    def __call__(self, model_output):
        return torch.softmax(model_output, dim=-1)


class RawScoresOutputTarget:
    def __init__(self):
        pass

    def __call__(self, model_output):
        return model_output


class SemanticSegmentationTarget:
    """ Gets a binary spatial mask and a category,
        And return the sum of the category scores,
        of the pixels in the mask. """

    def __init__(self, category, mask):
        self.category = category
        self.mask = torch.from_numpy(mask)

    def __call__(self, model_output):
        return (model_output[self.category, :, :] * self.mask.to(model_output.device)).sum()


class FasterRCNNBoxScoreTarget:
    """ For every original detected bounding box specified in "bounding boxes",
        assign a score on how the current bounding boxes match it,
            1. In IOU
            2. In the classification score.
        If there is not a large enough overlap, or the category changed,
        assign a score of 0.

        The total score is the sum of all the box scores.
    """

    def __init__(self, labels, bounding_boxes, iou_threshold=0.5):
        self.labels = labels
        self.bounding_boxes = bounding_boxes
        self.iou_threshold = iou_threshold

    def __call__(self, model_outputs):
        output = torch.Tensor([0])
        if torch.cuda.is_available():
            output = output.cuda()
        elif torch.backends.mps.is_available():
            output = output.to("mps")

        if len(model_outputs["boxes"]) == 0:
            return output

        for box, label in zip(self.bounding_boxes, self.labels):
            box = torch.Tensor(box[None, :])
            if torch.cuda.is_available():
                box = box.cuda()
            elif torch.backends.mps.is_available():
                box = box.to("mps")

            ious = torchvision.ops.box_iou(box, model_outputs["boxes"])
            index = ious.argmax()
            if ious[0, index] > self.iou_threshold and model_outputs["labels"][index] == label:
                score = ious[0, index] + model_outputs["scores"][index]
                output = output + score
        return output

class FinerWeightedTarget:
    """
    Computes a weighted difference between a primary category and a set of comparison categories.
    
    This target calculates the difference between the score for the main category and each of the comparison categories.
    It obtains a weight for each comparison category from the softmax probabilities of the model output and computes a 
    weighted difference scaled by a comparison strength factor alpha.
    """
    def __init__(self, main_category, comparison_categories, alpha):
        self.main_category = main_category
        self.comparison_categories = comparison_categories
        self.alpha = alpha
    
    def __call__(self, model_output):
        select = lambda idx: model_output[idx] if model_output.ndim == 1 else model_output[..., idx]
        
        wn = select(self.main_category)

        prob = torch.softmax(model_output, dim=-1)

        weights = [prob[idx] if model_output.ndim == 1 else prob[..., idx] for idx in self.comparison_categories]
        numerator = sum(w * (wn - self.alpha * select(idx)) for w, idx in zip(weights, self.comparison_categories))
        denominator = sum(weights)

        return numerator / (denominator + 1e-9) 
        

def preprocess_image(
    img: np.ndarray, mean=[
        0.5, 0.5, 0.5], std=[
            0.5, 0.5, 0.5]) -> torch.Tensor:
    preprocessing = Compose([
        ToTensor(),
        Normalize(mean=mean, std=std)
    ])
    return preprocessing(img.copy()).unsqueeze(0)


def deprocess_image(img):
    """ see https://github.com/jacobgil/keras-grad-cam/blob/master/grad-cam.py#L65 """
    img = img - np.mean(img)
    img = img / (np.std(img) + 1e-5)
    img = img * 0.1
    img = img + 0.5
    img = np.clip(img, 0, 1)
    return np.uint8(img * 255)


def show_cam_on_image(img: np.ndarray,
                      mask: np.ndarray,
                      use_rgb: bool = False,
                      colormap: int = cv2.COLORMAP_JET,
                      image_weight: float = 0.5) -> np.ndarray:
    """ This function overlays the cam mask on the image as an heatmap.
    By default the heatmap is in BGR format.

    :param img: The base image in RGB or BGR format.
    :param mask: The cam mask.
    :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format.
    :param colormap: The OpenCV colormap to be used.
    :param image_weight: The final result is image_weight * img + (1-image_weight) * mask.
    :returns: The default image with the cam overlay.
    """
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap)
    if use_rgb:
        heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
    heatmap = np.float32(heatmap) / 255

    if np.max(img) > 1:
        raise Exception(
            "The input image should np.float32 in the range [0, 1]")

    if image_weight < 0 or image_weight > 1:
        raise Exception(
            f"image_weight should be in the range [0, 1].\
                Got: {image_weight}")

    cam = (1 - image_weight) * heatmap + image_weight * img
    cam = cam / np.max(cam)
    return np.uint8(255 * cam)


def create_labels_legend(concept_scores: np.ndarray,
                         labels: Dict[int, str],
                         top_k=2):
    concept_categories = np.argsort(concept_scores, axis=1)[:, ::-1][:, :top_k]
    concept_labels_topk = []
    for concept_index in range(concept_categories.shape[0]):
        categories = concept_categories[concept_index, :]
        concept_labels = []
        for category in categories:
            score = concept_scores[concept_index, category]
            label = f"{','.join(labels[category].split(',')[:3])}:{score:.2f}"
            concept_labels.append(label)
        concept_labels_topk.append("\n".join(concept_labels))
    return concept_labels_topk


def show_factorization_on_image(img: np.ndarray,
                                explanations: np.ndarray,
                                colors: List[np.ndarray] = None,
                                image_weight: float = 0.5,
                                concept_labels: List = None) -> np.ndarray:
    """ Color code the different component heatmaps on top of the image.
        Every component color code will be magnified according to the heatmap itensity
        (by modifying the V channel in the HSV color space),
        and optionally create a lagend that shows the labels.

        Since different factorization component heatmaps can overlap in principle,
        we need a strategy to decide how to deal with the overlaps.
        This keeps the component that has a higher value in it's heatmap.

    :param img: The base image RGB format.
    :param explanations: A tensor of shape num_componetns x height x width, with the component visualizations.
    :param colors: List of R, G, B colors to be used for the components.
                   If None, will use the gist_rainbow cmap as a default.
    :param image_weight: The final result is image_weight * img + (1-image_weight) * visualization.
    :concept_labels: A list of strings for every component. If this is paseed, a legend that shows
                     the labels and their colors will be added to the image.
    :returns: The visualized image.
    """
    n_components = explanations.shape[0]
    if colors is None:
        # taken from https://github.com/edocollins/DFF/blob/master/utils.py
        _cmap = plt.cm.get_cmap('gist_rainbow')
        colors = [
            np.array(
                _cmap(i)) for i in np.arange(
                0,
                1,
                1.0 /
                n_components)]
    concept_per_pixel = explanations.argmax(axis=0)
    masks = []
    for i in range(n_components):
        mask = np.zeros(shape=(img.shape[0], img.shape[1], 3))
        mask[:, :, :] = colors[i][:3]
        explanation = explanations[i]
        explanation[concept_per_pixel != i] = 0
        mask = np.uint8(mask * 255)
        mask = cv2.cvtColor(mask, cv2.COLOR_RGB2HSV)
        mask[:, :, 2] = np.uint8(255 * explanation)
        mask = cv2.cvtColor(mask, cv2.COLOR_HSV2RGB)
        mask = np.float32(mask) / 255
        masks.append(mask)

    mask = np.sum(np.float32(masks), axis=0)
    result = img * image_weight + mask * (1 - image_weight)
    result = np.uint8(result * 255)

    if concept_labels is not None:
        px = 1 / plt.rcParams['figure.dpi']  # pixel in inches
        fig = plt.figure(figsize=(result.shape[1] * px, result.shape[0] * px))
        plt.rcParams['legend.fontsize'] = int(
            14 * result.shape[0] / 256 / max(1, n_components / 6))
        lw = 5 * result.shape[0] / 256
        lines = [Line2D([0], [0], color=colors[i], lw=lw)
                 for i in range(n_components)]
        plt.legend(lines,
                   concept_labels,
                   mode="expand",
                   fancybox=True,
                   shadow=True)

        plt.tight_layout(pad=0, w_pad=0, h_pad=0)
        plt.axis('off')
        fig.canvas.draw()
        data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
        plt.close(fig=fig)
        data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        data = cv2.resize(data, (result.shape[1], result.shape[0]))
        result = np.hstack((result, data))
    return result


def scale_cam_image(cam, target_size=None):
    result = []
    for img in cam:
        img = img - np.min(img)
        img = img / (1e-7 + np.max(img))
        if target_size is not None:
            if len(img.shape) > 2:
                img = zoom(np.float32(img), [
                           (t_s / i_s) for i_s, t_s in zip(img.shape, target_size[::-1])])
            else:
                img = cv2.resize(np.float32(img), target_size)

        result.append(img)
    result = np.float32(result)

    return result


def scale_accross_batch_and_channels(tensor, target_size):
    batch_size, channel_size = tensor.shape[:2]
    reshaped_tensor = tensor.reshape(
        batch_size * channel_size, *tensor.shape[2:])
    result = scale_cam_image(reshaped_tensor, target_size)
    result = result.reshape(
        batch_size,
        channel_size,
        target_size[1],
        target_size[0])
    return result



def letterbox(
    im: np.ndarray,
    new_shape=(640, 640),
    color=(114, 114, 114),
    auto=True,
    scaleFill=False,
    scaleup=True,
    stride=32,
):
    """
    Resize and pad image while meeting stride-multiple constraints.

    Args:
        im (numpy.ndarray): Input image.
        new_shape (tuple, optional): Desired output shape. Defaults to (640, 640).
        color (tuple, optional): Color of the border. Defaults to (114, 114, 114).
        auto (bool, optional): Whether to automatically determine padding. Defaults to True.
        scaleFill (bool, optional): Whether to stretch the image to fill the new shape. Defaults to False.
        scaleup (bool, optional): Whether to scale the image up if necessary. Defaults to True.
        stride (int, optional): Stride of the sliding window. Defaults to 32.

    Returns:
        numpy.ndarray: Letterboxed image.
        tuple: Ratio of the resized image.
        tuple: Padding sizes.

    """
    shape = im.shape[:2]  # current shape [height, width]
    if isinstance(new_shape, int):
        new_shape = (new_shape, new_shape)

    # Scale ratio (new / old)
    r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
    if not scaleup:  # only scale down, do not scale up (for better val mAP)
        r = min(r, 1.0)

    # Compute padding
    ratio = r, r  # width, height ratios
    new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
    dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]  # wh padding
    if auto:  # minimum rectangle
        dw, dh = np.mod(dw, stride), np.mod(dh, stride)  # wh padding
    elif scaleFill:  # stretch
        dw, dh = 0.0, 0.0
        new_unpad = (new_shape[1], new_shape[0])
        ratio = new_shape[1] / shape[1], new_shape[0] / shape[0]  # width, height ratios

    dw /= 2  # divide padding into 2 sides
    dh /= 2

    if shape[::-1] != new_unpad:  # resize
        im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
    top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
    left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
    im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)  # add border
    
    return im, ratio, (dw, dh)

def display_images1(images):
    """
    Display a list of PIL images in a grid.

    Args:
        images (list[PIL.Image]): A list of PIL images to display.

    Returns:
        None
    """
    fig, axes = plt.subplots(1, len(images), figsize=(15, 7))
    if len(images) == 1:
        axes = [axes]
    for ax, img in zip(axes, images):
        ax.imshow(img)
        ax.axis('off')
    plt.show()

def display_images(images, save_path=None):
    """
    Display a list of PIL images in a grid.
    Args:
        images (list[PIL.Image]): A list of PIL images to display.
        save_path (str, optional): Path to save the image grid.
    Returns:
        None
    """
    fig, axes = plt.subplots(1, len(images), figsize=(15, 7))
    if len(images) == 1:
        axes = [axes]
    for ax, img in zip(axes, images):
        ax.imshow(img)
        ax.axis('off')
    
    if save_path:
        plt.savefig(save_path, bbox_inches='tight')
    plt.show()
    
def get_top_predictions(outputs, conf_threshold):
    """
    Filter YOLOv8 predictions based on confidence threshold and return sum of top prediction boxes.
    
    Args:
        outputs: Nested list [[post_result, pre_post_boxes]]
        conf_threshold: Confidence threshold (e.g., 0.5)
    
    Returns:
        Tensor with sum of box coordinates for top prediction, or None if no predictions above threshold
    """
    # Extract tensors from nested structure
    post_result, pre_post_boxes = outputs[0][0], outputs[0][1]
    
    # Get confidence values (squeeze to remove extra dimension)
    confidences = post_result.squeeze()  # Shape: [6300]
    
    # Find indices where confidence >= threshold
    valid_mask = confidences >= conf_threshold
    valid_indices = torch.where(valid_mask)[0]
    
    if len(valid_indices) == 0:
        return None  # No predictions above threshold
    
    # Get confidences and boxes for valid detections
    valid_confidences = confidences[valid_indices]
    valid_boxes = pre_post_boxes[valid_indices]  # Shape: [N, 4]
    
    # Find the index of the highest confidence prediction
    top_idx = torch.argmax(valid_confidences)
    
    # Get the top prediction box and sum its coordinates
    top_box = valid_boxes[top_idx]  # Shape: [4]
    box_sum = torch.sum(top_box)  # Sum all 4 coordinates
    
    return box_sum
    
    # Get max confidence for each detection
    max_confidences = post_result.max(dim=1)[0]  # Shape: [6300]
    
    # Find indices where confidence >= threshold
    valid_indices = torch.where(max_confidences >= conf_threshold)[0]
    
    if len(valid_indices) == 0:
        return None  # No predictions above threshold
    
    # Get confidences and boxes for valid detections
    valid_confidences = max_confidences[valid_indices]
    valid_boxes = pre_post_boxes[valid_indices]
    
    # Find the index of the highest confidence prediction
    top_idx = torch.argmax(valid_confidences)
    
    # Get the actual index in the original arrays
    original_idx = valid_indices[top_idx]
    print(valid_confidences[top_idx].item())
    return valid_confidences[top_idx].item()


class GradCAM(BaseCAM):
    def __init__(self, model, target_layers,
                 reshape_transform=None):
        super(
            GradCAM,
            self).__init__(
            model,
            target_layers,
            reshape_transform)
        #print('GradCAM_init')
    def get_cam_weights(self,
                        input_tensor,
                        target_layer,
                        target_category,
                        activations,
                        grads):
        #print('GradCAM_get_cam_weights')
        # 2D image
        if len(grads.shape) == 4:
            return np.mean(grads, axis=(2, 3))
        
        # 3D image
        elif len(grads.shape) == 5:
            return np.mean(grads, axis=(2, 3, 4))
        
        else:
            raise ValueError("Invalid grads shape." 
                             "Shape of grads should be 4 (2D image) or 5 (3D image).")

        

class yolov8_heatmap:
    """
    This class is used to implement the YOLOv8 target layer.

     Args:
            weight (str): The path to the checkpoint file.
            device (str): The device to use for inference. Defaults to "cuda:0" if a GPU is available, otherwise "cpu".
            method (str): The method to use for computing the CAM. Defaults to "EigenGradCAM".
            layer (list): The indices of the layers to use for computing the CAM. Defaults to [10, 12, 14, 16, 18, -3].
            conf_threshold (float): The confidence threshold for detections. Defaults to 0.2.
            ratio (float): The ratio of maximum scores to return. Defaults to 0.02.
            show_box (bool): Whether to show bounding boxes with the CAM. Defaults to True.
            renormalize (bool): Whether to renormalize the CAM to be in the range [0, 1] across the entire image. Defaults to False.

    Returns:
        A tensor containing the output.

    """

    def __init__(
            self,
            weight: str,
            device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
            method="GradCAM",
            # layer=[12, 15, 17, 21],
            layer=[17, 15],
            conf_threshold=0.2,
            ratio=0.02,
            show_box=True,
            renormalize=False,
    ) -> None:
        """
        Initialize the YOLOv8 heatmap layer.
        """
        device = device
        backward_type = "all"
        ckpt = torch.load(weight)
        model_names = ckpt['model'].names
        model = attempt_load_weights(weight, device)
        model.info()
        for p in model.parameters():
            p.requires_grad_(True)
        model.eval()
        #print('yolov8_heatmap_init')
        target = yolov8_target(backward_type, conf_threshold, ratio)
        target_layers = [model.model[l] for l in layer]
        method = eval(method)(model, target_layers)
        colors = np.random.uniform(
            0, 255, size=(len(model_names), 3)).astype(int)
        self.__dict__.update(locals())

    def post_process(self, result):
        """
        Perform non-maximum suppression on the detections and process results.

        Args:
            result (torch.Tensor): The raw detections from the model.

        Returns:
            torch.Tensor: Filtered and processed detections.
        """
        # Perform non-maximum suppression
        processed_result = non_max_suppression(
            result,
            conf_thres=self.conf_threshold,  # Use the class's confidence threshold
            iou_thres=0.5  # Intersection over Union threshold
        )

        # If no detections, return an empty tensor
        if len(processed_result) == 0 or processed_result[0].numel() == 0:
            return torch.empty(0, 6)  # Return an empty tensor with 6 columns

        # Take the first batch of detections (assuming single image)
        detections = processed_result[0]

        # Filter detections based on confidence
        mask = detections[:, 4] >= self.conf_threshold
        filtered_detections = detections[mask]

        return filtered_detections

    def draw_detections(self, box, color, name, img):
        """
        Draw bounding boxes and labels on an image for multiple detections.

        Args:
            box (torch.Tensor or np.ndarray): The bounding box coordinates in the format [x1, y1, x2, y2]
            color (list): The color of the bounding box in the format [B, G, R]
            name (str): The label for the bounding box.
            img (np.ndarray): The image on which to draw the bounding box

        Returns:
            np.ndarray: The image with the bounding box drawn.
        """
        # Ensure box coordinates are integers
        xmin, ymin, xmax, ymax = map(int, box[:4])

        # Draw rectangle
        cv2.rectangle(img, (xmin, ymin), (xmax, ymax),
                      tuple(int(x) for x in color), 2)

        # Draw label
        cv2.putText(img, name, (xmin, ymin - 5),
                    cv2.FONT_HERSHEY_SIMPLEX,
                    0.8, tuple(int(x) for x in color), 2,
                    lineType=cv2.LINE_AA)

        return img

    def renormalize_cam_in_bounding_boxes(
            self,
            boxes: np.ndarray,  # type: ignore
            image_float_np: np.ndarray,  # type: ignore
            grayscale_cam: np.ndarray,  # type: ignore
    ) -> np.ndarray:
        """
        Normalize the CAM to be in the range [0, 1]
        inside every bounding boxes, and zero outside of the bounding boxes.

        Args:
            boxes (np.ndarray): The bounding boxes.
            image_float_np (np.ndarray): The image as a numpy array of floats in the range [0, 1].
            grayscale_cam (np.ndarray): The CAM as a numpy array of floats in the range [0, 1].

        Returns:
            np.ndarray: The renormalized CAM.
        """
        renormalized_cam = np.zeros(grayscale_cam.shape, dtype=np.float32)
        for x1, y1, x2, y2 in boxes:
            x1, y1 = max(x1, 0), max(y1, 0)
            x2, y2 = min(grayscale_cam.shape[1] - 1,
                         x2), min(grayscale_cam.shape[0] - 1, y2)
            renormalized_cam[y1:y2, x1:x2] = scale_cam_image(
                grayscale_cam[y1:y2, x1:x2].copy())
        renormalized_cam = scale_cam_image(renormalized_cam)
        eigencam_image_renormalized = show_cam_on_image(
            image_float_np, renormalized_cam, use_rgb=True)
        return eigencam_image_renormalized

    def renormalize_cam(self, boxes, image_float_np, grayscale_cam):
        """Normalize the CAM to be in the range [0, 1]
        across the entire image."""
        renormalized_cam = scale_cam_image(grayscale_cam)
        eigencam_image_renormalized = show_cam_on_image(
            image_float_np, renormalized_cam, use_rgb=True)
        return eigencam_image_renormalized

    def process(self, img_path):
        """Process the input image and generate CAM visualization."""
        img = cv2.imread(img_path)
        img = letterbox(img)[0]
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = np.float32(img) / 255.0
        tensor = (
            torch.from_numpy(np.transpose(img, axes=[2, 0, 1]))
            .unsqueeze(0)
            .to(self.device)
        )
        #print('yolov8_heatmap_process')
        try:
            grayscale_cam = self.method(tensor, [self.target])
        except AttributeError as e:
            #print(e)
            return

        grayscale_cam = grayscale_cam[0, :]

        pred1 = self.model(tensor)[0]
        pred = non_max_suppression(
            pred1,
            conf_thres=self.conf_threshold,
            iou_thres=0.45
        )[0]
        # print(pred)
        # Debugging print

        if self.renormalize:
            cam_image = self.renormalize_cam(
                pred[:, :4].cpu().detach().numpy().astype(np.int32),
                img,
                grayscale_cam
            )
        else:
            cam_image = show_cam_on_image(img, grayscale_cam, use_rgb=True)
            pure_cam = cam_image.copy()
        if self.show_box and len(pred) > 0:
            for detection in pred:
                detection = detection.cpu().detach().numpy()

                # Get class index and confidence
                class_index = int(detection[5])
                conf = detection[4]

                # Draw detection
                cam_image = self.draw_detections(
                    detection[:4],  # Box coordinates
                    self.colors[class_index],  # Color for this class
                    f"{self.model_names[class_index]}",  # Label with confidence
                    cam_image,
                )

        cam_image = Image.fromarray(cam_image)
        pure_cam = Image.fromarray(pure_cam)
        return cam_image, pure_cam, grayscale_cam, pred

    def __call__(self, img_path):
        """
        Generate CAM visualizations for one or more images.
    
        Args:
            img_path (str): Path to the input image or directory containing images.
    
        Returns:
            Tuple[List[PIL.Image], List[PIL.Image], List[np.ndarray], List[torch.Tensor]]:
            (cam_with_boxes, cam_without_boxes, saliency_maps, bounding_boxes)
        """
        cam_images = []
        pure_cams = []
        salient_maps = []
        bounding_boxes = []
        #print('yolov8_heatmap_call')
        if os.path.isdir(img_path):
            for img_path_ in os.listdir(img_path):
                cam_img, pure_cam, salient_map, pred_boxes = self.process(f"{img_path}/{img_path_}")
                cam_images.append(cam_img)
                pure_cams.append(pure_cam)
                salient_maps.append(salient_map)
                bounding_boxes.append(pred_boxes)
        else:
            cam_img, pure_cam, salient_map, pred_boxes = self.process(img_path)
            cam_images.append(cam_img)
            pure_cams.append(pure_cam)
            salient_maps.append(salient_map)
            bounding_boxes.append(pred_boxes)
    
        return cam_images, pure_cams, salient_maps, bounding_boxes
    

class yolov8_target(torch.nn.Module):
    def __init__(self, ouput_type, conf, ratio) -> None:
        super().__init__()
        self.ouput_type = ouput_type
        self.conf = conf
        self.ratio = ratio
        #print('yolov8_target_init')

    def forward(self, data):
        #print('yolov8_target_forward')
        print(f"Data type: {type(data)}")
        print(f"Data length: {len(data)}")
        
        post_result, pre_post_boxes = data
        
        print(f"post_result type: {type(post_result)}")
        print(f"post_result shape: {getattr(post_result, 'shape', 'no shape')}")
        print(f"pre_post_boxes type: {type(pre_post_boxes)}")
        print(f"pre_post_boxes shape/length: {getattr(pre_post_boxes, 'shape', len(pre_post_boxes) if hasattr(pre_post_boxes, '__len__') else 'no length')}")
        
        result = []
        for i in range(post_result.size(0)):
            if float(post_result[i].max()) >= self.conf:
                if self.ouput_type == 'class' or self.ouput_type == 'all':
                    result.append(post_result[i].max())
                if self.ouput_type == 'box' or self.ouput_type == 'all':
                    for j in range(4):
                        result.append(pre_post_boxes[i, j])
        return sum(result)

        


In [None]:
def main():
    methods = ["LayerCAM"]
    # methods = ["EigenGradCAM", "GradCAM", "EigenCAM", "XGradCAM", "RandomCAM", "LayerCAM", "KPCA_CAM", "HiResCAM", "GradCAMPlusPlus", "GradCAMElementWise"]
    image_folder = '/home/axy9651/Tryp/trypanosome_parasite_detection/chagas_detection/DINO-atten-spa/data_cvat/val2017/images/'
    image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']
    layers =[[15]]
    for method in methods:
        print(method)
        for layer in layers:
            model = yolov8_heatmap(
                weight="best_chagas.pt",
                method=method,
                layer=layer
            )
    
            for filename in os.listdir(image_folder):
                if any(filename.lower().endswith(ext) for ext in image_extensions):
                    image_path = os.path.join(image_folder, filename)
                    cam_with_boxes, cam_without_boxes, salient, bounding_boxes = model(img_path=image_path)
                    # print(layer)
                    display_images(cam_with_boxes)

main()