# <p style="text-align: center;">Gradient-weighted Class Activation Mapping of Pneumonia Identification in Children by CNNs</p>
##### This script provides visual explanations of the decisions made by CNNs for identification of pneumonia in children.

In [45]:
import torch
from torchvision import models, transforms
from torchvision.transforms.functional import to_pil_image
import PIL
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from torch.nn import functional as F
import os
import cv2

In [70]:
# Load a pre-trained ResNet-18 model from a saved file
model = torch.load('resnet18_model.pt', map_location=torch.device('cpu')) # switch to 'resnet18_preprocessed_model.pt'

# set the model to evaluation mode, this is done during inference. Turns off certain behaviours used during training.
model.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [71]:
class ModifiedResNet18(torch.nn.Module):
    """
    A modified ResNet-18 model adapted for Grad-CAM analysis.

    This class takes a pre-trained ResNet-18 model and modifies it to extraction feature maps 
    and gradients which are essential for Grad-CAM visualization. It retains all convolutional 
    layers of the original model except the last two (fully connected layers).

    Attributes:
        features (torch.nn.Sequential): The sequential container of all convolutional layers in the model.
        pooling (torch.nn.Module): The average pooling layer from the original model.
        fc (torch.nn.Module): The fully connected layer from the original model.
        gradients (torch.Tensor): Tensor to store gradients of the activations.

    Functions:
        activations_hook(grad): Stores the gradients of activations during backpropagation.
        forward(x): Defines the forward pass of the model.
        get_activations_gradient(): Retrieves the stored gradients of activations.
        get_activations(x): Extracts activations from the feature layers for a given input x.
    """
    def __init__(self, model):
        super(ModifiedResNet18, self).__init__()
        self.features = list(model.children())[:-2] # all layers except the last two
        self.features = torch.nn.Sequential(*self.features)
        self.pooling = model.avgpool
        self.fc = model.fc
        self.gradients = None
      
    def activations_hook(self, grad):
        """
        Hook for capturing the gradients of the activations.
        Required for: Gradient-based visualization techniques.

        Args:
            grad (Tensor): The gradient of the activations.

        This method stores the gradient passed to it in the self.gradients attribute,
        allowing for later retrieval and analysis of the gradients. 
        """
        self.gradients = grad

    def forward(self, x):
        """
        Forward pass of the model.
        Required for: Computing the output of the model given an input.

        Args:
            x (Tensor): The input tensor to the model.

        This method processes the input tensor through the modified feature layers,
        registers a hook for gradient capturing, applies pooling, flattens the output,
        and then passes it through the fully connected layer. This is needed for 
        feature extraction and gradient capture.
        """
        x = self.features(x)
        h = x.register_hook(self.activations_hook)
        x = self.pooling(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

    def get_activations_gradient(self):
        """
        Retrieve the stored gradients of the activations.
        Required for: Analyzing the gradients post-forward pass.

        Returns:
            Tensor: The gradients of the activations captured during the forward pass.

        This method is useful for gradient-based analysis post model inference,
        enabling insight into which parts of the input image the model is focusing on.
        """
        return self.gradients

    def get_activations(self, x):
        """
        Retrieve the activations from the feature layers for a given input.
        Required for: Extracting intermediate layer activations.

        Args:
            x (Tensor): The input tensor for which activations need to be computed.

        Returns:
            Tensor: The activations from the feature layers of the model.

        This method is useful for visualizing and analyzing the intermediate
        features that the network learns and uses for making predictions.
        """
        return self.features(x)

modified_model = ModifiedResNet18(model)

In [72]:
def grad_cam(modified_model, input_image, target_class):
    """
    Computes a Grad-CAM heatmap for a specific target class using a modified CNN model.
    
    Performs a forward pass of the input image through the modified model to get the model's 
    output for the target class. Then, it executes a backward pass to calculate the gradients
    of the target class output.These gradients are pooled and used to weight the activations 
    from the final convolutional layer of the model.By averaging these weighted activations 
    and applying a ReLU, a heatmap is generated. This heatmap visualizes the areas in the 
    input image that most significantly influence the model's prediction for the target class. 
    The resulting heatmap is a NumPy array, providing a visual interpretation of the model's
    decision-making process for the class.

    Args:
        modified_model (torch.nn.Module): The modified neural network model.
        input_image (Tensor): The input image for which the heatmap is to be generated.
        target_class (int): The target class index for which the heatmap is to be computed.

    Returns:
        numpy.ndarray: The computed Grad-CAM heatmap as a NumPy array.
    """
    # Forward
    output = modified_model(input_image)
    target = output[:, target_class]

    # Backward
    modified_model.zero_grad()
    target.backward()

    # Get the gradients and activations
    gradients = modified_model.get_activations_gradient()
    pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])
    activations = modified_model.get_activations(input_image).detach()

    # Weighting the activations with the gradients and creating heatmap
    for i in range(activations.shape[1]):
        activations[:, i, :, :] *= pooled_gradients[i]
    heatmap = torch.mean(activations, dim=1).squeeze()
    heatmap = np.maximum(heatmap, 0)
    heatmap /= torch.max(heatmap)
    heatmap = heatmap.numpy()

    return heatmap

