In [None]:
import caiman as cm
from caiman.motion_correction import MotionCorrect
from caiman.source_extraction.cnmf import cnmf, params

import matplotlib.pyplot as plt
import numpy as np
import scipy.io as sio

import os

from src.caiman_preprocessing import copy_data, find_local_max, replace_rows
from src.caiman_preprocessing_hyperparams import Hyperparams

## Hyperparameter Setup

In [None]:
# Hyperparameters for F147
F147 = Hyperparams(name='F147')
F147.set_paths(
    path_orig='data/2p_raw/F147/F147_20210526_fish4_blk1_LT_9dpf_00001.tif',
    path_src='results/F147.tif'
)
F147.set_params_dict(tau=2, k=1250)  # proc_index = 0: (tau = 2, k = 1250), proc_index = 1: (tau = 2, k = 250)
F147.set_lr_params(local_max_thr=20, local_max_rad=50,
                   channel_thr=0, correction_thr=25, correction_rad=50)
F147.set_lr_proxy_params(proxy_slices=[(slice(247, 311), slice(39, 152))])
F147.set_blank_params(
    path_image_meta='data/imfinfo/F147_imfinfo.mat',
    image_meta_var='image'
)
F147.set_piecewise_processing(proc_slices=[
    (slice(0, 247), slice(0, 256)),
    (slice(247, 320), slice(0, 256))
])

In [None]:
# Hyperparameters for F201
F201 = Hyperparams(name='F201')
F201.set_paths(
    path_orig='data/2p_raw/F201/F201_20210812_fish2_blk1_RT_9dpf_00001.tif',
    path_src='results/F201.tif'
)
F201.set_params_dict(tau=2, k=1500)  # proc_index = 0: (tau = 2, k = 1500)
F201.set_blank_params(
    path_image_meta='data/imfinfo/F201_imfinfo.mat',
    image_meta_var='image'
)
F201.set_lr_params(local_max_thr=50, local_max_rad=50,
                   channel_thr=0, correction_thr=55, correction_rad=50)

In [None]:
# Currently selected hyperparameters
hyp = F147

# Select the piece to run CNMF on (only for piecewise processing)
if hyp.piecewise_proc:
    hyp.proc_index = 0

## File Setup

In [None]:
# Move from the scripts directory to the main project directory
os.chdir('../')

In [None]:
# Uncomment the last line to load and resave the data
# See GitHub Issue #377 - https://github.com/flatironinstitute/CaImAn/issues/377#issuecomment-426740429
# copy_data(hyp.path_orig, hyp.path_src)

## CaImAn Parameter Setup

In [None]:
# Save the CaImAn parameter dictionary
hyp.set_fname(hyp.path_src)
opts = params.CNMFParams(params_dict=hyp.params_dict)

## Movie

In [None]:
# Play the movie of the original data
movie_orig = cm.load(hyp.path_src)
movie_orig.play()

## Fluorescence Diagnostics

In [None]:
# Create an array used for diagonistics and line removal
movie_dgn = np.copy(movie_orig)

# Remove all rectangular slices if line removal by proxy is used
if hyp.lr_proxy:
    for rectangle in hyp.proxy_slices:
        movie_dgn[:, rectangle[0], rectangle[1]] = np.full_like(movie_dgn[:, rectangle[0], rectangle[1]], np.nan)

In [None]:
# Find the mean flourescence of each frame
movie_dgn_means = np.nanmean(movie_dgn, axis=(1, 2))

In [None]:
# Create a diagonstic plot of the mean fluorescences
plt.plot(movie_dgn_means, '.')

# Draw horizontal lines to show thresholds for finding local maxima and identifying channels
local_max_threshold = plt.axhline(hyp.local_max_thr, c='y')
channel_threshold = plt.axhline(hyp.channel_thr, c='g')

