In [1]:
import numpy as np
from skimage import io
from os.path import expanduser
from tqdm import tqdm
HOME = expanduser("~")
import os, sys
import cv2
import SimpleITK as sitk
from collections import OrderedDict
import subprocess
%load_ext autoreload
%autoreload 2

In [2]:
PATH = '/home/eddyod/programming/pipeline_utility'
sys.path.append(PATH)
from utilities.sqlcontroller import SqlController
from utilities.alignment_utility import convert_resolution_string_to_um, SCALING_FACTOR

Connecting dklab@192.168.1.12:3306


In [3]:
animal = 'DK39'
DIR = f'/net/birdstore/Active_Atlas_Data/data_root/pipeline_data/{animal}/preps'
INPUT = os.path.join(DIR, 'CH1', 'thumbnail_cleaned')
OUTPUT = os.path.join(DIR, 'CH1', 'thumbnail_aligned')
ELASTIX = os.path.join(DIR, 'elastix')

In [4]:
def register(fixed_image, moving_image, filter):
    """
    This code is very similar to the code from the ITK site. I added a variable
    for the filter but it should usually use the MOMENTS filter
    """
    
    initial_transform = sitk.CenteredTransformInitializer(
        fixed_image,
        moving_image,
        sitk.Euler2DTransform(),
        filter)

    registration_method = sitk.ImageRegistrationMethod()

    # Similarity metric settings.
    registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
    registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
    registration_method.SetMetricSamplingPercentage(0.01)

    registration_method.SetInterpolator(sitk.sitkLinear)

    # Optimizer settings.
    registration_method.SetOptimizerAsGradientDescent(learningRate=0.5, 
                                                      numberOfIterations=100, 
                                                      convergenceMinimumValue=1e-6, 
                                                      convergenceWindowSize=10)
    registration_method.SetOptimizerScalesFromPhysicalShift()

    # Setup for the multi-resolution framework.            
    registration_method.SetShrinkFactorsPerLevel(shrinkFactors = [4,2,1])
    registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[2,1,0])
    registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()

    # Don't optimize in-place, we would possibly like to run this cell multiple times.
    registration_method.SetInitialTransform(initial_transform, inPlace=False)


    final_transform = registration_method.Execute(sitk.Cast(fixed_image, sitk.sitkFloat32), 
                                                   sitk.Cast(moving_image, sitk.sitkFloat32))
    return final_transform

In [5]:
def parameter_elastix_parameter_file_to_dict(filename):
    """
    Reads the *.tfm files and parses out the rotation and x,y shift
    """
    d = {}
    with open(filename, 'r') as f:
        for line in f.readlines():
            line = line.strip()
            if line.startswith('Parameters:'):
                tokens = line.split(' ')
                tokens = [float(i) for i in tokens[1:4]]
                d['Parameters'] = tokens
            if line.startswith('FixedParameters:'):
                tokens = line.split(' ')
                tokens = [float(i) for i in tokens[1:4]]
                d['FixedParameters'] = tokens
    return d
    

def parse_elastix_parameter_file(filepath):
    """
    Takes rotation, x and y shift from tfm file and puts it
    in a rotation matrix
    """
    d = parameter_elastix_parameter_file_to_dict(filepath)
    rot_rad, x_mm, y_mm = d['Parameters']
    center = np.array(d['FixedParameters'])

    xshift = x_mm 
    yshift = y_mm 

    R = np.array([[np.cos(rot_rad), -np.sin(rot_rad)],
                  [np.sin(rot_rad), np.cos(rot_rad)]])
    shift = center + (xshift, yshift) - np.dot(R, center)
    T = np.vstack([np.column_stack([R, shift]), [0, 0, 1]])
    return T



def load_consecutive_section_transform(moving_fn, fixed_fn):
    """
    Finds the correct tfm file to use for the moving and fixed files
    """
    infile = f'{moving_fn}-{fixed_fn}.tfm'
    param_fp = os.path.join(ELASTIX, infile)
    return parse_elastix_parameter_file(param_fp)

