In [None]:
# Import Libraries
import matplotlib.pyplot as plt
import os
import numpy as np
import nibabel as nib
import cv2
import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import Sequence
from tensorflow.keras.models import load_model
from sklearn.model_selection import train_test_split
import ipywidgets as widgets
from IPython.display import display
from skimage.metrics import structural_similarity as ssim

In [None]:
# CLAHE Equalization Function
def clahe_equalization(image, clipLimit=1.1, tileGridSize=(8, 8)):
    clahe = cv2.createCLAHE(clipLimit=clipLimit, tileGridSize=tileGridSize)
    for i in range(image.shape[-1]):
        image[..., i] = clahe.apply((image[..., i] * 255).astype(np.uint8)) / 255.0
    return image

# Preprocess Image Function
def preprocess_image(image, brightness_correct=True, clipLimit=1.1, tileGridSize=(8, 8)):
    image = image / (np.max(image) + 1e-8)
    if brightness_correct:
        image = clahe_equalization(image, clipLimit=clipLimit, tileGridSize=tileGridSize)
    return image

# Pads the image to have exactly 192 slices along the last axis.
def pad_image_to_192_slices(image):
    if image.shape[2] < 192:
        padding_needed = 192 - image.shape[2]
        image = np.pad(image, ((0, 0), (0, 0), (0, padding_needed)), mode='constant')
    elif image.shape[2] > 192:
        image = image[:, :, :192]
    return image

# Swap the axes to change from sagittal to axial
def reorient_sagittal_to_axial(image):
    image = np.swapaxes(image, 0, 2)
    return image

In [None]:
# MRI Data Generator Class
class MRI_DataGenerator(Sequence):
    def __init__(self, list_IDs, batch_size=4, dim=(160, 160, 192), shuffle=True, augment=False, brightness_correct=True, clipLimit=2.0, tileGridSize=(8, 8)):
        self.dim = dim
        self.batch_size = batch_size
        self.list_IDs = list_IDs
        self.shuffle = shuffle
        self.augment = augment
        self.brightness_correct = brightness_correct
        self.clipLimit = clipLimit
        self.tileGridSize = tileGridSize
        self.on_epoch_end()

    def __len__(self):
        return int(np.floor(len(self.list_IDs) / self.batch_size))

    def __getitem__(self, index):
        indexes = self.indexes[index * self.batch_size : (index + 1) * self.batch_size]
        list_IDs_temp = [self.list_IDs[k] for k in indexes]
        X = self.__data_generation(list_IDs_temp)
        return X, X

    def on_epoch_end(self):
        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle:
            np.random.shuffle(self.indexes)

    def __data_generation(self, list_IDs_temp):
        X = np.empty((self.batch_size, self.dim[0], self.dim[1], 192, 1))

        for i, ID in enumerate(list_IDs_temp):
            path = ID
            try:
                nii = nib.load(path)
                image = nii.get_fdata()

                # Reorient if the dataset is Healthy, MCI, or AD
                if "Healthy_sorted" in path or "MCI_sorted" in path or "AD_sorted" in path:
                    image = reorient_sagittal_to_axial(image)
                    image = np.rot90(image, k=2, axes=(0, 1))  # Rotate 180 degrees anticlockwise
                    image = image[:, :, ::-1]  # Reverse the slice order

                # Preprocess, resize, and pad image
                image = pad_image_to_192_slices(image)
                resized_slices = [cv2.resize(image[:, :, slice_idx], (self.dim[0], self.dim[1])) for slice_idx in range(image.shape[2])]
                image = np.stack(resized_slices, axis=-1)
                image = preprocess_image(image, brightness_correct=self.brightness_correct, clipLimit=self.clipLimit, tileGridSize=self.tileGridSize)

                if self.augment:
                    image = self.__augment(image)
                X[i, ..., 0] = image  # Adding the channel dimension

            except Exception as e:
                print(f"Error processing file {path}: {e}")
                X[i,] = np.zeros((self.dim[0], self.dim[1], 192, 1))

        return X



    def __augment(self, image):
        if np.random.rand() < 0.5:
            image = np.fliplr(image)
        if np.random.rand() < 0.5:
            image = np.flipud(image)
        return image

