# We show how to run rigid and/or piecewise rigid motion correction on the demoMovie.tif dataset found in the datasets folder of this repo

In [1]:
import numpy as np
from jnormcorre import motion_correction
# %matplotlib widget
import matplotlib.pyplot as plt
#Things that the mc function was missing -- put this in a better location
import os
import tifffile
%load_ext autoreload


# Specify dataset location

In [2]:
filename = "../datasets/demoMovie.tif"

# Run Motion Correction

In [3]:
def get_shape(filename):
    import tifffile
    with tifffile.TiffFile(filename) as tffl:
      num_frames = len(tffl.pages)
      for page in tffl.pages[0:1]:
          image = page.asarray()
          x, y = page.shape
    return (x,y,num_frames)

def resolve_dataformats(filename):
    '''
    Function for managing bad data formats (such as single-page tif files) which are tough to load. Resolves these issues by loading the data into memmap format and then saving the data (in small batches) into a better format
    Input: 
        filename: str. String describing the full filepath of the datafile
    Returns: 
        file_output: list of strings. In this list, each string is a filename. These files, taken together, form the entire dataset
    '''
    _, extension = os.path.splitext(filename)[:2]
    if extension in ['.tif', '.tiff', '.btf']:  # load tif file
        with tifffile.TiffFile(filename) as tffl:
            multi_page = True if tffl.series[0].shape[0] > 1 else False
            if len(tffl.pages) == 1:
                display("Data is saved as single page tiff file. We will re-save data as sequence of smaller tifs to improve performance, but this will take time. To avoid this issue, save your data as multi-page tiff files")
                file_output = chunk_singlepage_data(filename)
                return file_output

    file_output = [filename]
    return file_output

