# Register MRI data

In [1]:
import os
import numpy as np
import nibabel as nib
import glob
from pprint import pprint
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import SimpleITK as sitk
from scipy.ndimage import binary_fill_holes
import ants

## Functions

In [2]:
import time 

def n4_bias_correction(img, bg_mask, shrink_factor: float=15, show: bool=False) -> np.ndarray:
    """
    N4 bias correction for the input image.
    
    Parameters:
    - img: The input image to correct.
    - alpha: The alpha value for contrast adjustment.
    - shrink_factor: The shrink factor for downsampling the image for bias correction.
    - show: Whether to show the intermediate results.
    
    Returns:
    - corrected_image_full_resolution: The bias corrected image.
    """ 
    # Create the brain tissue mask
    bg_mask = bg_mask.astype(np.uint8)
    mask_img = sitk.GetImageFromArray(bg_mask)
    mask_img = sitk.LiThreshold(mask_img, 0, 1)

    # Use the raw image and convert it to float32
    raw_img = sitk.GetImageFromArray(img.copy())
    raw_img = sitk.Cast(raw_img, sitk.sitkFloat32)

    # Downsample it for bias correction
    inputImage = raw_img
    if shrink_factor > 1:
        inputImage = sitk.Shrink( raw_img, [ shrink_factor ] * raw_img.GetDimension() ) #2
        maskImage = sitk.Shrink( mask_img, [ shrink_factor ] * inputImage.GetDimension() ) #3
    else:
        inputImage = raw_img
        maskImage = mask_img

    # Run bias correction
    start_time = time.time()
    bias_corrector = sitk.N4BiasFieldCorrectionImageFilter()
    corrected = bias_corrector.Execute(inputImage, maskImage)
    
    # Apply bias correction to full resolution image
    log_bias_field = bias_corrector.GetLogBiasFieldAsImage(raw_img)
    corrected_image_full_resolution = raw_img / sitk.Exp(log_bias_field)
    end_time = time.time()
    corrected_image_full_resolution = sitk.GetArrayFromImage(corrected_image_full_resolution)
    
    # Show the process if True
    if show:
        print(f"Time taken for bias correction: {end_time - start_time:.2f} seconds")
        
        # Show the brain tissue mask
        plt.figure(figsize=(15, 5))
        plt.subplot(1, 2, 1)
        plt.imshow(sitk.GetArrayFromImage(mask_img), cmap='gray')
        plt.title(f"Full resolution brain mask")
        plt.subplot(1, 2, 2)
        plt.imshow(sitk.GetArrayFromImage(maskImage), cmap='gray')
        plt.title(f"Downsampled brain mask (shrink factor={shrink_factor})")
        plt.show()
        
        # Show the log bias field
        plt.figure(figsize=(10, 5))
        plt.imshow(sitk.GetArrayFromImage(log_bias_field))
        plt.colorbar()
        plt.title(f"Log bias field")
        plt.show()

        # Show the corrected bias field image
        plt.figure(figsize=(15, 5))
        plt.subplot(1, 2, 1)
        plt.imshow(img, cmap='gray')
        plt.title(f"Original raw image")
        plt.subplot(1, 2, 2)
        plt.imshow(corrected_image_full_resolution, cmap='gray')
        plt.title(f"Corrected bias raw image")
        plt.show()

        # Increase the contrast of the corrected image and show side-by-side
        preview_alpha = 0.25
        plt.figure(figsize=(15, 5))
        plt.subplot(1, 2, 1)
        plt.imshow(img, cmap='gray')
        plt.title(f"Original contrast image (alpha={preview_alpha})")
        plt.subplot(1, 2, 2)
        plt.imshow(corrected_image_full_resolution, cmap='gray')
        plt.title(f"Corrected bias contrast image (alpha={preview_alpha})")
        plt.show()
        
    return corrected_image_full_resolution


def show_slice(data: np.ndarray, idx: int, title: str = "") -> None:
    """
    Show a slice of the data.
    
    Parameters:
    - data: The processed data.
    - idx: The index of the slice to show.
    - title: The title for the plot.
    """
    f = plt.figure(figsize=(10, 5))
    plt.imshow(data[idx, :, :], cmap='gray')
    plt.title(title)
    plt.show()
    f.clear()
    plt.close(f)


