## Inspect results of motion correction
Step 2 of the Caiman processing pipeline for multi-layer two-photon calcium imaging movies. This notebook shows the results of the motion correction performed in step 1 and allows the selection of 'bad frames' (i.e. frames which have too much residual motion). This is an interactive step that has to be run seperately for each dataset.

### Specify analysis folder
Select the folder where the results from the motion correction are stored. This folder should contain a file `caiman_mc_log.yml` with all the important information about files, parameters etc.

In [None]:
mc_analysis_folder = '/home/luetcke/neurophys-storage/Luetcke/Gwen/M4.3/20181114/S1'
caiman_logfile = 'caiman_mc_log.yml'

### Imports & Setup
The first cells import the various Python modules required by the notebook. In particular, a number of modules are imported from the Caiman package. In addition, we also setup the environment so that everything works as expected.

In [None]:
# General imports
import os, yaml, json, fnmatch
from pprint import pprint
import numpy as np
from scipy.io import savemat
from tifffile import imsave
from IPython.display import clear_output

# Caiman
import caiman as cm
import utils, mc_utils
import caiman_utils as cm_utils

# Import Bokeh library
from bokeh.plotting import Figure, show
from bokeh.layouts import gridplot
from bokeh.models import Range1d, CrosshairTool, HoverTool, Legend
from bokeh.io import output_notebook, export_svgs
from bokeh.models.sources import ColumnDataSource
from bokeh import palettes

In [None]:
# This has to be in a separate cell, otherwise it wont work.
from bokeh import resources
output_notebook(resources=resources.INLINE)

In [None]:
if os.path.isfile(os.path.join(mc_analysis_folder, caiman_logfile)):
    with open(os.path.join(mc_analysis_folder, caiman_logfile)) as f:
        mc_log = yaml.load(f, Loader=yaml.FullLoader)
else:
    raise Exception('Could not find %s in %s' % (caiman_logfile, mc_analysis_folder))

In [None]:
data_folder = mc_log['data_folder']
joined_tif_list = mc_log['joined_tif_list']
stacked_files_by_group = mc_log['stacked_files_by_group']
trial_indices_list = mc_log['trial_indices_list']
total_frames_list = mc_log['total_frames_list']
n_groups = mc_log['n_groups']
n_planes = mc_log['n_planes']
metrics_files = mc_log['metrics_files']
mmap_files_rig = mc_log['mmap_files_rig']
if mc_log['config']['mc']['pw_rigid']:
    mmap_files_els = mc_log['mmap_files_els']

In [None]:
# get metadata
for file in os.listdir(data_folder):
    if fnmatch.fnmatch(file, '%s_%s_Join_G0_*[!badFrames].json' % 
                       (mc_log['config']['data']['day_folder'], mc_log['config']['data']['area_folder'])):
        meta = json.load(open(os.path.join(data_folder,file)))
        break
trial_index = np.array(meta['trial_index'])

### Display average signal intensity
This step is useful as sanity check how the imported data looks like.

In [None]:
# select group (0, 1, ...)
group_ix = 0

# customize plot
width = 1000
height = 400

trial_names = [x.replace(data_folder + os.path.sep,'')[:8] for x in stacked_files_by_group[group_ix]]

color_map = palettes.d3['Category10'][10] # colors for different planes

# prepare data structure
trial_names = [x.replace(data_folder + os.path.sep,'')[:8] for x in stacked_files_by_group[group_ix]]
trial_names_frames = [trial_names[x] for x in trial_indices_list[group_ix]]
data = {'x': np.array(range(total_frames_list[group_ix])), 
        'trial_idx': trial_indices_list[group_ix],
        'trial_name': trial_names_frames
       }

# add average for each plane
for i_plane in range(n_planes):
    tiff_file = os.path.join(data_folder, joined_tif_list[group_ix] + '_P%d.tif' % (i_plane))
    mov = cm.load(tiff_file, outtype=np.int16)

    # plot average signal intensity per frame
    frame_avg = np.mean(np.mean(mov, axis=1), axis=1)

    fieldname = 'y%s' % (i_plane)
    data[fieldname] = frame_avg
    
data_source = ColumnDataSource(data)

# create figure and plot
p = Figure(plot_width=width, plot_height=height, title=('Frame average - Group %d' % (group_ix))) 
p.add_tools(CrosshairTool(), utils.getHover())
for i_plane in range(n_planes):
    p.line('x', 'y%s' % (i_plane), source=data_source, line_width=2, color=color_map[i_plane], legend='Plane %s' % (i_plane))

show(p)

### Load metrics data
This may take a while as the metrics file is quite large. The metrics are stored in a single list variable:

`[metrics, crispness, norms, corr_mean, corr_min]`

