In [17]:
# #!/usr/bin/env python
# import os
# import logging
# import numpy as np
# import pandas as pd
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from torch.utils.data import Dataset, DataLoader
# from torchvision import transforms
# from monai.networks.nets import resnet18  # MONAI’s 3D ResNet18
# from fastai.learner import Learner
# from fastai.data.core import DataLoaders
# from fastai.metrics import accuracy
# from fastai.losses import CrossEntropyLossFlat
# from fastai.callback.all import SaveModelCallback, EarlyStoppingCallback
# import matplotlib.pyplot as plt
# from matplotlib.backends.backend_pdf import PdfPages
# import seaborn as sns

# # =============================================================================
# # Configure Logging
# # =============================================================================
# logging.basicConfig(
#     level=logging.INFO,
#     format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
#     handlers=[
#         logging.FileHandler("heart_disease_model.log"),
#         logging.StreamHandler()
#     ]
# )
# logger = logging.getLogger(__name__)

# # =============================================================================
# # Data Loading: Custom Dataset for .npy Files
# # =============================================================================
# class NPYDataset(Dataset):
#     """
#     Custom PyTorch Dataset for loading 3D medical imaging data from .npy files.
#     """
#     def __init__(self, dataframe, image_column_name, label_column_name, custom_transform=None):
#         """
#         Args:
#             dataframe (pd.DataFrame): DataFrame with file paths and labels.
#             image_column_name (str): Column name for image file paths.
#             label_column_name (str): Column name for labels.
#             custom_transform (callable, optional): Custom transform pipeline.
#         """
#         self.dataframe = dataframe
#         self.image_column_name = image_column_name
#         self.label_column_name = label_column_name

#         # Default transform (applied to the tensor)
#         default_transform = transforms.Compose([
#             transforms.RandomHorizontalFlip(p=0.5),
#             transforms.RandomRotation(15),
#             transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
#             transforms.Resize((112, 112)),  # Med3D default size
#         ])
#         self.transform = custom_transform if custom_transform is not None else default_transform

#     def __len__(self):
#         return len(self.dataframe)

#     def __getitem__(self, idx):
#         try:
#             npy_path = self.dataframe[self.image_column_name].iloc[idx]
#             label = self.dataframe[self.label_column_name].iloc[idx]
#             # Load image: assuming shape [D, H, W, Channels]. We take the first channel.
#             image = np.load(npy_path)[:, :, :, 0]
#             image = image[17:33, :, :]  # Select frames 17 to 32
#             # Convert to tensor and add a channel dimension: (1, D, H, W)
#             image = torch.tensor(image, dtype=torch.float32).unsqueeze(0)
#             # Apply transforms
#             image = self.transform(image)
#             return image, label
#         except Exception as e:
#             logger.error(f"Error loading image at index {idx}: {e}")
#             raise

# # =============================================================================
# # Model Definition (Only Loading – No Training)
# # =============================================================================
# class HeartDiseaseModel:
#     """
#     Class to load data and the model. (Training is skipped; we only load saved weights.)
#     """
#     def __init__(self, config):
#         self.config = self._validate_config(config)
#         self.logger = logging.getLogger(self.__class__.__name__)
#         self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#         self.logger.info(f"Using device: {self.device}")
#         self._prepare_data()
#         self._prepare_model()

#     def _validate_config(self, config):
#         default_config = {
#             'train_dataframe_path': None,
#             'test_dataframe_path': None,
#             'image_column_name': 'FilePath',
#             'label_column_name': 'CAD',
#             'pretrained_weights_path': None,  # initial weights file (used to initialize the model)
#             'batch_size': 8,
#             'split_ratio': 0.85,
#             'model_name': 'heart_ch0_3channel_MONAI_resnet18',  # name used by fastai to save the model
#             'learning_rate': 1e-5,
#             'epochs': 50,
#             'early_stopping_patience': 20,
#             'weight_decay': 1e-4
#         }
#         default_config.update(config)
#         required_paths = [
#             'train_dataframe_path',
#             'test_dataframe_path',
#             'pretrained_weights_path'
#         ]
#         for key in required_paths:
#             if not default_config[key] or not os.path.exists(default_config[key]):
#                 raise ValueError(f"Invalid or missing path for {key}: {default_config[key]}")
#         return default_config