def show_2_slices(data1: np.ndarray, idx1: int, segmentation1: np.ndarray, 
                  data2: np.ndarray, idx2: int, segmentation2: np.ndarray = None,
                  cmap1="gray", cmap2="gray",
                  title1: str = "", title2: str = "") -> None:
    """
    Show two slices of the data and segmentation.
    
    Parameters:
    - data: The processed data.
    - segmentation: The segmentation mask.
    - original_data: The original data.
    - idx: The index of the slice to show.
    - title: The title for the plot.
    """
    f, ax = plt.subplots(1, 2, figsize=(20, 10))
    # Original slice
    ax[0].imshow(data1[idx1, :, :], cmap=cmap1)
    if segmentation1 is not None:
        ax[0].contour(segmentation1[idx1, :, :], levels=[0.5], colors='r')
    ax[0].set_title(title1)
    # Processed slice
    ax[1].imshow(data2[idx2, :, :], cmap=cmap2)
    if segmentation2 is not None:
        ax[1].contour(segmentation2[idx2, :, :], levels=[0.5], colors='r')
    ax[1].set_title(title2)
    plt.show()
    f.clear()
    plt.close(f)
    
    
def show_3_slices(data1: np.ndarray, idx1: int, segmentation1: np.ndarray, 
                  data2: np.ndarray, idx2: int, 
                  data3: np.ndarray, idx3: int,
                  title1: str = "", title2: str = "", title3: str = "") -> None:
    """
    Show three slices of the data and segmentation.
    
    Parameters:
    - data: The processed data.
    - segmentation: The segmentation mask.
    - original_data: The original data.
    - idx: The index of the slice to show.
    - title: The title for the plot.
    """
    f, ax = plt.subplots(1, 3, figsize=(20, 10))
    # Original slice
    ax[0].imshow(data1[idx1, :, :], cmap='gray')
    ax[0].contour(segmentation1[idx1, :, :], levels=[0.5], colors='r')
    ax[0].set_title(title1)
    # Processed slice
    ax[1].imshow(data2[idx2, :, :], cmap='gray')
    ax[1].set_title(title2)
    # Atlas color segmentation
    ax[2].imshow(data3[idx3, :, :])
    ax[2].set_title(title3)
    plt.show()
    f.clear()
    plt.close(f)
    
    
def export_gif(volume: np.ndarray, output_path: str, fps: int = 15, cmap='gray') -> None:
    def update(i):
        # Clear the previous frame
        plt.clf()
        # Plot the data for the current frame (e.g., using plt.imshow for images)
        plt.imshow(volume[i], cmap=cmap) # Assuming 'images' is a list of image arrays
        plt.title(f"Frame {i+1}")
    
    fig, _ = plt.subplots()
    ani = animation.FuncAnimation(fig, update, frames=len(volume), interval=50, repeat_delay=1000)
    writer = animation.PillowWriter(fps=fps) # or animation.ImageMagickWriter()
    ani.save(output_path, writer=writer)
    plt.close(fig)
    
    
    
from PIL import Image
import numpy as np

def resize_and_pad_grayscale_np(image_array, scale, target_size):
    """
    Resize a grayscale NumPy image array by a scale factor, then pad with zeros to match target size.

    Args:
        image_array (np.ndarray): Grayscale image as a 2D NumPy array (H, W).
        scale (float): Scale factor.
        target_size (tuple): (height, width) of the final output image.

    Returns:
        np.ndarray: Resized and padded image as a 2D NumPy array.
    """
    # Convert NumPy array to PIL Image
    image = Image.fromarray(image_array)

    # Resize
    new_width = int(image.width * scale)
    new_height = int(image.height * scale)
    resized = image.resize((new_width, new_height), Image.BILINEAR)

    # Create new padded image
    target_height, target_width = target_size
    padded_image = Image.new("F", (target_width, target_height), 0)

    # Compute paste position
    left = (target_width - new_width) // 2
    top = (target_height - new_height) // 2

    padded_image.paste(resized, (left, top))

    # Convert back to NumPy array
    return np.array(padded_image)


from matplotlib.colors import ListedColormap
from scipy.ndimage import center_of_mass, label


