## Locate data
---

In [1]:
# file paths
p = '/nrs/ahrens/ahrenslab/fleishmang/functional_alex_early_march/20210303'
fix_directory = p + '/green_anatomy_after_20210303_123132'
mov_directory = p + '/spont4um_20210303_121349'

## Get fixed image mean
---

In [None]:
from CircuitSeeker import motion_correct
import nrrd

# path information for folder of time frames
frames = {'folder':fix_directory,
          'prefix':'TM',
          'suffix':'.h5',
          'dataset_path':'/default'}

# compute mean from all frames
fix = motion_correct.distributed_image_mean(
    frames,
    cluster_kwargs={'project':'ahrens', 'cores':4, 'processes':1},
)

# store output - switch to xyz axis order
nrrd.write('./fix.nrrd', fix.transpose(2,1,0), compression_level=2)

## Compute fixed image mask (optional)
---
Later, when we apply the transforms that motion correct all the time frames, we have the option to apply a mask to each frame. This will reduce the dataset size on disk by about 60% - depending on your data that could be several terabytes.

You should run this cell a few times varying the value of `lambda2`. Larger values will make the mask bigger, smaller values will make the mask smaller. Each time you should look at the mask on top of the fixed image (written out above) and proceed with a mask that covers the brain area entirely.

In [None]:
from CircuitSeeker.level_set import brain_detection
from scipy.ndimage import zoom, binary_dilation, binary_closing
import numpy as np

# the voxel spacings for the fixed and time series data
# these are essential to get right - make sure you know them
fix_spacing = np.array([1.0, 0.406, 0.406])
mov_spacing = np.array([4.0, 0.406, 0.406])

# we segment on downsampled data to make the segmentation sufficiently fast
fix_small = zoom(fix, [0.5, 0.25, 0.25], order=1).transpose(2,1,0)
fix_small_spacing = fix_spacing[::-1] * [4, 4, 2]

# segment
mask = brain_detection(
    fix_small,
    fix_small_spacing,
    smooth_sigmas=[12, 6, 4],
    lambda2=4,
    mask_smoothing=2,
)

# dilate the boundaries a little, go back to original sampling, and smooth boundaries
# you can also play with the dilation/closing element size here to adjust mask boundaries
mask = binary_dilation(mask, np.ones((10,10,10))).astype(np.uint8)
mask = zoom(mask, np.array(fix.shape[::-1]) / fix_small.shape, order=0)
mask = binary_closing(mask, np.ones((5,5,5))).astype(np.uint8)

# save the result
nrrd.write('./mask.nrrd', mask, compression_level=2)
mask = mask.transpose(2,1,0)

## Motion correct
---
Two important parameters here are `time_stride` and `sigma`. `time_stride` is the sub-sampling in time for which you want to correct, i.e. `time_stride=10` means only rigid align every 10th frame and interpolate to find the transforms for the frames in between. `sigma` is the standard deviation of a Gaussian applied to the transform parameters _over time_. This stabilizes the motion correction. When you increase `time_stride` you should _decrease_ sigma.

In [3]:
# information regarding the time series data
frames = {'folder':mov_directory,
          'prefix':'TM',
          'suffix':'.h5',
          'dataset_path':'/default'}

# the voxel spacings for the fixed and time series data
# these are essential to get right - make sure you know them
fix_spacing = np.array([1.0, 0.406, 0.406])
mov_spacing = np.array([4.0, 0.406, 0.406])

# motion correct
# this will launch a dask cluster and print some useful information
# you should watch the dashboard to get a sense of all the computations happening
# `transforms` will contain a 4x4 rigid transform matrix for every time frame
transforms = motion_correct.motion_correct(
    fix, frames,
    fix_spacing, mov_spacing,
    time_stride=10,
    sigma=0.75,
    cluster_kwargs={
        'project':'ahrens',
        'cores':1, 'processes':1,
        'max_workers':100,
    },
)

# write the transforms out as individual files for storage
transforms_folder = './transforms'
motion_correct.save_transforms(
    transforms, transforms_folder,
)

Cluster dashboard link:  http://10.36.110.12:8787/status
Scaling cluster to 100 workers with 1 cores per worker
*** This cluster costs 7.0 dollars per hour starting now ***
Waiting 30 seconds for cluster to scale
Wait time complete


## Apply transforms
---
Note we are using the `mask` to reduce the overall size of the written dataset here.

An important parameter here is `config`. This allows you to change the [dask configuration](https://docs.dask.org/en/latest/configuration-reference.html). Very large resample jobs may require configuration changes in order for the cluster to handle the large amount of computation without shutting down. You should always test these functions using a small dataset first.

For example - you can provide a keyword argument here `subset` which is a python `slice` object. This will specify the subset of frames that you actually want to transform: `subset = slice(100, 200, 10)` will only transform time points 100, 110, 120, ... 200.

In [5]:
# the motion corrected dataset will be written here as a zarr file
write_path = './motion_corrected.zarr'

# This requires some heavy computation
# set `cores` and `max_workers` carefully - you need to be aware of the
# resource cost of your job
aligned_frames = motion_correct.resample_frames(
    frames,
    mov_spacing,
    transforms,
    write_path,
    mask=mask,
    cluster_kwargs={
        'project':'ahrens',
        'cores':1, 'processes':1,
        'max_workers':200,
        'config':{},
    },
)

Cluster dashboard link:  http://10.36.110.12:8787/status
Scaling cluster to 200 workers with 1 cores per worker
*** This cluster costs 14.0 dollars per hour starting now ***
Waiting 30 seconds for cluster to scale
Wait time complete


## Write a time slice (optional)
---
Take a look at one slice of your data over time. We're taking every other time point so that the file isn't huge.

In [6]:
# how to slice in time/space
time_stride, plane = 2, 30

# get the data from the zarr file
slice_over_time = aligned_frames[::time_stride, plane, :, :]

# write out in a format you can read with Fiji/Icy etc.
nrrd.write('./slice_over_time.nrrd', slice_over_time.transpose(2,1,0), compression_level=2)