#     def _prepare_data(self):
#         try:
#             train_df = pd.read_csv(self.config['train_dataframe_path'])
#             test_df = pd.read_csv(self.config['test_dataframe_path'])
#             # Create dataset from training CSV
#             dataset = NPYDataset(
#                 train_df,
#                 self.config['image_column_name'],
#                 self.config['label_column_name']
#             )
#             # Split training data into training and validation subsets
#             train_size = int(self.config['split_ratio'] * len(dataset))
#             val_size = len(dataset) - train_size
#             self.train_dataset, self.val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
#             self.train_loader = DataLoader(
#                 self.train_dataset,
#                 batch_size=self.config['batch_size'],
#                 shuffle=True,
#                 num_workers=8
#             )
#             self.val_loader = DataLoader(
#                 self.val_dataset,
#                 batch_size=self.config['batch_size'],
#                 shuffle=False,
#                 num_workers=8
#             )
#             # Prepare test dataset
#             self.test_dataset = NPYDataset(
#                 test_df,
#                 self.config['image_column_name'],
#                 self.config['label_column_name']
#             )
#             self.test_loader = DataLoader(
#                 self.test_dataset,
#                 batch_size=self.config['batch_size'],
#                 shuffle=False,
#                 num_workers=8
#             )
#             self.logger.info("Data preparation completed successfully.")
#         except Exception as e:
#             self.logger.error(f"Error in data preparation: {e}")
#             raise

#     def _prepare_model(self):
#         try:
#             # Initialize the MONAI ResNet18 model for 3D data.
#             self.model = resnet18(
#                 spatial_dims=3,
#                 n_input_channels=1,   # assuming grayscale input
#                 num_classes=2         # binary classification (e.g., CAD vs. No CAD)
#             )
#             self.logger.info("Loading initial pretrained weights...")
#             state_dict = torch.load(self.config['pretrained_weights_path'], map_location=self.device)
#             self.model.load_state_dict(state_dict, strict=False)
#             # Wrap in DataParallel and send to device.
#             self.model = nn.DataParallel(self.model)
#             self.model.to(self.device)
#             # Create FastAI DataLoaders and Learner.
#             self.dls = DataLoaders(self.train_loader, self.val_loader)
#             self.learn = Learner(
#                 self.dls,
#                 self.model,
#                 loss_func=CrossEntropyLossFlat(),
#                 metrics=[accuracy],
#                 wd=self.config['weight_decay'],
#                 cbs=[
#                     SaveModelCallback(
#                         fname=self.config['model_name'],
#                         monitor='valid_loss'
#                     ),
#                     EarlyStoppingCallback(
#                         monitor='valid_loss',
#                         patience=self.config['early_stopping_patience']
#                     )
#                 ]
#             ).to_fp16()
#             self.logger.info("Model preparation completed successfully.")
#         except Exception as e:
#             self.logger.error(f"Error in model preparation: {e}")
#             raise

# # =============================================================================
# # GradCAM Implementation for 3D Data
# # =============================================================================
# class GradCam3D:
#     """
#     A simple Grad-CAM implementation for 3D models.
#     Registers forward and full-backward hooks on a target convolutional layer.
#     """
#     def __init__(self, model, target_layer):
#         """
#         Args:
#             model (torch.nn.Module): The underlying model (not the DataParallel wrapper).
#             target_layer (torch.nn.Module): The layer on which to register hooks.
#         """
#         self.model = model
#         self.target_layer = target_layer
#         self.gradients = None
#         self.activations = None
#         self.hook_handles = []
#         self._register_hooks()

#     def _register_hooks(self):
#         def forward_hook(module, input, output):
#             self.activations = output.detach()

#         def backward_hook(module, grad_input, grad_output):
#             self.gradients = grad_output[0].detach()

#         self.hook_handles.append(self.target_layer.register_forward_hook(forward_hook))
#         self.hook_handles.append(self.target_layer.register_full_backward_hook(backward_hook))

