In [1]:
import matplotlib.pyplot as plt
import os
import numpy as np
import matplotlib
import nibabel as nib

In [2]:
def find_highest_variance_time_points(img_data, num_points=5):
    """
    Identify time points with the highest variance in 4D PET scan data.

    Parameters:
    - img_data: 4D numpy array of the PET scan data.
    - num_points: Number of time points to select based on highest variance.

    Returns:
    - top_variance_time_points: Indices of time points with the highest variance.
    """
    variance_list = []
    for time_point in range(img_data.shape[-1]):
        time_point_data = img_data[..., time_point]
        variance_list.append(np.var(time_point_data))

    # Get indices of the top `num_points` variances
    top_variance_indices = np.argsort(variance_list)[-num_points:]
    # get lowest variance
    # top_variance_indices = np.argsort(variance_list)[:num_points]

    # Sort the indices to maintain the temporal order
    top_variance_time_points = sorted(top_variance_indices)
    
    # Optionally, print the variance values for the selected time points
    # for idx in top_variance_time_points:
    #     print(f"Time point {idx} has variance: {variance_list[idx]}")

    return top_variance_time_points

In [3]:
def guess_orientation_and_extract_all_slices(img_data):

    # Guess orientations based on dimensions
    dimensions = img_data.shape
    sorted_dims = np.argsort(dimensions)  # Ascending order of dimensions
    
    # Assuming the smallest dimension is Sagittal, next is Coronal, and largest is Axial
    if len(sorted_dims) == 3:
        sagittal_index, coronal_index, axial_index = sorted_dims
    else:
        sagittal_index, coronal_index, axial_index, fourth_index = sorted_dims
    
    # Initialize dictionaries to hold all slices for each orientation
    sagittal_slices = {}
    coronal_slices = {}
    axial_slices = {}
    
    # Extract all slices for each orientation
    for i in range(dimensions[sagittal_index]):
        sagittal_slices[i] = img_data.take(i, axis=sagittal_index)
    for i in range(dimensions[coronal_index]):
        coronal_slices[i] = img_data.take(i, axis=coronal_index)
    for i in range(dimensions[axial_index]):
        axial_slices[i] = img_data.take(i, axis=axial_index)
    
    return {'Sagittal': sagittal_slices, 'Coronal': coronal_slices, 'Axial': axial_slices}

In [4]:
def extract_all_slices_optimized(img_data):
    """
    Optimized function to extract all slices from a 3D or 4D PET scan data in Sagittal, Coronal, and Axial orientations.

    Parameters:
    - img_data: 3D or 4D numpy array of the PET scan data.

    Returns:
    A dictionary containing all slices for Sagittal, Coronal, and Axial orientations.
    """
    # print(img_data.shape)
    # Guess orientations based on dimensions (excluding time if 4D)
    spatial_dimensions = img_data.shape[:-1] if img_data.ndim == 4 else img_data.shape
    # print(spatial_dimensions)
    sorted_dims = np.argsort(spatial_dimensions)  # Ascending order of dimensions
    
    # Assuming the smallest dimension is Sagittal, next is Coronal, and largest is Axial
    if len(spatial_dimensions) == 3:
        sagittal_index, coronal_index, axial_index = sorted_dims

    orientations = {}

    useful_time_points = find_highest_variance_time_points(img_data, 1) 
    # print(f"Useful time points: {useful_time_points}")
    
    for time_point in useful_time_points:

    # time_point = img_data.shape[-1] - 1

        # Extract slices for each orientation
        if img_data.ndim == 3:  # For 3D data
            orientations['Sagittal'] = img_data.swapaxes(0, sagittal_index)
            orientations['Coronal'] = img_data.swapaxes(0, coronal_index)
            orientations['Axial'] = img_data.swapaxes(0, axial_index)
        elif img_data.ndim == 4:  # For 4D data, selecting the first time point for simplicity
            orientations[f'Sagittal_time_point_{time_point}'] = img_data[:,:,:,time_point].swapaxes(0, sagittal_index)
            orientations[f'Coronal_time_point_{time_point}'] = img_data[:,:,:,time_point].swapaxes(0, coronal_index)
            orientations[f'Axial_time_point{time_point}'] = img_data[:,:,:,time_point].swapaxes(0, axial_index)

    return orientations