def motion_correct(filename,
                   outdir,
                   dxy = (2., 2.),
                   max_shift_um = (12., 12.),
                   max_deviation_rigid = 3,
                   patch_motion_um = (100., 100.),
                   overlaps = (24, 24),
                   border_nan= 'copy',
                   niter_rig = 4,
                   splits=200,
                   pw_rigid = True,
                   gSig_filt=None,
                   save_movie=True,
                   dtype='int16',
                   sketch_template=False,
                   **params):
    """
    Runs motion correction from caiman on the input dataset with the
    option to process the same dataset in multiple passes.
    Parameters
    ----------
    filename : string
        Full path + name for destination of output config file.
    outdir : string
        Full path to location where outputs should be written.
    dxy: tuple (2 elements)
        Spatial resolution in x and y in (um per pixel)
    max_shift_um: tuple (2 elements)
        Maximum shift in um
    max_deviation_rigid: int
        Maximum deviation allowed for patch with respect to rigid shifts
    patch_motion_um: 
        Patch size for non rigid correction in um
    overlaps:
        Overlap between patches
    border_nan: 
        See linked caiman docs for details
    niter_rig: int
        Number of passes of rigid motion correction (used to estimate template)
    splits: int
        We divide the registration into chunks (temporally). Splits = number of frames in each chunk. So splits = 200 means we break the data into chunks, each containing ~200 frames.
    pw_rigid: boolean 
        Indicates whether or not to run piecewise rigid motion correction
    devel: boolean
        Indicates whether this code is run in development mode. If in development mode, the original data is not deleted.
    Returns
    -------
    None :
    """

    from jnormcorre.utils.movies import load
    from jnormcorre import motion_correction
    import math

    # Iteratively Run MC On Input File
    display("Running motion correction...")
    target = resolve_dataformats(filename)

    total_frames_firstfile = get_shape(target[0])[2]
    splits = math.ceil(total_frames_firstfile / splits)
    display("Number of chunks is {}".format(splits))

    ## TODO: Eliminate this convention where we include the default parameters both in the function signature above and here
    # Default MC_dict
    mc_dict = {
    'max_deviation_rigid': 3,           # maximum deviation between rigid and non-rigid
    'max_shifts': (6, 6),               # maximum shifts per dimension (in pixels)
    'min_mov': -5,                      # minimum value of movie
    'niter_rig': 4,                     # number of iterations rigid motion correction
    'niter_els': 1,                     # number of iterations of piecewise rigid motion correction
    'nonneg_movie': True,               # flag for producing a non-negative movie
    'num_splits_to_process_els': None,  # The number of splits of the data which we use to estimate the template for the rigid motion correction. If none, we look at entire dataset.
    'num_splits_to_process_rig': None,  # The number of splits of the data which we use to estimate the template for pwrigid motion correction. If none, we look at entire dataset.
    'overlaps': (32, 32),               # overlap between patches in pw-rigid motion correction
    'pw_rigid': False,                  # flag for performing pw-rigid motion correction
    'splits_els': 14,                   # number of splits across time for pw-rigid registration
    'splits_rig': 14,                   # number of splits across time for rigid registration
    'strides': (96, 96),                # how often to start a new patch in pw-rigid registration
    'upsample_factor_grid': 4,          # motion field upsampling factor during FFT shifts
    'indices': (slice(None), slice(None)),  # part of FOV to be corrected
    'gSig_filt': None
}

    max_shifts = [int(a/b) for a, b in zip(max_shift_um, dxy)]
    strides = tuple([int(a/b) for a, b in zip(patch_motion_um, dxy)])
    
    mc_dict['pw_rigid']= pw_rigid
    mc_dict['strides'] = strides
    mc_dict['overlaps'] = overlaps
    mc_dict['max_deviation_rigid'] = max_deviation_rigid

    #Add these as formal parameters
    # mc_dict['niter_rig'] = niter_rig
    # mc_dict['niter_els'] = niter_els
    if sketch_template:
        mc_dict['num_splits_to_process_els'] = 5
        mc_dict['num_splits_to_process_rig'] = 5
    mc_dict['gSig_filt'] = gSig_filt
    mc_dict['max_shifts'] = max_shifts
    mc_dict['splits_els'] = splits
    mc_dict['splits_rig'] = splits

    corrector = motion_correction.MotionCorrect(target, **mc_dict)

    # Run MC, Always Saving Non-Final Outputs For Use In Next Iteration
    corrector_obj, target_file = corrector.motion_correct(
        save_movie=save_movie
    )

    display("Motion correction completed.")

    # Save Frame-wise Shifts
    display(f"Saving computed shifts to ({outdir})...")
    np.savez(os.path.join(outdir, "shifts.npz"),
             shifts_rig=corrector.shifts_rig,
             x_shifts_els=corrector.x_shifts_els if pw_rigid else None,
             y_shifts_els=corrector.y_shifts_els if pw_rigid else None)
    display('Shifts saved as "shifts.npz".')

    corrector_obj.batching=10 ##Long term need to avoid this...
    return corrector_obj, target_file

In [1]:
obj, registered_filename = motion_correct(filename, ".", sketch_template = True, pw_rigid = True, save_movie=True, overlaps=(10,10))

# Generate Visualization (Run this if you have used save_movie = True to save out motion corrected movie)

In [8]:
def motion_correction_diagnostic(original_file, registered_file, frame_list = None):
    if frame_list is None:
        original_movie = tifffile.imread(original_file).transpose(1,2,0)
        registered_movie = tifffile.imread(registered_file).transpose(1,2,0)
    else:
        original_movie = tifffile.imread(original_file, key=frame_list).transpose(1,2,0)
        registered_movie = tifffile.imread(registered_file, key=frame_list).transpose(1,2,0)
    d1, d2, T = original_movie.shape
    display_movie = np.zeros((d1, d2*2, T), dtype=np.float32)
    display_movie[:, :d2, :] = original_movie
    display_movie[:, d2:, :] = registered_movie
    
    return display_movie

display_movie = motion_correction_diagnostic(filename, registered_filename)

tifffile.imwrite("diagnostic.tiff", display_movie.transpose(2, 0, 1))