#     def remove_hooks(self):
#         for handle in self.hook_handles:
#             handle.remove()

#     def __call__(self, input_tensor, target_class=None):
#         """
#         Generates the CAM for the input.
        
#         Args:
#             input_tensor (torch.Tensor): Input tensor of shape (B, C, D, H, W).
#             target_class (torch.Tensor or None): If provided, these class indices will be used.
        
#         Returns:
#             torch.Tensor: CAMs with shape (B, D, H, W) normalized between 0 and 1.
#         """
#         # Forward pass.
#         output = self.model(input_tensor)
#         if target_class is None:
#             target_class = output.argmax(dim=1)
#         one_hot = torch.zeros_like(output)
#         for i, tc in enumerate(target_class):
#             one_hot[i, tc] = 1

#         self.model.zero_grad()
#         output.backward(gradient=one_hot, retain_graph=True)

#         # Retrieve stored activations and gradients.
#         gradients = self.gradients  # shape: (B, C, d, h, w)
#         activations = self.activations  # shape: (B, C, d, h, w)
#         # Global average pooling over depth, height, and width.
#         weights = torch.mean(gradients, dim=(2, 3, 4), keepdim=True)  # shape: (B, C, 1, 1, 1)
#         cam = torch.sum(weights * activations, dim=1)  # shape: (B, d, h, w)
#         cam = F.relu(cam)  # Apply ReLU

#         # Normalize each CAM individually.
#         cams = []
#         for i in range(cam.shape[0]):
#             cam_i = cam[i]
#             cam_i = cam_i - cam_i.min()
#             if cam_i.max() != 0:
#                 cam_i = cam_i / cam_i.max()
#             cams.append(cam_i)
#         cams = torch.stack(cams)  # shape: (B, d, h, w)
#         # Upsample the CAM to match the input resolution.
#         target_size = input_tensor.shape[2:]  # (D, H, W)
#         cams = F.interpolate(cams.unsqueeze(1), size=target_size, mode='trilinear', align_corners=False)
#         cams = cams.squeeze(1)  # shape: (B, D, H, W)
#         return cams

# # =============================================================================
# # Advanced Visualization: Montage of Multiple Slices
# # =============================================================================
# def create_montage_for_sample(model, gradcam, sample, n_slices=5):
#     """
#     Creates a montage (grid) of multiple slices from a sample.
    
#     For each of n evenly spaced slices in the 3D volume, this function plots:
#       - The original image slice.
#       - The same slice overlaid with the CAM.
    
#     Args:
#         model (torch.nn.Module): The underlying trained model.
#         gradcam (GradCam3D): An instance of GradCam3D.
#         sample (torch.Tensor): Image tensor of shape (C, D, H, W).
#         n_slices (int): Number of slices to display.
        
#     Returns:
#         fig (matplotlib.figure.Figure): The figure containing the montage.
#     """
#     device = next(model.parameters()).device
#     input_tensor = sample.unsqueeze(0).to(device)
#     cams = gradcam(input_tensor)  # shape: (1, D, H, W)
#     cam = cams[0].cpu().numpy()
#     original = sample.squeeze(0).cpu().numpy()  # shape: (D, H, W)
#     D = original.shape[0]
#     indices = np.linspace(0, D-1, n_slices, dtype=int)
    
#     fig, axs = plt.subplots(n_slices, 2, figsize=(8, 2.5 * n_slices))
#     for i, idx in enumerate(indices):
#         axs[i, 0].imshow(original[idx], cmap='gray')
#         axs[i, 0].set_title(f"Slice {idx} Original")
#         axs[i, 0].axis('off')
        
#         axs[i, 1].imshow(original[idx], cmap='gray')
#         axs[i, 1].imshow(cam[idx], cmap='jet', alpha=0.5)
#         axs[i, 1].set_title(f"Slice {idx} CAM Overlay")
#         axs[i, 1].axis('off')
#     fig.tight_layout()
#     return fig

# def create_combined_montage(model, gradcam, sample_control, sample_case, n_slices=5, title="Comparison"):
#     """
#     Creates a side-by-side montage comparing control and case samples.
    
