One-paragraph abstract (150 words or less, in a Word file)  
Digital headshot (high-resolution JPG preferred)  
Letter of tax determination from your institution  
W-9 Form from your institution 


# Imports

In [19]:
# imports

import os, sys
import numpy as np
import SimpleITK as sitk

from helpers_general import sitk2np, np2sitk, print_sitk_info, round_tuple, lrange, get_roi_range, numbers2groups
from helpers_preprocess import mask2bbox, print_bbox, get_bbox_size, print_bbox_size, get_data_dict, folder2objs
from helpers_metrics import compute_dice_coefficient, compute_coverage_coefficient
from helpers_viz import viz_axis

In [14]:
# auto-reload when local helper fns change
%load_ext autoreload
%autoreload 2

%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Load Data

In [3]:
PROJ_PATH = "."

# Folders containing MR train data
train_path = f"{PROJ_PATH}/train_data/train_data"
train_data_dict = get_data_dict(train_path)

# print train data dict
print(f"Train data folders: {numbers2groups(sorted([int(x) for x in os.listdir(train_path)]))}")
print(f"Training data: key = train folder, value = full path to (segm obj, nii file)\n")

Train data folders: [range(50002, 50017), range(50019, 50020), range(50022, 50049), range(50455, 50464)]
Training data: key = train folder, value = full path to (segm obj, nii file)



### Load Atlas


In [16]:
# set atlas MRs (10 MRs labelled by Dr. Hollon, need nii LPS=>RAS adjustment)

atlas_range, ras_adj = [50460], True
atlas_folders = [str(i) for i in atlas_range]
atlas_objs, atlas_mask_objs = zip(*[folder2objs(atlas_folder, train_data_dict, ras_adj) \
                               for atlas_folder in atlas_folders])

### Load Input

Samir folders: range(50002, 50017), range(50019, 50020), range(50022, 50049)

In [17]:
# set input MRs (Samir's, no RAS adj needed for mask)

input_range, ras_adj = [50012], False
input_folders = [str(i) for i in input_range]
input_objs, input_mask_objs = zip(*[folder2objs(input_folder, train_data_dict, ras_adj) \
                               for input_folder in input_folders])

### Compare metadata

In [26]:
print("Input MR"); print_sitk_info(input_objs[0]), print();
print("Atlas MR"); print_sitk_info(atlas_objs[0]);

Input MR
Size:  (176, 256, 256)
Origin:  (-87.51664733886719, 132.53253173828125, -127.22270202636719)
Spacing:  (1.0500000715255737, 1.05078125, 1.05078125)
Direction:  (1.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0, 0.0, 1.0)
Pixel type: 2 = 16-bit signed integer

Atlas MR
Size:  (160, 480, 512)
Origin:  (-95.70238494873047, 77.71624755859375, -118.06993103027344)
Spacing:  (1.2000000476837158, 0.5, 0.5)
Direction:  (1.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0, 0.0, 1.0)
Pixel type: 2 = 16-bit signed integer


### Resample to Standard Reference Domain

In [57]:
# src: https://github.com/SimpleITK/ISBI2018_TUTORIAL/blob/master/python/03_data_augmentation.ipynb

dimension = 3 # 3D MRs
pixel_id = 2 # 16-bit signed integer

# Physical image size corresponds to the largest physical size in the training set, or any other arbitrary size.
reference_physical_size = np.zeros(dimension)

img_data = [(o.GetSize(), o.GetSpacing()) for o in (atlas_objs[0], input_objs[0])]
for img_sz, img_spc in img_data:
    reference_physical_size[:] = [(sz-1)*spc if sz*spc>mx else mx \
                                  for sz, spc, mx in zip(img_sz, img_spc, reference_physical_size)]
    
# Create the reference image with a zero origin, identity direction cosine matrix and dimension     
reference_origin = np.zeros(dimension)
reference_direction = np.identity(dimension).flatten()


# Isotropic (1,1,1) pixels
reference_spacing = np.ones(dimension)
reference_size = [int(phys_sz/(spc) + 1) for phys_sz,spc in zip(reference_physical_size, reference_spacing)]

# Set reference image attributes
reference_image = sitk.Image(reference_size, pixel_id)
reference_image.SetOrigin(reference_origin)
reference_image.SetSpacing(reference_spacing)
reference_image.SetDirection(reference_direction)