# Add a title, labels, and a legend to the plot
plt.title("Mean Fluorescences of All Frames")
plt.xlabel("Frame")
plt.ylabel("Mean Fluorescence")
plt.legend([local_max_threshold, channel_threshold], ["Local Maximum\nThreshold", "Channel\nThreshold"], loc=1)

# Display the final plot
plt.show()

In [None]:
# Find all local maxima of the mean flourescences
local_max = find_local_max(movie_dgn_means, hyp.local_max_thr, hyp.local_max_rad)

In [None]:
# Create diagonstic plots of mean fluorescences around each local maximum
for point in local_max:
    
    # Plot all points within the specified radius to the local maximum
    lb, ub = point - hyp.correction_rad, point + hyp.correction_rad + 1
    plt.plot(range(lb, ub), movie_dgn_means[lb:ub], '.')
    
    # Draw a horizontal line to show the channel threshold
    channel_threshold = plt.axhline(hyp.channel_thr, c='g')
    
    # Add a title, labels, and a legend to the plot
    plt.title("Mean Fluorescences of Frames Near " + str(point))
    plt.xlabel("Frame")
    plt.ylabel("Mean Fluorescence")
    plt.legend([channel_threshold], ["Channel\nThreshold"], loc=1)
    
    # Display the final plot
    plt.show()

In [None]:
# Create diagonstic plots of mean row fluorescences for each image requiring correction
for point in local_max:
    
    # Check every frame within the specified radius to the local maximum
    lb, ub = point - hyp.correction_rad, point + hyp.correction_rad + 1
    for i in range(lb, ub):
        
        # Skip plotting if the frame is not the correct channel
        if not movie_dgn_means[i] > hyp.channel_thr:
            continue
        
        # Calculate the mean row fluorescences in of the current frame
        frame_row_means = np.nanmean(movie_dgn[i], axis=1)
        
        # Check that at least one mean row fluorescence is above the correction threshold
        for j in range(frame_row_means.size):
            if frame_row_means[j] > hyp.correction_thr:
                
                # Plot the mean row fluorescences
                plt.plot(frame_row_means, '.')
                
                # Draw a horizontal line to show the correction threshold
                correction_threshold = plt.axhline(hyp.correction_thr, c='r')
                
                # Add a title, labels, and a legend to the plot
                plt.title("Mean Row Fluorescences of Frame " + str(i))
                plt.xlabel("Row Index")
                plt.ylabel("Mean Fluorescence")
                plt.legend([correction_threshold], ["Correction\nThreshold"], loc=1)
                
                # Display the final plot and move on the next frame (if any)
                plt.show()
                break

## Line Removal

In [None]:
# Reload the movie
movie_edit = cm.load(hyp.path_src)

In [None]:
# Remove lines
for point in local_max:
    replace_rows(movie_edit, movie_dgn, point, hyp.channel_thr,
                 hyp.correction_thr, hyp.correction_rad)

## Blank Removal

In [None]:
# Load frame metadata
image_metadata = sio.loadmat(hyp.path_image_meta)

In [None]:
# Find indices of blank frames
blank_idx = []
for i in range(movie_dgn_means.size):
    if not movie_dgn_means[i] > hyp.channel_thr:
        blank_idx.append(i)

# Convert the blank frame indices list to an array
blank_idx = np.array(blank_idx)

In [None]:
# Remove all blank frames from the data
movie_edit = np.delete(movie_edit, blank_idx, axis=0)

In [None]:
# Remove all blank frames from the metadata
image_metadata[hyp.image_meta_var] = np.delete(image_metadata[hyp.image_meta_var], blank_idx)

## Editing Results

In [None]:
# Create a diagonstic plot of the corrected mean fluorescences
plt.plot(np.mean(movie_edit, axis=(1, 2)), '.')

# Add a title and labels to the plot
plt.title("Corrected Mean Fluorescences of All Frames")
plt.xlabel("Frame")
plt.ylabel("Mean Fluorescence")

# Display the final plot
plt.show()

In [None]:
# Switch to the results directory (debug)
os.chdir('results/')

