This notebook is supposed to be created to perform tests on the MGA-YOLOv8 architecture

In [21]:
from typing import List, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
import os
from pathlib import Path
import torchvision.transforms as transforms
from PIL import Image

class CAM(nn.Module):
    """Channel Attention Module.

    This module generates a channel attention map by exploiting both max-pooled
    and average-pooled features along the spatial dimensions.

    Args:
        channels (int): Number of input channels.
        r (int): Reduction ratio for the MLP.
    """

    def __init__(self, channels: int, r: int) -> None:
        super(CAM, self).__init__()
        if channels <= 0 or r <= 0 or channels % r != 0:
            raise ValueError(
                f"Invalid parameters: channels={channels}, r={r}. "
                f"Ensure channels > 0, r > 0, and channels is divisible by r."
            )

        self.channels = channels
        self.r = r
        self.mlp = nn.Sequential(
            nn.Linear(
                in_features=self.channels,
                out_features=self.channels // self.r,
                bias=True,
            ),
            nn.ReLU(inplace=True),
            nn.Linear(
                in_features=self.channels // self.r,
                out_features=self.channels,
                bias=True,
            ),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Apply channel attention mechanism."""
        batch_size, channels, _, _ = x.size()

        # Global max pooling
        max_pool = torch.nn.functional.adaptive_max_pool2d(x, output_size=1).view(batch_size, channels)
        # Global average pooling
        avg_pool = torch.nn.functional.adaptive_avg_pool2d(x, output_size=1).view(batch_size, channels)

        # Apply shared MLP to both pooled features
        max_out = self.mlp(max_pool).view(batch_size, channels, 1, 1)
        avg_out = self.mlp(avg_pool).view(batch_size, channels, 1, 1)

        # Combine and create attention map
        attention = torch.sigmoid(max_out + avg_out)

        return attention * x


class SAM(nn.Module):
    """Spatial Attention Module.

    This module generates a spatial attention map by utilizing max-pooled and
    average-pooled features along the channel dimension.

    Args:
        bias (bool, optional): Whether to include bias in the convolution layer. Default: False.
    """

    def __init__(self, bias: bool = False) -> None:
        super(SAM, self).__init__()
        self.bias = bias
        self.conv = nn.Conv2d(
            in_channels=2,
            out_channels=1,
            kernel_size=7,
            stride=1,
            padding=3,
            bias=self.bias,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Apply spatial attention mechanism."""
        # Max pooling along channel dimension
        max_pool = torch.max(x, dim=1, keepdim=True)[0]
        # Average pooling along channel dimension
        avg_pool = torch.mean(x, dim=1, keepdim=True)

        # Concatenate pooled features
        concat = torch.cat((max_pool, avg_pool), dim=1)

        # Generate spatial attention map
        spatial_map = torch.sigmoid(self.conv(concat))

        return spatial_map * x


class CBAM(nn.Module):
    """Convolutional Block Attention Module (CBAM).

    This module combines channel attention and spatial attention mechanisms to
    enhance feature representation by focusing on 'what' and 'where' information.

    Reference:
        "CBAM: Convolutional Block Attention Module"
        https://arxiv.org/abs/1807.06521

    Args:
        channels (int): Number of input channels.
        reduction_ratio (int): Reduction ratio for the channel attention module.
    """

    def __init__(self, channels: int, r: int) -> None:
        super(CBAM, self).__init__()
        self.channels = channels
        self.reduction_ratio = r
        self.channel_attention = CAM(channels=self.channels, r=self.reduction_ratio)
        self.spatial_attention = SAM(bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Apply channel and spatial attention sequentially."""
        cam_output = self.channel_attention(x)
        cam_output = cam_output * x
        sam_output = self.spatial_attention(cam_output)
        x_refined = cam_output * sam_output

        # Skipped connection
        if x.shape != x_refined.shape:
            raise ValueError(
                f"Input shape {x.shape} does not match refined shape {x_refined.shape}."
            )
        # Apply skip connection
        x_refined = x + x*x_refined
        return x_refined


class CBAMMaskedFeatureExtractor:
    """
    A hook manager that extracts feature maps from specified layers,
    applies corresponding masks, and enhances them with CBAM.
    """

    def __init__(self, target_layers: List[str], mask_folder: str, alpha: float = 0.5, reduction_ratio: int = 0.5) -> None:
        """
        Initialize the feature extractor.

        Args:
            target_layers: List of layer names to extract features from
            mask_folder: Path to the folder containing image masks
            alpha: Weight for skip connection (0.0 to 1.0)
                   feature_map = (1-alpha)*feature_map + alpha*(CBAM(feature_map*mask))
            reduction_ratio: Reduction ratio for CBAM module
        """
        self.target_layers = target_layers
        self.mask_folder = mask_folder
        self.alpha = max(0.0, min(1.0, alpha))  # Clamp between 0 and 1
        self.reduction_ratio = reduction_ratio
        self.hooks = []
        
        # Storage dictionaries
        self.feature_maps: Dict[str, torch.Tensor] = {}
        self.masked_feature_maps: Dict[str, torch.Tensor] = {}
        self.current_image_path: Optional[str] = None
        self.image_names: Dict[str, str] = {}
        
        # CBAM modules for each layer (will be created dynamically)
        self.cbam_modules: Dict[str, CBAM] = {}
        
    def set_current_image(self, image_path: str) -> None:
        """
        Set the current image being processed.
        
        Args:
            image_path: Path to the image file
        """
        self.current_image_path = image_path
        
    def register_hooks(self, model: nn.Module) -> nn.Module:
        """
        Register hooks to extract feature maps from the model.

        Args:
            model: The YOLO model to attach hooks to

        Returns:
            The model with hooks attached
        """
        # Clear existing hooks and storage
        self.clear_hooks()
        self.feature_maps = {}
        self.masked_feature_maps = {}
        self.image_names = {}
        self.cbam_modules = {}

        # Register hooks on target layers
        if hasattr(model, 'model') and isinstance(model.model, nn.Module):
            for name, module in model.model.named_modules():
                if name in self.target_layers:
                    # Register hook to capture feature map
                    hook = self._create_feature_hook(name)
                    self.hooks.append(module.register_forward_hook(hook))
                    
        return model

    def _create_feature_hook(self, layer_name: str):
        """
        Create a hook function to store feature maps and apply masks with CBAM.

        Args:
            layer_name: Name of the layer for reference

        Returns:
            Hook function to be registered
        """
        def hook(module, input_feat, output):
            # Store original feature map
            self.feature_maps[layer_name] = output.clone()
            
            # Store the image name if available
            if self.current_image_path:
                image_name = Path(self.current_image_path).stem
                self.image_names[layer_name] = image_name
                
                # Find and apply mask if available
                mask_path = self._find_mask_path(image_name)
                if mask_path:
                    # Process the mask and apply it
                    mask_tensor = self._process_mask(mask_path, output.shape[2:])
                    if mask_tensor is not None:
                        mask_tensor = mask_tensor.to(output.device)
                        masked_output = self._apply_mask_with_cbam(output, mask_tensor, layer_name)
                        self.masked_feature_maps[layer_name] = masked_output
            
            # Don't modify the output passing through the model
            return output
            
        return hook
    
    def _find_mask_path(self, image_name: str) -> Optional[str]:
        """
        Find the corresponding mask file for an image.
        
        Args:
            image_name: Base name of the image without extension
            
        Returns:
            Path to the mask file if found, None otherwise
        """
        # Try the exact same filename with different possible extensions
        for ext in ['.png', '.jpg', '.jpeg', '.bmp']:
            mask_path = os.path.join(self.mask_folder, f"{image_name}{ext}")
            if os.path.exists(mask_path):
                return mask_path
        
        # If not found, look for any file starting with the image name
        if os.path.exists(self.mask_folder):
            for filename in os.listdir(self.mask_folder):
                if filename.startswith(image_name):
                    return os.path.join(self.mask_folder, filename)
        
        return None
    
    def _process_mask(self, mask_path: str, target_size: Tuple[int, int]) -> Optional[torch.Tensor]:
        """
        Load and process a mask to match the feature map dimensions.
        
        Args:
            mask_path: Path to the mask file
            target_size: Target size as (height, width)
            
        Returns:
            Processed mask tensor or None if processing failed
        """
        try:
            # Load mask as grayscale image
            mask = Image.open(mask_path).convert("L")
            
            # Resize to match feature map dimensions
            resized_mask = transforms.Resize(
                target_size, interpolation=transforms.InterpolationMode.NEAREST
            )(mask)
            
            # Convert to tensor [1, 1, H, W]
            mask_tensor = transforms.ToTensor()(resized_mask).unsqueeze(0)
            
            return mask_tensor
        except Exception as e:
            print(f"Error processing mask {mask_path}: {str(e)}")
            return None
    
    def _apply_mask_with_cbam(self, feature_map: torch.Tensor, mask: torch.Tensor, layer_name: str) -> torch.Tensor:
        """
        Apply mask to feature map, enhance with CBAM, and combine with skip connection.
        
        Implements: feature_map = (1-alpha)*feature_map + alpha*(CBAM(feature_map*mask))
        
        Args:
            feature_map: Input feature map [B, C, H, W]
            mask: Binary mask [1, 1, H, W]
            layer_name: Name of the layer for CBAM module caching
            
        Returns:
            Modified feature map with same shape as input
        """
        # Expand mask to match feature map channels
        B, C, H, W = feature_map.shape
        expanded_mask = mask.expand(B, C, H, W)
        
        # Apply mask to get masked feature
        masked_feature = feature_map * expanded_mask
        
        # Create or reuse CBAM module for this layer
        if layer_name not in self.cbam_modules:
            self.cbam_modules[layer_name] = CBAM(
                channels=C, 
                r=self.reduction_ratio
            ).to(feature_map.device)
        
        # Apply CBAM to the masked feature
        enhanced_masked_feature = self.cbam_modules[layer_name](masked_feature)
        
        # Apply skip connection
        output = (1 - self.alpha) * feature_map + self.alpha * enhanced_masked_feature
        
        return output

    def clear_hooks(self) -> None:
        """Remove all registered hooks to prevent memory leaks."""
        for hook in self.hooks:
            hook.remove()
        self.hooks = []
        
    def get_feature_maps(self, include_masked: bool = True) -> Dict[str, Dict[str, Union[torch.Tensor, str]]]:
        """
        Get the captured feature maps with their associated image names.
        
        Args:
            include_masked: Whether to include masked versions of feature maps
            
        Returns:
            Dictionary with layer names as keys and dict of data as values
        """
        result = {}
        for layer_name, feature_map in self.feature_maps.items():
            layer_data = {
                'original': feature_map,
                'image_name': self.image_names.get(layer_name)
            }
            
            if include_masked and layer_name in self.masked_feature_maps:
                layer_data['masked'] = self.masked_feature_maps[layer_name]
                
            result[layer_name] = layer_data
            
        return result
    
    def set_alpha(self, alpha: float) -> None:
        """
        Set the weight for the skip connection.
        
        Args:
            alpha: Value between 0.0 and 1.0
        """
        self.alpha = max(0.0, min(1.0, alpha))

# Example usage in notebook

from ICA_Detection.external.ultralytics.ultralytics import YOLO
import os
import matplotlib.pyplot as plt

base = "/home/mariopasc/Python/Datasets/COMBINED/YOLO_MGA"

# Image and mask paths
image_path = os.path.join(base, "detection", "images", "val", "arcadetest_p66_v66_00066.png")
mask_folder = os.path.join(base, "masks")
test_path = "/home/mariopasc/Python/Datasets/COMBINED/detection/testing"

# Initialize the feature extractor with target layers and mask folder
target_layers = ["model.15", "model.18", "model.21"]
feature_extractor = CBAMMaskedFeatureExtractor(
    target_layers=target_layers,
    mask_folder=mask_folder,
    alpha=0.7,  # Weight for mask application (adjust as needed)
    reduction_ratio=16  # Reduction ratio for CBAM
)

# Load your YOLOv8 model
model = YOLO('yolov8n.pt')

# Register hooks
feature_extractor.register_hooks(model)

# Set the current image path before inference
feature_extractor.set_current_image(image_path)

# Run inference with your model
results = model(image_path)

# Get the extracted feature maps with image names
feature_maps_data = feature_extractor.get_feature_maps()

# Visualize the original and masked+CBAM enhanced feature maps
for layer_name, data in feature_maps_data.items():
    image_name = data['image_name']
    original_map = data['original']
    
    # Create a figure for comparison
    fig, axs = plt.subplots(1, 2 if 'masked' in data else 1, figsize=(16, 8))
    
    # Plot original feature map
    if 'masked' in data:
        ax1 = axs[0]
    else:
        ax1 = axs
    
    im1 = ax1.imshow(original_map[0, 0].cpu().detach().numpy(), cmap='viridis')
    ax1.set_title(f"Original Feature Map - {layer_name}")
    plt.colorbar(im1, ax=ax1)
    
    # Plot masked+CBAM feature map if available
    if 'masked' in data:
        masked_map = data['masked']
        im2 = axs[1].imshow(masked_map[0, 0].cpu().detach().numpy(), cmap='viridis')
        axs[1].set_title(f"Masked+CBAM Feature Map - {layer_name}")
        plt.colorbar(im2, ax=axs[1])
    
    plt.suptitle(f"Feature Maps from {layer_name} (Image: {image_name})")
    plt.tight_layout()
    
    # Save the figure
    plt.savefig(os.path.join(test_path, f"{image_name}_{layer_name}_cbam_comparison.png"))
    plt.close()

# Always clear hooks when done
feature_extractor.clear_hooks()


image 1/1 /home/mariopasc/Python/Datasets/COMBINED/YOLO_MGA/detection/images/val/arcadetest_p66_v66_00066.png: 640x640 1 person, 5.8ms
Speed: 0.9ms preprocess, 5.8ms inference, 0.6ms postprocess per image at shape (1, 3, 640, 640)


In [16]:
from typing import List, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
import os
from pathlib import Path
import torchvision.transforms as transforms
from PIL import Image

class MaskedFeatureExtractor:
    """
    A hook manager that extracts feature maps from specified layers
    and applies corresponding masks with skip connections.
    """

    def __init__(self, target_layers: List[str], mask_folder: str, alpha: float = 0.5) -> None:
        """
        Initialize the feature extractor.

        Args:
            target_layers: List of layer names to extract features from
            mask_folder: Path to the folder containing image masks
            alpha: Weight for skip connection (0.0 to 1.0)
                   feature_map = (1-alpha)*feature_map + alpha*(feature_map*mask)
        """
        self.target_layers = target_layers
        self.mask_folder = mask_folder
        self.alpha = max(0.0, min(1.0, alpha))  # Clamp between 0 and 1
        self.hooks = []
        
        # Storage dictionaries
        self.feature_maps: Dict[str, torch.Tensor] = {}
        self.masked_feature_maps: Dict[str, torch.Tensor] = {}
        self.current_image_path: Optional[str] = None
        self.image_names: Dict[str, str] = {}
        
    def set_current_image(self, image_path: str) -> None:
        """
        Set the current image being processed.
        
        Args:
            image_path: Path to the image file
        """
        self.current_image_path = image_path
        
    def register_hooks(self, model: nn.Module) -> nn.Module:
        """
        Register hooks to extract feature maps from the model.

        Args:
            model: The YOLO model to attach hooks to

        Returns:
            The model with hooks attached
        """
        # Clear existing hooks and storage
        self.clear_hooks()
        self.feature_maps = {}
        self.masked_feature_maps = {}
        self.image_names = {}

        # Register hooks on target layers
        if hasattr(model, 'model') and isinstance(model.model, nn.Module):
            for name, module in model.model.named_modules():
                if name in self.target_layers:
                    # Register hook to capture feature map
                    hook = self._create_feature_hook(name)
                    self.hooks.append(module.register_forward_hook(hook))
                    
        return model

    def _create_feature_hook(self, layer_name: str):
        """
        Create a hook function to store feature maps and apply masks.

        Args:
            layer_name: Name of the layer for reference

        Returns:
            Hook function to be registered
        """
        def hook(module, input_feat, output):
            # Store original feature map
            self.feature_maps[layer_name] = output.clone()
            
            # Store the image name if available
            if self.current_image_path:
                image_name = Path(self.current_image_path).stem
                self.image_names[layer_name] = image_name
                
                # Find and apply mask if available
                mask_path = self._find_mask_path(image_name)
                if mask_path:
                    # Process the mask and apply it
                    mask_tensor = self._process_mask(mask_path, output.shape[2:])
                    if mask_tensor is not None:
                        mask_tensor = mask_tensor.to(output.device)
                        masked_output = self._apply_mask_with_skip(output, mask_tensor)
                        self.masked_feature_maps[layer_name] = masked_output
            
            # Don't modify the output passing through the model
            return output
            
        return hook
    
    def _find_mask_path(self, image_name: str) -> Optional[str]:
        """
        Find the corresponding mask file for an image.
        
        Args:
            image_name: Base name of the image without extension
            
        Returns:
            Path to the mask file if found, None otherwise
        """
        # Try the exact same filename with different possible extensions
        for ext in ['.png', '.jpg', '.jpeg', '.bmp']:
            mask_path = os.path.join(self.mask_folder, f"{image_name}{ext}")
            if os.path.exists(mask_path):
                return mask_path
        
        # If not found, look for any file starting with the image name
        if os.path.exists(self.mask_folder):
            for filename in os.listdir(self.mask_folder):
                if filename.startswith(image_name):
                    return os.path.join(self.mask_folder, filename)
        
        return None
    
    def _process_mask(self, mask_path: str, target_size: Tuple[int, int]) -> Optional[torch.Tensor]:
        """
        Load and process a mask to match the feature map dimensions.
        
        Args:
            mask_path: Path to the mask file
            target_size: Target size as (height, width)
            
        Returns:
            Processed mask tensor or None if processing failed
        """
        try:
            # Load mask as grayscale image
            mask = Image.open(mask_path).convert("L")
            
            # Resize to match feature map dimensions
            resized_mask = transforms.Resize(
                target_size, interpolation=transforms.InterpolationMode.NEAREST
            )(mask)
            
            # Convert to tensor [1, 1, H, W]
            mask_tensor = transforms.ToTensor()(resized_mask).unsqueeze(0)
            
            return mask_tensor
        except Exception as e:
            print(f"Error processing mask {mask_path}: {str(e)}")
            return None
    
    def _apply_mask_with_skip(self, feature_map: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        """
        Apply mask to feature map with skip connection.
        
        Implements: feature_map = (1-alpha)*feature_map + alpha*(feature_map*mask)
        
        Args:
            feature_map: Input feature map [B, C, H, W]
            mask: Binary mask [1, 1, H, W]
            
        Returns:
            Modified feature map with same shape as input
        """
        # Expand mask to match feature map channels
        B, C, H, W = feature_map.shape
        expanded_mask = mask.expand(B, C, H, W)
        
        # Apply mask with skip connection
        masked_feature = feature_map * expanded_mask
        output = (1 - self.alpha) * feature_map + self.alpha * masked_feature
        
        return output

    def clear_hooks(self) -> None:
        """Remove all registered hooks to prevent memory leaks."""
        for hook in self.hooks:
            hook.remove()
        self.hooks = []
        
    def get_feature_maps(self, include_masked: bool = True) -> Dict[str, Dict[str, Union[torch.Tensor, str]]]:
        """
        Get the captured feature maps with their associated image names.
        
        Args:
            include_masked: Whether to include masked versions of feature maps
            
        Returns:
            Dictionary with layer names as keys and dict of data as values
        """
        result = {}
        for layer_name, feature_map in self.feature_maps.items():
            layer_data = {
                'original': feature_map,
                'image_name': self.image_names.get(layer_name)
            }
            
            if include_masked and layer_name in self.masked_feature_maps:
                layer_data['masked'] = self.masked_feature_maps[layer_name]
                
            result[layer_name] = layer_data
            
        return result
    
    def set_alpha(self, alpha: float) -> None:
        """
        Set the weight for the skip connection.
        
        Args:
            alpha: Value between 0.0 and 1.0
        """
        self.alpha = max(0.0, min(1.0, alpha))
# Example usage in notebook cell

from ICA_Detection.external.ultralytics.ultralytics import YOLO
import os
import matplotlib.pyplot as plt

base = "/home/mariopasc/Python/Datasets/COMBINED/YOLO_MGA"

# Image and mask paths
image_path = os.path.join(base, "detection", "images", "val", "arcadetest_p66_v66_00066.png")
mask_folder = os.path.join(base, "masks")
test_path = "/home/mariopasc/Python/Datasets/COMBINED/detection/testing"

# Initialize the feature extractor with target layers and mask folder
target_layers = ["model.15", "model.18", "model.21"]
feature_extractor = MaskedFeatureExtractor(
    target_layers=target_layers,
    mask_folder=mask_folder,
    alpha=0.5  # Weight for mask application (adjust as needed)
)

# Load your YOLOv8 model
model = YOLO('yolov8n.pt')

# Register hooks
feature_extractor.register_hooks(model)

# Set the current image path before inference
feature_extractor.set_current_image(image_path)

# Run inference with your model
results = model(image_path)

# Get the extracted feature maps with image names
feature_maps_data = feature_extractor.get_feature_maps()

# Visualize the original and masked feature maps
for layer_name, data in feature_maps_data.items():
    image_name = data['image_name']
    original_map = data['original']
    
    # Create a figure for comparison
    fig, axs = plt.subplots(1, 2 if 'masked' in data else 1, figsize=(16, 8))
    
    # Plot original feature map
    if 'masked' in data:
        ax1 = axs[0]
    else:
        ax1 = axs
    
    im1 = ax1.imshow(original_map[0, 0].cpu().detach().numpy(), cmap='viridis')
    ax1.set_title(f"Original Feature Map - {layer_name}")
    plt.colorbar(im1, ax=ax1)
    
    # Plot masked feature map if available
    if 'masked' in data:
        masked_map = data['masked']
        im2 = axs[1].imshow(masked_map[0, 0].cpu().detach().numpy(), cmap='viridis')
        axs[1].set_title(f"Masked Feature Map - {layer_name}")
        plt.colorbar(im2, ax=axs[1])
    
    plt.suptitle(f"Feature Maps from {layer_name} (Image: {image_name})")
    plt.tight_layout()
    
    # Save the figure
    plt.savefig(os.path.join(test_path, f"{image_name}_{layer_name}_comparison.png"))
    plt.close()

# Always clear hooks when done
feature_extractor.clear_hooks()


image 1/1 /home/mariopasc/Python/Datasets/COMBINED/YOLO_MGA/detection/images/val/arcadetest_p66_v66_00066.png: 640x640 1 person, 4.2ms
Speed: 0.9ms preprocess, 4.2ms inference, 0.4ms postprocess per image at shape (1, 3, 640, 640)


Here we have a simple feature extractor that returns the feature map and the name of the input image:

In [None]:
from typing import List, Dict, Optional, Tuple
import torch
import torch.nn as nn
import os
from pathlib import Path

class SimpleFeatureExtractor:
    """
    A simple hook manager that extracts feature maps from specified layers
    of a YOLOv8 model without modifying them.
    """

    def __init__(self, target_layers: List[str]) -> None:
        """
        Initialize the feature extractor.

        Args:
            target_layers: List of layer names to extract features from
        """
        self.target_layers = target_layers
        self.hooks = []
        self.feature_maps: Dict[str, torch.Tensor] = {}
        self.current_image_path: Optional[str] = None
        self.image_names: Dict[str, str] = {}
        
    def set_current_image(self, image_path: str) -> None:
        """
        Set the current image being processed.
        
        Args:
            image_path: Path to the image file
        """
        self.current_image_path = image_path
        
    def register_hooks(self, model: nn.Module) -> nn.Module:
        """
        Register hooks to extract feature maps from the model.

        Args:
            model: The YOLO model to attach hooks to

        Returns:
            The model with hooks attached
        """
        # Clear existing hooks and feature maps
        self.clear_hooks()
        self.feature_maps = {}
        self.image_names = {}

        # Register hooks on target layers
        if hasattr(model, 'model') and isinstance(model.model, nn.Module):
            for name, module in model.model.named_modules():
                if name in self.target_layers:
                    # Register hook to capture feature map
                    hook = self._create_feature_hook(name)
                    self.hooks.append(module.register_forward_hook(hook))
                    
        return model

    def _create_feature_hook(self, layer_name: str):
        """
        Create a hook function to store feature maps.

        Args:
            layer_name: Name of the layer for reference

        Returns:
            Hook function to be registered
        """
        def hook(module, input_feat, output):
            self.feature_maps[layer_name] = output.clone()
            
            # Store the image name if available
            if self.current_image_path:
                image_name = Path(self.current_image_path).stem
                self.image_names[layer_name] = image_name
                
            return output
            
        return hook

    def clear_hooks(self) -> None:
        """Remove all registered hooks to prevent memory leaks."""
        for hook in self.hooks:
            hook.remove()
        self.hooks = []
        
    def get_feature_maps(self) -> Dict[str, Tuple[torch.Tensor, Optional[str]]]:
        """
        Get the captured feature maps with their associated image names.
        
        Returns:
            Dictionary with layer names as keys and tuples of (feature_map, image_name) as values
        """
        result = {}
        for layer_name, feature_map in self.feature_maps.items():
            image_name = self.image_names.get(layer_name)
            result[layer_name] = (feature_map, image_name)
            
        return result
    
    def get_image_names(self) -> Dict[str, str]:
        """
        Get the image names associated with each layer's feature maps.
        
        Returns:
            Dictionary with layer names as keys and image names as values
        """
        return self.image_names