In [60]:
def resample(arr, interpolator = sitk.sitkLinear, default_intensity_value = 0):
    return [sitk.Resample(img, reference_image, sitk.Transform(), interpolator, default_intensity_value) for img in arr]

atlas_objs, atlas_mask_objs = resample(atlas_objs), resample(atlas_mask_objs)
input_objs, input_mask_objs = resample(input_objs), resample(input_mask_objs)

In [62]:
print("Input"); print_sitk_info(input_objs[0]); print();
print("Atlas"); print_sitk_info(atlas_objs[0]);

Input
Size:  (191, 268, 268)
Origin:  (0.0, 0.0, 0.0)
Spacing:  (1.0, 1.0, 1.0)
Direction:  (1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)
Pixel type: 2 = 16-bit signed integer

Atlas
Size:  (191, 268, 268)
Origin:  (0.0, 0.0, 0.0)
Spacing:  (1.0, 1.0, 1.0)
Direction:  (1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)
Pixel type: 2 = 16-bit signed integer


# Crop to same shape

- Input: 185 x 269 x 269
- Atlas: 192 x 240 x 256

In sagittal view, the last two dims determine 
- how much of the right side of the image is retained and 
- how much the top of the image is retained.

IMPORTANT:
- Cropping the input to 185 x 240 x 256 retains important structures. 
- Also does not affect bbox coordinates

In [None]:
print_info(input_objs[0])
print_info(atlas_objs[0])

In [None]:
atlas_size = atlas_objs[0].GetSize()
input_size = input_objs[0].GetSize()
print(atlas_size, input_size)

In [None]:
shape0, shape1, shape2 = (min(atlas_d, input_d) for atlas_d, input_d in zip(atlas_size, input_size))
print(shape0, shape1, shape2)

In [None]:
window = np.s_[0:shape0, 0:shape1, -shape2:]

In [None]:
atlas_objs      = [o[window]  for o in atlas_objs]
atlas_mask_objs = [o[window]  for o in atlas_mask_objs]

input_objs      = [o[window]  for o in input_objs]
input_mask_objs = [o[window]  for o in input_mask_objs]

In [None]:
print_info(input_objs[0])
print_info(atlas_objs[0])

In [None]:
atlas_arrs      = [sitk2np(o) for o in atlas_objs]
atlas_mask_arrs = [sitk2np(o) for o in atlas_mask_objs]

input_arrs      = [sitk2np(o) for o in input_objs]
input_mask_arrs = [sitk2np(o) for o in input_mask_objs]

In [None]:
input_arrs[0].shape, atlas_arrs[0].shape

In [None]:
# Viz
index = 0
slice_range = lrange(80,85)

viz_axis(input_arrs[index], \
        bin_mask_arr  = input_mask_arrs[index], color1 = "yellow", alpha1=0.3, \
        bin_mask_arr2 = atlas_mask_arrs[index], color2 = "blue", alpha2=0.3,
        slices=slice_range, fixed_axis=0, \
        axis_fn = np.rot90, \
        grid = [1, 5], hspace=0.3, fig_mult=2)

Original dice

In [None]:
orig_dice = compute_dice_coefficient(input_mask_arrs[index], atlas_mask_arrs[index])
print(f"Unaligned dice: {orig_dice:.3f}")

Align

In [None]:
# set moving and fixed images (resample moving=>fixed using T:fixed=>moving)
fixed_obj = input_objs[0]
moving_obj = atlas_objs[0]

In [None]:
elastixImageFilter = sitk.ElastixImageFilter()
elastixImageFilter.SetFixedImage(fixed_obj)
elastixImageFilter.SetMovingImage(moving_obj)

# set parameter map
param_folder = "ElastixParamFiles"
param_files = ["affine.txt"]

parameterMapVector = sitk.VectorOfParameterMap()
for param_file in param_files:
    parameterMapVector.append(sitk.ReadParameterFile(f"{param_folder}/{param_file}"))
elastixImageFilter.SetParameterMap(parameterMapVector)

In [None]:
# Execute alignment
#elastixImageFilter.SetLogToConsole(False)
elastixImageFilter.Execute()

In [None]:
res_img = elastixImageFilter.GetResultImage()