#     For each sample, n slices are shown; for each slice, the original and CAM overlay are
#     displayed. The resulting figure has 4 columns:
#       - Column 1: Control Original.
#       - Column 2: Control CAM Overlay.
#       - Column 3: Case Original.
#       - Column 4: Case CAM Overlay.
    
#     Args:
#         model (torch.nn.Module): The underlying trained model.
#         gradcam (GradCam3D): Instance of GradCam3D.
#         sample_control (tuple): (image, label) for the control sample.
#         sample_case (tuple): (image, label) for the case sample.
#         n_slices (int): Number of slices to display.
#         title (str): Title for the figure.
    
#     Returns:
#         fig (matplotlib.figure.Figure): The combined figure.
#     """
#     device = next(model.parameters()).device

#     # Get CAM and original for control.
#     input_control = sample_control[0].unsqueeze(0).to(device)
#     cams_control = gradcam(input_control)[0].cpu().numpy()
#     orig_control = sample_control[0].squeeze(0).cpu().numpy()

#     # Get CAM and original for case.
#     input_case = sample_case[0].unsqueeze(0).to(device)
#     cams_case = gradcam(input_case)[0].cpu().numpy()
#     orig_case = sample_case[0].squeeze(0).cpu().numpy()

#     D = orig_control.shape[0]
#     indices = np.linspace(0, D-1, n_slices, dtype=int)

#     fig, axs = plt.subplots(n_slices, 4, figsize=(16, 2.5 * n_slices))
#     fig.suptitle(title, fontsize=16)
#     for i, idx in enumerate(indices):
#         # Control Original.
#         axs[i, 0].imshow(orig_control[idx], cmap='gray')
#         axs[i, 0].set_title(f"Control Slice {idx}\nOriginal", fontsize=10)
#         axs[i, 0].axis('off')
#         # Control CAM.
#         axs[i, 1].imshow(orig_control[idx], cmap='gray')
#         axs[i, 1].imshow(cams_control[idx], cmap='jet', alpha=0.5)
#         axs[i, 1].set_title(f"Control Slice {idx}\nCAM", fontsize=10)
#         axs[i, 1].axis('off')
#         # Case Original.
#         axs[i, 2].imshow(orig_case[idx], cmap='gray')
#         axs[i, 2].set_title(f"Case Slice {idx}\nOriginal", fontsize=10)
#         axs[i, 2].axis('off')
#         # Case CAM.
#         axs[i, 3].imshow(orig_case[idx], cmap='gray')
#         axs[i, 3].imshow(cams_case[idx], cmap='jet', alpha=0.5)
#         axs[i, 3].set_title(f"Case Slice {idx}\nCAM", fontsize=10)
#         axs[i, 3].axis('off')
#     fig.tight_layout(rect=[0, 0.03, 1, 0.95])
#     return fig

# # =============================================================================
# # Utility to Get a Sample by Label from a Dataset
# # =============================================================================
# def get_sample_by_label(dataset, target_label):
#     """
#     Returns the first sample from the dataset with the specified label.
    
#     Args:
#         dataset (torch.utils.data.Dataset): The dataset to search.
#         target_label (int): The label to search for.
    
#     Returns:
#         tuple: (image, label) or None if not found.
#     """
#     for i in range(len(dataset)):
#         try:
#             image, label = dataset[i]
#             if label == target_label:
#                 return image, label
#         except Exception as e:
#             logger.error(f"Error accessing sample {i}: {e}")
#     return None

# # =============================================================================
# # Main Execution Function (Load Trained Model and Generate Advanced Visualizations)
# # =============================================================================
# def main():
#     try:
#         # Update these paths for your environment.
#         config = {
#             'train_dataframe_path': 'Final_Datasets/train_resnet_heart.csv',
#             'test_dataframe_path': 'Final_Datasets/test_data_incidence.csv',
#             'pretrained_weights_path': '../Med3D/resnet_18_23dataset.pth',  # initial weights file path
#             'model_name': 'heart_ch0_3channel_MONAI_resnet18',  # name used by fastai's SaveModelCallback
#             'epochs': 50,
#             'learning_rate': 1e-5,
#             'batch_size': 8,
#             'split_ratio': 0.85,
#             'early_stopping_patience': 20,
#             'weight_decay': 1e-4
#         }

