In [None]:
import numpy as np
import nibabel as nb
import SimpleITK as sitk
import glob
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
from loguru import logger
import sys
import maximum_expectation_algorithm
import random

In [None]:
def print_shapes(data,name):
    """
    Print the shapes of all white matter (WM) tissue images.

    :param wm_data: List of white matter (WM) brain images.
    """
    for idx, image in enumerate(data):
        print(f"Shape of {name} image {idx + 1}: {image.shape}")

In [None]:

root_dir = '../Register_testing_volumes/mni_atlas_REGISTERED_AFFINE10/'

# Get all .nii.gz files recursively from the root directory
all_nii_files = glob.glob(os.path.join(root_dir, '100*', '*.nii'), recursive=True)

# Categorize files by type
file_categories = {
    'mni_atlas_1C': [],
    'mni_atlas_p_atlas_background': [],
    'mni_atlas_p_atlas_csf': [],
    'mni_atlas_p_atlas_gm': [],
    'mni_atlas_p_atlas_wm': [],
    'mni_atlas_template': []
    }

for file in all_nii_files:
    for category in file_categories.keys():
        if category in os.path.basename(file):
            file_categories[category].append(file)

# Load the files for each category (example shown for one category)
number_of_brains = 7
brain_layer = 127

# Check file paths (add this for debugging)
print(f"Found mni_atlas_1C files: {file_categories['mni_atlas_1C'][:number_of_brains]}")
print(' ')
# Load the files for each category using simpleITK
mni_atlas_p_atlas_background = [sitk.GetArrayFromImage(sitk.ReadImage(file)) for file in file_categories['mni_atlas_p_atlas_background'][:number_of_brains]]
mni_atlas_p_atlas_csf = [sitk.GetArrayFromImage(sitk.ReadImage(file)) for file in file_categories['mni_atlas_p_atlas_csf'][:number_of_brains]]
mni_atlas_p_atlas_gm = [sitk.GetArrayFromImage(sitk.ReadImage(file)) for file in file_categories['mni_atlas_p_atlas_gm'][:number_of_brains]]
mni_atlas_p_atlas_wm = [sitk.GetArrayFromImage(sitk.ReadImage(file)) for file in file_categories['mni_atlas_p_atlas_wm'][:number_of_brains]]
mni_atlas_p_atlas_mask = [sitk.GetArrayFromImage(sitk.ReadImage(file)) for file in file_categories['mni_atlas_p_atlas_background'][:number_of_brains]]

t1 = sitk.ReadImage(os.path.join('..', 'Register_testing_volumes','test','test','testing-images', '1005.nii.gz'))
t1_array = sitk.GetArrayFromImage(t1)
t1_mask = sitk.ReadImage(os.path.join('..', 'Register_testing_volumes','test','test','testing-labels', '1005_3C.nii.gz'))
t1_mask = sitk.GetArrayFromImage(t1_mask)

# Stack and process the data as needed
logger.success(f'Imported {len(mni_atlas_p_atlas_csf)} mni_atlas_p_atlas_csf images')
logger.success(f'Imported {len(mni_atlas_p_atlas_gm)} mni_atlas_p_atlas_csf images')
logger.success(f'Imported {len(mni_atlas_p_atlas_wm)} mni_atlas_p_atlas_csf images')
logger.success(f'Imported {len(mni_atlas_p_atlas_mask)} mni_atlas_p_atlas_mask images')

# For debuggig
print_shapes(mni_atlas_p_atlas_mask,'WM')


In [None]:
stacked_data = np.stack([mni_atlas_p_atlas_background[0],mni_atlas_p_atlas_csf[0],mni_atlas_p_atlas_gm[0],mni_atlas_p_atlas_wm[0]],axis=0)
print(stacked_data.shape)

In [None]:
plt.imshow((t1_mask!=0)[brain_layer,:,:],cmap='gray')
print(np.unique(t1_mask))

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import random

def plot_different_brains_rotated(csf_data, gm_data, wm_data, layer):
    """
    Plot specific brain layers from three different brains for each tissue type,
    with each image rotated by 90 degrees.
    
    :param csf_data: List of cerebrospinal fluid (CSF) brain images.
    :param gm_data: List of gray matter (GM) brain images.
    :param wm_data: List of white matter (WM) brain images.
    :param layer: Specific layer of the brain to plot.
    """
    # Randomly select three different brains for each tissue type
    brains_indices = random.sample(range(len(csf_data)), 3)

    # Plotting
    fig, axes = plt.subplots(3, 3, figsize=(15, 15))

    for i, idx in enumerate(brains_indices):
        axes[i, 0].imshow(np.rot90(csf_data[idx][layer, :, :]), cmap='gray')
        axes[i, 0].set_title(f'Brain {idx+1} CSF')
        axes[i, 0].axis('off')

        axes[i, 1].imshow(np.rot90(gm_data[idx][layer, :, :]), cmap='gray')
        axes[i, 1].set_title(f'Brain {idx+1} GM')
        axes[i, 1].axis('off')

        axes[i, 2].imshow(np.rot90(wm_data[idx][layer, :, :]), cmap='gray')
        axes[i, 2].set_title(f'Brain {idx+1} WM')
        axes[i, 2].axis('off')

    plt.show()

# Example usage of the function
plot_different_brains_rotated(mni_atlas_p_atlas_csf, mni_atlas_p_atlas_gm, mni_atlas_p_atlas_wm, brain_layer)


In [None]:
# Initialize an empty list to store the segmented images
probmaps = []

# Assuming mni_atlas_p_atlas_csf, mni_atlas_p_atlas_gm, mni_atlas_p_atlas_wm are 3D arrays
# Iterate over each slice/index
for i in range(len(mni_atlas_p_atlas_csf)):  # Adjust according to the number of slices
    # Stack the probability maps for this slice
    prob_map = np.stack((mni_atlas_p_atlas_csf[i], mni_atlas_p_atlas_gm[i], mni_atlas_p_atlas_wm[i]), axis=-1)


In [None]:
EM = maximum_expectation_algorithm.maximum_expectation_algorithm()

In [None]:
ll=stacked_data.reshape(-1,stacked_data.shape[0])
print(ll.shape)

In [None]:
tpm_mni = np.where(stacked_data > 1, 1, stacked_data)
tpm_mni = np.where(tpm_mni < 0, 0, tpm_mni)

t1 = EM.min_max_normalization(t1_array)

t1_segmentation, t1_segmentation_time = EM.tissue_segmentation(t1_array,type='label_propragation',label_pro=tpm_mni)

In [None]:
print(t1_segmentation.shape)
plt.imshow(t1_segmentation[brain_layer,:,:])
print(np.unique(t1_segmentation))
