# CaImAn Motion Correction Pipeline


## User Input

In [None]:
is_this_a_test_run = False

experiments_folder = '/Users/priscilla/Documents/Local - Moss Lab/20251003/sid260'

# Image parameters
um_per_pixels = 2.

# Motion correction parameters (in micrometers to match Fiji measurements)
pw_rigid = True
max_shift_um = 128.
strides_um = 128.
overlap_um = 96.
max_deviation_um = 12.

shifts_opencv = False  # True = bicubic interpolation, False = FFT (True is faster)
border_nan = 'copy'  # replicate values along the boundary (if True, fill in with NaN)

# Testing subset (gets overwritten in real runs)
subset_slice = slice(0, None, 10)  # slice(first_frame, last_frame, step)

## Motion Correction

Import the same packages as the `demo_motion_correction.ipynb` together with `datetime` and `pathlib` for logging purposes.

In [None]:
from datetime import datetime
from pathlib import Path

import cv2
from IPython import get_ipython
import matplotlib.pyplot as plt
import numpy as np
import os.path
import logging

try:
    cv2.setNumThreads(0)
except:
    pass

try:
    if __IPYTHON__:
        get_ipython().run_line_magic('load_ext', 'autoreload')
        get_ipython().run_line_magic('autoreload', '2')
except NameError:
    pass

import caiman as cm
from caiman.motion_correction import MotionCorrect, tile_and_correct, motion_correction_piecewise

Get files names and start a log of the current session.

In [None]:
# Get movie paths and set up log filename
experiments_path = Path(experiments_folder)
movie_paths = sorted([p for p in experiments_path.rglob("**/*.tif")])
movie_paths = [p for p in movie_paths if p.parent.name == 'raw']

if movie_paths == []:
    raise ValueError(f'No .tif files found in {experiments_folder}"')

current_time = datetime.now()
time_filename = current_time.strftime("%Y%m%d_%H%M%S")
filename_start = '_'.join(movie_paths[0].stem.split('_')[:-2])

log_filename = time_filename + '_mcor_' + filename_start + '.log'

# Get subset of files if test run and adjust log filename
if is_this_a_test_run:
    movie_paths = movie_paths[subset_slice]
    log_filename = './logs/TEST_' + log_filename
else:
    log_filename = './logs/' + log_filename

# Set up logging
logfile = Path(log_filename)
logger = logging.getLogger('caiman')

# I'm setting to INFO to get detailed logs, change to WARNING to reduce output
logger.setLevel(logging.INFO)
logfmt = logging.Formatter('%(relativeCreated)12d [%(filename)s:%(funcName)20s():%(lineno)s] [%(process)d] %(message)s')
if logfile is not None:
    handler = logging.FileHandler(logfile)
else:
    handler = logging.StreamHandler()
handler.setFormatter(logfmt)
logger.addHandler(handler)

# Log start time and files to be processed
time_logging = current_time.strftime("%Y-%m-%d %H:%M:%S")

if is_this_a_test_run:
    logger.info(f"Test Run started at {time_logging}")
else:
    logger.info(f"Motion Correction started at {time_logging}")

logger.info("User Input:")
logger.info(f"  experiments_folder: {experiments_folder}")
logger.info(f"  um_per_pixels: {um_per_pixels}")
logger.info(f"  pw_rigid: {pw_rigid}")
logger.info(f"  max_shift_um: {max_shift_um}")
logger.info(f"  strides_um: {strides_um}")
logger.info(f"  overlap: {overlap_um}")
logger.info(f"  max_deviation_um: {max_deviation_um}")
logger.info(f"  shifts_opencv: {shifts_opencv}")
logger.info(f"  border_nan: {border_nan}")

if is_this_a_test_run:
    logger.info("Test run settings:")
    logger.info(f"  first file: {subset_slice.start}")
    logger.info(f"  last file: {subset_slice.stop}")
    logger.info(f"  step: {subset_slice.step}")

logger.info(f"List of files being processed ({len(movie_paths)} files):")

for mp in movie_paths:
    logger.info(f"  {mp.resolve()}")

# Set environment variables (so that linear algebra libraries don't use other threads)
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["VECLIB_MAXIMUM_THREADS"] = "1"

Load the first file to set motion correction parameters (same as in Matlab).

In [None]:
# Get movie dimensions from the first movie
# first_movie = cm.load(movie_paths[0])
# frames, height, width = np.shape(first_movie)

# Motion correction parameters