#         # Instantiate the model and load data (no training is performed).
#         heart_disease_model = HeartDiseaseModel(config)

#         # Load the trained model weights saved by FastAI.
#         heart_disease_model.learn.load(config['model_name'])
#         logger.info("Loaded trained model weights.")

#         # Retrieve the underlying model (if wrapped in DataParallel, use .module).
#         trained_model = heart_disease_model.learn.model
#         if isinstance(trained_model, torch.nn.DataParallel):
#             base_model = trained_model.module
#         else:
#             base_model = trained_model

#         # Select a target layer for Grad-CAM.
#         target_layer = base_model.layer3[1].conv2

#         # Initialize GradCAM with the underlying module.
#         gradcam = GradCam3D(base_model, target_layer)

#         # Get one control and one case sample from training data.
#         train_control = get_sample_by_label(heart_disease_model.train_dataset, target_label=0)
#         train_case = get_sample_by_label(heart_disease_model.train_dataset, target_label=1)

#         # Get one control and one case sample from test data.
#         test_control = get_sample_by_label(heart_disease_model.test_dataset, target_label=0)
#         test_case = get_sample_by_label(heart_disease_model.test_dataset, target_label=1)

#         pdf_filename = "advanced_gradcam_comparison.pdf"
#         with PdfPages(pdf_filename) as pdf:
#             # Advanced montage for Training Data.
#             if train_control is not None and train_case is not None:
#                 fig_train = create_combined_montage(base_model, gradcam, train_control, train_case,
#                                                     n_slices=5, title="Training Data Comparison: Control vs. Case")
#                 pdf.savefig(fig_train)
#                 plt.close(fig_train)
#             else:
#                 print("Missing training samples for comparison.")

#             # Advanced montage for Test Data.
#             if test_control is not None and test_case is not None:
#                 fig_test = create_combined_montage(base_model, gradcam, test_control, test_case,
#                                                    n_slices=5, title="Test Data Comparison: Control vs. Case")
#                 pdf.savefig(fig_test)
#                 plt.close(fig_test)
#             else:
#                 print("Missing test samples for comparison.")
#         print(f"Saved advanced comparison figures to {pdf_filename}")

#         # Optionally, you can also save individual montages:
#         # For example:
#         # fig_montage = create_montage_for_sample(base_model, gradcam, train_control, n_slices=7)
#         # fig_montage.savefig("train_control_montage.pdf")
#         # plt.close(fig_montage)

#         # Remove hooks when finished.
#         gradcam.remove_hooks()

#     except Exception as e:
#         logger.error(f"Critical error in main execution: {e}")
#         raise

# if __name__ == "__main__":
#     main()

In [18]:
#!/usr/bin/env python
import os
import logging
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from monai.networks.nets import resnet18  # MONAI’s 3D ResNet18
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

# =============================================================================
# Configure Logging
# =============================================================================
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# =============================================================================
# Data Loading: Custom Dataset for .npy Files
# =============================================================================
class NPYDataset(Dataset):
    """ Custom PyTorch Dataset for loading 3D medical imaging data from .npy files. """
    def __init__(self, dataframe, image_column_name, label_column_name):
        self.dataframe = dataframe
        self.image_column_name = image_column_name
        self.label_column_name = label_column_name

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        try:
            npy_path = self.dataframe[self.image_column_name].iloc[idx]
            label = self.dataframe[self.label_column_name].iloc[idx]
            image = np.load(npy_path)[:, :, :, 0]  # Load single-channel data
            image = image[17:33, :, :]  # Select slices 17 to 32
            image = torch.tensor(image, dtype=torch.float32).unsqueeze(0)  # Shape (1, D, H, W)
            return image, label
        except Exception as e:
            logger.error(f"Error loading image at index {idx}: {e}")
            raise

