In [None]:
import numpy as np
from skimage import io
from os.path import expanduser
from tqdm import tqdm
HOME = expanduser("~")
import os, sys
import cv2
import SimpleITK as sitk
from collections import OrderedDict
from shutil import copyfile
import subprocess
%load_ext autoreload
%autoreload 2

In [None]:
PATH = '/home/eddyod/programming/pipeline_utility'
sys.path.append(PATH)
from utilities.sqlcontroller import SqlController
from utilities.utilities_registration import (start_plot, end_plot, update_multires_iterations, 
                                              plot_values, command_iteration)
from utilities.alignment_utility import convert_resolution_string_to_um, SCALING_FACTOR

In [None]:
animal = 'DK39'
DIR = f'/net/birdstore/Active_Atlas_Data/data_root/pipeline_data/{animal}/preps'
INPUT = os.path.join(DIR, 'CH1', 'thumbnail_cleaned')
OUTPUT = os.path.join(DIR, 'CH1', 'thumbnail_aligned')
MASKED = os.path.join(DIR, 'rotated_masked')
ELASTIX = os.path.join(DIR, 'elastix')

In [None]:
def register_with_masks(fixed_index, moving_index):
    
    fixed_file = os.path.join(INPUT, f'{fixed_index}.tif')
    moving_file = os.path.join(INPUT, f'{moving_index}.tif')
    #fixed_mask_file = os.path.join(MASKED, f'{fixed_index}.tif')
    #moving_mask_file = os.path.join(MASKED, f'{moving_index}.tif')
    
    
    fixed = sitk.ReadImage(fixed_file, sitk.sitkFloat32);
    moving = sitk.ReadImage(moving_file, sitk.sitkFloat32)
    #maskFixed = sitk.ReadImage(fixed_mask_file, sitk.sitkUInt8)
    #maskMoving= sitk.ReadImage(moving_mask_file, sitk.sitkUInt8)
    # Handle optimizer
    R = sitk.ImageRegistrationMethod()
    # Restrict the evaluation of the similarity metric thanks to masks
    #R.SetMetricFixedMask(maskFixed)
    #R.SetMetricMovingMask(maskMoving)
    # Set metric as mutual information using joint histogram
    #R.SetMetricAsMattesMutualInformation(numberOfHistogramBins=100)
    R.SetMetricAsJointHistogramMutualInformation(100)

    # Gradient descent optimizer
    R.SetOptimizerAsRegularStepGradientDescent(learningRate=0.01, minStep=1e-6, 
                                               numberOfIterations=100, gradientMagnitudeTolerance=1e-8)

    #R.SetOptimizerScalesFromPhysicalShift()
    R.SetMetricSamplingStrategy(R.REGULAR) #R.RANDOM
    # Define the transformation (Rigid body here)
    transformation = sitk.CenteredTransformInitializer(fixed, moving, sitk.Euler2DTransform())
    R.SetInitialTransform(transformation)
    # Define interpolation method
    R.SetInterpolator(sitk.sitkLinear)
    # Add command to the registration process
    #R.AddCommand(sitk.sitkStartEvent, start_plot)
    #R.AddCommand(sitk.sitkEndEvent, end_plot)
    #R.AddCommand(sitk.sitkMultiResolutionIterationEvent, update_multires_iterations) 
    #R.AddCommand(sitk.sitkIterationEvent, lambda: plot_values(R))

    
    
    # Perform registration
    final_transform = R.Execute(fixed, moving)
    stop_condition = R.GetOptimizerStopConditionDescription()
    iterations = R.GetOptimizerIteration()
    return final_transform


In [None]:
def register(fixed_index, moving_index, filter):
    
    fixed_file = os.path.join(INPUT, f'{fixed_index}.tif')
    moving_file = os.path.join(INPUT, f'{moving_index}.tif')
    fixed = sitk.ReadImage(fixed_file, sitk.sitkUInt16);
    moving = sitk.ReadImage(moving_file, sitk.sitkUInt16)

    
    initial_transform = sitk.CenteredTransformInitializer(
        fixed,
        moving,
        sitk.Euler2DTransform(),
        filter)

    R = sitk.ImageRegistrationMethod()
    # Similarity metric settings.
    R.SetMetricAsMattesMutualInformation(numberOfHistogramBins=150)
    R.SetMetricSamplingStrategy(R.REGULAR)
    R.SetMetricSamplingPercentage(0.01)
    R.SetInterpolator(sitk.sitkLinear)    
    # Optimizer settings.
    R.SetOptimizerAsGradientDescent(learningRate=0.1, 
                                                      numberOfIterations=1000, 
                                                      convergenceMinimumValue=1e-8, 
                                                      convergenceWindowSize=10)
    R.SetOptimizerScalesFromPhysicalShift()

    # Setup for the multi-resolution framework.            
    R.SetShrinkFactorsPerLevel(shrinkFactors = [4,2,1])
    R.SetSmoothingSigmasPerLevel(smoothingSigmas=[2,1,0])
    R.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()

    # Don't optimize in-place, we would possibly like to run this cell multiple times.
    R.SetInitialTransform(initial_transform, inPlace=False)


    final_transform = R.Execute(sitk.Cast(fixed, sitk.sitkFloat32), 
                                                   sitk.Cast(moving, sitk.sitkFloat32))
    return final_transform

