## Inputs
---

In [1]:
# reading data and working with arrays
from tifffile import imread
import nrrd
import numpy as np

# data paths
p = '/groups/ahrens/ahrenslab/fleishmang/data/forj/EM-LM_3d_3d'
fix_path = p + '/2P_overview_stack_green_gauss45_gamma70.tif'
mov_path = p + '/EM_multiSEM_thumbnails-cropped_iso.tif'

# load data,, make obvious axis order corrections
fix = imread(fix_path).transpose(2,1,0)  # convert to xyz
mov = imread(mov_path).transpose(1,0,2)  # reorder axes to be consistent with fix

# spacings
fix_spacing = np.array([0.6, 0.6, 2.0])
mov_spacing = np.array([0.512, 0.512, 0.512])

# write results
nrrd.write('./fix.nrrd', fix)
nrrd.write('./mov.nrrd', mov)

## Foreground detection
---
Brain masks help speed up alignment and tell downstream algorithms which areas to align and which areas to ignore.

### fixed

In [None]:
# tools for coarse whole brain segmentation
from CircuitSeeker import level_set
from scipy.ndimage import zoom, binary_closing

# get small mask
fix_skip = fix[::4, ::4, :]
fix_skip_spacing = fix_spacing * [4, 4, 1]
fix_mask_small = level_set.brain_detection(
    fix_skip, fix_skip_spacing,
    mask_smoothing=1,
    smooth_sigmas=[6,3,1.5],
    lambda2=8,
)

# enlarge and smooth mask
fix_mask = zoom(fix_mask_small, np.array(fix.shape) / fix_skip.shape, order=0)
fix_mask = binary_closing(fix_mask, np.ones((5,5,5))).astype(np.uint8)

# write result
nrrd.write('./fix_mask.nrrd', fix_mask)

### moving

In [None]:
# tools for coarse whole brain segmentation
from CircuitSeeker import level_set
from scipy.ndimage import zoom, binary_closing

# background is bright - temporarily remove
BACKGROUND_THRESHOLD = 200
mov_bgs = np.copy(mov)
mov_bgs[mov_bgs > BACKGROUND_THRESHOLD] = 0

# get small mask
mov_skip = mov_bgs[::4, ::4, ::4]
mov_skip_spacing = mov_spacing * [4, 4, 4]
mov_mask_small = level_set.brain_detection(
    mov_skip, mov_skip_spacing,
    mask_smoothing=1,
    smooth_sigmas=[6,3,1.5],
    lambda2=8,
)

# enlarge and smooth mask
mov_mask = zoom(mov_mask_small, np.array(mov.shape) / mov_skip.shape, order=0)
mov_mask = binary_closing(mov_mask, np.ones((5,5,5))).astype(np.uint8)

# save output
nrrd.write('./mov_mask.nrrd', mov_mask)

## Coarse affine alignment
---
This step puts the datasets on top of each other and matches their orientation and size. We expect there to be significant residual deformation required to obtain single cell accuracy.

In [None]:
# alignment functions
from CircuitSeeker.align import affine_align
from CircuitSeeker.transform import apply_transform

# by eye, seems like we should start optimization from about the middle
fix_size = fix_spacing * fix.shape
mov_size = mov_spacing * mov.shape
mov_origin = (fix_size - mov_size) / 2

print ("RIGID ALIGNMENT")
rigid = affine_align(
    fix, mov, fix_spacing, mov_spacing,
    fix_mask=fix_mask,
    mov_mask=mov_mask,
    mov_origin=mov_origin,
    shrink_factors=[4,2,],
    smooth_sigmas=[6,4,],
    iterations=100,
)

print ("AFFINE ALIGNMENT")
affine = affine_align(
    fix, mov, fix_spacing, mov_spacing,
    fix_mask=fix_mask,
    mov_mask=mov_mask,
    mov_origin=mov_origin,
    rigid=rigid,
    shrink_factors=[4,2,],
    smooth_sigmas=[6,4,],
    iterations=100,
)

# compose mov_origin change into affine
mov_origin_trans = np.eye(4)
mov_origin_trans[:3, -1] = -1 * mov_origin
affine = np.matmul(mov_origin_trans, affine)