# =============================================================================
# Grad-CAM Implementation for 3D Data
# =============================================================================
class GradCam3D:
    """ Grad-CAM implementation for 3D models. """
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        self.hook_handles = []
        self._register_hooks()

    def _register_hooks(self):
        def forward_hook(module, input, output):
            self.activations = output.detach()
        def backward_hook(module, grad_input, grad_output):
            self.gradients = grad_output[0].detach()
        self.hook_handles.append(self.target_layer.register_forward_hook(forward_hook))
        self.hook_handles.append(self.target_layer.register_full_backward_hook(backward_hook))

    def remove_hooks(self):
        for handle in self.hook_handles:
            handle.remove()

    def __call__(self, input_tensor, target_class=None):
        output = self.model(input_tensor)
        if target_class is None:
            target_class = output.argmax(dim=1)
        one_hot = torch.zeros_like(output)
        for i, tc in enumerate(target_class):
            one_hot[i, tc] = 1

        self.model.zero_grad()
        output.backward(gradient=one_hot, retain_graph=True)

        gradients = self.gradients
        activations = self.activations
        weights = torch.mean(gradients, dim=(2, 3, 4), keepdim=True)
        cam = torch.sum(weights * activations, dim=1)
        cam = F.relu(cam)

        cams = []
        for i in range(cam.shape[0]):
            cam_i = cam[i] - cam[i].min()
            if cam_i.max() != 0:
                cam_i /= cam_i.max()
            cams.append(cam_i)
        cams = torch.stack(cams)
        target_size = input_tensor.shape[2:]
        cams = F.interpolate(cams.unsqueeze(1), size=target_size, mode='trilinear', align_corners=False).squeeze(1)
        return cams

# =============================================================================
# Model Definition
# =============================================================================
def load_model(config):
    model = resnet18(spatial_dims=3, n_input_channels=1, num_classes=2)
    state_dict = torch.load(config['pretrained_weights_path'], map_location="cpu")
    model.load_state_dict(state_dict, strict=False)
    return model

# =============================================================================
# Utility Functions for Visualization
# =============================================================================
def get_sample_by_label(dataset, target_label):
    for i in range(len(dataset)):
        image, label = dataset[i]
        if label == target_label:
            return image, label
    return None

def generate_gradcam_plot(pdf, model, gradcam, sample, title):
    device = next(model.parameters()).device
    input_tensor = sample[0].unsqueeze(0).to(device)
    cam = gradcam(input_tensor)[0].cpu().numpy()  # Shape: (D, H, W)
    original = sample[0].squeeze(0).cpu().numpy()  # Shape: (D, H, W)

    # Define slices used in dataset (17 to 32)
    original_indices = np.arange(17, 33)  # The real slice numbers
    num_slices = len(original_indices)

    fig, axs = plt.subplots(num_slices, 2, figsize=(12, 2.5 * num_slices))
    fig.suptitle(title, fontsize=16)

    for i, idx in enumerate(original_indices):
        slice_idx = i  # Adjusted index for our dataset slices (0-based)
        
        # Original MRI Image
        axs[i, 0].imshow(original[slice_idx], cmap='gray')
        axs[i, 0].set_title(f"Original Slice {idx}", fontsize=10)
        axs[i, 0].axis('off')

        # CAM Overlay
        axs[i, 1].imshow(original[slice_idx], cmap='gray')
        axs[i, 1].imshow(cam[slice_idx], cmap='jet', alpha=0.5)
        axs[i, 1].set_title(f"Grad-CAM Overlay (Slice {idx})", fontsize=10)
        axs[i, 1].axis('off')

    plt.tight_layout()
    pdf.savefig(fig)
    plt.close(fig)