In [73]:
def preprocess_image(img_path):
    """
    Load and preprocess an image.

    Args:
    - img_path (str): Path to the image file.

    Returns:
    - Tensor: Preprocessed image as a tensor.
    """
    img = Image.open(img_path).convert('RGB')  # Convert image to RGB if it's not
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])
    img = transform(img).unsqueeze(0)
    return img

target_class = 1

In [74]:
def gradcam_process(folder_path, model, target_class, output_folder):
    """
    Process all images in a specified folder to apply Grad-CAM and save the resulting visualizations.
    Required for: Automating the application of Grad-CAM to multiple images and saving the results.
        - Iterating over all images in the specified folder.
        - Preprocessing each image and applying the Grad-CAM function to generate a heatmap.
        - Overlaying the heatmap on the original image and applying a colormap.
        - Saving the overlaid image in the specified output folder.
    Args:
        folder_path (str): Path to the folder containing images to be processed.
        model (torch.nn.Module): The neural network model, modified to capture gradients and activations.
        target_class (int): The class index for which the Grad-CAM is to be applied to each image.
        output_folder (str): Path to the folder where the Grad-CAM visualizations will be saved.

    The function creates the output folder if it does not exist. 
    Each output file is named by prefixing 'gradcam_' to the original image filename.
    """
    # Create the output folder if it does not exist
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    # Loop through all files in the folder
    for img_file in os.listdir(folder_path):
        if img_file.lower().endswith('.jpeg'):
            img_path = os.path.join(folder_path, img_file)
            input_image = preprocess_image(img_path)
            
            # Apply Grad-CAM
            heatmap = grad_cam(model, input_image, target_class)

            # Convert the heatmap to PIL image and resize
            overlay = to_pil_image(torch.from_numpy(heatmap), mode='F').resize(input_image.squeeze().shape[1:], resample=PIL.Image.BICUBIC)

            # Apply a colormap to the heatmap
            cmap = cm.jet
            overlay = (255 * cmap(np.asarray(overlay) ** 2)[:, :, :3]).astype(np.uint8)

            # Create a figure and plot the first image
            fig, ax = plt.subplots()
            ax.axis('off')  # removes the axis markers

            # First plot the original image
            ax.imshow(to_pil_image(input_image.squeeze(), mode='RGB'))

            # Plot the heatmap on the same axes with alpha < 1 (transparency)
            ax.imshow(PIL.Image.fromarray(overlay), alpha=0.4, interpolation='nearest')

            # Save the figure to the output folder
            output_img_path = os.path.join(output_folder, 'gradcam_' + img_file)
            plt.savefig(output_img_path, bbox_inches='tight', pad_inches=0)
            plt.close(fig)