In [None]:
res_img = sitk.Cast(res_img, sitk.sitkInt16)

Resample same physiscal space

In [None]:
res_img.GetSize(), moving_obj.GetSize(), fixed_obj.GetSize()

In [None]:
def print_info(image):    
    print("Size: ", image.GetSize())
    print("Origin: ", image.GetOrigin())
    print("Spacing: ", image.GetSpacing())
    print("Direction: ", image.GetDirection())
    print(f"Pixel type: {image.GetPixelIDValue()} = {image.GetPixelIDTypeAsString()}")

In [None]:
#print_info(res_img), print_info(moving_obj), print_info(fixed_obj)

print("Input"), print_info(input_objs[0])
print("Atlas"), print_info(atlas_objs[0])
print("Res"), print_info(res_img)

In [None]:
simg1 = fixed_obj
simg2 = res_img
cimg = sitk.Cast(sitk.Compose(simg1, 0.5*simg1+0.5*simg2, simg2), sitk.sitkVectorUInt8)

In [None]:
carr = sitk2np(cimg)

In [None]:
carr.shape

In [None]:
import matplotlib.pyplot as plt
plt.imshow(np.rot90(carr[80,:,:,:]))

In [None]:
# original

orig_res_img = sitk.Resample(moving_obj, fixed_obj.GetSize(), sitk.Transform(), sitk.sitkLinear,
                         fixed_obj.GetOrigin(), fixed_obj.GetSpacing(), fixed_obj.GetDirection(), 0,
                         fixed_obj.GetPixelID())

In [None]:
osimg2 = orig_res_img
ocimg = sitk.Cast(sitk.Compose(simg1, 0.5*simg1+0.5*osimg2, osimg2), sitk.sitkVectorUInt8)

In [None]:
ocarr = sitk2np(ocimg)

In [None]:
plt.imshow(np.rot90(ocarr[80,:,:,:]))

Get new dice

In [None]:
moving_mask_obj = atlas_mask_objs[0]

In [None]:
# MAP MOVING (ATLAS BINARY ROI) ONTO FIXED (INPUT) 

# set moving image (atlas)                                                    
transformixImageFilter = sitk.TransformixImageFilter()
transformixImageFilter.SetMovingImage(moving_mask_obj)

# set parameter map (Binary mask => nearest neighbor final interpolation)
transformedParameterMapVector = elastixImageFilter.GetTransformParameterMap()
transformedParameterMapVector[-1]["ResampleInterpolator"] = ["FinalNearestNeighborInterpolator"]
transformixImageFilter.SetTransformParameterMap(transformedParameterMapVector)

# Execute transformation
transformixImageFilter.Execute()

pred_mask_obj = transformixImageFilter.GetResultImage()

In [None]:
pred_dice = compute_dice_coefficient(sitk2np(pred_mask_obj).astype(bool), input_mask_arrs[0].astype(bool))
print(f"Pred dice: {pred_dice:.3f}")

In [None]:
# Viz
index = 0
slice_range = lrange(80,85)

viz_axis(cimg, \
#         bin_mask_arr  = input_mask_arrs[index], color1 = "yellow", alpha1=0.3, \
#         bin_mask_arr2 = atlas_mask_arrs[index], color2 = "blue", alpha2=0.3,
        slices=slice_range, fixed_axis=0, \
        axis_fn = np.rot90, \
        grid = [1, 5], hspace=0.3, fig_mult=2)

In [None]:
im
plt.imshow(sitk.GetArrayViewFromImage(cimg))

