<a href="https://colab.research.google.com/github/JWKKWJ123/google-colab/blob/main/Heatmaps_of_models_on_AD_CN_tasks_on_ADNI.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This Google Colab notebook provides an interactive interface for visualizing 3D heatmaps in the transverse plane. The slice coordinate range is [0, 180]. Detailed explanations of the models and visualization methods are available in the paper.

In [31]:
import nibabel as nib
import math
import pandas as pd
import numpy as np
import os
import glob
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.gridspec import GridSpec
from matplotlib.colors import ListedColormap, LinearSegmentedColormap
from scipy import ndimage
from scipy.ndimage import gaussian_filter
import ipywidgets as widgets
from ipywidgets import interactive_output
from ipywidgets import GridBox, Layout, Output
from IPython.display import display


In [29]:
###Load data

!gdown --id 1kHQlVTazUwTMSeGzrT3kgfsG-lM6LNuA -O /content/aparcaseg_in_MNI.nii.gz
!gdown --id 14ls7KNATP1P3eTokGOQB7VxJ-AeCiEei -O /content/ICNN_ADCN_ADNI.npy
!gdown --id 1_QoODC4r3oilIDfDZY-G-Ov3H_8YogoB -O /content/VolEBM_ADCN_ADNI.npy
!gdown --id 1hqCSFUkAg_7sn72JCr8Yx0yJMVYGjocx -O /content/VGG_OCC_ADCN_ADNI.npy
!gdown --id 1m2Iq26iWcA54FnI1qk9WdezX8duvikCa -O /content/DenseNet_OCC_ADCN_ADNI.npy
!gdown --id 1buhyBl67_EdpiQlXbSwPBDhSIzSiWeJ2 -O /content/VolSVM_pmap_ADCN_ADNI.npy
!gdown --id 1xafHk76pTFzZQ36StJu9WuvzC6PUZmmE -O /content/VBMSVM_pmap_ADCN_ADNI.npy
!gdown --id 1OOcL-gs4Y-gFsP6sFW6AoWew1wVyXt5l -O /content/GLCNN_OCC_ADCN_ADNI.npy




Downloading...
From (original): https://drive.google.com/uc?id=1kHQlVTazUwTMSeGzrT3kgfsG-lM6LNuA
From (redirected): https://drive.google.com/uc?id=1kHQlVTazUwTMSeGzrT3kgfsG-lM6LNuA&confirm=t&uuid=350d5912-8a76-4761-8ebd-edb4b17fe431
To: /content/aparcaseg_in_MNI.nii.gz
100% 336k/336k [00:00<00:00, 32.7MB/s]
Downloading...
From: https://drive.google.com/uc?id=14ls7KNATP1P3eTokGOQB7VxJ-AeCiEei
To: /content/ICNN_ADCN_ADNI.npy
100% 32.4M/32.4M [00:00<00:00, 90.5MB/s]
Downloading...
From: https://drive.google.com/uc?id=1_QoODC4r3oilIDfDZY-G-Ov3H_8YogoB
To: /content/VolEBM_ADCN_ADNI.npy
100% 32.4M/32.4M [00:00<00:00, 148MB/s]
Downloading...
From: https://drive.google.com/uc?id=1hqCSFUkAg_7sn72JCr8Yx0yJMVYGjocx
To: /content/VGG_OCC_ADCN_ADNI.npy
100% 32.4M/32.4M [00:00<00:00, 49.9MB/s]


In [36]:
###Functions to plot the heatmaps