def create_matrix(final_transform):
    finalParameters = final_transform.GetParameters()
    fixedParameters = final_transform.GetFixedParameters()
    #print(finalParameters)
    #print(fixedParameters)
    #return
    rot_rad, xshift, yshift = finalParameters
    center = np.array(fixedParameters)

    R = np.array([[np.cos(rot_rad), -np.sin(rot_rad)],
                  [np.sin(rot_rad), np.cos(rot_rad)]])
    shift = center + (xshift, yshift) - np.dot(R, center)
    T = np.vstack([np.column_stack([R, shift]), [0, 0, 1]])
    return T

def create_warp_transforms(animal, transforms, transforms_resol, resolution):
    #transforms_resol = op['resolution']
    transforms_scale_factor = convert_resolution_string_to_um(animal, resolution=transforms_resol) / convert_resolution_string_to_um(animal, resolution=resolution)
    tf_mat_mult_factor = np.array([[1, 1, transforms_scale_factor], [1, 1, transforms_scale_factor]])
    transforms_to_anchor = {
        img_name:
            convert_2d_transform_forms(np.reshape(tf, (3, 3))[:2] * tf_mat_mult_factor) for
        img_name, tf in transforms.items()}

    return transforms_to_anchor

def convert_2d_transform_forms(arr):
    """
    Just creates correct size matrix
    """
    return np.vstack([arr, [0,0,1]])



In [None]:
moments = sitk.CenteredTransformInitializerFilter.MOMENTS
geometry = sitk.CenteredTransformInitializerFilter.GEOMETRY

In [None]:
### loops through the image stack and performs the registration between each set of two consecutive files
files = sorted(os.listdir(INPUT))
anchor_index = len(files) // 2 # middle section of the brain
limit = 2
start = anchor_index - limit
end = anchor_index + limit
#files = files[start:end]
transformation_to_previous_section = OrderedDict()

for i in tqdm(range(1, len(files))):
    final_transform = None
    
    fixed_file = os.path.join(INPUT, files[i - 1])
    moving_file = os.path.join(INPUT, files[i])
    fixed_index = str(i-1).zfill(3)
    moving_index = str(i).zfill(3)
    
    #final_transform, condition, iterations = register_with_ms(fixed_index, moving_index, debug=True)
    final_transform = register_with_masks(fixed_index, moving_index)
    #print(final_transform)
    
    T = create_matrix(final_transform)
    #continue    
    transformation_to_previous_section[files[i]] = T
    #print(moving_index, condition, iterations)

In [None]:
transformation_to_previous_section

In [None]:
transformation_to_anchor_section = {}
# Converts every transformation
for moving_index in range(len(files)):
    if moving_index == anchor_index:
        transformation_to_anchor_section[files[moving_idx]] = np.eye(3)
    elif moving_index < anchor_index:
        T_composed = np.eye(3)
        for i in range(anchor_index, moving_index, -1):
            T_composed = np.dot(np.linalg.inv(transformation_to_previous_section[files[i]]), T_composed)
        transformation_to_anchor_section[files[moving_index]] = T_composed
    else:
        T_composed = np.eye(3)
        for i in range(anchor_index + 1, moving_index + 1):
            T_composed = np.dot(transformation_to_previous_section[files[i]], T_composed)
        transformation_to_anchor_section[files[moving_index]] = T_composed


#### Set the resolution and create a dictionary of transforms

In [None]:
resolution = 'thumbnail'
warp_transforms = create_warp_transforms(animal, transformation_to_anchor_section, 'thumbnail', resolution)
sqlController = SqlController(animal)
width = sqlController.scan_run.width
height = sqlController.scan_run.height
max_width = int(width * SCALING_FACTOR)
max_height = int(height * SCALING_FACTOR)
bgcolor = 'white' # this should be black, but white lets you see the rotation and shift
print(files[anchor_index])

##### This gets the dictionary of transforms and passes those parameters to imagemagick's convert tool to rotate, shift and crop the image

In [None]:
#OUTPUT = "setme to some place where you can write files"
ordered_transforms = OrderedDict(sorted(warp_transforms.items()))
for file, arr in tqdm(ordered_transforms.items()):
    T = np.linalg.inv(arr)
    sx = T[0, 0]
    sy = T[1, 1]
    rx = T[1, 0]
    ry = T[0, 1]
    tx = T[0, 2]
    ty = T[1, 2]
    # sx, rx, ry, sy, tx, ty
    op_str = f" +distort AffineProjection '{sx},{rx},{ry},{sy},{tx},{ty}'"
    op_str += f' -crop {max_width}x{max_height}+0.0+0.0!'
    input_fp = os.path.join(INPUT, file)
    output_fp = os.path.join(OUTPUT, file)
    if os.path.exists(output_fp):
        continue

    cmd = f"convert {input_fp} +repage -virtual-pixel background -background {bgcolor} {op_str} -flatten -compress lzw {output_fp}"
    subprocess.run(cmd, shell=True)
