## Locate data
---

In [None]:
# file paths
p = ''
fix_directory = p + ''
mov_directory = p + ''

## Get fixed image mean
---

In [None]:
from CircuitSeeker import motion_correct
import nrrd
import numpy as np

# path information for folder of time frames
frames = {'folder':fix_directory,
          'prefix':'TM',
          'suffix':'.h5',
          'dataset_path':'/default',}

# compute mean from all frames
fix = motion_correct.distributed_image_mean(
    frames,
    cluster_kwargs={},
)

# set spacing
fix_spacing = np.array([1.0, 0.406, 0.406])

# store output - switch to xyz axis order for visualizing
nrrd.write('./fix.nrrd', fix.transpose(2,1,0), compression_level=2)

# # load saved results
# fix, _ = nrrd.read('./fix.nrrd')
# fix = fix.transpose(2,1,0)  # switch back to zyx
# fix_spacing = np.array([1.0, 0.406, 0.406])

## Compute fixed image mask
---

In [None]:
from CircuitSeeker.level_set import brain_detection
from scipy.ndimage import zoom, binary_dilation, binary_closing

# segment on downsampled data for speed
fix_small = zoom(fix, [.5, .25, .25], order=1)
fix_small_spacing = fix_spacing * [2, 4, 4]

# segment
fix_mask = brain_detection(
    fix_small,
    fix_small_spacing,
    smooth_sigmas=[3.,1.5,0.75],
    lambda2=48.,  # 32. almost perfect, just a little too tight, 4. too tight
    mask_smoothing=1,
)

# dilate the boundaries a little, go back to original sampling, and smooth boundaries
# you can also play with the dilation/closing element size here to adjust mask boundaries
fix_mask = zoom(fix_mask, np.array(fix.shape) / fix_small.shape, order=0)
fix_mask = binary_closing(fix_mask, np.ones((5,5,5))).astype(np.uint8)
fix_mask = binary_dilation(fix_mask, np.ones((5,5,5))).astype(np.uint8)

# save the result
nrrd.write('./fix_mask.nrrd', fix_mask.transpose(2,1,0), compression_level=2)

# # load saved results
# fix_mask, _ = nrrd.read('./fix_mask.nrrd')
# fix_mask = fix_mask.transpose(2,1,0)

## Motion correct
---

In [None]:
from CircuitSeeker import motion_correct

# information regarding the time series data
frames = {'folder':mov_directory,
          'prefix':'TM',
          'suffix':'.h5',
          'dataset_path':'/default',}

# voxel spacing for moving frames
mov_spacing = np.array([4.0, 0.406, 0.406])

# motion correct
transforms = motion_correct.motion_correct(
    fix, frames,
    fix_spacing, mov_spacing,
    fix_mask=fix_mask,
    time_stride=10,
    sigma=0.25,
    cluster_kwargs={
        'ncpus':6, 'threads':5,
        'min_workers': 40, 'max_workers':40,
    },
    metric='MI',
    alignment_spacing=1.6,
    sampling='regular',
    sampling_percentage=1.0,
    optimizer='GD',
    estimate_learning_rate='never',
    learning_rate=0.2,
    iterations=100,
)

# write the transforms out as individual files for storage
motion_correct.save_transforms('./rigid_transforms.json', transforms)

# # load precomputed results
# transforms = motion_correct.read_transforms('./rigid_transforms.json')

## Apply transforms
---

In [None]:
# the motion corrected dataset will be written here as a zarr file
write_path = './motion_corrected.zarr'

# This requires some heavy computation
aligned_frames = motion_correct.resample_frames(
    frames,
    mov_spacing,
    transforms,
    write_path,
    mask=fix_mask,
    time_stride=1,
    cluster_kwargs={
        'ncpus':6, 'threads':5,
        'min_workers':40, 'max_workers':40,
    },
)

# # load precomputed result
# import zarr
# aligned_frames = zarr.open(write_path)

## Write a time slice
---

In [None]:
# how to slice in time/space
plane = 27
stride = 10

# initialize an array to hold slices
slice_over_time = np.empty((389, 1224, 2048), dtype=np.uint16)

for i in range(389):
    slice_over_time[i] = aligned_frames[i*stride, plane, ...]

# write out in a format you can read with Fiji/Icy etc.
nrrd.write('./slice27_timestride10_corrected.nrrd', slice_over_time.transpose(2,1,0), compression_level=2)