In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import SimpleITK as sitk
from tqdm import tqdm
from pathlib import Path


import utils_registration 
import shutil

# Atlas Registration [Test Dataset]

In [4]:
# Define all necessary paths
base_path       = Path('../Lab_03/').resolve()
test_set_path   = base_path / 'data' / 'test-set'
params_path     = base_path / 'parameter' 
test_imgs_dir   = test_set_path / 'testing-images'
test_labels_dir = test_set_path / 'testing-labels'
test_masks_dir  = test_set_path / 'testing-masks'
output_path     = test_set_path / 'testing-outputs'

our_atlas_path  = base_path / 'data' / 'training-set' / 'atlases' / 'Par0009.affine'
#mni_atlas_path  = base_path / 'data' / 'MNITemplateAtlas'

# Define parameter maps to use
param_file_path =  params_path / 'Par0009.affine.txt'

WindowsPath('D:/Lab_03/data/training-set/atlases/Par0009.affine')

In [16]:
# Separating Labels for MNI atlas
#mni_atlas_template = sitk.ReadImage(str(mni_atlas_path / 'template.nii.gz'))
#mni_atlas_labels = sitk.ReadImage(str(mni_atlas_path / 'atlas.nii.gz')) # probability maps image 
#mni_atlas_labels_array = sitk.GetArrayFromImage(mni_atlas_labels) # probability maps array

#atlas_background = mni_atlas_labels_array[0, :, :, :]
#atlas_csf = mni_atlas_labels_array[1, :, :, :]
#atlas_gm = mni_atlas_labels_array[2, :, :, :]
#atlas_wm = mni_atlas_labels_array[3, :, :, :]
 
#utils.save_segementations(atlas_background, mni_atlas_template , str(mni_atlas_path/'p_atlas_background.nii.gz'))
#utils.save_segementations(atlas_csf, mni_atlas_template , str(mni_atlas_path/'p_atlas_csf.nii.gz'))
#utils.save_segementations(atlas_gm, mni_atlas_template , str(mni_atlas_path/'p_atlas_gm.nii.gz'))
#utils.save_segementations(atlas_wm, mni_atlas_template , str(mni_atlas_path/'p_atlas_wm.nii.gz'))

In [6]:
# Registration of the atlases to each test image

our_atlas_template_path = our_atlas_path / 'mean_volume.nii.gz'
#mni_atlas_template_path = mni_atlas_path / 'template.nii.gz'

atlas_path = {'our_atlas': our_atlas_path}
atlas_template_path = {'our_atlas': our_atlas_template_path}
atlas_map_names = ['p_atlas_background', 'p_atlas_csf', 'p_atlas_gm', 'p_atlas_wm']

# Read and modify parameters file
field_value_pairs = [('ResultImageFormat', 'nii.gz'), ('FinalBSplineInterpolationOrder', '0.0')]
utils_registration.modify_parameter(field_value_pairs, param_file_path)


for j, fixed_img_path in enumerate(test_imgs_dir.iterdir()):
    
    print(fixed_img_path)
    fix_name = fixed_img_path.name.rstrip('.nii.gz')
    mask_path = test_masks_dir / f'{fix_name}_1C.nii.gz'
    
    # For each atlas 
    for atlas_name in ['our_atlas']:
        result_path = output_path / atlas_name / fix_name
        result_path.mkdir(exist_ok=True, parents=True)

        res_img_path  = result_path / f'{atlas_name}_template.nii.gz'
        res_mask_path = result_path / f'{atlas_name}_1C.nii.gz'

        # Register
        transform_map_path = utils_registration.elastix(fixed_img_path, atlas_template_path[atlas_name], res_img_path, param_file_path)

        # Correct transformation parameters file
        field_value_pairs = [('ResultImageFormat', 'nii.gz'), ('FinalBSplineInterpolationOrder', '0.0')]
        utils_registration.modify_parameter(field_value_pairs, transform_map_path)
        
        # Transform brain_mask
        utils_registration.transformix(mask_path, res_mask_path, transform_map_path)  
    
        for label_name in atlas_map_names:
            res_lab_path = result_path / f'{atlas_name}_{label_name}.nii.gz'
            lab_path = str(atlas_path[atlas_name] / label_name)
            
            # Transform labels
            utils_registration.transformix(lab_path, res_lab_path, transform_map_path)

D:\Lab_03\data\test-set\testing-images\1003.nii.gz
D:\Lab_03\data\test-set\testing-images\1004.nii.gz
D:\Lab_03\data\test-set\testing-images\1005.nii.gz
D:\Lab_03\data\test-set\testing-images\1018.nii.gz
D:\Lab_03\data\test-set\testing-images\1019.nii.gz
D:\Lab_03\data\test-set\testing-images\1023.nii.gz
D:\Lab_03\data\test-set\testing-images\1024.nii.gz
D:\Lab_03\data\test-set\testing-images\1025.nii.gz
D:\Lab_03\data\test-set\testing-images\1038.nii.gz
D:\Lab_03\data\test-set\testing-images\1039.nii.gz
D:\Lab_03\data\test-set\testing-images\1101.nii.gz
D:\Lab_03\data\test-set\testing-images\1104.nii.gz
D:\Lab_03\data\test-set\testing-images\1107.nii.gz
D:\Lab_03\data\test-set\testing-images\1110.nii.gz
D:\Lab_03\data\test-set\testing-images\1113.nii.gz
D:\Lab_03\data\test-set\testing-images\1116.nii.gz
D:\Lab_03\data\test-set\testing-images\1119.nii.gz
D:\Lab_03\data\test-set\testing-images\1122.nii.gz
D:\Lab_03\data\test-set\testing-images\1125.nii.gz
D:\Lab_03\data\test-set\testing