In [None]:
corr_mean = []
corr_min = []
crispness = []
norms = []
mtrs = []
for i_group in range(n_groups):
    mc_metrics = np.load(os.path.join(data_folder, metrics_files[i_group]))
    corr_mean.append(mc_metrics[3])
    corr_min.append(mc_metrics[4])
    crispness.append(mc_metrics[1])
    norms.append(mc_metrics[2])
    mtrs.append(mc_metrics[0])

### Metrics and summary plots
Print different metrics for raw movie and rigid / pw-rigid corrected movies.

In [None]:
for i_group in range(n_groups):
    for i_plane in range(n_planes):
        print('MC evaluation - Group %d - Plane %d:' % (i_group, i_plane))
        mc_utils.printMetrics(corr_mean[i_group][i_plane], corr_min[i_group][i_plane], crispness[i_group][i_plane], norms[i_group][i_plane])
        print('\n')

Plot correlations of each frame with the template image (binned median) for original, rigid correction and pw-rigid correction.

In [None]:
# select group (0, 1, ...)
group_ix = 0
# select plane (0, 1, ..)
plane_ix = 0

metrics = mtrs[group_ix]

gridplot_array = []

frames = np.array(range(total_frames_list[group_ix]))

for plane_ix in range(n_planes):
    gridplot_array.append([])
    gridplot_array[plane_ix].append(Figure(plot_width=900, plot_height=300, title=('Correlation with template - Group %d - Plane %d' % (group_ix, plane_ix))))
    gridplot_array[plane_ix][0].line(frames,np.array(metrics[plane_ix]['corr_orig']), line_width=2, legend='Original', color='blue')
    gridplot_array[plane_ix][0].line(frames,np.array(metrics[plane_ix]['corr_rig']), line_width=2, legend='Rigid', color='orange')
    if mc_log['config']['mc']['pw_rigid']:
        gridplot_array[plane_ix][0].line(frames,np.array(metrics[plane_ix]['corr_els']), line_width=2, legend='PW-Rigid', color='green')
    
    gridplot_array[plane_ix].append(Figure(plot_width=250, plot_height=250))
    gridplot_array[plane_ix][1].circle(np.array(metrics[plane_ix]['corr_orig']), np.array(metrics[plane_ix]['corr_rig']), size=5)
    gridplot_array[plane_ix][1].line([0,1],[0,1], line_width=1, color='black', line_dash='dashed')
    gridplot_array[plane_ix][1].xaxis.axis_label = 'Original'
    gridplot_array[plane_ix][1].yaxis.axis_label = 'Rigid'
    
    if mc_log['config']['mc']['pw_rigid']:
        gridplot_array[plane_ix].append(Figure(plot_width=250, plot_height=250))
        gridplot_array[plane_ix][2].circle(np.array(metrics[plane_ix]['corr_rig']), np.array(metrics[plane_ix]['corr_els']), size=5)
        gridplot_array[plane_ix][2].line([0,1],[0,1], line_width=1, color='black', line_dash='dashed')
        gridplot_array[plane_ix][2].xaxis.axis_label = 'Rigid'
        gridplot_array[plane_ix][2].yaxis.axis_label = 'PW-Rigid'
    
grid = gridplot(gridplot_array, sizing_mode='fixed', toolbar_location='left')

show(grid)

### Detect frames with bad motion
Identify frames with significant residual motion (low correlation with template). Write a JSON file with criterion and indices of frames matching the criterion. This file can be used in further analysis to exclude the frames corrupted by motion.

In [None]:
thresh = [
    [0.07, 0.11, 0.15, 0.15]
] # find frames where value is less than criterion (one value per group and plane)

for i_group in range(n_groups):
    metrics = mtrs[i_group]
    for i_plane in range(n_planes):
        print('Group %d - Plane %d' % (i_group, i_plane))
        if mc_log['config']['mc']['pw_rigid']:
            # pw-rigid registration
            criterion = 'corr_els'
            bad_frames = [ix for ix, i in enumerate(metrics[i_plane][criterion]) 
                          if i < thresh[i_group][i_plane]]
            print('%1.0f frames matching criterion after pw-rigid registration.' % (len(bad_frames)))
            mc_utils.writeJsonBadFrames(criterion, thresh[i_group][i_plane], 
                                        bad_frames, mc_list[i_group][i_plane], 'els', data_folder)
        # rigid registration
        criterion = 'corr_rig'
        bad_frames = [ix for ix, i in enumerate(metrics[i_plane][criterion])
                      if i < thresh[i_group][i_plane]]
        print('\n%1.0f frames matching criterion after rigid registration.' % (len(bad_frames)))
        mc_utils.writeJsonBadFrames(criterion, thresh[i_group][i_plane], 
                                    bad_frames, mmap_files_rig[i_group][i_plane], 'rig', data_folder)
        print('\n')

