In [1]:
import matplotlib.pyplot as plt
import os
import numpy as np
import matplotlib
import random


import nilearn as nl
import nilearn.plotting as nlplt
import nibabel as nib
from nilearn import image
from nilearn import plotting
from nilearn import datasets
from nilearn import surface
import h5py




In [2]:
def save_all_slices(nii_gz_file_path, output_directory, filename):
    """
    Save all slices from a .nii.gz file.

    Parameters:
    - nii_gz_file_path: The path to the .nii.gz file.
    - output_directory: The directory where the slice images will be saved.
    """
    img = nib.load(nii_gz_file_path)
    data = img.get_fdata()
    
    for i, slice in enumerate(data):
        if slice.ndim == 3:
            for j, sub_slice in enumerate(slice):
                plt.imshow(sub_slice, cmap="gray", origin="lower")
                plt.axis('off')
                
                slice_filename = f"slice_{i:03d}_{j:03d}.png"
                output_file = os.path.join(output_directory, slice_filename)
                plt.savefig(output_file, bbox_inches='tight', pad_inches=0)
                plt.close()
        else:
            plt.imshow(slice, cmap="gray", origin="lower")
            plt.axis('off')
            
            slice_filename = f"slice_{i:03d}_{filename}.png"
            output_file = os.path.join(output_directory, slice_filename)
            plt.savefig(output_file, bbox_inches='tight', pad_inches=0)
            plt.close()