def plot_selected_slices(struct_arr, num_slices=6, cmap='gray', vmin=None, vmax=None, overlay=None,
                         overlay_cmap='hot', overlay_vmin=None, overlay_vmax=None, _class=None, iteration=0):
    """
    Plot slices of a 3D image (and an overlay) along every axis, restricted to the middle 25%-75% region.
    """
    import matplotlib.pyplot as plt
    import numpy as np
    from scipy import ndimage
    from matplotlib.gridspec import GridSpec

    if vmin is None:
        vmin = struct_arr.min()
    if vmax is None:
        vmax = struct_arr.max()
    if overlay_vmin is None and overlay is not None:
        overlay_vmin = overlay.min()
    if overlay_vmax is None and overlay is not None:
        overlay_vmax = overlay.max()
    print(vmin, vmax, overlay_vmin, overlay_vmax)

    # Create the figure with GridSpec for precise control
    fig = plt.figure(figsize=(num_slices * 2, 6))
    gs = GridSpec(3, num_slices, figure=fig, wspace=0.05, hspace=0.02)  # Reduced hspace to minimize row spacing

    # Define slice indices for the middle 25%-75% range
    slice_ranges = []
    for axis in range(3):
        start = int(struct_arr.shape[axis] * 0.25)
        end = int(struct_arr.shape[axis] * 0.75)
        indices = np.linspace(start, end, num_slices, dtype=int)
        slice_ranges.append(indices)

    # Add the overlay colorbar (if overlay is provided)
    cbar_ax = None
    if overlay is not None:
        cbar_ax = fig.add_axes([0.92, 0.2, 0.015, 0.6])  # Adjusted colorbar position and size

    for axis, axis_label in zip([0, 1, 2], ['x', 'y', 'z']):
        for i in range(num_slices):
            ax = fig.add_subplot(gs[axis, i])  # Use GridSpec to position subplots
            i_slice = slice_ranges[axis][i]

            ax.axis('off')
            ax.imshow(ndimage.rotate(np.take(struct_arr, i_slice, axis=axis), 90),
                      vmin=vmin, vmax=vmax, cmap=cmap, interpolation=None)
            ax.text(0.03, 0.97, '{}={}'.format(axis_label, i_slice), color='white',
                    horizontalalignment='left', verticalalignment='top', transform=ax.transAxes)

            if overlay is not None:
                im = ax.imshow(ndimage.rotate(np.take(overlay, i_slice, axis=axis), 90),
                               cmap=overlay_cmap, vmin=overlay_vmin, vmax=overlay_vmax,
                               interpolation=None, alpha=0.7)

    # Add colorbar for the overlay
    if overlay is not None:
        fig.colorbar(im, cax=cbar_ax)

    # Save the plotted heatmap slice (optional; specify appropriate file path)
    plt.show()
    #plt.close(fig)



def plot_transverse_heatmaps(img_npy, heatmaps, z_coord, cmap='gray', vmin=None, vmax=None,
                             overlay_cmap='hot', overlay_vmin=None, overlay_vmax=None, max_cols=4):
    """
    Plot multiple heatmaps on a specified z slice, each with its own independent colorbar,
    and automatically wrap to a new row (with a maximum of max_cols plots per row).

    Parameters:
    - img_npy: 3D array (base anatomical image)
    - heatmaps: Dictionary {model_name: 3D heatmap array}
    - z_coord: Specified z slice to display
    - cmap: Colormap for the base image
    - vmin, vmax: Intensity range for the base image
    - overlay_cmap: Colormap for the overlay heatmaps
    - overlay_vmin, overlay_vmax: Intensity range for each heatmap (set independently)
    - max_cols: Maximum number of plots per row; if exceeded, new rows are created
    """
    num_models = len(heatmaps)

    # Calculate the number of columns and rows
    ncols = min(num_models, max_cols)
    nrows = math.ceil(num_models / ncols)

    # Create subplots; adjust the figsize as needed
    fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 3, nrows * 4), constrained_layout=True)

    # If there's only one subplot, axes is not an array, so convert it to a list
    if nrows * ncols == 1:
        axes = [axes]
    else:
        axes = axes.flatten()

    # Plotting for each model
    for i, (model_name, heatmap) in enumerate(heatmaps.items()):
        ax = axes[i]
        ax.axis('off')

        # Calculate local vmin/vmax for the current heatmap
        local_vmin = overlay_vmin if overlay_vmin is not None else heatmap.min()
        local_vmax = overlay_vmax if overlay_vmax is not None else heatmap.max()

        # Plot the base anatomical image (rotated 90° to match the original orientation)
        ax.imshow(ndimage.rotate(img_npy[:, :, z_coord], 90), vmin=vmin, vmax=vmax, cmap=cmap, interpolation=None)
        ax.text(0.03, 0.97, f'{model_name}\n(z={z_coord})', color='white',
                horizontalalignment='left', verticalalignment='top', transform=ax.transAxes)

        # Plot the overlay heatmap
        im = ax.imshow(ndimage.rotate(heatmap[:, :, z_coord], 90), cmap=overlay_cmap,
                       vmin=local_vmin, vmax=local_vmax, interpolation=None, alpha=0.7)

        # Add an independent colorbar with correct scaling
        cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
        cbar.ax.tick_params(labelsize=7)

    # If there are more subplots than models, turn off the extra axes
    for j in range(i + 1, nrows * ncols):
        axes[j].axis('off')

    plt.show()