max_shifts = (int(max_shift_um / um_per_pixels), int(max_shift_um / um_per_pixels))
strides = (int(strides_um / um_per_pixels), int(strides_um / um_per_pixels))
overlaps = (int(overlap_um / um_per_pixels), int(overlap_um / um_per_pixels))
max_deviation_rigid = int(max_deviation_um / um_per_pixels)

# Log the parameters used
logger.info("Motion correction parameters:")
logger.info(f"  pw_rigid: {pw_rigid}")
logger.info(f"  max_shifts: {max_shifts}")
logger.info(f"  strides: {strides}")
logger.info(f"  overlaps: {overlaps}")
logger.info(f"  max_deviation_rigid: {max_deviation_rigid}")
logger.info(f"  shifts_opencv: {shifts_opencv}")
logger.info(f"  border_nan: {border_nan}")

Start cluster to execute the computations in parallel and create motion correction object.

In [None]:
#%% start the cluster (if a cluster already exists terminate it)
if 'dview' in locals():
    cm.stop_server(dview=dview)
c, dview, n_processes = cm.cluster.setup_cluster(
    backend='multiprocessing', n_processes=None, single_thread=False)

# create a motion correction object
mc = MotionCorrect(movie_paths, dview=dview, max_shifts=max_shifts,
                  strides=strides, overlaps=overlaps,
                  max_deviation_rigid=max_deviation_rigid, 
                  shifts_opencv=shifts_opencv, nonneg_movie=True,
                  border_nan=border_nan, pw_rigid=False)

Run motion correction using NoRMCorre.

In [None]:
%%time
mc.motion_correct(save_movie=True)

time_logging = current_time.strftime("%Y-%m-%d %H:%M:%S")
logger.info(f"Finished motion correction at {time_logging}")

In [None]:
%%time
mc.pw_rigid = True
mc.motion_correct(save_movie=True, template=mc.total_template_rig)

time_logging = current_time.strftime("%Y-%m-%d %H:%M:%S")
logger.info(f"Finished motion correction at {time_logging}")

## Quality Control

The plots below help check if the motion correction was sucessfull.

The "Patch Shifts" plot below shows how much each patch was shifted. Thick plots (~4) usually mean wobbly images, because patches are moving relatively independently. Sharp spikes usually imply a sudden shift in the image. Steps can mean that the image changed "shape".

In [None]:
#%% visualize elastic shifts
plt.close()
plt.figure(figsize = (20,10))
plt.suptitle('Patch Shifts')
plt.subplot(2, 1, 1)
plt.plot(mc.x_shifts_els)
plt.ylabel('x shifts (pixels)')
plt.subplot(2, 1, 2)
plt.plot(mc.y_shifts_els)
plt.ylabel('y shifts (pixels)')
plt.xlabel('frames')
#%% compute borders to exclude
bord_px_els = np.ceil(np.maximum(np.max(np.abs(mc.x_shifts_els)),
                                 np.max(np.abs(mc.y_shifts_els)))).astype(int)

In [None]:
%%capture
#% compute metrics for the results (TAKES TIME!!)
fnames = [str(p) for p in movie_paths]
final_size = np.subtract(mc.total_template_els.shape, 2 * bord_px_els) # remove pixels in the boundaries
winsize = 100
swap_dim = False
resize_fact_flow = .2    # downsample for computing ROF

tmpl_orig, correlations_orig, flows_orig, norms_orig, crispness_orig = cm.motion_correction.compute_metrics_motion_correction(
    fnames[0], final_size[0], final_size[1], swap_dim, winsize=winsize, play_flow=False, resize_fact_flow=resize_fact_flow)

tmpl_els, correlations_els, flows_els, norms_els, crispness_els = cm.motion_correction.compute_metrics_motion_correction(
    mc.fname_tot_els[0], final_size[0], final_size[1],
    swap_dim, winsize=winsize, play_flow=False, resize_fact_flow=resize_fact_flow)

In [None]:
plt.close()
plt.figure(figsize = (20,10))
plt.suptitle('Correlation with Mean Frame')
plt.subplot(211); plt.plot(correlations_orig); plt.plot(correlations_els)
plt.legend(['Original', 'PW-Rigid'])
plt.subplot(223); plt.scatter(correlations_orig, correlations_els); plt.xlabel('Original'); 
plt.ylabel('PW-Rigid'); plt.plot([0.3,0.7],[0.3,0.7],'r--');

In [None]:
print('Crispness original: ' + str(int(crispness_orig)))
print('Crispness elastic: ' + str(int(crispness_els)))

In [None]:
# %% plot the results of Residual Optical Flow
fls = [
    cm.paths.fname_derived_presuffix(str(mc.fname_tot_els[0]), "metrics", swapsuffix="npz"),
    cm.paths.fname_derived_presuffix(str(mc.fname[0]), "metrics", swapsuffix="npz"),
]