def show_labels(idx: int, image: np.ndarray, labels: np.ndarray, alpha: float = 0.5, title: str = ""):
    curr_img = image[idx, :, :]
    curr_label = labels[idx, :, :]
    
    # Generate a custom colormap with enough unique colors
    unique_values = np.unique(curr_label)
    num_labels = len(unique_values)
    np.random.seed(142)  # For reproducibility
    colors = np.random.rand(num_labels - 1, 3)  # Generate random RGB colors
    colors = np.vstack([[0, 0, 0], colors])  # Ensure 0 is always black
    custom_cmap = ListedColormap(colors)

    # Registration output without overlay
    f = plt.figure(figsize=(10, 10))
    plt.imshow(curr_img, cmap='gray')
    plt.title(title + f" (slice {idx})")
    # Overlay boundaries for each unique value in atlas_label
    for value in unique_values:
        if value == 0:  # Skip background
            continue
        plt.contour(curr_label == value, 
                    levels=[0.5], colors='r', linewidths=1, alpha=alpha)
        
        # Compute the center of mass for the current label
        mask = curr_label == value
        # Label connected components in the mask
        labeled_components, num_components = label(mask)

        # Compute the center of mass for each connected component
        for component_idx in range(1, num_components + 1):
            component_mask = labeled_components == component_idx
            center = center_of_mass(component_mask)

            # Add the label index at the center of the region
            plt.text(center[1], center[0], str(int(value)), color='white', fontsize=6, ha='center', va='center')
            plt.text(center[1], center[0], str(int(value)), color='white', fontsize=6, ha='center', va='center')


        
    plt.imshow(curr_label, cmap=custom_cmap, alpha=alpha)  # Use a colormap like 'tab20'
    plt.show()
    f.clear()
    plt.close(f)
    
    
def export_gif_with_labels(volume: np.ndarray, annotation: np.ndarray, output_path: str, alpha: float = 0.5, fps: int = 15, fontsize=6, cmap='gray') -> None:
    
    def update(i):
        curr_img = volume[i, :, :]
        curr_label = annotation[i, :, :]
        
        # Generate a custom colormap with enough unique colors
        unique_values = np.unique(curr_label)
        num_labels = len(unique_values)
        np.random.seed(142)  # For reproducibility
        colors = np.random.rand(num_labels - 1, 3)  # Generate random RGB colors
        colors = np.vstack([[0, 0, 0], colors])  # Ensure 0 is always black
        custom_cmap = ListedColormap(colors)
        
        # Clear the previous frame
        plt.clf()
        # Plot the data for the current frame (e.g., using plt.imshow for images)
        plt.imshow(curr_img, cmap=cmap) # Assuming 'images' is a list of image arrays
        
        # Overlay boundaries for each unique value in atlas_label
        for value in unique_values:
            if value == 0:  # Skip background
                continue
            plt.contour(curr_label == value, 
                        levels=[0.5], colors='r', linewidths=1, alpha=alpha)
            
            
            if fontsize != 0:
                # Compute the center of mass for the current label
                mask = curr_label == value
                # Label connected components in the mask
                labeled_components, num_components = label(mask)

                # Compute the center of mass for each connected component
                for component_idx in range(1, num_components + 1):
                    component_mask = labeled_components == component_idx
                    center = center_of_mass(component_mask)

                    # Add the label index at the center of the region
                    plt.text(center[1], center[0], str(int(value)), color='white', fontsize=fontsize, ha='center', va='center')
                    plt.text(center[1], center[0], str(int(value)), color='white', fontsize=fontsize, ha='center', va='center')

            
        plt.title(f"Frame {i+1}")
    
    fig, _ = plt.subplots()
    ani = animation.FuncAnimation(fig, update, frames=len(volume), interval=50, repeat_delay=1000)
    writer = animation.PillowWriter(fps=fps) # or animation.ImageMagickWriter()
    ani.save(output_path, writer=writer)
    plt.close(fig)
    
    
def add_masked_suffix(filename):
    # Handles .nii.gz extension
    if filename.endswith('.nii.gz'):
        base = filename[:-7]  # Remove .nii.gz
        return base + '_MASKED.nii.gz'
    else:
        base, ext = os.path.splitext(filename)
        return base + '_MASKED' + ext