# apply affine
affine_aligned = apply_transform(
    fix, mov,
    fix_spacing, mov_spacing,
    transform_list=[affine,],
)
mov_mask_affine = apply_transform(
    fix, mov_mask,
    fix_spacing, mov_spacing,
    transform_list=[affine,],
)
mov_mask_affine = (mov_mask_affine > 0).astype(np.uint8)

# write results
np.savetxt('affine.mat', affine)
nrrd.write('./affine.nrrd', affine_aligned, compression_level=2)
nrrd.write('./mov_mask_affine.nrrd', mov_mask_affine)

## Manual correction of moving mask
---
For this dataset, the moving image has cells that the fixed image does not have. We remove these areas from the moving image mask so they don't affect the alignment

In [None]:
mov_mask_modified, _ = nrrd.read('./mov_mask_affine_modified.nrrd')

## Prepare modified-in-place variables
---
The next few steps will iteratively update the values held in `deform`, `current_moving`, and `current_mask`. Here we initialize them with the output of the coarse alignment.

In [None]:
# useful conversion functions
from CircuitSeeker.utility import matrix_to_displacement_field

# transform
deform = matrix_to_displacement_field(fix, affine, fix_spacing)

# moving image and mask
current_moving = affine_aligned
current_mask = mov_mask_modified

## Twisting alignment
---
This alignment approach is meant to register out low spatial frequency but high amplitude differences that are non-affine. These are things like large bends and twists that can happen after dissecting and handling a sample. This is a "meta" registration algorithm since it composes the results of many registrations into a single transform. Future versions of `CircuitSeeker` will make this algorithm available as a single function call, but for know you should try to read and understand what this code is doing.

In [None]:
# function for overlapping affine alignments
from CircuitSeeker.align import piecewise_affine_align

# define blocks
blocks = [ [(2, 1, 1), (1, 2, 1)],
           [(3, 1, 1), (1, 2, 1)],
           [(4, 1, 1), (1, 2, 1)],
           [(5, 1, 1), (1, 3, 1), (1, 1, 2)],
           [(6, 1, 1), (1, 3, 1), (1, 1, 2)],
           [(1, 1, 1)], ]

# nested twisting
for block_iter, block_list in enumerate(blocks):
    
    d = np.zeros_like(deform)
    for ax, nblocks in enumerate(block_list):
        print("PIECEWISE AXIS:", ax)
        sf = [4, 2] if block_iter < 3 else [2, 1]
        ss = [6, 4] if block_iter < 3 else [4, 2]
        d += piecewise_affine_align(
            fix, current_moving, fix_spacing, fix_spacing,
            nblocks=nblocks,
            fix_mask=fix_mask,
            mov_mask=current_mask,
            shrink_factors=sf,
            smooth_sigmas=ss,
            iterations=100,
            cluster_kwargs={'project':'ahrens'}
        )
    
    # take mean
    d = d / len(block_list)
    
    # compose with deform
    for i in range(3):
        padded = np.pad(deform[..., i], [(25, 25),]*3, mode='edge')
        deform[..., i] = apply_transform(
            deform[..., i], padded, fix_spacing, fix_spacing,
            transform_list=[d[..., ::-1]],
            mov_origin=fix_spacing * -25,
        )
    
    deform = deform + d
    
    # update current moving
    current_moving = apply_transform(
        fix, mov, fix_spacing, mov_spacing,
        transform_list=[deform[..., ::-1]],
    )
    
    # update current moving mask
    current_mask = apply_transform(
        fix, current_mask, fix_spacing, fix_spacing,
        transform_list=[d[..., ::-1]],
    )
    current_mask = (current_mask > 0).astype(np.uint8)
    
    # write intermediate result
    nrrd.write(f'twist_{block_iter}.nrrd', current_moving, compression_level=2)
    nrrd.write(f'twist_mask_{block_iter}.nrrd', current_mask, compression_level=2)
    nrrd.write(f'twist_deform_{block_iter}.nrrd', deform, compression_level=2)


## Piecewise affine refinement
---
This is similar to the previous step but allows the moving image a lot more freedom since blocks are cut along multiple axes at the same time.

In [None]:
# function for overlapping affine alignments
from CircuitSeeker.align import piecewise_affine_align