In [None]:
# Save the corrected movie
path_edit = hyp.name + '_edit.tif'
movie_edit.save(path_edit)

In [None]:
# Save the edited metadata
sio.savemat(hyp.name + '_imfinfo_edit.mat', image_metadata)

## Motion Correction

In [None]:
# Start or restart the cluster
if 'dview' in locals():
    cm.stop_server(dview=dview)
dview, n_processes = cm.cluster.setup_cluster()[1:3]

In [None]:
# Perform rigid motion correction
mc = MotionCorrect(path_edit, dview=dview, **opts.get_group('motion'))
mc.motion_correct(save_movie=True)

In [None]:
# Compare the motion-corrected data with the line-removed data
movie_mc = cm.load(mc.mmap_file)
cm.concatenate([movie_edit - mc.min_mov * mc.nonneg_movie, movie_mc], axis=2).play()

##  Piecewise Processing

In [None]:
# Get the entire FOV or a subrectangle of the entire FOV
movie_piece = movie_mc
if hyp.piecewise_proc:
    hslice = hyp.proc_slices[hyp.proc_index][0]
    vslice = hyp.proc_slices[hyp.proc_index][1]
    movie_piece = movie_mc[:, hslice, vslice]

## Memory Mapping

In [None]:
# Save memory mapped files from piecewise processing
base_name = hyp.name + '_' + str(hyp.proc_index) + '_memmap_'
fname_mmap = cm.save_memmap([movie_piece], base_name=base_name, order='C', dview=dview)

In [None]:
# Load memory mapped files from piecewise processing
Yr, dims, T = cm.load_memmap(fname_mmap)
images = np.reshape(Yr.T, [T] + list(dims), order='F')

## Source Extraction

In [None]:
# Restart the cluster to clean up memory
if 'dview' in locals():
    cm.stop_server(dview=dview)
dview, n_processes = cm.cluster.setup_cluster()[1:3]

In [None]:
# Update the parameters to use the memory mapped file
hyp.set_fname(fname_mmap)
opts = params.CNMFParams(params_dict=hyp.params_dict)

In [None]:
# Run CNMF
cnm_orig = cnmf.CNMF(n_processes, params=opts, dview=dview)
cnm_orig = cnm_orig.fit(images)

In [None]:
# Compute the correlation image
corr_img = cm.local_correlations(images.transpose((1, 2, 0)))
corr_img[np.isnan(corr_img)] = 0

In [None]:
# Plot the contours of any identified components
cnm_orig.estimates.plot_contours_nb(img=corr_img)

In [None]:
# Rerun CNMF
cnm = cnm_orig.refit(images, dview=dview)

In [None]:
# Plot the revised contours of identified components
cnm.estimates.plot_contours_nb(img=corr_img)

## Component Evaluation

In [None]:
# Evaluate the quality of inferred spatial components
cnm.estimates.evaluate_components(images, cnm.params, dview=dview)

In [None]:
# Display plot of all components
cnm.estimates.plot_contours_nb(img=corr_img, idx=cnm.estimates.idx_components)

In [None]:
# Display plot of accepted components
cnm.estimates.nb_view_components(img=corr_img, idx=cnm.estimates.idx_components)

In [None]:
# Display plot of rejected components, if any
if len(cnm.estimates.idx_components_bad) > 0:
    cnm.estimates.nb_view_components(img=corr_img, idx=cnm.estimates.idx_components_bad)

In [None]:
# Keep only the accepted components
cnm.estimates.select_components(use_object=True)

## Final Results

In [None]:
# View a plot of the results
cnm.estimates.nb_view_components(img=corr_img, denoised_color='red')

In [None]:
# Save CNMF results
cnm.save(cnm.mmap_file[:-4] + 'hdf5')

In [None]:
# View a movie of the results
cnm.estimates.play_movie(images)

In [None]:
# Stop the cluster
cm.stop_server(dview=dview)