## MRI AAVretro

In [7]:
import os

file_path = "/path/to/parent/child/file.nii.gz"
parts = os.path.normpath(file_path).split(os.sep)
child = parts[-2]
file_base = os.path.splitext(os.path.splitext(parts[-1])[0])[0]  # Handles .nii.gz
child_and_file_base = os.path.join(child, file_base)
print(child_and_file_base)  # child/file

child/file


In [3]:
folder_path = "data/aav-mri/*/*.nii.gz"

In [8]:
nii_files = glob.glob(folder_path)
nii_files = sorted([f for f in nii_files if not os.path.splitext(os.path.splitext(os.path.basename(f))[0])[0].endswith('_MASKED')])
masked_files = [add_masked_suffix(f) for f in nii_files]

#pprint(nii_files)
#pprint(masked_files)

# Load the atlas
atlas_path = "../../../CCF_DATA/average_template_25.nii.gz"
atlas = nib.load(atlas_path).get_fdata()

# Load atlas annotation
annotation_path = "../../../CCF_DATA/annotation_25.nii.gz"
annotation = nib.load(annotation_path).get_fdata()

# Subset the atlas and annotation for partial scanned images
atlas_subset = atlas[100:500:10, :, :]
annotation_subset = annotation[100:500:10, :, :]
fixed_image = ants.from_numpy(atlas_subset)

for i in range(len(nii_files)):
    print(f"\nProcessing file {i+1}/{len(nii_files)}")
    data_path = nii_files[i]
    mask_path = masked_files[i]
    
    # Make output directory based on the file structure
    parts = os.path.normpath(data_path).split(os.sep)
    child = parts[-2]
    file_base = os.path.splitext(os.path.splitext(parts[-1])[0])[0]  # Handles .nii.gz
    child_and_file_base = os.path.join(child, file_base)
    output_directory = f"output/{child_and_file_base}"
    
    last_file = f"{output_directory}/registered_atlas_inv.gif"
    if os.path.exists(last_file):
        print(f"Skipping {data_path} as the output already exists.")
        continue
    
    # Load the data and reorder the axes
    data = nib.load(data_path).get_fdata()
    data = np.transpose(data, (2, 1, 0))

    # Load the segmentation mask
    mask = nib.load(mask_path).get_fdata()
    mask = np.transpose(mask, (2, 1, 0))

    # Fill holes in the mask
    mask = binary_fill_holes(mask > 0).astype(np.uint8)

    # Perform N4 bias correction
    data = n4_bias_correction(data, mask, shrink_factor=1, show=False)

    # Mask over the data
    data[mask == 0] = 0

    # Get xmin, xmax, ymin, ymax boundaries of the mask
    xmin, xmax = np.where(mask.sum(axis=(0, 1)) > 0)[0][[0, -1]]
    ymin, ymax = np.where(mask.sum(axis=(0, 2)) > 0)[0][[0, -1]]
    #print(f"Mask boundaries: x: {xmin} - {xmax}, y: {ymin} - {ymax}")

    # Print information about the data
    print("Data path:")
    print("\t" + data_path)

    print("Mask path:")
    print("\t" + mask_path)

    #print("\nData shape:")
    #pprint(data.shape)
    #print("Data type:")
    #pprint(data.dtype)

    #print("\nAtlas shape:")
    #pprint(atlas.shape)
    
    #show_labels(22, atlas, annotation, 0.5, "")  # Show an example
    
    # Crop, pad, and resize data to be similar to the atlas
    cropped_data = data.copy()[:, ymin:ymax, xmin:xmax]
    resized = []
    for i in range(0, len(cropped_data)):
        result_np = resize_and_pad_grayscale_np(cropped_data[i], scale=4, target_size=(320, 456))
        resized.append(result_np)

    # Convert the list to a NumPy array
    resized = np.array(resized)
    #print("\nResized shape:")
    #pprint(resized.shape)

    # Show index
    #for idx in range(0, cropped_data.shape[0]):
    #    show_slice(cropped_data, idx, f"Cropped MRI slice {idx}")
    
    # Make sure the output directory exists and export the masked out data
    os.makedirs(output_directory, exist_ok=True)
    export_gif(resized, f"{output_directory}/data_01_resized.gif", fps=6)
    export_gif(data, f"{output_directory}/data_02_masked.gif", fps=6)
    
    
    # Register 

    # Run code for ANTs registration
    #print("Converting arrays to ANTs format...")
    moving_image = ants.from_numpy(resized)
        
    # Register Degu to atlas forward mapping
    fwd_output_path = f"{output_directory}/registered_data_fwd.nii.gz"
    #print("Registering MRI AAVretro to atlas subset...")
    result = ants.registration(fixed_image, moving_image, type_of_transform = 'SyN' )
    print("Saving to", fwd_output_path)
    # Save Moving image warped to space of fixed image
    ants.image_write(result['warpedmovout'], fwd_output_path)
    #print("Done!\n")
    
    
    # Perform inverse mapping
    inv_output_path = f"{output_directory}/registered_data_inv.nii.gz"
    inv_annotation_path = f"{output_directory}/registered_labels_inv.nii.gz"

    # Apply inverse transform to fixed image (atlas_subset)
    warped_fixed = ants.apply_transforms(
        fixed=moving_image,           # target space: moving image
        moving=fixed_image,           # image to warp: fixed image (atlas_subset)
        transformlist=result['invtransforms'],
        interpolator='linear'
    )
    ants.image_write(warped_fixed, inv_output_path)

    # Apply inverse transform to annotation (use nearest neighbor for labels)
    annotation_ants = ants.from_numpy(annotation_subset)
    warped_annotation = ants.apply_transforms(
        fixed=moving_image,
        moving=annotation_ants,
        transformlist=result['invtransforms'],
        interpolator='nearestNeighbor'
    )
    ants.image_write(warped_annotation, inv_annotation_path)
    print("Inverse mapping done! Saved to", inv_output_path, "and", inv_annotation_path)

    # Export the results as GIFs with labels
    fwd_registered_data = nib.load(fwd_output_path).get_fdata()
    inv_annotation = nib.load(inv_annotation_path).get_fdata()
    inv_registered_data = nib.load(inv_output_path).get_fdata()
    export_gif_with_labels(fwd_registered_data, annotation_subset, f"{output_directory}/registered_data_fwd.gif", alpha=0.5, fps=6, fontsize=4)
    export_gif_with_labels(resized, inv_annotation, f"{output_directory}/registered_data_inv.gif", alpha=0.5, fps=6, fontsize=4)
    export_gif_with_labels(inv_registered_data, inv_annotation, f"{output_directory}/registered_atlas_inv.gif", alpha=0.5, fps=6, fontsize=4)


