Get all imports

In [None]:
import nibabel as nib
import meshio
import pymeshlab as ml
import plotly as ply
import numpy as np

import os
from os.path import join

from skimage.feature import match_descriptors, plot_matches, SIFT

import SimpleITK as sitk

Define some plotting utility functions

In [11]:
%matplotlib inline
import matplotlib.pyplot as plt

from ipywidgets import interact, fixed
from IPython.display import clear_output

# Callback invoked by the interact IPython method for scrolling through the image stacks of
# the two images (moving and fixed).
def display_images(fixed_image_z, moving_image_z, fixed_npa, moving_npa):
    # Create a figure with two subplots and the specified size.
    plt.subplots(1,2,figsize=(10,8))
    
    # Draw the fixed image in the first subplot.
    plt.subplot(1,2,1)
    plt.imshow(fixed_npa[fixed_image_z,:,:],cmap=plt.cm.Greys_r);
    plt.title('fixed image')
    plt.axis('off')
    
    # Draw the moving image in the second subplot.
    plt.subplot(1,2,2)
    plt.imshow(moving_npa[moving_image_z,:,:],cmap=plt.cm.Greys_r);
    plt.title('moving image')
    plt.axis('off')
    
    plt.show()

# Callback invoked by the IPython interact method for scrolling and modifying the alpha blending
# of an image stack of two images that occupy the same physical space. 
def display_images_with_alpha(image_z, alpha, fixed, moving):
    img = (1.0 - alpha)*fixed[:,:,image_z] + alpha*moving[:,:,image_z] 
    plt.imshow(sitk.GetArrayViewFromImage(img),cmap=plt.cm.Greys_r);
    plt.axis('off')
    plt.show()
    
# Callback invoked when the StartEvent happens, sets up our new data.
def start_plot():
    global metric_values, multires_iterations
    
    metric_values = []
    multires_iterations = []

# Callback invoked when the EndEvent happens, do cleanup of data and figure.
def end_plot():
    global metric_values, multires_iterations
    
    del metric_values
    del multires_iterations
    # Close figure, we don't want to get a duplicate of the plot latter on.
    plt.close()

# Callback invoked when the IterationEvent happens, update our data and display new figure.
def plot_values(registration_method):
    global metric_values, multires_iterations
    
    metric_values.append(registration_method.GetMetricValue())                                       
    # Clear the output area (wait=True, to reduce flickering), and plot current data
    clear_output(wait=True)
    # Plot the similarity metric values
    plt.plot(metric_values, 'r')
    plt.plot(multires_iterations, [metric_values[index] for index in multires_iterations], 'b*')
    plt.xlabel('Iteration Number',fontsize=12)
    plt.ylabel('Metric Value',fontsize=12)
    plt.show()
    
# Callback invoked when the sitkMultiResolutionIterationEvent happens, update the index into the 
# metric_values list. 
def update_multires_iterations():
    global metric_values, multires_iterations
    multires_iterations.append(len(metric_values))

Define the function for extracting high contrast regions from NIFTI images

In [8]:
def segment_high_contrast(nifti_file_path: str, threshold: int=0.95):
    """
    Parameters:
    nifti_file_path: file path of the nifti image to load and segment
    threshold: percentage of maximum response to threshold the contrast to

    Returns:
    NIFTI image of only the highest values
    """
    nifti_img = nib.load(nifti_file_path)
    img_voxels = nifti_img.get_fdata()
    resultant_voxels = np.zeros(img_voxels.shape)
    indices = np.where(img_voxels > threshold * np.max(img_voxels))
    resultant_voxels[indices] = img_voxels[indices]
    return resultant_voxels

Read in image files for reference and target mouse

In this case, the target mouse is the "fixed image"

In [9]:
WORKDIR = os.getcwd()
ATLAS_DIR = join(WORKDIR, "atlas")
DATA_DIR = join(WORKDIR, "data")
ASSETS_DIR = join(WORKDIR, "assets")
REFERENCE_DIR = join(ATLAS_DIR, "reference")

REF_NII = join(REFERENCE_DIR, "mouse_nii", "scaled_mouse.nii")
TARG_NII = join(ASSETS_DIR, "images", "sample", "CT_TS_HEUHR_In111_free_M1039_0h_220721-selfcal.nii")

Read in image files

In [None]:
ref_mouse = sitk.GetImageFromArray(segment_high_contrast(REF_NII))
targ_mouse = sitk.GetImageFromArray(segment_high_contrast(TARG_NII))

In [14]:
interact(display_images, fixed_image_z=(0,targ_mouse.GetSize()[2]-1), moving_image_z=(0,ref_mouse.GetSize()[2]-1), fixed_npa = fixed(sitk.GetArrayViewFromImage(targ_mouse)), moving_npa=fixed(sitk.GetArrayViewFromImage(ref_mouse)))