In [None]:
# Load NIfTI Files from Folder
def load_nii_files_from_folder(folder):
    nii_files = []
    for filename in os.listdir(folder):
        if filename.endswith(".nii"):
            nii_files.append(os.path.join(folder, filename))
    return nii_files

# Define paths to datasets
healthy_path = 'CamCAN/CamCAN_subset'
mci_path = 'ADNI/MCI_sorted'
ad_path = 'ADNI/AD_sorted'

# Load the data
healthy_files = load_nii_files_from_folder(healthy_path)
mci_files = load_nii_files_from_folder(mci_path)
ad_files = load_nii_files_from_folder(ad_path)

# Load training data
folder_path = 'Human_Connectome'
nii_files = load_nii_files_from_folder(folder_path)
train_files, remaining_files = train_test_split(nii_files, test_size=0.2, random_state=42)
val_files, test_files = train_test_split(remaining_files, test_size=0.5, random_state=42)

# Load the test data from Human_Connectome\vali_test for evaluation and visualization
eval_test_files_path = 'Human_Connectome/vali_test'
eval_test_files = load_nii_files_from_folder(eval_test_files_path)

# Create data generators
train_gen = MRI_DataGenerator(list_IDs=train_files, batch_size=4, dim=(160, 160, 192), shuffle=True, augment=True)
val_gen = MRI_DataGenerator(list_IDs=val_files, batch_size=2, dim=(160, 160, 192), shuffle=False)
test_gen = MRI_DataGenerator(list_IDs=test_files, batch_size=2, dim=(160, 160, 192), shuffle=False)
eval_test_gen = MRI_DataGenerator(list_IDs=eval_test_files, batch_size=1, dim=(160, 160, 192), shuffle=False)
healthy_gen = MRI_DataGenerator(list_IDs=healthy_files, batch_size=1, dim=(160, 160, 192), shuffle=False)
mci_gen = MRI_DataGenerator(list_IDs=mci_files, batch_size=1, dim=(160, 160, 192), shuffle=False)
ad_gen = MRI_DataGenerator(list_IDs=ad_files, batch_size=1, dim=(160, 160, 192), shuffle=False)

In [None]:
# Define the UNet model
def unet_model(input_size=(160, 160, 192, 1)):
    inputs = tf.keras.layers.Input(input_size)
    c1 = tf.keras.layers.Conv3D(64, (3, 3, 3), activation='relu', padding='same')(inputs)
    p1 = tf.keras.layers.MaxPooling3D((2, 2, 2))(c1)
    c2 = tf.keras.layers.Conv3D(128, (3, 3, 3), activation='relu', padding='same')(p1)
    p2 = tf.keras.layers.MaxPooling3D((2, 2, 2))(c2)
    c3 = tf.keras.layers.Conv3D(256, (3, 3, 3), activation='relu', padding='same')(p2)
    p3 = tf.keras.layers.MaxPooling3D((2, 2, 2))(c3)
    c4 = tf.keras.layers.Conv3D(512, (3, 3, 3), activation='relu', padding='same')(p3)
    p4 = tf.keras.layers.MaxPooling3D((2, 2, 2))(c4)
    c5 = tf.keras.layers.Conv3D(1024, (3, 3, 3), activation='relu', padding='same')(p4)
    
    u6 = tf.keras.layers.Conv3DTranspose(512, (2, 2, 2), strides=(2, 2, 2), padding='same')(c5)
    u6 = tf.keras.layers.concatenate([u6, c4], axis=-1)
    c6 = tf.keras.layers.Conv3D(512, (3, 3, 3), activation='relu', padding='same')(u6)

    u7 = tf.keras.layers.Conv3DTranspose(256, (2, 2, 2), strides=(2, 2, 2), padding='same')(c6)
    u7 = tf.keras.layers.concatenate([u7, c3], axis=-1)
    c7 = tf.keras.layers.Conv3D(256, (3, 3, 3), activation='relu', padding='same')(u7)
    
    u8 = tf.keras.layers.Conv3DTranspose(128, (2, 2, 2), strides=(2, 2, 2), padding='same')(c7)
    u8 = tf.keras.layers.concatenate([u8, c2], axis=-1)
    c8 = tf.keras.layers.Conv3D(128, (3, 3, 3), activation='relu', padding='same')(u8)
    
    u9 = tf.keras.layers.Conv3DTranspose(64, (2, 2, 2), strides=(2, 2, 2), padding='same')(c8)
    u9 = tf.keras.layers.concatenate([u9, c1], axis=-1)
    c9 = tf.keras.layers.Conv3D(64, (3, 3, 3), activation='relu', padding='same')(u9)
    
    outputs = tf.keras.layers.Conv3D(1, (1, 1, 1), activation='sigmoid')(c9)
    model = tf.keras.models.Model(inputs=[inputs], outputs=[outputs])  
    
    return model