In [75]:
def heatmap_process(folder_path, model, target_class, output_folder):
    """
    Processes a folder of images to generate Grad-CAM heatmaps, averages these heatmaps, 
    and saves the averaged heatmap to a specified output folder.

    This function iterates over each JPEG image in the given folder, applies the Grad-CAM
    algorithm to each image using the specified model and target class, and accumulates the
    generated heatmaps. It then computes the average of these heatmaps, resizes the averaged 
    heatmap to the original image size, applies a colormap for better visualization, and saves
    the result as an image in the output folder.

    Args:
        folder_path (str): Path to the folder containing JPEG images to be processed.
        model (torch.nn.Module): The neural network model, modified for Grad-CAM usage.
        target_class (int): The class index for which Grad-CAM is applied to each image.
        output_folder (str): Path to the folder where the averaged heatmap image will be saved.

    The function assumes that the original images are of size 224x224 pixels. This size is used
    when resizing the averaged heatmap. The output image is saved using a 'jet' colormap for
    heatmap visualization.
    """

    heatmaps = []

    for img_file in os.listdir(folder_path):
        if img_file.lower().endswith('.jpeg'):
            img_path = os.path.join(folder_path, img_file)
            input_image = preprocess_image(img_path)
            
            # Apply Grad-CAM
            heatmap = grad_cam(model, input_image, target_class)
            heatmaps.append(heatmap)
             
    # Compute the average heatmap
    average_heatmap = np.mean(heatmaps, axis=0)

    # Convert the average heatmap to uint8 format
    average_heatmap_uint8 = np.uint8(255 * average_heatmap / np.max(average_heatmap))

    # Resize using PIL
    original_image_size = (224, 224)
    average_heatmap_resized = Image.fromarray(average_heatmap_uint8).resize(original_image_size, Image.BICUBIC)

    # Apply colormap
    average_heatmap_colored = cm.jet(np.array(average_heatmap_resized))[:, :, :3]
    
    # Save the average heatmap
    output_img_path = os.path.join(output_folder, 'average_heatmap.png')
    plt.imsave(output_img_path, average_heatmap_colored)

In [76]:
# set image storage locations

normal_test_preprocessed_folder = 'dataset/PRE_NORMAL/TEST'
normal_train_preprocessed_folder = 'dataset/PRE_NORMAL/TRAIN'
pneumonia_test_preprocessed_folder = 'dataset/PRE_PNEUMONIA/TEST'
pneumonia_train_preprocessed_folder = 'dataset/PRE_PNEUMONIA/TRAIN'
normal_test_unprocessed_folder = 'dataset/NORMAL/TEST'
normal_train_unprocessed_folder = 'dataset/NORMAL/TRAIN'
pneumonia_test_unprocessed_folder = 'dataset/PNEUMONIA/TEST'
pneumonia_train_unprocessed_folder = 'dataset/PNEUMONIA/TRAIN'

output_test_preprocessed_normal = 'output/PRE_NORMAL/TEST' 
output_train_preprocessed_normal = 'output/PRE_NORMAL/TRAIN' 
output_test_preprocessed_pneumonia = 'output/PRE_PNEUMONIA/TEST'
output_train_preprocessed_pneumonia = 'output/PRE_PNEUMONIA/TRAIN'
output_test_unprocessed_normal = 'output/NORMAL/TEST' 
output_train_unprocessed_normal = 'output/NORMAL/TRAIN' 
output_test_unprocessed_pneumonia = 'output/PNEUMONIA/TEST'
output_train_unprocessed_pneumonia = 'output/PNEUMONIA/TRAIN'

In [95]:
# run functions to produce average heatmaps

heatmap_process(normal_test_preprocessed_folder, modified_model, target_class=1, output_folder=output_test_preprocessed_normal)
heatmap_process(normal_train_preprocessed_folder, modified_model, target_class=1, output_folder=output_train_preprocessed_normal)

heatmap_process(pneumonia_test_preprocessed_folder, modified_model, target_class=1, output_folder=output_test_preprocessed_pneumonia)
heatmap_process(pneumonia_train_preprocessed_folder, modified_model, target_class=1, output_folder=output_train_preprocessed_pneumonia)