plt.figure(figsize=(20, 10))
plt.suptitle('Residual Optical Flow')
for cnt, fl, metr in zip(range(len(fls)), fls, ["pw_rigid", "raw"]):
    with np.load(fl) as ld:
        print(ld.keys())
        print(fl)
        print(
            str(np.mean(ld["norms"]))
            + "+/-"
            + str(np.std(ld["norms"]))
            + " ; "
            + str(ld["smoothness"])
            + " ; "
            + str(ld["smoothness_corr"])
        )

        plt.subplot(len(fls), 3, 1 + 3 * cnt)
        plt.ylabel(metr)
        print(f"Loading data with base {fl[:-12]}")
        try:
            mean_img = np.mean(cm.load(fl[:-12] + ".mmap"), 0)[12:-12, 12:-12]
        except:
            try:
                mean_img = np.mean(cm.load(fl[:-12] + ".tif"), 0)[12:-12, 12:-12]
            except:
                mean_img = np.mean(cm.load(fl[:-12] + "hdf5"), 0)[12:-12, 12:-12]

        lq, hq = np.nanpercentile(mean_img, [0.5, 99.5])
        plt.imshow(mean_img, vmin=lq, vmax=hq)
        plt.title("Mean")
        plt.subplot(len(fls), 3, 3 * cnt + 2)
        plt.imshow(ld["img_corr"], vmin=0, vmax=0.35)
        plt.title("Corr image")
        plt.subplot(len(fls), 3, 3 * cnt + 3)
        flows = ld["flows"]
        plt.imshow(
            np.mean(np.sqrt(flows[:, :, :, 0] ** 2 + flows[:, :, :, 1] ** 2), 0),
            vmin=0,
            vmax=0.3,
        )
        plt.colorbar()
        plt.title("Mean optical flow")

## Save Motion Corrected TIFF Files

Run the next cell to save the output of the motion correction.

In [None]:
logger.info(f"Saving files...")

for mmap_filename, movie_path in zip(mc.mmap_file, movie_paths):
    save_folder = movie_path.parent.parent / 'processed' / 'caiman_mcor'
    
    # Create output directory if it doesn't exist
    save_folder.mkdir(parents=True, exist_ok=True)
    mcor_path = save_folder / (movie_path.stem + '_mcor.tif')

    # Check if file already exists
    if mcor_path.exists():
        message = f"File {mcor_path.resolve()} already exists."
        logger.error(message)
        raise FileExistsError(message)

    # Load the motion corrected mmap and save as tiff
    m = cm.load(mmap_filename)
    m.save(mcor_path)
    logger.info(f"Saved file {mcor_path.resolve()}")

In case you restarted the kernel and lost the motion correction object, you can still get the TIFFs by directly converting the files on the temp folder. The next cell gets the paths of all memory map files in that folder. Filter the list to get only the files corresponding to `movie_paths` above (and that they have they are sorted).

In [None]:
# logger.info(f"Saving files...")

# for mmap_filename, movie_path in zip(mc.mmap_file, movie_paths):
#     # Check if mmap filename starts with movie path stem
#     if not Path(mmap_filename).stem.startswith(movie_path.stem):
#         message = f"File {mmap_filename} does not match movie {movie_path.name}"
#         logger.error(message)
#         raise ValueError(message)

#     save_folder = movie_path.parent.parent / 'processed' / 'caiman_mcor'
    
#     # Create output directory if it doesn't exist
#     save_folder.mkdir(parents=True, exist_ok=True)
#     mcor_path = save_folder / (movie_path.stem + '_mcor.tif')

#     # Check if file already exists
#     if mcor_path.exists():
#         message = f"File {mcor_path.resolve()} already exists."
#         logger.error(message)
#         raise FileExistsError(message)

#     # Load the motion corrected mmap and save as tiff
#     m = cm.load(mmap_filename)
#     m.save(mcor_path)
#     logger.info(f"Saved file {mcor_path.resolve()}")

# mcor_folder = experiments_path / 'processed/caiman_mcor'
# logger.info(f"Saving files to folder (directly from temp): {mcor_folder.resolve()}")

# # Create output directory if it doesn't exist
# mcor_folder.mkdir(parents=True, exist_ok=True)

# temp_folder = Path('/Users/priscilla/caiman_data/temp')
# temp_paths = sorted([p for p in temp_folder.rglob("*.mmap")])

## Clean up Resources

In [None]:
cm.stop_server(dview=dview)
logging.shutdown()