# Create and compile the model
model = unet_model()
model.compile(optimizer=Adam(learning_rate=0.0001), loss='mean_squared_error', metrics=['mse'])


In [None]:
# Load pre-trained model (if available)
try:
    model = load_model('Saved_models/Trained_model.h5', compile=False)
    model.compile(optimizer=Adam(learning_rate=0.0001), loss='mean_squared_error', metrics=['mse'])
except Exception as e:
    print(f"Error loading pre-trained model: {e}")

# Training setup
early_stopping = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
model_checkpoint = ModelCheckpoint('best_model.h5', save_best_only=True, monitor='val_loss')
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5)

# Train the model
try:
    history = model.fit(
        train_gen,
        validation_data=val_gen,
        epochs=5,
        callbacks=[early_stopping, model_checkpoint, reduce_lr],
        verbose=2
    )
except Exception as e:
    print(f"Error during model training: {e}")

# Load the best model after training
model.save('Saved_models/best_model.h5')
model = load_model('Saved_models/best_model.h5')

In [None]:
model = load_model('Saved_models/Trained_model.h5', compile=False)
model.compile(optimizer=Adam(learning_rate=0.0001), loss='mean_squared_error', metrics=['mse'])

In [None]:
# Visualization of original and reconstructed images
def visualize_reconstruction(generator, model, num_samples=2):
    X, _ = generator[0]
    actual_num_samples = min(num_samples, X.shape[0])
    all_images = X[:actual_num_samples]
    
    reconstructed_imgs = []
    for i in range(actual_num_samples):
        image = np.expand_dims(all_images[i], axis=0)
        reconstructed_img = model.predict(image)
        reconstructed_imgs.append(reconstructed_img[0])

    plt.figure(figsize=(20, 4))
    for i in range(actual_num_samples):
        ax = plt.subplot(2, actual_num_samples, i + 1)
        middle_slice = all_images[i].shape[2] // 2
        plt.imshow(all_images[i][:, :, middle_slice, 0], cmap='gray')
        plt.title(f"Original {i+1}")
        plt.axis('off')

        ax = plt.subplot(2, actual_num_samples, i + 1 + actual_num_samples)
        plt.imshow(reconstructed_imgs[i][:, :, middle_slice, 0], cmap='gray')
        plt.title(f"Reconstructed {i+1}")
        plt.axis('off')

    plt.show()

# Visualize training data reconstruction
print("Visualizing original and reconstructed images:")
visualize_reconstruction(val_gen, model, num_samples=2)

In [None]:
def calculate_reconstruction_error(model, images):
    reconstructions = model.predict(images)
    errors = np.mean(np.square(images - reconstructions), axis=(1, 2, 3, 4))
    return errors