# define blocks
blocks = [(4, 2, 1), (4, 3, 1), (4, 3, 2),]

# piecewise affines
for block_iter, nblocks in enumerate(blocks):
    
    # align
    d = piecewise_affine_align(
        fix, current_moving, fix_spacing, fix_spacing,
        nblocks=nblocks,
        fix_mask=fix_mask,
        mov_mask=current_mask,
        shrink_factors=[2, 1],
        smooth_sigmas=[4, 2],
        iterations=100,
        cluster_kwargs={'project':'ahrens'}
    )
    
    # compose with deform
    for i in range(3):
        padded = np.pad(deform[..., i], [(25, 25),]*3, mode='edge')
        deform[..., i] = apply_transform(
            deform[..., i], padded, fix_spacing, fix_spacing,
            transform_list=[d[..., ::-1]],
            mov_origin=fix_spacing * -25,
        )
    deform = deform + d
    
    # update current moving
    current_moving = apply_transform(
        fix, mov, fix_spacing, mov_spacing,
        transform_list=[deform[..., ::-1]],
    )
    
    # update current moving mask
    current_mask = apply_transform(
        fix, current_mask, fix_spacing, fix_spacing,
        transform_list=[d[..., ::-1]],
    )
    current_mask = (current_mask > 0).astype(np.uint8)
    
    # write intermediate result
    nrrd.write(f'pwa_{block_iter}.nrrd', current_moving, compression_level=2)
    nrrd.write(f'pwa_mask_{block_iter}.nrrd', current_mask, compression_level=2)
    nrrd.write(f'pwa_deform_{block_iter}.nrrd', deform, compression_level=2)
    

## Local Exhaustive Refinement
---
This is the final step. The major difference here is that blocks of the moving image search for an alignment location using brute force (the check all possible locations within a search grid) rather than by gradient descent. This ensures that local minima (which are common for images with high periodiciy/high frequency content like cells) don't block us from finding the best alignment for each small patch of cells. We are only able to use this brute force approach here because the previous steps have brought the correct locations close enough to each other that we don't have to search very far for the

In [3]:
# import local exhaustive function
from CircuitSeeker.align import piecewise_exhaustive_translation

# define query block sizes
blocks = [[24, 24, 6], [16, 16, 4]]

# nested exhaustive searches
for block_iter, query_radius in enumerate(blocks):
    
    # define search radius and strides from query size
    step_sizes = [2, 2, 1]
    num_steps = [int(x/y) for x, y in zip(query_radius, step_sizes)]
    stride = [int(1.5*x) for x in query_radius]

    # compute local exhaustive
    d = piecewise_exhaustive_translation(
        fix, current_moving,
        fix_spacing, fix_spacing,
        stride, query_radius,
        num_steps, step_sizes,
        smooth_sigma=6.0,
        mask=current_mask,
        peak_ratio=1.05,
        bins=64,
        nworkers=300,
        cluster_kwargs={'project':'ahrens'}
    )

    # compose with deform
    for i in range(3):
        padded = np.pad(deform[..., i], [(25, 25),]*3, mode='edge')
        deform[..., i] = apply_transform(
            deform[..., i], padded, fix_spacing, fix_spacing,
            transform_list=[d[..., ::-1]],
            mov_origin=fix_spacing * -25,
        )
    deform = deform + d

    # update current moving
    current_moving = apply_transform(
        fix, mov, fix_spacing, mov_spacing,
        transform_list=[deform[..., ::-1]],
    )

    # update current moving mask
    current_mask = apply_transform(
        fix, current_mask, fix_spacing, fix_spacing,
        transform_list=[d[..., ::-1]],
    )
    current_mask = (current_mask > 0).astype(np.uint8)

    # write intermediate result
    nrrd.write(f'exhaustive_{block_iter}.nrrd', current_moving, compression_level=2)
    nrrd.write(f'exhaustive_mask_{block_iter}.nrrd', current_mask, compression_level=2)
    nrrd.write(f'exhaustive_deform_{block_iter}.nrrd', deform, compression_level=2)

Cluster dashboard link:  http://10.36.110.12:8787/status
Waiting 30 seconds for cluster to scale
Cluster wait time complete


KeyboardInterrupt: 