In [5]:
def get_slice_range(scan, tracer):
    """
    Determine the ideal range of slices to analyze based on the tracer used.
    """
    num_slices = scan # scan.shape[2]

    # print(f"Number of slices: {num_slices}")
    
    if tracer.lower() == 'av45' or tracer.lower() == 'pib':
        # Middle to upper slices for cortical amyloid plaques
        start_slice = int(num_slices * 0.2) # 0.4
        end_slice = int(num_slices * 0.9) # 0.7
    elif tracer.lower() == 'fdg':
        # Broad range for hypometabolism in Alzheimer's
        start_slice = int(num_slices * 0.3)
        end_slice = int(num_slices * 0.8)
    else:
        raise ValueError("Unknown tracer. Please use 'AV45', 'PIB', or 'FDG'.")
        
    return start_slice, end_slice

In [6]:
def extract_and_select_slices(img_data, tracer):
    """
    Extract slices for Sagittal, Coronal, and Axial orientations and select a range based on the tracer type.
    """
    
    # Extract slices for each orientation
    orientations = extract_all_slices_optimized(img_data)

    selected_axial_slices_dictionary = {}
    
    # Assuming axial orientation is of interest
    for key, value in orientations.items():
        # print(key, value.shape)
        if 'Axial' in key:
            axial_slices = orientations[key]
            # print(axial_slices.shape)
            num_axial_slices = axial_slices.shape[-1] if img_data.ndim == 4 else axial_slices.shape[2]
            
            # Get start and end slice based on tracer
            start_slice, end_slice = get_slice_range(num_axial_slices, tracer)
            
            # Select the slice range for axial orientation
            selected_axial_slices = axial_slices[:, :, start_slice:end_slice+1]

            # save in selected_axial_slices_dictionary
            selected_axial_slices_dictionary[key] = selected_axial_slices
    
    return selected_axial_slices_dictionary

In [7]:
def display_selected_slices(selected_axial_slices, filename, key, output_base_directory):
    """
    Display the selected axial slices.
    """
    # print(selected_axial_slices.shape)
    num_slices = selected_axial_slices.shape[2]
    cols = 5  # Number of columns in the plot grid
    rows = num_slices // cols + (1 if num_slices % cols else 0)  # Calculate rows needed

    # fig, axs = plt.subplots(rows, cols, figsize=(20, 4 * rows))
    # axs = axs.flatten()

    for i in range(num_slices):
        plt.figure(figsize=(20, 15))
        plt.imshow(selected_axial_slices[:, :, i].T, cmap='hot', origin='lower')
        plt.axis('off')
        plt.title(f'{key} Slice {i}', fontsize=16)
        plt.tight_layout()
        # plt.show()
        plt.savefig(f"{output_base_directory}/Slice_{i}_{filename}_{key}.png")
        plt.close()

In [8]:
def return_tracer(file_path):
    """
    Return the tracer used in the PET scan based on the file name.
    """
    file_name = os.path.basename(file_path)
    
    if 'av45' in file_name.lower():
        return 'AV45'
    elif 'pib' in file_name.lower():
        return 'PIB'
    elif 'fdg' in file_name.lower():
        return 'FDG'
    else:
        raise ValueError("Unknown tracer. Please use 'AV45', 'PIB', or 'FDG'.")

In [9]:
# Example usage
root_directory_path = "/Users/izzymohamed/Desktop/MLPData/PET/OASIS3/AV45"  # Update this to the directory containing your .nii.gz files


original_directory_path = root_directory_path + "/Original2"
processed_directory_path = root_directory_path + "/Processed"

os.makedirs(processed_directory_path, exist_ok=True)


In [10]:
# %matplotlib inline
matplotlib.use('Agg')

completed_files_count = 0  # Initialize the counter

# get length of files that end with .nii.gz
scan_length = 0
for root, dirs, files in os.walk(original_directory_path):
    for file in files:
        if file.endswith(".nii.gz"):
            scan_length += 1

# Check if file is already processed by checking if the OAS30001 and the d0000 are already in the processed directory
# If it is already processed, then skip it
# for root, dirs, files in os.walk(original_directory_path):
#     for file in files:
#         if file.endswith(".nii.gz"):
#             patientid = file.split("_")[0]
#             sessionid = file.split("_")[2]
#             for root, dirs, files in os.walk(processed_directory_path):
#                 for file1 in files:
#                     if patientid in file1 and sessionid in file1:
#                         print(f"Already processed: {patientid}_{sessionid}")