# Evaluate and visualize slices
def evaluate_and_visualize_slices(generator, subset_name, model, slice_indices=[30, 96, 162]):
    mse_list = []
    ssim_list = []
    visualization_done = False  # Flag to ensure visualization happens only once

    for i in range(len(generator)):
        print(f"Processing {subset_name} batch {i+1}/{len(generator)}...")

        # Get a batch of images
        batch_images, _ = generator[i]

        for j in range(batch_images.shape[0]):
            sample_image = batch_images[j, ..., 0]  # Extract the image from the batch

            # Predict the reconstructed image
            sample_image_expanded = np.expand_dims(sample_image, axis=0)
            sample_image_expanded = np.expand_dims(sample_image_expanded, axis=-1)
            reconstructed_image = model.predict(sample_image_expanded)
            reconstructed_image = reconstructed_image[0, :, :, :, 0]

            mse = np.mean(np.square(sample_image - reconstructed_image))
            ssim_value = np.mean([ssim(sample_image[:, :, slice_idx], reconstructed_image[:, :, slice_idx], data_range=sample_image.max() - sample_image.min()) for slice_idx in range(sample_image.shape[-1])])

            mse_list.append(mse)
            ssim_list.append(ssim_value)

            # Visualize selected slices only once
            if not visualization_done:
                for slice_idx in slice_indices:
                    plt.figure(figsize=(20, 5))

                    plt.subplot(1, 4, 1)
                    plt.title(f"{subset_name} Original - Slice {slice_idx}")
                    plt.imshow(sample_image[:, :, slice_idx], cmap='gray')
                    plt.axis('off')

                    plt.subplot(1, 4, 2)
                    plt.title(f"{subset_name} Reconstructed - Slice {slice_idx}")
                    plt.imshow(reconstructed_image[:, :, slice_idx], cmap='gray')
                    plt.axis('off')

                    plt.subplot(1, 4, 3)
                    plt.title(f"{subset_name} Error (Original - Reconstructed) - Slice {slice_idx}")
                    error_image = np.abs(sample_image[:, :, slice_idx] - reconstructed_image[:, :, slice_idx])
                    plt.imshow(error_image, cmap='jet')
                    plt.axis('off')

                    plt.subplot(1, 4, 4)
                    plt.title(f"{subset_name} Original with Error Heatmap - Slice {slice_idx}")
                    plt.imshow(sample_image[:, :, slice_idx], cmap='gray')
                    plt.imshow(error_image, cmap='jet', alpha=0.6)
                    plt.axis('off')

                    plt.show()
                
                visualization_done = True  # Set the flag to True to prevent further visualization

    print(f"\nAverage MSE for {subset_name}: {np.mean(mse_list)}")
    print(f"Average SSIM for {subset_name}: {np.mean(ssim_list)}")
    return mse_list, ssim_list

In [None]:
# Evaluate and visualize reconstruction error and SSIM for the test data using the generator
print("\nEvaluating Test Data:")
test_mse, test_ssim = evaluate_and_visualize_slices(eval_test_gen, "Test", model)

# Evaluate and visualize reconstruction error and SSIM for the Healthy, MCI, and AD datasets
print("\nEvaluating Healthy Data:")
healthy_mse, healthy_ssim = evaluate_and_visualize_slices(healthy_gen, "Healthy", model)

print("\nEvaluating MCI Data:")
mci_mse, mci_ssim = evaluate_and_visualize_slices(mci_gen, "MCI", model)

print("\nEvaluating AD Data:")
ad_mse, ad_ssim = evaluate_and_visualize_slices(ad_gen, "AD", model)