Processing file 1/19
Data path:
	data/aav-mri/Xu-AAVcontrol-M6/Xu_AAVcontrol_M6_0hrs.nii.gz
Mask path:
	data/aav-mri/Xu-AAVcontrol-M6/Xu_AAVcontrol_M6_0hrs_MASKED.nii.gz
Saving to output/Xu-AAVcontrol-M6/Xu_AAVcontrol_M6_0hrs/registered_data_fwd.nii.gz
Inverse mapping done! Saved to output/Xu-AAVcontrol-M6/Xu_AAVcontrol_M6_0hrs/registered_data_inv.nii.gz and output/Xu-AAVcontrol-M6/Xu_AAVcontrol_M6_0hrs/registered_labels_inv.nii.gz

Processing file 2/19
Data path:
	data/aav-mri/Xu-AAVcontrol-M6/Xu_AAVcontrol_M6_24hrs.nii.gz
Mask path:
	data/aav-mri/Xu-AAVcontrol-M6/Xu_AAVcontrol_M6_24hrs_MASKED.nii.gz
Saving to output/Xu-AAVcontrol-M6/Xu_AAVcontrol_M6_24hrs/registered_data_fwd.nii.gz
Inverse mapping done! Saved to output/Xu-AAVcontrol-M6/Xu_AAVcontrol_M6_24hrs/registered_data_inv.nii.gz and output/Xu-AAVcontrol-M6/Xu_AAVcontrol_M6_24hrs/registered_labels_inv.nii.gz

Processing file 3/19
Data path:
	data/aav-mri/Xu-AAVcontrol-M6/Xu_AAVcontrol_M6_2hrs.nii.gz
Mask path:
	data/aav-mri/Xu-