for root, dirs, files in os.walk(original_directory_path):
    for file in files:
        files.sort()
        if file.endswith(".nii.gz"):
            nii_gz_file_path = os.path.join(root, file)
            img_data = nib.load(nii_gz_file_path).get_fdata()
            result = extract_and_select_slices(img_data, return_tracer(nii_gz_file_path))
            for key, value in result.items():
                # print(key, value.shape)
                display_selected_slices(value, file, key, processed_directory_path)
            
            completed_files_count += 1  # Increment the counter after processing each file
            print(f"Completed: {completed_files_count}/{scan_length} - {file}")

print(f"Total files processed: {completed_files_count}")
                

Completed: 1/300 - sub-OAS30895_ses-d0077_acq-AV45_pet.nii.gz
Completed: 2/300 - sub-OAS30775_ses-d2395_acq-AV45_pet.nii.gz
Completed: 3/300 - sub-OAS31031_ses-d4072_acq-AV45_pet.nii.gz
Completed: 4/300 - sub-OAS30959_ses-d3385_acq-AV45_pet.nii.gz
Completed: 5/300 - sub-OAS30943_ses-d0295_acq-AV45_pet.nii.gz
Completed: 6/300 - sub-OAS30746_ses-d0035_acq-AV45_pet.nii.gz
Completed: 7/300 - sub-OAS30872_ses-d3097_acq-AV45_pet.nii.gz
Completed: 8/300 - sub-OAS30561_ses-d0106_acq-AV45_pet.nii.gz
Completed: 9/300 - sub-OAS30728_ses-d0516_acq-AV45_pet.nii.gz
Completed: 10/300 - sub-OAS31013_ses-d0628_acq-AV45_pet.nii.gz
Completed: 11/300 - sub-OAS31018_ses-d1208_acq-AV45_pet.nii.gz
Completed: 12/300 - sub-OAS30673_ses-d3803_acq-AV45_pet.nii.gz
Completed: 13/300 - sub-OAS30775_ses-d3017_acq-AV45_pet.nii.gz
Completed: 14/300 - sub-OAS30606_ses-d2822_acq-AV45_pet.nii.gz
Completed: 15/300 - sub-OAS30584_ses-d0096_acq-AV45_pet.nii.gz
Completed: 16/300 - sub-OAS30680_ses-d6255_acq-AV45_pet.nii.gz
C

In [11]:
# print names of files in the processed directory 
# print in form of patientid_AV45_sessionid
list1 = []
for root, dirs, files in os.walk(processed_directory_path, topdown=True):
    # go through all files in the directory sorted
    
    for file in files:
        if file.endswith(".png") and "Slice_0" in file:
            patientid = file.split("_")[2].split("-")[1]
            sessionid = file.split("_")[3].split("-")[1]
            list1.append(f"{patientid}_AV45_{sessionid}")
            print(f"{patientid}_AV45_{sessionid}")

OAS30843_AV45_d0236
OAS30632_AV45_d1800
OAS30531_AV45_d1347
OAS30283_AV45_d2660
OAS30917_AV45_d1196
OAS30434_AV45_d0054
OAS30296_AV45_d0069
OAS30759_AV45_d1442
OAS30262_AV45_d839
OAS31376_AV45_d1213
OAS31213_AV45_d0114
OAS31020_AV45_d1475
OAS30367_AV45_d4337
OAS30139_AV45_d1702
OAS31267_AV45_d0090
OAS30206_AV45_d3024
OAS31096_AV45_d1308
OAS31307_AV45_d0113
OAS30026_AV45_d0696
OAS31148_AV45_d1290
OAS30635_AV45_d1533
OAS31019_AV45_d1370
OAS31392_AV45_d0144
OAS30293_AV45_d1221
OAS30038_AV45_d5769
OAS31114_AV45_d2658
OAS30173_AV45_d3841
OAS30324_AV45_d2433
OAS30062_AV45_d4447
OAS31073_AV45_d3760
OAS31293_AV45_d0084
OAS30926_AV45_d1520
OAS30748_AV45_d3268
OAS31048_AV45_d3195
OAS30764_AV45_d0055
OAS30776_AV45_d4024
OAS31473_AV45_d0136
OAS30005_AV45_d2384
OAS30643_AV45_d2334
OAS31224_AV45_d0116
OAS31071_AV45_d0068
OAS31386_AV45_d0202
OAS30568_AV45_d2326
OAS31300_AV45_d0212
OAS30024_AV45_d0084
OAS31353_AV45_d0117
OAS30767_AV45_d948
OAS30322_AV45_d5194
OAS31295_AV45_d0120
OAS30584_AV45_d0096
OA