In [None]:
simg1 = moving_obj
simg2 = fixed_obj
cimg = sitk.Compose(simg1, simg2, simg1 // 2. + simg2 // 2.)

In [None]:
simg1 = res_img
simg2 = fixed_obj
cimg = sitk.Compose(simg1, simg2, simg1 // 2. + simg2 // 2.)

In [None]:
#selx.SetLogToConsole(False)

def align(fixed_obj, moving_obj, param_folder = "ElastixParamFiles", param_files = ["affine.txt"]):
    
    # ALIGN ATLAS AND INPUT IMAGE
    
    # set moving and fixed images (resample moving=>fixed using T:fixed=>moving)
    elastixImageFilter = sitk.ElastixImageFilter()
    elastixImageFilter.SetFixedImage(fixed_obj)
    elastixImageFilter.SetMovingImage(moving_obj)
    
    # set parameter map
    parameterMapVector = sitk.VectorOfParameterMap()
    for param_file in param_files:
        parameterMapVector.append(sitk.ReadParameterFile(f"{param_folder}/{param_file}"))
    elastixImageFilter.SetParameterMap(parameterMapVector)

    # Execute alignment
    #elastixImageFilter.SetLogToConsole(False)
    elastixImageFilter.Execute()

    # MAP MOVING (ATLAS BINARY ROI) ONTO FIXED (INPUT) 

    # set moving image (atlas)                                                    
    transformixImageFilter = sitk.TransformixImageFilter()
    transformixImageFilter.SetMovingImage(moving_mask_obj)
                    
    # set parameter map (Binary mask => nearest neighbor final interpolation)
    transformedParameterMapVector = elastixImageFilter.GetTransformParameterMap()
    transformedParameterMapVector[-1]["ResampleInterpolator"] = ["FinalNearestNeighborInterpolator"]
    transformixImageFilter.SetTransformParameterMap(transformedParameterMapVector)

    # Execute transformation
    #transformixImageFilter.SetLogToConsole(False)
    transformixImageFilter.Execute()
    
    return transformixImageFilter.GetResultImage()

# Elastix Registration

## Rigid Alignment

- Rigid: "rigid body, which can translate and rotate, but cannot be
scaled/stretched."

- Similarity: "translate, rotate, and scale isotropically."

- Affine: "translated, rotated, scaled,
and sheared."

### ROI Mapping
Elastix convention: Resampling $moving \to fixed$ image involves a transformation $T: fixed \to moving$. $T$ maps coordinates in the fixed image domain to the corresponding coordinates in the moving image. Resampling a moving image onto the fixed image coordinate system involves:
1. Apply $T$ to fixed image voxel coordinates $x$ to get corresponding coordinates $y$ in the moving domain: $y = T(x) \in I_M$.
2. Estimate the voxel intensities $v$ at the moving image coordinates $y \in I_m$ via (linear) interpolation from nearby moving image voxel intensities.
3. Set the voxel intensities $v$ at the fixed image coordinates $x \in I_F$ to the above moving image voxel intensities.

Source: 5.0.1 Elastix Manual

In [None]:
#selx.SetLogToConsole(False)

def align_and_tfm(fixed_obj, moving_obj, moving_mask_obj, \
                  param_folder = "ElastixParamFiles", param_files = ["affine.txt", "bspline.txt"]):
    
    # ALIGN ATLAS AND INPUT IMAGE
    
    # set moving and fixed images (resample moving=>fixed using T:fixed=>moving)
    elastixImageFilter = sitk.ElastixImageFilter()
    elastixImageFilter.SetFixedImage(fixed_obj)
    elastixImageFilter.SetMovingImage(moving_obj)
    
    # set parameter map
    parameterMapVector = sitk.VectorOfParameterMap()
    for param_file in param_files:
        parameterMapVector.append(sitk.ReadParameterFile(f"{param_folder}/{param_file}"))
    elastixImageFilter.SetParameterMap(parameterMapVector)

    # Execute alignment
    #elastixImageFilter.SetLogToConsole(False)
    elastixImageFilter.Execute()

    # MAP MOVING (ATLAS BINARY ROI) ONTO FIXED (INPUT) 

    # set moving image (atlas)                                                    
    transformixImageFilter = sitk.TransformixImageFilter()
    transformixImageFilter.SetMovingImage(moving_mask_obj)
                    
    # set parameter map (Binary mask => nearest neighbor final interpolation)
    transformedParameterMapVector = elastixImageFilter.GetTransformParameterMap()
    transformedParameterMapVector[-1]["ResampleInterpolator"] = ["FinalNearestNeighborInterpolator"]
    transformixImageFilter.SetTransformParameterMap(transformedParameterMapVector)

    # Execute transformation
    #transformixImageFilter.SetLogToConsole(False)
    transformixImageFilter.Execute()
    
    return transformixImageFilter.GetResultImage()

# Rigid only

In [None]:
import time

def get_dice_scores(input_obj, input_mask_arr, atlas_objs = atlas_objs, atlas_mask_objs = atlas_mask_objs):
    dice_scores = {}
    n_votes = len(atlas_objs)
    print(f"N={n_votes} atlases in the vote.")
    
    for i in range(n_votes):
        print(f"Getting pred for atlas {i}:")
        
        start = time.time()
        pred_mask_arr = sitk2np(align_and_tfm(input_obj, atlas_objs[i], atlas_mask_objs[i], \
                                              param_files = ["affine.txt"])).astype(bool)
        end = time.time()
        print(f"{end - start:.0f} sec.")
        # compute dice
        dice_scores[f"dice{i}"] = compute_dice_coefficient(input_mask_arr, pred_mask_arr)
    
        # add atlas vote
        if i == 0: vote_pred_mask_arr = pred_mask_arr.astype(np.uint8)
        else: vote_pred_mask_arr += pred_mask_arr.astype(np.uint8)
    
    # get avg dice score
    dice_scores[f"dice_avg"] = np.mean(list(dice_scores.values()))
    
    # get vote
    print(f"Getting vote dice")
    vote_pred_mask_arr = (vote_pred_mask_arr >= n_votes/2).astype(bool)
    dice_scores[f"dice_vote_all"] = compute_dice_coefficient(input_mask_arr, vote_pred_mask_arr)
    
    print(dice_scores)
    return dice_scores

In [None]:
dice_scores0 = get_dice_scores(input_objs[0], input_mask_arrs[0], atlas_objs[:2], atlas_mask_objs[:2])

In [None]:
dice_scores0

In [None]:
dice_scores0 = get_dice_scores(input_objs[0], input_mask_arrs[0], atlas_objs, atlas_mask_objs)

In [None]:
dice_scores0

In [None]:
input_df = DataFrame(get_dice_scores(input_obj, gt_mask, atlas_objs, atlas_mask_objs) \
                     for input_obj, gt_mask in zip(input_objs, input_mask_arrs))

In [None]:
input_df

In [None]:
input_info_df

In [None]:
atlas_info_df

In [None]:
# Viz
index = 5
slice_range = lrange(110,130)

viz_axis(sitk2np(atlas_objs[index]), \
        bin_mask_arr = atlas_mask_arrs[index], color1 = "yellow", alpha1=0.3,
        slices=slice_range, fixed_axis=2, \
        axis_fn = np.rot90, \
        grid = [4,5], hspace=0.3, fig_mult=2)

In [None]:
# Viz
index = 10
slice_range = lrange(80,100)

viz_axis(sitk2np(input_objs[index]), \
        bin_mask_arr = input_mask_arrs[index], color1 = "yellow", alpha1=0.3,
        slices=slice_range, fixed_axis=0, \
        axis_fn = np.rot90, \
        grid = [4, 5], hspace=0.3, fig_mult=2)

In [None]:
# Viz
index = 12
slice_range = lrange(80,90)

viz_axis(sitk2np(input_objs[index]), \
        bin_mask_arr = input_mask_arrs[index], color1 = "yellow", alpha1=0.3,
        slices=slice_range, fixed_axis=0, \
        axis_fn = np.rot90, \
        grid = [2, 5], hspace=0.3, fig_mult=2)

In [None]:
# Align the atlas and the input MR. Resample atlas ROI onto input ROI (fixed: input, moving: atlas). 
pred_mask_objs = [align_and_tfm(input_obj, atlas_obj, atlas_mask_obj, \
                              param_folder = "ElastixParamFiles", param_files = ["affine.txt"]) for
                 atlas_obj, atlas_mask_obj in zip(atlas_objs, atlas_mask_objs)]

In [None]:
# Evaluate predicted input ROI
gt_mask_arr = input_mask_arr

#gt_mask_arr   = sitk2np(gt_mask_obj).astype(bool)
pred_mask_arr = sitk2np(pred_mask_obj).astype(bool)

In [None]:
dice     = compute_dice_coefficient(gt_mask_arr, pred_mask_arr)
coverage = compute_coverage_coefficient(gt_mask_arr, pred_mask_arr)
bbox_coords = mask2bbox(pred_mask_arr)

In [None]:
print({"dice": f"{dice:0.2f}", "coverage": f"{coverage:0.2f}"})
print_bbox(*bbox_coords)

In [None]:
print("Affine: ", bbox_coords)
print("GT: ", input_bbox_coords)

In [None]:
# Viz affine
slice_range = lrange(77, 82) + lrange(107,112)

viz_axis(sitk2np(input_obj), cmap0="gray",
        crop_coords = bbox_coords, crop_extra=35,
        bin_mask_arr = input_mask_arr, color1 = "yellow", alpha1=0.3,
        bin_mask_arr2 = pred_mask_arr, color2 = "red", alpha2=0.3,
        slices=slice_range, fixed_axis=0, \
        axis_fn = np.rot90, \
        grid = [2, 4], hspace=0.3, fig_mult=2)

In [None]:
slice_range = lrange(110, 150)
viz_axis(sitk2np(input_obj), cmap0="gray",
        crop_coords = bbox_coords, crop_extra=20,
        bin_mask_arr = input_mask_arr, color1 = "yellow", alpha1=0.3,
        bin_mask_arr2 = pred_mask_arr, color2 = "red", alpha2=0.3,
        slices=slice_range, fixed_axis=1, \
        axis_fn = np.rot90, \
        grid = [5, 8], hspace=0.3, fig_mult=2)

In [None]:
# Align the atlas and the input MR. Resample atlas ROI onto input ROI (fixed: input, moving: atlas). 
bspline_pred_mask_obj = align_and_tfm(input_obj, atlas_obj, atlas_mask_obj, \
                              param_folder = "ElastixParamFiles", param_files = ["affine.txt", "bspline.txt"])

In [None]:
# Evaluate predicted input ROI
bspline_pred_mask_arr = sitk2np(bspline_pred_mask_obj).astype(bool)

bspline_dice     = compute_dice_coefficient(gt_mask_arr, bspline_pred_mask_arr)
bspline_coverage = compute_coverage_coefficient(gt_mask_arr, bspline_pred_mask_arr)
bspline_bbox_coords = mask2bbox(bspline_pred_mask_arr)

print({"dice": f"{bspline_dice:0.2f}", "coverage": f"{bspline_coverage:0.2f}"})
print_bbox(*bspline_bbox_coords)

In [None]:
print("Affine: ", bbox_coords)
print("+Bspline: ", bspline_bbox_coords)
print("GT: ", input_bbox_coords)

In [None]:
print("Ground Truth GT")
print_bbox(*input_bbox_coords)

print("Affine only")
print_bbox(*bbox_coords)

print("+Bspline")
print_bbox(*bspline_bbox_coords)

In [None]:
# Viz affine + bspline
slice_range = lrange(77, 82) + lrange(107,112)

viz_axis(sitk2np(input_obj), cmap0="gray",
        crop_coords = bspline_bbox_coords, crop_extra=35,
        bin_mask_arr = input_mask_arr, color1 = "yellow", alpha1=0.3,
        bin_mask_arr2 = bspline_pred_mask_arr, color2 = "red", alpha2=0.3,
        slices=slice_range, fixed_axis=0, \
        axis_fn = np.rot90, \
        grid = [2, 5], hspace=0.3, fig_mult=2)

In [None]:
slice_range = lrange(110, 150)
viz_axis(sitk2np(input_obj), cmap0="gray",
        crop_coords = bbox_coords, crop_extra=20,
        bin_mask_arr = input_mask_arr, color1 = "yellow", alpha1=0.3,
        bin_mask_arr2 = bspline_pred_mask_arr, color2 = "red", alpha2=0.3,
        slices=slice_range, fixed_axis=1, \
        axis_fn = np.rot90, \
        grid = [5, 8], hspace=0.3, fig_mult=2)

In [None]:
slice_range = [77,78,79, 80] + [109,110,111,112]

viz_axis(sitk2np(input_obj), cmap0="gray",
        crop_coords = bspline_bbox_coords, crop_extra=55,
        bin_mask_arr = input_mask_arr, color1 = "yellow", alpha1=0.3,
        bin_mask_arr2 = bspline_pred_mask_arr, color2 = "red", alpha2=0.3,
        slices=slice_range, fixed_axis=0, \
        axis_fn = np.rot90, \
        grid = [2, 4], hspace=0.3, fig_mult=2)

# Affine non-whole brain

Focus alignment on generated input ROI

In [None]:
# pad slices by 5, crop extra bbox by 10vox x 10vox

In [None]:
def crop_extra_mask(bin_mask_arr, crop_coords, slice_pad, vox_pad, fixed_axis=06):
    imin, imax, jmin, jmax, kmin, kmax = crop_coords
    if fixed_axis == 0:   
        jmin -= pad; jmax += pad; kmin -= pad; kmax += pad;
        np_arr = np_arr[:, jmin:jmax, kmin:kmax]
        if bin_mask_arr is not None:  bin_mask_arr  = bin_mask_arr[:, jmin:jmax, kmin:kmax]
        if bin_mask_arr2 is not None: bin_mask_arr2 = bin_mask_arr2[:, jmin:jmax, kmin:kmax]

    elif fixed_axis == 1: 
        imin -= pad; imax += pad; kmin -= pad; kmax += pad;
        np_arr = np_arr[imin:imax, :, kmin:kmax]
        if bin_mask_arr is not None:  bin_mask_arr  = bin_mask_arr[imin:imax, :, kmin:kmax]
        if bin_mask_arr2 is not None: bin_mask_arr2 = bin_mask_arr2[imin:imax, :, kmin:kmax]

    else:
        imin -= pad; imax += pad; jmin -= pad; jmax += pad;
        np_arr = np_arr[imin:imax, jmin:jmax, :]
        if bin_mask_arr is not None:  bin_mask_arr  = bin_mask_arr[imin:imax, jmin:jmax, :]
        if bin_mask_arr2 is not None: bin_mask_arr2 = bin_mask_arr2[imin:imax, jmin:jmax, :]

In [None]:
pred_mask_obj

In [None]:
# Align the atlas and the input MR. Resample atlas ROI onto input ROI (fixed: input, moving: atlas). 
pred_mask_obj = align_focused(input_obj, pred_roi, atlas_obj, atlas_mask_obj, \
                              param_folder = "ElastixParamFiles", param_files = ["affine.txt"])

In [None]:
def align_focused(fixed_obj, fixed_mask_obj, moving_obj, moving_mask_obj, \
                  param_folder = "ElastixParamFiles", param_files = ["affine.txt", "bspline.txt"]):
    
    # ALIGN ATLAS AND INPUT IMAGE
    
    # set moving and fixed images (resample moving=>fixed using T:fixed=>moving)
    elastixImageFilter = sitk.ElastixImageFilter()
    elastixImageFilter.SetFixedImage(fixed_obj)
    elastixImageFilter.SetMovingImage(moving_obj)
    
    # set masks
    elastixImageFilter.SetFixedMask(fixed_mask_obj)
    elastixImageFilter.SetMovingMask(moving_mask_obj)
    
    
    # set parameter map
    parameterMapVector = sitk.VectorOfParameterMap()
    for param_file in param_files:
        parameterMapVector.append(sitk.ReadParameterFile(f"{param_folder}/{param_file}"))
    elastixImageFilter.SetParameterMap(parameterMapVector)

    # Execute alignment
    elastixImageFilter.Execute()

    # MAP MOVING (ATLAS BINARY ROI) ONTO FIXED (INPUT) 

    # set moving image (atlas)                                                    
    transformixImageFilter = sitk.TransformixImageFilter()
    transformixImageFilter.SetMovingImage(moving_mask_obj)
                    
    # set parameter map (Binary mask => nearest neighbor final interpolation)
    transformedParameterMapVector = elastixImageFilter.GetTransformParameterMap()
    transformedParameterMapVector[-1]["ResampleInterpolator"] = ["FinalNearestNeighborInterpolator"]
    transformixImageFilter.SetTransformParameterMap(transformedParameterMapVector)

    # Execute transformation
    transformixImageFilter.Execute()
    
    return transformixImageFilter.GetResultImage()

### Viz Mapped ROI

In [None]:
print(transformed_input_mask_arr.shape)

transformed_input_bbox_coords = mask2bbox(transformed_input_mask_arr)
print_bbox(*transformed_input_bbox_coords)

Compare bounding boxes

In [None]:
print(f"Original Moving Mask"), print_bbox(*input_bbox_coords);
print(f"ROI contains {np.count_nonzero(input_mask_arr)} elements.", "\n");

print(f"Target Fixed Mask"), print_bbox(*atlas_bbox_coords);
print(f"ROI contains {np.count_nonzero(atlas_mask_arr)} elements.", "\n");

print(f"Transformed Moving Mask"), print_bbox(*transformed_input_bbox_coords);
print(f"ROI contains {np.count_nonzero(transformed_input_mask_arr)} elements.", "\n");

In [None]:
viz_axis(input_mask_arr, bin_mask_arr=transformed_input_mask_arr, 
        slices=lrange(76, 80) + lrange(110,112), atlas_axis=0, \
        axis_fn = np.rot90, \
        grid = [2, 3], hspace=0.3, fig_mult=2, cmap0="gray")

### Coronal

In [None]:
sitk2np(iso_atlas_obj).shape

In [None]:
viz_axis(sitk2np(atlas_obj), \
        bin_mask_arr=transformed_input_mask_arr, color1 = "yellow", \
        bin_mask_arr2=atlas_mask_arr, color2 = "red", \
        slices=lrange(120, 126), atlas_axis=1, \
        axis_fn = np.rot90, \
        grid = [2, 3], hspace=0.3, fig_mult=2, cmap0="gray")

## Sagittal

In [None]:
#cmap1 = [white, yellow]; cmap2 = [white, blue]
viz_axis(sitk2np(atlas_obj), \
        bin_mask_arr=transformed_input_mask_arr, color1 = "yellow", \
        bin_mask_arr2=atlas_mask_arr, color2 = "red", \
        slices=lrange(76, 80) + lrange(104,112), atlas_axis=0, \
        axis_fn = np.rot90, \
        grid = [3, 4], hspace=0.3, fig_mult=2, cmap0="gray")

# Old

In [None]:
def affine_align(fixed_obj, fixed_mask_obj, moving_obj, moving_mask_obj, param_file = "AffineParamFile.txt"):
    
    # map moving => fixed (the transform is fixed => moving)
    #parameterMap = sitk.GetDefaultParameterMap('affine')
    parameterMap  = sitk.ReadParameterFile(param_file)
    
    elastixImageFilter = sitk.ElastixImageFilter()
    elastixImageFilter.SetFixedImage(fixed_obj)
    elastixImageFilter.SetMovingImage(moving_obj)
    
    # focus on registering moving mask ROI
    #elastixImageFilter.SetMovingMask(moving_mask_obj)
    #parameterMap["ImageSampler"] = ["RandomSparseMask"]
   
    # print param map
    sitk.PrintParameterMap(parameterMap)
    
    elastixImageFilter.SetParameterMap(parameterMap)
    elastixImageFilter.Execute()

    transformed_moving_obj  = elastixImageFilter.GetResultImage()
    transformedParameterMap = elastixImageFilter.GetTransformParameterMap()[0]
    
    # Binary mask => nearest neighbor
    transformedParameterMap["ResampleInterpolator"] = ["FinalNearestNeighborInterpolator"]
    
    # map ROI of moving => fixed
    transformixImageFilter = sitk.TransformixImageFilter()
    transformixImageFilter.SetTransformParameterMap(transformedParameterMap)
    transformixImageFilter.SetMovingImage(moving_mask_obj)
    transformixImageFilter.Execute()
    
    transformed_moving_mask_obj = transformixImageFilter.GetResultImage()

    # evaluate: dice, coverage
    fixed_mask_arr              = sitk2np(fixed_mask_obj).astype(bool)
    transformed_moving_mask_arr = sitk2np(transformed_moving_mask_obj).astype(bool)

    dice     = compute_dice_coefficient(fixed_mask_arr, transformed_moving_mask_arr)
    coverage = compute_coverage_coefficient(fixed_mask_arr, transformed_moving_mask_arr)
    
    # save bounding box coords
    bbox_coords = mask2bbox(transformed_moving_mask_arr)
        
    return dice, coverage, bbox_coords, transformed_moving_obj, transformed_moving_mask_arr

In [None]:
dice, coverage, bbox_coords, transformed_input_obj, transformed_input_mask_arr = affine_align(atlas_obj, atlas_mask_obj, input_obj, input_mask_obj)

In [None]:
print(f"Dice: {dice}. Coverage {coverage}.")
print_bbox(*bbox_coords)