# =============================================================================
# Main Execution Function
# =============================================================================
def main():
    config = {
        'train_dataframe_path': 'Final_Datasets/train_resnet_heart.csv',
        'test_dataframe_path': 'Final_Datasets/test_data_incidence.csv',
        'pretrained_weights_path': '../Med3D/resnet_18_23dataset.pth',
    }

    # Load trained model
    model = load_model(config)
    model.eval()

    # Load datasets
    train_df = pd.read_csv(config['train_dataframe_path'])
    test_df = pd.read_csv(config['test_dataframe_path'])
    train_dataset = NPYDataset(train_df, 'FilePath', 'CAD')
    test_dataset = NPYDataset(test_df, 'FilePath', 'CAD')

    # Get case and control samples from train and test sets
    train_case = get_sample_by_label(train_dataset, target_label=1)
    train_control = get_sample_by_label(train_dataset, target_label=0)
    test_case = get_sample_by_label(test_dataset, target_label=1)
    test_control = get_sample_by_label(test_dataset, target_label=0)

    # Use layer3 for Grad-CAM
    target_layer = model.layer3[1].conv2
    gradcam = GradCam3D(model, target_layer)

    # Save all plots in a single PDF
    pdf_filename = "GradCAM_Results.pdf"
    with PdfPages(pdf_filename) as pdf:
        if train_case:
            generate_gradcam_plot(pdf, model, gradcam, train_case, "Training Set: Case")
        if train_control:
            generate_gradcam_plot(pdf, model, gradcam, train_control, "Training Set: Control")
        if test_case:
            generate_gradcam_plot(pdf, model, gradcam, test_case, "Test Set: Case")
        if test_control:
            generate_gradcam_plot(pdf, model, gradcam, test_control, "Test Set: Control")

    gradcam.remove_hooks()
    print(f"Saved Grad-CAM results to {pdf_filename}")

if __name__ == "__main__":
    main()


  state_dict = torch.load(config['pretrained_weights_path'], map_location="cpu")


Saved Grad-CAM results to GradCAM_Results.pdf


In [29]:
#!/usr/bin/env python
import os
import logging
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torchvision import transforms
from monai.networks.nets import resnet18  # MONAI’s 3D ResNet18
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

# =============================================================================
# Configure Logging
# =============================================================================
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# =============================================================================
# Data Loading: Custom Dataset for .npy Files
# =============================================================================
class NPYDataset(Dataset):
    """ Custom PyTorch Dataset for loading 3D medical imaging data from .npy files. """
    def __init__(self, dataframe, image_column_name, label_column_name):
        self.dataframe = dataframe
        self.image_column_name = image_column_name
        self.label_column_name = label_column_name

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        try:
            npy_path = self.dataframe[self.image_column_name].iloc[idx]
            label = self.dataframe[self.label_column_name].iloc[idx]
            image = np.load(npy_path)[:, :, :, 0]  # Load single-channel data
            image = image[17:33, :, :]  # Select slices 17 to 32
            image = torch.tensor(image, dtype=torch.float32).unsqueeze(0)  # Shape (1, D, H, W)
            return image, label
        except Exception as e:
            logger.error(f"Error loading image at index {idx}: {e}")
            raise

# =============================================================================
# Grad-CAM Implementation for 3D Data
# =============================================================================
class GradCam3D:
    """ Grad-CAM implementation for 3D models. """
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        self.hook_handles = []
        self._register_hooks()

    def _register_hooks(self):
        def forward_hook(module, input, output):
            if isinstance(output, tuple):  # Ensure output is a tensor
                output = output[0]
            self.activations = output.detach()

        def backward_hook(module, grad_input, grad_output):
            if isinstance(grad_output, tuple):  # Ensure gradients are a tensor
                grad_output = grad_output[0]
            self.gradients = grad_output.detach()

        self.hook_handles.append(self.target_layer.register_forward_hook(forward_hook))
        self.hook_handles.append(self.target_layer.register_full_backward_hook(backward_hook))

    def remove_hooks(self):
        for handle in self.hook_handles:
            handle.remove()

    def __call__(self, input_tensor, target_class=None):
        output = self.model(input_tensor)
        if target_class is None:
            target_class = output.argmax(dim=1)
        one_hot = torch.zeros_like(output)
        for i, tc in enumerate(target_class):
            one_hot[i, tc] = 1

        self.model.zero_grad()
        output.backward(gradient=one_hot, retain_graph=True)

        gradients = self.gradients
        activations = self.activations
        weights = torch.mean(gradients, dim=(2, 3, 4), keepdim=True)
        cam = torch.sum(weights * activations, dim=1)
        cam = F.relu(cam)

        cams = []
        for i in range(cam.shape[0]):
            cam_i = cam[i] - cam[i].min()
            if cam_i.max() != 0:
                cam_i /= cam_i.max()
            cams.append(cam_i)
        cams = torch.stack(cams)
        target_size = input_tensor.shape[2:]
        cams = F.interpolate(cams.unsqueeze(1), size=target_size, mode='trilinear', align_corners=False).squeeze(1)
        return cams