def parse_elastix():
    """
    Loops through the files and computes the final transform between each moving and fixed image
    """
    image_name_list = sorted(os.listdir(INPUT))
    anchor_idx = len(image_name_list) // 2 # middle section of the brain
    # anchor_idx = len(image_name_list) - 1
    transformation_to_previous_sec = {}

    for i in range(1, len(image_name_list)):
        fixed_fn = os.path.splitext(image_name_list[i - 1])[0]
        moving_fn = os.path.splitext(image_name_list[i])[0]
        transformation_to_previous_sec[i] = load_consecutive_section_transform(moving_fn, fixed_fn)

    transformation_to_anchor_sec = {}
    # Converts every transformation
    for moving_idx in range(len(image_name_list)):
        if moving_idx == anchor_idx:
            transformation_to_anchor_sec[image_name_list[moving_idx]] = np.eye(3)
        elif moving_idx < anchor_idx:
            T_composed = np.eye(3)
            for i in range(anchor_idx, moving_idx, -1):
                T_composed = np.dot(np.linalg.inv(transformation_to_previous_sec[i]), T_composed)
            transformation_to_anchor_sec[image_name_list[moving_idx]] = T_composed
        else:
            T_composed = np.eye(3)
            for i in range(anchor_idx + 1, moving_idx + 1):
                T_composed = np.dot(transformation_to_previous_sec[i], T_composed)
            transformation_to_anchor_sec[image_name_list[moving_idx]] = T_composed


    return transformation_to_anchor_sec


def convert_2d_transform_forms(arr):
    """
    Just creates correct size matrix
    """
    return np.vstack([arr, [0,0,1]])

def create_warp_transforms(animal, transforms, transforms_resol, resolution):
    #transforms_resol = op['resolution']
    transforms_scale_factor = convert_resolution_string_to_um(animal, resolution=transforms_resol) / convert_resolution_string_to_um(animal, resolution=resolution)
    tf_mat_mult_factor = np.array([[1, 1, transforms_scale_factor], [1, 1, transforms_scale_factor]])
    transforms_to_anchor = {
        img_name:
            convert_2d_transform_forms(np.reshape(tf, (3, 3))[:2] * tf_mat_mult_factor) for
        img_name, tf in transforms.items()}

    return transforms_to_anchor


In [6]:
moments = sitk.CenteredTransformInitializerFilter.MOMENTS
geometry = sitk.CenteredTransformInitializerFilter.GEOMETRY

In [7]:
### loops through the image stack and performs the registration between each set of two consecutive files
### then saves the transformation in a file. It would probably be much better if we could just
### use the values from the final transform directly, or put in a dictionary instead of saving
### to file and then rereading them
image_name_list = sorted(os.listdir(INPUT))
for i in tqdm(range(1, len(image_name_list))):
    final_transform = None
    fixed_file = os.path.join(INPUT, image_name_list[i - 1])
    moving_file = os.path.join(INPUT, image_name_list[i])
    previous = str(i-1).zfill(3)
    current = str(i).zfill(3)
    outfile = f'{current}-{previous}.tfm'
    outpath = os.path.join(ELASTIX, outfile)
    if os.path.exists(outpath):
        continue
    
    moving_image = sitk.ReadImage(moving_file, sitk.sitkUInt16)
    fixed_image =  sitk.ReadImage(fixed_file, sitk.sitkUInt16)
    
    try:
        final_transform = register(fixed_image, moving_image, moments)
    except:
        print('Could not create moments transform for ', outfile)
        try:
            final_transform = register(fixed_image, moving_image, geometry)
        except:
            print('Could not create geometry transform for ', outfile)
            
        
    if final_transform is not None:
        sitk.WriteTransform(final_transform, outpath)    


100%|██████████| 468/468 [03:52<00:00,  2.02it/s]


#### Set the resolution and create a dictionary of transforms

In [8]:
resolution = 'thumbnail'
transforms = parse_elastix()
warp_transforms = create_warp_transforms(animal, transforms, 'thumbnail', resolution)

In [14]:
sqlController = SqlController(animal)
width = sqlController.scan_run.width
height = sqlController.scan_run.height
max_width = int(width * SCALING_FACTOR)
max_height = int(height * SCALING_FACTOR)
bgcolor = 'white' # this should be black, but white lets you see the rotation and shift

##### This gets the dictionary of transforms and passes those parameters to imagemagick's convert tool to rotate, shift and crop the image

In [15]:
OUTPUT = "setme to some place where you can write files"
ordered_transforms = OrderedDict(sorted(warp_transforms.items()))
for file, arr in tqdm(ordered_transforms.items()):
    T = np.linalg.inv(arr)
    sx = T[0, 0]
    sy = T[1, 1]
    rx = T[1, 0]
    ry = T[0, 1]
    tx = T[0, 2]
    ty = T[1, 2]
    
    op_str = f" +distort AffineProjection {sx},{rx},{ry},{sy},{tx},{ty}"
    op_str += f' -crop {max_width}x{max_height}+0.0+0.0!'
    input_fp = os.path.join(INPUT, file)
    output_fp = os.path.join(OUTPUT, file)
    if os.path.exists(output_fp):
        continue

    cmd = "convert {} +repage -virtual-pixel background -background {} {} -flatten -compress lzw {}"\
        .format(input_fp, bgcolor, op_str, output_fp)
    subprocess.run(cmd, shell=True)


100%|██████████| 469/469 [01:59<00:00,  3.92it/s]