def interactive_slider(img_npy, heatmaps, overlay_cmap='hot'):
    """
    Create an interactive slider to control the z slice dynamically,
    ensuring that each model has its own independent colorbar.

    Parameters:
    - img_npy: 3D base anatomical image
    - heatmaps: Dictionary {model_name: 3D heatmap array}
    - overlay_cmap: Colormap for heatmaps
    """
    max_z = img_npy.shape[2] - 1  # Max slice index
    z_slider = widgets.IntSlider(value=max_z // 2, min=0, max=max_z, step=1, description="Z Slice")

    # Output area
    output = widgets.Output()

    def update_plot(z_coord):
        with output:
            output.clear_output(wait=True)
            # Call function ensuring colorbars are independent and fixed
            plot_transverse_heatmaps(img_npy, heatmaps, z_coord, overlay_cmap=overlay_cmap)

    # Bind slider to function
    widgets.interactive(update_plot, z_coord=z_slider)

    # Display slider and output
    display(z_slider, output)


In [38]:
# Define color lists for the colormaps.
# The following line defines the active color list.
colors = ['black', 'blue', 'orange', 'gold', 'tomato', 'red']

# Alternative color list can be:
# colors = ['blue', 'green', 'black', 'gold', 'red']

# Create a ListedColormap using the specified colors.
color_list = ListedColormap(colors)

# Create a LinearSegmentedColormap from the list of colors.
mycmap = LinearSegmentedColormap.from_list('mycmap', colors)

# Load the brain atlas used in this study which functioned as the backgroud of heatmaps
img_seg = "/content/drive/MyDrive/aparcaseg_in_MNI.nii.gz"
img_seg = nib.load(img_seg).get_fdata()
img_npy = np.array(img_seg)

# Crop the segmentation image to the region of interest.
img_npy = img_npy[16:166, 19:199, 16:166]

# Calculate the midpoint along the first axis (used for splitting hemispheres).
mid_slice = img_npy.shape[0] // 2

# Split the image into left and right hemispheres.
left_hemisphere = img_npy[:mid_slice, :, :]
right_hemisphere = img_npy[mid_slice:, :, :]

# Normalize the intensities of each hemisphere independently.
left_norm = (left_hemisphere - np.min(left_hemisphere)) / (np.max(left_hemisphere) - np.min(left_hemisphere))
right_norm = (right_hemisphere - np.min(right_hemisphere)) / (np.max(right_hemisphere) - np.min(right_hemisphere))

# Concatenate the normalized left and right hemispheres back together along the first axis.
img_npy = np.concatenate((left_norm, right_norm), axis=0)


In [37]:
# Load the absolute values of various heatmap arrays from .npy files.
heatmap_GLCNN = abs(np.load("/content/GLCNN_OCC_ADCN_ADNI.npy"))
heatmap_VGG = abs(np.load("/content/VGG_OCC_ADCN_ADNI.npy"))
heatmap_DenseNet = abs(np.load("/content/DenseNet_OCC_ADCN_ADNI.npy"))
heatmap_Vol_EBM = abs(np.load("/content/VolEBM_ADCN_ADNI.npy"))
heatmap_Vol_SVM = abs(np.load("/content/VolSVM_pmap_ADCN_ADNI.npy"))
heatmap_VBM_SVM = abs(np.load("/content/VBMSVM_pmap_ADCN_ADNI.npy"))
heatmap_ICNN = abs(np.load("/content/ICNN_ADCN_ADNI.npy"))

# Set heatmap values to zero wherever the corresponding anatomical image (img_npy) is zero.
# This masks out non-brain areas or background regions.
heatmap_GLCNN = np.where(img_npy == 0, 0, heatmap_GLCNN)
heatmap_VGG = np.where(img_npy == 0, 0, heatmap_VGG)
heatmap_DenseNet = np.where(img_npy == 0, 0, heatmap_DenseNet)
heatmap_VBM_SVM = np.where(img_npy == 0, 0, heatmap_VBM_SVM)
heatmap_ICNN = np.where(img_npy == 0, 0, heatmap_ICNN)

# Organize the heatmaps into a dictionary with model names as keys.
heatmaps = {
    "GLCNN": heatmap_GLCNN,
    "VGG": heatmap_VGG,
    "DenseNet": heatmap_DenseNet,
    "Vol-EBM": heatmap_Vol_EBM,
    "Vol-SVM": heatmap_Vol_SVM,
    "VBM-SVM": heatmap_VBM_SVM,
    "ICNN": heatmap_ICNN,
}

# Run the interactive UI using the custom colormap (mycmap) to overlay the heatmaps on the anatomical image.
interactive_slider(img_npy, heatmaps, overlay_cmap=mycmap)


IntSlider(value=74, description='Z Slice', max=149)

Output()

SyntaxError: invalid decimal literal (<ipython-input-10-2ea4add428bd>, line 1)