### Remove bad frames

In [None]:
group_ix = 0

# t_start = time.time()

bad_frames = np.array([], dtype='int64')
fname_list = []
images_list = []

# first, create list of bad frame indices (for all planes combined)
for fname in mmap_files_rig[group_ix]:
    bad_frames = np.concatenate((bad_frames, cm_utils.getBadFrames(os.path.join(data_folder, fname))))
bad_frames = np.unique(bad_frames)

# remove the bad frames from all files
# trial_index = np.array((trial_indices_list[group_ix]))
for fname in mmap_files_rig[group_ix]:
    Yr, dims = cm_utils.loadData(os.path.join(data_folder, fname))
    images, Y, fname_rem, bad_frames_by_trial, trial_idx = cm_utils.removeBadFrames(os.path.join(data_folder, fname), 
                                                                                      trial_index, 
                                                                                      Yr, dims, bad_frames, 
                                                                                      data_folder)
    fname_list.append(fname_rem)
    images_list.append(images)
trial_index = trial_idx

# t_elapsed = time.time() - t_start
# print('Loading data / removing frames in %1.2f s' % (t_elapsed))

### Export data for manual source extraction
The following are exported to the folder where the original data is stored:
- 1 TIFF file per plane of motion corrected images with bad frames removed
- 1 MAT file per plane that contains:
    - motion corrected images with bad frames removed (images)
    - trial index for each frame (trial_index)
    - list of trial names (trial_names)
    - number of frames per trial (trial_frames)
    - frame indices of bad frames (bad_frames)

In [None]:
bad_frames_by_trial_copy = dict()
for key in bad_frames_by_trial.keys():
    bad_frames_by_trial_copy['trial_%s' % (key)] = bad_frames_by_trial[key]

for ix_plane, images in enumerate(images_list):
    # export to TIFF
    tiff_name = fname_list[ix_plane].replace('.mmap', '.tif')
    imsave(tiff_name, images, bigtiff=True)
    print('\nExported TIFF file for plane %d\n%s' % (ix_plane, tiff_name))
    
    # export to Matlab
    # create dictionary for saving as mat file (field names will be variable names in Matlab)
    mdict = {
        'images': images,
        'trial_index': trial_index,
        'trial_names': meta['source_file'],
        'trial_frames': meta['source_frames'],
        'bad_frames': bad_frames,
        'bad_frames_by_trial': bad_frames_by_trial_copy,
    }
    matfile_name = fname_list[ix_plane].replace('.mmap', '.mat')
    savemat(matfile_name, mdict=mdict, long_field_names=True)
    print('\nExported MAT file for plane %d\n%s' % (ix_plane, matfile_name))
    
print('Done!')

### Remove temporary files
Delete files that were created during processing and will not be required for downstream analysis. 

**Warning: there is no `undo` for this. Once the files have been deleted, one needs to re-run the entire motion correction pipeline in order to re-export data! Please make sure this is what you want!**

The following files will be deleted:
- all `.mmap` files
- concatenated `.tif` files **without** bad frames removed
- file with motion correction metrics results (`*_MC_metrics.npy`)

In [None]:
def get_files_to_delete(data_folder):
    """
    Returns a list of files to delete
    """
    only_files = [f for f in os.listdir(data_folder) if os.path.isfile(os.path.join(data_folder, f))]
    files_to_delete = []
    for file in only_files:
        if file.endswith('_.mmap'):
            files_to_delete.append(file)
        elif file.endswith('_MC_metrics.npy'):
            files_to_delete.append(file)
        elif file.endswith('.tif') and not 'remFrames' in file:
            files_to_delete.append(file)
    return [os.path.join(data_folder, f) for f in files_to_delete]

In [None]:
def get_total_size(file_list):
    """
    Returns the total disk space in GB for a list of files.
    """
    disk_space = 0
    for file in file_list:
        disk_space += os.path.getsize(file)
    return disk_space / 1000000000 # in GB

In [None]:
# clear any text in the output
clear_output()
# get the list of files to delete
files_to_delete = get_files_to_delete(data_folder)

# show files and ask to confirm
print('The following files will be deleted and can NOT be restored:')
pprint(files_to_delete)
a = input(prompt="Type 'yes' if you are really sure: ")


if a == 'yes':
    print('\nDeleting temporary files')
    # how many GB do we save?
    saved_disk_space = get_total_size(files_to_delete)
    # delete
    out = [os.remove(f) for f in files_to_delete]
    print('\nDone. You saved %1.1f GB disk space!' % saved_disk_space)
else:
    print('\nSkipping deletion of temporary files')