In [3]:
def save_axial_slices(nii_gz_file_path, output_directory, fillename, num_slices=20):
    """
    Save axial slices from a .nii.gz file.

    Parameters:
    - nii_gz_file_path: The path to the .nii.gz file.
    - output_directory: The directory where the slice images will be saved.
    - num_slices: Number of slices to save around the middle axial slice.
    """
    img = nib.load(nii_gz_file_path)
    data = img.get_fdata()
    
    axial_middle = data.shape[2] // 2
    slices = range(axial_middle - num_slices // 2, axial_middle + num_slices // 2)
    
    for i, slice in enumerate(slices):
        slice_data = data[:, :, slice].T
        if slice_data.ndim == 3:
            for j, sub_slice in enumerate(slice_data):
                plt.imshow(sub_slice, cmap="gray", origin="lower")
                plt.axis('off')
                
                slice_filename = f"slice_{i:03d}_{j:03d}_{fillename}.png"
                output_file = os.path.join(output_directory, slice_filename)
                plt.savefig(output_file, bbox_inches='tight', pad_inches=0)
                plt.close()
        else:
            plt.imshow(slice_data, cmap="gray", origin="lower")
            plt.axis('off')
            
            slice_filename = f"slice_{i:03d}.png"
            output_file = os.path.join(output_directory, slice_filename)
            plt.savefig(output_file, bbox_inches='tight', pad_inches=0)
            plt.close()

In [4]:
def bulk_convert_nii_gz_to_images(input_directory, output_base_directory, num_slices=20):
    """
    Convert all .nii.gz files in the input directory into axial image slices and save them.

    Parameters:
    - input_directory: Directory containing .nii.gz files.
    - output_base_directory: Base directory where output images will be saved.
    - num_slices: Number of slices to save around the middle axial slice.
    """
    part = output_base_directory + "/Part"
    all = output_base_directory + "/All"
    for root, dirs, files in os.walk(input_directory):
        for file in files:
            if file.endswith(".nii.gz"):
                nii_gz_file_path = os.path.join(root, file)
                save_all_slices(nii_gz_file_path, all, os.path.splitext(os.path.splitext(file)[0])[0])
                save_axial_slices(nii_gz_file_path, part, os.path.splitext(os.path.splitext(file)[0])[0], 60)

In [5]:
# Example usage
input_directory = "/Users/izzymohamed/Desktop/MLPData/PET/OASIS3/Original"  # Update this to the directory containing your .nii.gz files
output_base_directory = "/Users/izzymohamed/Desktop/MLPData/PET/OASIS3/Preprocessed/Part1"  # Update this to your desired output directory for the image slices


In [6]:
# bulk_convert_nii_gz_to_images(input_directory, output_base_directory)


In [None]:
def analyze_time_points_greater_than_threshold(img_data, variance_threshold=100000):
    """
    Analyze time points in 4D PET scan data to find those with significant variance.
    
    Parameters:
    - img_data: 4D numpy array of the PET scan data.
    - variance_threshold: Threshold for considering a time point as significant based on variance.

    Returns:
    - useful_time_points: List of time points considered useful based on variance threshold.
    """
    useful_time_points = []
    for time_point in range(img_data.shape[-1]):
        time_point_data = img_data[..., time_point]
        time_point_variance = np.var(time_point_data)
        print(f"Time point {time_point} variance: {time_point_variance} if: {float(time_point_variance) > float(variance_threshold)}")
        
        if time_point_variance > variance_threshold:
            useful_time_points.append(time_point)
    
    return useful_time_points

In [None]:
def find_highest_variance_time_points(img_data, num_points=5):
    """
    Identify time points with the highest variance in 4D PET scan data.

    Parameters:
    - img_data: 4D numpy array of the PET scan data.
    - num_points: Number of time points to select based on highest variance.

    Returns:
    - top_variance_time_points: Indices of time points with the highest variance.
    """
    variance_list = []
    for time_point in range(img_data.shape[-1]):
        time_point_data = img_data[..., time_point]
        variance_list.append(np.var(time_point_data))

    # Get indices of the top `num_points` variances
    top_variance_indices = np.argsort(variance_list)[-num_points:]

    # Sort the indices to maintain the temporal order
    top_variance_time_points = sorted(top_variance_indices)
    
    # Optionally, print the variance values for the selected time points
    for idx in top_variance_time_points:
        print(f"Time point {idx} has variance: {variance_list[idx]}")

    return top_variance_time_points

In [7]:
def display_axial_slices(scan_data, start_slice, end_slice):
    """
    Display axial slices from a 4D PET scan within the specified range.
    
    Parameters:
    - scan_data: 4D numpy array with shape (X, Y, Z, time).
    - start_slice: Starting index of the axial slice range to be displayed.
    - end_slice: Ending index of the axial slice range to be displayed.
    - time_point: Time point from which the slices should be taken.
    """
    if scan_data.ndim != 4:
        raise ValueError("Scan data should be a 4D numpy array.")
    
    for time_point in range(scan_data.shape[-1]):
        # Select the volume at the specified time_point
        volume = scan_data[:, :, :, time_point]

        # Calculate the number of slices to display
        num_slices = end_slice - start_slice + 1
        # Determine the layout of subplots
        cols = int(np.ceil(np.sqrt(num_slices)))
        rows = int(np.ceil(num_slices / cols))

        plt.figure(figsize=(20,15))
        for i, slice_index in enumerate(range(start_slice, end_slice + 1), start=1):
            ax = plt.subplot(rows, cols, i)
            slice_img = volume[:, :, slice_index]
            
            ax.imshow(slice_img.T, cmap='hot', origin='lower')  # Transpose to correct orientation
            ax.axis('off')
            ax.set_title(f'Slice {slice_index}', fontsize=10)

        plt.suptitle(f'Axial Slices {start_slice} to {end_slice} at Time Point {time_point}', fontsize=16)
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        plt.show()

In [8]:
def extract_all_slices_optimized(img_data):
    """
    Optimized function to extract all slices from a 3D PET scan data in Sagittal, Coronal, and Axial orientations.

    Parameters:
    - img_data: 3D numpy array of the PET scan data.

    Returns:
    A dictionary containing all slices for Sagittal, Coronal, and Axial orientations.
    """
    # Ensure img_data is a 3D array
    # if img_data.ndim != 3:
    #     raise ValueError("img_data must be a 3D numpy array.")

    # Guess orientations based on dimensions
    dimensions = img_data.shape
    sorted_dims = np.argsort(dimensions)  # Ascending order of dimensions
    
    # Assuming the smallest dimension is Sagittal, next is Coronal, and largest is Axial
    if len(sorted_dims) == 3:
        sagittal_index, coronal_index, axial_index = sorted_dims
    else:
        sagittal_index, coronal_index, axial_index, fourth_index = sorted_dims

    # print(sagittal_index, coronal_index, axial_index, fourth_index)

    # Initialize dictionaries to hold slices
    orientations = {
        'Sagittal': [],
        'Coronal': [],
        'Axial': []
    }

    # Directly assign slices to the dictionary without looping
    orientations['Sagittal'] = np.rollaxis(img_data, 2, 0)  # Swap the axes to make the 3rd dimension (Sagittal) the first
    orientations['Coronal'] = np.rollaxis(img_data, 1, 0)   # Swap the axes to make the 2nd dimension (Coronal) the first
    orientations['Axial'] = img_data                        # Axial slices are already aligned with the third dimension

    return orientations

In [9]:
def extract_all_slices_optimized(img_data):
    """
    Optimized function to extract all slices from a 3D or 4D PET scan data in Sagittal, Coronal, and Axial orientations.

    Parameters:
    - img_data: 3D or 4D numpy array of the PET scan data.

    Returns:
    A dictionary containing all slices for Sagittal, Coronal, and Axial orientations.
    """
    # print(img_data.shape)
    # Guess orientations based on dimensions (excluding time if 4D)
    spatial_dimensions = img_data.shape[:-1] if img_data.ndim == 4 else img_data.shape
    # print(spatial_dimensions)
    sorted_dims = np.argsort(spatial_dimensions)  # Ascending order of dimensions
    
    # Assuming the smallest dimension is Sagittal, next is Coronal, and largest is Axial
    if len(spatial_dimensions) == 3:
        sagittal_index, coronal_index, axial_index = sorted_dims

    orientations = {}

    useful_time_points = find_highest_variance_time_points(img_data, 1) 
    print(f"Useful time points2: {useful_time_points}")
    
    for time_point in useful_time_points:

    # time_point = img_data.shape[-1] - 1

        # Extract slices for each orientation
        if img_data.ndim == 3:  # For 3D data
            orientations['Sagittal'] = img_data.swapaxes(0, sagittal_index)
            orientations['Coronal'] = img_data.swapaxes(0, coronal_index)
            orientations['Axial'] = img_data.swapaxes(0, axial_index)
        elif img_data.ndim == 4:  # For 4D data, selecting the first time point for simplicity
            orientations[f'Sagittal_time_point_{time_point}'] = img_data[:,:,:,time_point].swapaxes(0, sagittal_index)
            orientations[f'Coronal_time_point_{time_point}'] = img_data[:,:,:,time_point].swapaxes(0, coronal_index)
            orientations[f'Axial_time_point{time_point}'] = img_data[:,:,:,time_point].swapaxes(0, axial_index)

    return orientations

In [11]:
def get_slice_range(scan, tracer):
    """
    Determine the ideal range of slices to analyze based on the tracer used.
    """
    num_slices = scan # scan.shape[2]

    print(f"Number of slices: {num_slices}")
    
    if tracer.lower() == 'av45' or tracer.lower() == 'pib':
        # Middle to upper slices for cortical amyloid plaques
        start_slice = int(num_slices * 0.4)
        end_slice = int(num_slices * 0.7)
    elif tracer.lower() == 'fdg':
        # Broad range for hypometabolism in Alzheimer's
        start_slice = int(num_slices * 0.3)
        end_slice = int(num_slices * 0.8)
    else:
        raise ValueError("Unknown tracer. Please use 'AV45', 'PIB', or 'FDG'.")
        
    return start_slice, end_slice

In [12]:
def extract_and_select_slices(img_data, tracer):
    """
    Extract slices for Sagittal, Coronal, and Axial orientations and select a range based on the tracer type.
    """
    
    # Extract slices for each orientation
    orientations = extract_all_slices_optimized(img_data)

    selected_axial_slices_dictionary = {}
    
    # Assuming axial orientation is of interest
    for key, value in orientations.items():
        print(key, value.shape)
        if 'Axial' in key:
            axial_slices = orientations[key]
            print(axial_slices.shape)
            num_axial_slices = axial_slices.shape[-1] if img_data.ndim == 4 else axial_slices.shape[2]
            
            # Get start and end slice based on tracer
            start_slice, end_slice = get_slice_range(num_axial_slices, tracer)
            
            # Select the slice range for axial orientation
            selected_axial_slices = axial_slices[:, :, start_slice:end_slice+1]

            # save in selected_axial_slices_dictionary
            selected_axial_slices_dictionary[key] = selected_axial_slices
    
    return selected_axial_slices_dictionary

In [16]:
def display_selected_slices(selected_axial_slices, filename, key, output_base_directory):
    """
    Display the selected axial slices.
    """
    print(selected_axial_slices.shape)
    num_slices = selected_axial_slices.shape[2]
    cols = 5  # Number of columns in the plot grid
    rows = num_slices // cols + (1 if num_slices % cols else 0)  # Calculate rows needed

    # fig, axs = plt.subplots(rows, cols, figsize=(20, 4 * rows))
    # axs = axs.flatten()

    for i in range(num_slices):
        plt.figure(figsize=(20, 15))
        plt.imshow(selected_axial_slices[:, :, i].T, cmap='hot', origin='lower')
        plt.axis('off')
        plt.title(f'{key} Slice {i}', fontsize=16)
        plt.tight_layout()
        # plt.show()
        plt.savefig(f"{output_base_directory}/Slice_{i}{filename}_{key}_.png")
        plt.close()

In [None]:
def return_tracer(file_path):
    """
    Return the tracer used in the PET scan based on the file name.
    """
    file_name = os.path.basename(file_path)
    
    if 'av45' in file_name.lower():
        return 'AV45'
    elif 'pib' in file_name.lower():
        return 'PIB'
    elif 'fdg' in file_name.lower():
        return 'FDG'
    else:
        raise ValueError("Unknown tracer. Please use 'AV45', 'PIB', or 'FDG'.")

In [14]:
# %matplotlib inline
matplotlib.use('Agg')

for root, dirs, files in os.walk(output_base_directory):
    for file in files:
        if file.endswith(".nii.gz"):
            nii_gz_file_path = os.path.join(root, file)
            img_data = nib.load(nii_gz_file_path).get_fdata()
            result = extract_and_select_slices(img_data, return_tracer(nii_gz_file_path))
            for key, value in result.items():
                print(key, value.shape)
                display_selected_slices(value, file, key, output_base_directory)
                

(256, 256, 127, 4)
(256, 256, 127, 4)
(256, 256, 127)
Sagittal_0 (127, 256, 256)
Coronal_0 (256, 256, 127)
Axial_0 (256, 256, 127)
(256, 256, 127)
Number of slices: 127
Sagittal_1 (127, 256, 256)
Coronal_1 (256, 256, 127)
Axial_1 (256, 256, 127)
(256, 256, 127)
Number of slices: 127
Sagittal_2 (127, 256, 256)
Coronal_2 (256, 256, 127)
Axial_2 (256, 256, 127)
(256, 256, 127)
Number of slices: 127
Sagittal_3 (127, 256, 256)
Coronal_3 (256, 256, 127)
Axial_3 (256, 256, 127)
(256, 256, 127)
Number of slices: 127


dict_keys(['Axial_0', 'Axial_1', 'Axial_2', 'Axial_3'])

In [None]:
# def save_single_slice(scan_data, start_slice, end_slice, filename, slices, output_directory):
#     """
#     Saves slices of MRI data as PNG images in specific orientation folders.

#     Parameters:
#     - file_name: Base name for the output files.
#     - slices: Dictionary of slices for each orientation.
#     - output_directory: The base directory to save the images.
#     - label: Label for the subdirectory structure.
#     """

#     # Ensure the save directory exists
#     os.makedirs(output_directory, exist_ok=True)

#     # Iterate over each orientation to save the slices
#     for orientation, slice_list in slices.items():
#         if orientation == 'Axial':
#             orientation_path = os.path.join(output_directory, 'Axial')
#             # Iterate through the specified range of slices and save each as an image
#             for time_point in range(slice_list[0].shape[-1]):
#                 volume = scan_data[:, :, :, time_point]
#                 for slice_index in range(start_slice, end_slice):
#                     if slice_index in slice_list:
#                         slice_img = slice_list[slice_index][:, :, time_point]
#                         slice_filename = f"Slice_{slice_index:03d}_{filename}_Axial_{time_point:03d}.png"
#                         save_axial_slices(slice_img, orientation_path, slice_filename)
                        
#                     else:
#                         print(f"Slice index {slice_index} not found in {orientation} slices. Length: {len(slice_list)}")


In [None]:
# def save_axial_slices(scan_data,filename, save_dir, volume, slice_index, time_point):
#     """
#     Save axial slices from a 4D PET scan within the specified range to a directory.
    
#     Parameters:
#     - scan_data: 4D numpy array with shape (X, Y, Z, time).
#     - start_slice: Starting index of the axial slice range to be saved.
#     - end_slice: Ending index of the axial slice range to be saved.
#     - save_dir: Directory where the slice images will be saved.
#     - time_point: Time point from which the slices should be taken.
#     """
#     if scan_data.ndim != 4:
#         raise ValueError("Scan data should be a 4D numpy array.")
    
    
#     # Generate the file path
#     img_filename = f"slice_{slice_index}_{filename}_time_{time_point}.png"
#     print(img_filename)
#     img_file_path = os.path.join(save_dir, img_filename)
#     print(img_file_path)
    
#     # Extract the specific slice
#     slice_img = volume[:, :, slice_index]
    
#     # Save the slice image
#     plt.figure(figsize=(10, 10))
#     plt.imshow(slice_img.T, cmap='hot', origin='lower')  # Transpose to correct orientation
#     plt.axis('off')
#     plt.title(f'Slice {slice_index}', fontsize=10)
#     plt.savefig(img_file_path)
#     plt.close()  # Close the figure to free memory

#         # print(f"Saved slices {start_slice} to {end_slice} at time point {time_point} to {save_dir}")


In [None]:
def load_pet_scan(file_path):
    """
    Load a PET scan using nibabel with memory mapping.
    """
    return nib.load(file_path, mmap=True)

In [None]:
# Create function that will go through all input_directory files and return the file path, loaded img, loaded, data and tracer
# def get_pet_scans(input_directory):
#     """
#     Get all PET scans from the input directory.
#     """
#     pet_scans = []
    
#     for root, dirs, files in os.walk(input_directory):
#         for file in files:
#             if file.endswith(".nii.gz"):
#                 file_path = os.path.join(root, file)
#                 img, data = load_pet_scan(file_path)
#                 # tracer = return_tracer(file_path)
                
#                 pet_scans.append((file_path, img, data, "AV45"))
    
#     return pet_scans

In [None]:
def get_pet_scans_generator(input_directory):
    """
    A generator that yields PET scans from the input directory to reduce memory usage.
    """
    for root, dirs, files in os.walk(input_directory):
        for file in files:
            if file.endswith(".nii.gz"):
                file_path = os.path.join(root, file)
                img = load_pet_scan(file_path)
                tracer = return_tracer(file_path)
                start_slice, end_slice = get_slice_range(img, tracer)
                yield file, file_path, img, tracer, start_slice, end_slice

In [None]:
pet_scans_generator = get_pet_scans_generator(input_directory)

In [None]:
# get len of nii.gz files in input_directory
total_nii_files = len([name for name in os.listdir(input_directory) if os.path.isfile(os.path.join(input_directory, name))])

count = 0

for file, file_path, img, tracer, start_slice, end_slice in pet_scans_generator:
    # print(f"Saving slices {start_slice} to {end_slice} from {file_path}")

    data = img.get_fdata()

    orientations = extract_all_slices_optimized(img_data)

    axial_data = orientations['Axial']

    start_slice, end_slice = get_slice_range(data, tracer)

    # Display slices within the specified range for the first time point
    # display_axial_slices(img, start_slice, end_slice)

    # Save slices within the specified range for the first time point
    # save_axial_slices(data, start_slice, end_slice, file.split(".")[0], output_base_directory)

    count += 1

    # Display completed scans from input_directory
    print(f"Completed {count}/{total_nii_files} scans - {file}")