# Get Tissue Model

In [7]:
base_path    = Path('../Lab_03/').resolve()
data_path    = base_path / 'data' /  'training-set'
pm_path      = data_path / 'training-outputs' / 'Par0009.affine'
img_path     = pm_path / 'training-images'
lab_path     = pm_path / 'training-labels'
bm_path      = pm_path / 'training-masks'
atlases_path = data_path / 'atlases' / 'Par0009.affine'

# Reference values
ref_img_path = img_path / 'r_1001.nii.gz'

t_atlas = sitk.GetArrayFromImage(sitk.ReadImage(str(atlases_path/'t_atlas.nii.gz')))
gt_labels = sitk.GetArrayFromImage(sitk.ReadImage(str(lab_path/'r_1001_3C.nii.gz')))

n_classes = 4
labels_keys = {0: 'background', 1: 'csf', 2: 'wm', 3: 'gm'}

t_histograms = np.zeros((4, 256))
gt_histograms = np.zeros((4, 256))

for i, img_filepath in tqdm(enumerate(img_path.iterdir()), total=len(list(img_path.iterdir()))):
    if img_filepath.name.endswith('txt'):
        continue
    mov_img_id = img_filepath.name.rstrip('.nii.gz')
    mov_lab_path = lab_path / f'{mov_img_id}_3C.nii.gz'
    mov_bm_path = bm_path / f'{mov_img_id}_1C.nii.gz'
    img_array = sitk.GetArrayFromImage(sitk.ReadImage(str(img_filepath)))
    img_labels = sitk.GetArrayFromImage(sitk.ReadImage(str(mov_lab_path)))
    img_bm = sitk.GetArrayFromImage(sitk.ReadImage(str(mov_bm_path)))
    img_array = utils.min_max_norm(img_array, 255, img_bm, 'uint8')
    t_atlas_temp = t_atlas[img_bm != 0].flatten()
    img_array = img_array[img_bm != 0].flatten()
    img_labels = img_labels[img_bm != 0].flatten()
    for c in range(n_classes):
        t_histograms[c, :] += np.histogram(img_array[t_atlas_temp == c], bins=256, range=[0, 256])[0]
        gt_histograms[c, :] += np.histogram(img_array[img_labels == c], bins=256, range=[0, 256])[0]

t_histograms_density = t_histograms / np.sum(t_histograms, axis=1)[:, None]
gt_histograms_density = gt_histograms / np.sum(gt_histograms, axis=1)[:, None]

t_histograms_density = t_histograms_density[1:, :]
gt_histograms_density = gt_histograms_density[1:, :]

t_sum = np.sum(t_histograms_density, axis=0)[None, :]
gt_sum = np.sum(gt_histograms_density, axis=0)[None, :]

t_histograms_posterior = t_histograms_density / t_sum
gt_histograms_posterior = gt_histograms_density / gt_sum
gt_histograms_posterior[:, 225:] = np.array([0,1,0])[:, None]

with open((data_path.parent/'tissue_models_3C_bm.pkl'), 'wb') as f:
    pkl.dump(gt_histograms_posterior, f)

RuntimeError: Exception thrown in SimpleITK ReadImage: C:\dafne\SimpleElastix\Code\IO\src\sitkImageReaderBase.cxx:99:
sitk::ERROR: The file "D:\Lab_03\data\training-set\atlases\t_atlas.nii.gz" does not exist.

In [None]:
# Registration of the atlases to each test image

# Atlas paths
our_atlas_template_path = our_atlas_path / 'mean_volume.nii.gz'
#mni_atlas_template_path = mni_atlas_path / 'template.nii.gz'

atlas_path = {'our_atlas': our_atlas_path, 'mni_atlas': mni_atlas_path}
atlas_template_path = {'our_atlas': our_atlas_template_path, 'mni_atlas': mni_atlas_template_path}
atlas_map_names = ['p_atlas_background', 'p_atlas_csf', 'p_atlas_gm', 'p_atlas_wm']

# Read and modify parameters file
field_value_pairs = [('ResultImageFormat', 'nii.gz'), ('FinalBSplineInterpolationOrder', '0.0')]
utils.modify_field_parameter_map(field_value_pairs, param_file_path)


for j, fixed_img_path in enumerate(test_imgs_dir.iterdir()):
    
    print(fixed_img_path)
    fix_name = fixed_img_path.name.rstrip('.nii.gz')
    
    mask_path = test_masks_dir / f'{fix_name}_1C.nii.gz'
    
    # For each atlas 
    for atlas_name in ['our_atlas', 'mni_atlas']:
        result_path = output_path / atlas_name / fix_name
        result_path.mkdir(exist_ok=True, parents=True)

        res_img_path = result_path / f'{atlas_name}_template.nii.gz'
        res_mask_path = result_path / f'{atlas_name}_1C.nii.gz'

        # Register
        transform_map_path = utils.elastix_wrapper(fixed_img_path, atlas_template_path[atlas_name], res_img_path, param_file_path)

        # Correct transformation parameters file
        field_value_pairs = [('ResultImageFormat', 'nii.gz'), ('FinalBSplineInterpolationOrder', '0.0')]
        utils.modify_field_parameter_map(field_value_pairs, transform_map_path)
        
        # Transform brain_mask
        utils.transformix_wrapper(mask_path, res_mask_path, transform_map_path)  
    
        for label_name in atlas_map_names:
            res_lab_path = result_path / f'{atlas_name}_{label_name}.nii.gz'
            lab_path = str(atlas_path[atlas_name] / label_name)
            
            # Transform labels
            utils.transformix_wrapper(lab_path, res_lab_path, transform_map_path)