# After evaluating, analyze the MSE and SSIM results
print("\nResults Summary:")
print(f"Test Data - Average MSE: {np.mean(test_mse)}, Average SSIM: {np.mean(test_ssim)}")
print(f"Healthy Data - Average MSE: {np.mean(healthy_mse)}, Average SSIM: {np.mean(healthy_ssim)}")
print(f"MCI Data - Average MSE: {np.mean(mci_mse)}, Average SSIM: {np.mean(mci_ssim)}")
print(f"AD Data - Average MSE: {np.mean(ad_mse)}, Average SSIM: {np.mean(ad_ssim)}")

In [None]:
def multi_image_interactive_slice_viewer(images_dict, alpha=0.6, cmap='jet'):
    
    num_slices = list(images_dict.values())[0][0].shape[2]

    # Slider to choose the slice index
    slice_slider = widgets.IntSlider(min=0, max=num_slices-1, step=1, description='Slice:')
    
    # Checkbox to toggle the heatmap
    heatmap_checkbox = widgets.Checkbox(value=True, description='Show Heatmap', disabled=False)
    
    def update_visualization(slice_index, show_heatmap):
        plt.figure(figsize=(20, len(images_dict) * 5))
        
        for i, (label, (original_image, reconstructed_image)) in enumerate(images_dict.items()):
            error_image = np.abs(original_image - reconstructed_image)
            
            original_slice = original_image[:, :, slice_index]
            reconstructed_slice = reconstructed_image[:, :, slice_index]
            error_slice = error_image[:, :, slice_index]
            
            error_slice_normalized = cv2.normalize(error_slice, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX)

            plt.subplot(len(images_dict), 4, i * 4 + 1)
            plt.title(f"{label} - Original Slice")
            plt.imshow(original_slice, cmap='gray')
            plt.axis('off')

            plt.subplot(len(images_dict), 4, i * 4 + 2)
            plt.title(f"{label} - Reconstructed Slice")
            plt.imshow(reconstructed_slice, cmap='gray')
            plt.axis('off')

            plt.subplot(len(images_dict), 4, i * 4 + 3)
            plt.title(f"{label} - Error (Original - Reconstructed)")
            plt.imshow(error_slice_normalized, cmap=cmap)
            plt.axis('off')

            plt.subplot(len(images_dict), 4, i * 4 + 4)
            plt.title(f"{label} - Original with Error Heatmap")
            plt.imshow(original_slice, cmap='gray')
            if show_heatmap:
                plt.imshow(error_slice_normalized, cmap=cmap, alpha=alpha)
            plt.axis('off')
        
        plt.show()

    interactive_plot = widgets.interactive(update_visualization, slice_index=slice_slider, show_heatmap=heatmap_checkbox)
    display(interactive_plot) 

In [None]:
# Extract images from the eval_test_gen generator
sample_images_test, _ = eval_test_gen[10]
sample_image_test = sample_images_test[0, ..., 0]
reconstructed_image_test = model.predict(np.expand_dims(sample_image_test, axis=(0, -1)))[0, ..., 0]

# Extract images from the other generators
sample_images_healthy, _ = healthy_gen[10]
sample_image_healthy = sample_images_healthy[0, ..., 0]
reconstructed_image_healthy = model.predict(np.expand_dims(sample_image_healthy, axis=(0, -1)))[0, ..., 0]

sample_images_mci, _ = mci_gen[10]
sample_image_mci = sample_images_mci[0, ..., 0]
reconstructed_image_mci = model.predict(np.expand_dims(sample_image_mci, axis=(0, -1)))[0, ..., 0]

sample_images_ad, _ = ad_gen[10]
sample_image_ad = sample_images_ad[0, ..., 0]
reconstructed_image_ad = model.predict(np.expand_dims(sample_image_ad, axis=(0, -1)))[0, ..., 0]

images_dict = {
    'Test': (sample_image_test, reconstructed_image_test),
    'Healthy': (sample_image_healthy, reconstructed_image_healthy),
    'MCI': (sample_image_mci, reconstructed_image_mci),
    'AD': (sample_image_ad, reconstructed_image_ad),
}

multi_image_interactive_slice_viewer(images_dict)