## Inputs and obvious corrections
---

In [None]:
# reading data and working with arrays
import nrrd
import zarr
import numpy as np
from scipy.ndimage import zoom

# data paths
p = ''
fix_path = p + ''
mov_path = p + ''

# load fix data, reflect to match functional data
fix_zarr = zarr.open(store=zarr.N5Store(fix_path), mode='r')
fix_meta = fix_zarr['/c2/s3'].attrs.asdict()
fix = fix_zarr['/c2/s3'][...].transpose(1,2,0)[:, ::-1, :]

# load mov data, reflect to match functional data
mov_zarr = zarr.open(store=zarr.N5Store(mov_path), mode='r')
mov_meta = mov_zarr['/c3/s3'].attrs.asdict()
mov = mov_zarr['/c3/s3'][...].transpose(1,2,0)[:, ::-1, :]

# spacings
fix_spacing = np.array(fix_meta['pixelResolution']) * fix_meta['downsamplingFactors']
mov_spacing = np.array(mov_meta['pixelResolution']) * mov_meta['downsamplingFactors']

# adjust mov_spacing by expansion factor to get pre-expansion size
exp_factor = 2
fix_spacing = fix_spacing / exp_factor
mov_spacing = mov_spacing / exp_factor

# write results
nrrd.write('./fix.nrrd', fix, compression_level=2)
nrrd.write('./mov.nrrd', mov, compression_level=2)

# check spacings
print(fix.shape, mov.shape)
print(fix_spacing, mov_spacing)

# # load precomputed data
# fix, _ = nrrd.read('./fix.nrrd')
# mov, _ = nrrd.read('./mov.nrrd')
# fix_spacing = np.array([0.928184, 0.928184, 0.84    ])
# mov_spacing = np.array([0.92768,  0.92768,  0.846698])

## Foreground detection
---

### fixed

In [None]:
# tools for coarse whole brain segmentation
from CircuitSeeker import level_set
from scipy.ndimage import zoom, binary_closing, binary_dilation

# get small mask
fix_skip = fix[::4, ::4, ::4]
skip_spacing = fix_spacing * [4, 4, 4]
fix_mask_small = level_set.brain_detection(
    fix_skip, skip_spacing,
    mask_smoothing=2,
    iterations=[80,40,10],
    smooth_sigmas=[12,6,3],
    lambda2=64.0,
)

# enlarge and smooth mask
fix_mask = zoom(fix_mask_small, np.array(fix.shape) / fix_skip.shape, order=0)
fix_mask = binary_closing(fix_mask, np.ones((5,5,5))).astype(np.uint8)
fix_mask = binary_dilation(fix_mask, np.ones((5,5,5))).astype(np.uint8)

# write result
nrrd.write('./fix_mask.nrrd', fix_mask)

# # load precomputed mask
# fix_mask, _ = nrrd.read('./fix_mask.nrrd')

### moving

In [None]:
# tools for coarse whole brain segmentation
from CircuitSeeker import level_set
from scipy.ndimage import zoom, binary_closing, binary_dilation

# get small mask
mov_skip = mov[::4, ::4, ::4]
skip_spacing = mov_spacing * [4, 4, 4]
mov_mask_small = level_set.brain_detection(
    mov_skip, skip_spacing,
    mask_smoothing=2,
    iterations=[80,40,10],
    smooth_sigmas=[12,6,3],
    lambda2=64.0,
)

# enlarge and smooth mask
mov_mask = zoom(mov_mask_small, np.array(mov.shape) / mov_skip.shape, order=0)
mov_mask = binary_closing(mov_mask, np.ones((5,5,5))).astype(np.uint8)
mov_mask = binary_dilation(mov_mask, np.ones((5,5,5))).astype(np.uint8)

# save output
nrrd.write('./mov_mask.nrrd', mov_mask)

# # load precomputed mask
# mov_mask, _ = nrrd.read('./mov_mask.nrrd')

## Moments alignment
---

In [None]:
from CircuitSeeker.axisalign import principal_axes, align_modes
from CircuitSeeker.transform import apply_transform

# get modes and align
fix_mean, fix_evals, fix_evecs = principal_axes(fix_mask, fix_spacing)
mov_mean, mov_evals, mov_evecs = principal_axes(mov_mask, mov_spacing)
modes = align_modes(fix_mean, fix_evecs, mov_mean, mov_evecs)

# apply mode transform
modes_aligned = apply_transform(
    fix, mov,
    fix_spacing, mov_spacing,
    transform_list=[modes,],
)

# write results
np.savetxt('modes.mat', modes)
nrrd.write('./modes.nrrd', modes_aligned, compression_level=2)