# =============================================================================
# Model Definition
# =============================================================================
def load_model(config):
    model = resnet18(spatial_dims=3, n_input_channels=1, num_classes=2)
    state_dict = torch.load(config['pretrained_weights_path'], map_location="cpu")
    model.load_state_dict(state_dict, strict=False)
    return model

# =============================================================================
# Utility Functions for Visualization
# =============================================================================
def get_sample_by_label(dataset, target_label):
    for i in range(len(dataset)):
        image, label = dataset[i]
        if label == target_label:
            return image, label
    return None

def generate_gradcam_plot(pdf, model, sample, layer_name, target_layer, sample_type):
    device = next(model.parameters()).device
    input_tensor = sample[0].unsqueeze(0).to(device)
    
    # Compute Grad-CAM for the specified layer
    gradcam = GradCam3D(model, target_layer)
    cam = gradcam(input_tensor)[0].cpu().numpy()  # Shape: (D, H, W)
    original = sample[0].squeeze(0).cpu().numpy()  # Shape: (D, H, W)
    
    # Slice 18 is index 1 in our tensor (since we select slices 17-32)
    slice_idx = 1
    slice_number = 18

    fig, axs = plt.subplots(1, 2, figsize=(10, 5))
    fig.suptitle(f"Grad-CAM for Slice {slice_number} ({layer_name}) - {sample_type}", fontsize=16)

    # Original MRI Image
    axs[0].imshow(original[slice_idx], cmap='gray')
    axs[0].set_title(f"Original Slice {slice_number} - {sample_type}")
    axs[0].axis('off')

    # CAM Overlay
    axs[1].imshow(original[slice_idx], cmap='gray')
    axs[1].imshow(cam[slice_idx], cmap='jet', alpha=0.5)
    axs[1].set_title(f"Grad-CAM Overlay (Layer {layer_name})")
    axs[1].axis('off')

    plt.tight_layout()
    pdf.savefig(fig)
    plt.close(fig)

    gradcam.remove_hooks()  # Cleanup hooks after use

# =============================================================================
# Main Execution Function
# =============================================================================
def main():
    config = {
        'train_dataframe_path': 'Final_Datasets/train_resnet_heart.csv',
        'pretrained_weights_path': '../Med3D/resnet_18_23dataset.pth',
    }

    # Load trained model
    model = load_model(config)
    model.eval()

    # Load dataset
    train_df = pd.read_csv(config['train_dataframe_path'])
    train_dataset = NPYDataset(train_df, 'FilePath', 'CAD')

    # Get one case and one control from the training set
    train_case = get_sample_by_label(train_dataset, target_label=1)
    train_control = get_sample_by_label(train_dataset, target_label=0)

    # Define layers to analyze
    layers_to_analyze = {
        "layer2": model.layer2[1].conv2,
        "layer3": model.layer3[1].conv2,
        "layer4": model.layer4[1].conv2
    }

    # Save all plots in a single PDF
    pdf_filename = "GradCAM_Slice18_Comparison_ch0_Med3D.pdf"
    with PdfPages(pdf_filename) as pdf:
        for layer_name, target_layer in layers_to_analyze.items():
            if train_case:
                generate_gradcam_plot(pdf, model, train_case, layer_name, target_layer, "Case")
            if train_control:
                generate_gradcam_plot(pdf, model, train_control, layer_name, target_layer, "Control")

    print(f"Saved Grad-CAM results for Slice 18 to {pdf_filename}")

if __name__ == "__main__":
    main()

  state_dict = torch.load(config['pretrained_weights_path'], map_location="cpu")


Saved Grad-CAM results for Slice 18 to GradCAM_Slice18_Comparison_ch0_Med3D.pdf