heatmap_process(normal_test_unprocessed_folder, modified_model, target_class=1, output_folder=output_test_unprocessed_normal)
heatmap_process(normal_train_unprocessed_folder, modified_model, target_class=1, output_folder=output_train_unprocessed_normal)

heatmap_process(pneumonia_test_unprocessed_folder, modified_model, target_class=1, output_folder=output_test_unprocessed_pneumonia)
heatmap_process(pneumonia_train_unprocessed_folder, modified_model, target_class=1, output_folder=output_train_unprocessed_pneumonia)

In [None]:
# run functions to produce GradCam images

gradcam_process(normal_test_preprocessed_folder, modified_model, target_class=1, output_folder=output_test_preprocessed_normal)
gradcam_process(normal_train_preprocessed_folder, modified_model, target_class=1, output_folder=output_train_preprocessed_normal)

gradcam_process(pneumonia_test_preprocessed_folder, modified_model, target_class=1, output_folder=output_test_preprocessed_pneumonia)
gradcam_process(pneumonia_train_preprocessed_folder, modified_model, target_class=1, output_folder=output_train_preprocessed_pneumonia)

gradcam_process(normal_test_unprocessed_folder, modified_model, target_class=1, output_folder=output_test_unprocessed_normal)
gradcam_process(normal_train_unprocessed_folder, modified_model, target_class=1, output_folder=output_train_unprocessed_normal)

gradcam_process(pneumonia_test_unprocessed_folder, modified_model, target_class=1, output_folder=output_test_unprocessed_pneumonia)
gradcam_process(pneumonia_train_unprocessed_folder, modified_model, target_class=1, output_folder=output_train_unprocessed_pneumonia)

In [100]:
def overlay_heatmap_on_lungs(lung_image_path, heatmap_folder, output_folder, alpha=0.5):
    """
    Overlays a fixed Grad-CAM heatmap onto a lung image and saves the result in the specified folder,
    preserving the original colors of the heatmap.

    Args:
        lung_image_path (str): File path to the lung image.
        heatmap_folder (str): Directory path where the Grad-CAM heatmap image is stored.
        output_folder (str): Directory path where the overlaid image will be saved.
        alpha (float, optional): Transparency of the heatmap overlay. Default is 0.4.
    """

    # Define the fixed heatmap filename
    heatmap_filename = 'average_heatmap_preprocessed_model.png'

    # Generate the full path to the heatmap
    heatmap_path = os.path.join(heatmap_folder, heatmap_filename)

    # Load the lung image
    lung_image = cv2.imread(lung_image_path)
    lung_image = cv2.cvtColor(lung_image, cv2.COLOR_BGR2RGB)

    # Load the heatmap
    heatmap = cv2.imread(heatmap_path, cv2.IMREAD_GRAYSCALE)
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)

    # Resize heatmap to match lung image dimensions
    heatmap = cv2.resize(heatmap, (lung_image.shape[1], lung_image.shape[0]))

    # Manually blend the heatmap and lung image
    overlayed_image = (1 - alpha) * lung_image + alpha * heatmap
    overlayed_image = overlayed_image.astype(np.uint8)

    # Convert back to PIL Image for saving
    overlayed_image = Image.fromarray(overlayed_image)

    # Create the output folder if it doesn't exist
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    # Define a fixed output filename
    output_filename = 'lung_with_heatmap_overlay.jpg'

    # Save the image
    output_img_path = os.path.join(output_folder, output_filename)
    overlayed_image.save(output_img_path)

In [101]:
# run functions to produce average heatmap overlay on lung images 

overlay_heatmap_on_lungs('lung_NORMAL.jpeg', output_test_unprocessed_normal, output_test_unprocessed_normal)
overlay_heatmap_on_lungs('lung_NORMAL.jpeg', output_test_preprocessed_normal, output_test_preprocessed_normal)
overlay_heatmap_on_lungs('lung_NORMAL.jpeg', output_test_unprocessed_pneumonia, output_test_unprocessed_pneumonia)
overlay_heatmap_on_lungs('lung_NORMAL.jpeg', output_test_preprocessed_pneumonia, output_test_preprocessed_pneumonia)