# # load precomputed mode results
# modes = np.loadtxt('./modes.mat')

## Whole Image Alignment
---

In [None]:
# alignment functions
from CircuitSeeker.align import alignment_pipeline
from CircuitSeeker.align import deformable_align
from CircuitSeeker.transform import apply_transform


affine, deform = alignment_pipeline(
    fix, mov, fix_spacing, mov_spacing,
    steps=['rigid', 'affine', 'deform'],
    initial_transform=modes,
    alignment_spacing=2.0,
    shrink_factors=[2,],
    smooth_sigmas=[2.,],
    iterations=400,
    deform_kwargs={
        'control_point_spacing':100.0,
        'control_point_levels':[1,],
    }
)

# we don't need bspline params, just field
deform = deform[1]

# apply affine only
affine_aligned = apply_transform(
    fix, mov,
    fix_spacing, mov_spacing,
    transform_list=[affine,],
)

# apply affine and deform
deform_aligned = apply_transform(
    fix, mov,
    fix_spacing, mov_spacing,
    transform_list=[affine, deform],
)

# write results
np.savetxt('affine.mat', affine)
nrrd.write('./deform.nrrd', deform, compression_level=2)
nrrd.write('./affine.nrrd', affine_aligned, compression_level=2)
nrrd.write('./deformed.nrrd', deform_aligned, compression_level=2)

# # load precomputed results
# affine = np.loadtxt('./affine.mat')
# deform, _ = nrrd.read('./deform.nrrd')

## Wiggle
---

In [None]:
# function for overlapping affine alignments
from CircuitSeeker.align import nested_distributed_piecewise_alignment_pipeline

# define blocks
block_schedule = [ [tuple(np.round(np.array(fix.shape) / 32).astype(int))], ]

# define parameters
parameter_schedule = [
    {'random_kwargs':{'max_translation':5.,
                      'max_rotation':5. * np.pi/180.,
                      'max_scale':1.05,
                      'max_shear':.05,
                      'random_iterations':2500,
                      'affine_align_best':10,
                      'iterations':24,},
     'affine_kwargs':{},
     'deform_kwargs':{'control_point_spacing':29.0,
                      'control_point_levels':[1,],
                      'iterations':100,
                      'metric':'MI',},
    },
]

# run twist
wiggle = nested_distributed_piecewise_alignment_pipeline(
    fix,
    mov,
    fix_spacing,
    mov_spacing,
    block_schedule,
    parameter_schedule=parameter_schedule,
    initial_transform_list=[affine, deform,],
    fix_mask=fix_mask,
    mov_mask=mov_mask,
    steps=['random', 'affine', 'deform'],
    bins=256,
    shrink_factors=[1,],
    smooth_sigmas=[1.,],
    iterations=400,
    learning_rate=0.1,
    max_step=0.1,
    estimate_learning_rate='never',
    callback=lambda irm: None,
    intermediates_path='./',
    cluster_kwargs={
        'ncpus':6,
        'threads':6,
        'min_workers':25,
        'max_workers':25,
    },
)

# apply twist
wiggled = apply_transform(
    fix, mov,
    fix_spacing, mov_spacing,
    transform_list=[affine, deform, wiggle,],
)

# write results
nrrd.write('./wiggle.nrrd', wiggle, compression_level=2)
nrrd.write('./wiggled.nrrd', wiggled, compression_level=2)

# # load precomputed results
# wiggle, _ = nrrd.read('./wiggle.nrrd')

## Invert all transforms
---

In [None]:
from CircuitSeeker.transform import invert_displacement_vector_field
from CircuitSeeker.transform import apply_transform

# invert affine
affine_inv = np.linalg.inv(affine)
np.savetxt('./affine_inverse.mat', affine_inv)

# invert deform
deform_inv = invert_displacement_vector_field(deform, fix_spacing)
nrrd.write('./deform_inverse.nrrd', deform_inv, compression_level=2)

# invert wiggle
wiggle_inv = invert_displacement_vector_field(wiggle, fix_spacing)
nrrd.write('./wiggle_inverse.nrrd', wiggle_inv, compression_level=2)

# test via image resampling
fix_to_mov = apply_transform(
    mov, fix, mov_spacing, fix_spacing,
    transform_list=[wiggle_inv, deform_inv, affine_inv],
    transform_spacing=fix_spacing,
)
nrrd.write('./fix_warped_to_mov.nrrd', fix_to_mov, compression_level=2)

# # load precomputed results
# affine_inv = np.loadtxt('./affine_inverse.mat')
# deform_inv, _ = nrrd.read('./deform_inverse.nrrd')
# wiggle_inv, _ = nrrd.read('./wiggle_inverse.nrrd')