# Light Beads Microscopy Demo Pipeline 

## Overview
### Pre-Processing:
- Extract ScanImage metadata
- Correct Bi-Directional Offset for each ROI
- Calculates and corrects the MROI seams (IN PROGRESS)
### Motion Correction
- Apply the nonrigid motion correction (NoRMCorre) algorithm for motion correction.
- View pre/most correction movie
- Use quality metrics to evaluate registration quality
### Segmentation
- Apply the constrained nonnegative matrix factorization (CNMF) source separation algorithm to extract initial estimates of neuronal spatial footprints and calcium traces.
- Apply quality control metrics to evaluate the initial estimates, and narrow down to the final set of estimates.

### Setup
- Import necessary libraries

Notable: Numpy, Cv2, Zarr, Dask, Matplotlib

In [1]:
import os
import sys
import time
from pathlib import Path

import cv2
import numpy as np
import pandas as pd
import psutil

# Give this notebook access to the root package
sys.path.append('../../')  # TODO: Take this out when we upload to pypi
print(sys.path[0])

import core.io
import scanreader

import zarr
import bokeh.plotting as bpl
import holoviews as hv
import panel as pn
from IPython import get_ipython
import logging
import matplotlib.pyplot as plt

try:
    import dask.array as da
    has_dask = True
except ImportError:
    has_dask = False

try:
    cv2.setNumThreads(0)
except():
    pass

try:
    if __IPYTHON__:
        get_ipython().run_line_magic('load_ext', 'autoreload')
        get_ipython().run_line_magic('autoreload', '2')
except NameError:
    pass

bpl.output_notebook()
hv.notebook_extension('bokeh')

# logging
logging.basicConfig(format="{asctime} - {levelname} - [{filename} {funcName}() {lineno}] - pid {process} - {message}",
                    filename=None, 
                    level=logging.WARNING, style="{") # this shows you just errors that can harm your program
                    # level=logging.DEBUG, style="{") # this shows you general information that developers use to trakc their program 
                    # (be careful when playing movies, there will be a lot of debug messages)

# set env variables 
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["VECLIB_MAXIMUM_THREADS"] = "1"

# this session had output planes ordered differently, we need to reorder them
chan_order = np.array([ 1, 5, 6, 7, 8, 9, 2, 10, 11, 12, 13, 14, 15, 16, 17, 3, 18, 19, 20, 21, 22, 23, 4, 24, 25, 26, 27, 28, 29, 30])  # this is specific to our dataset
chan_order = [x-1 for x in chan_order]  # convert to 0-based indexing

## Extract data using scanreader, joining contiguous ROI's, and plot our mean image

Our ScanReader object contains all of the properties needed to keep track of our raw data. 
- ScanImage metadata is stored alongside header metadata, this ScanImage specific data is what's needed to assemble frames from constituent ROIs.
- We calculate the frame rate and time/pixels between scans and ROI's using the following metadata:

![frame rate calculation](../../docs/img/FrameRate1eq.png)


#### Joining Contiguious ROI's

Setting `join_contiguous=True` will combine ROI's with the following constraints:

1) Must be the same size/shape
2) Must be located in the same scanning depth
3) Must be located in the same slice
- ROI can be directly left, right, above or below the adjacent ROI's

In [20]:
overwrite = False                                 # flag to re-extract tiffs with extracted data

datapath = Path('/data2/fpo/data/raw/high_res/')   # string pointing to directory containing your data
savepath = Path('/data2/fpo/data/extraction/')           # string pointing to directory containing your data
savepath.mkdir(exist_ok=True, parents=True)

htiffs = [x for x in datapath.glob('*.tif')]      # this accumulates a list of every filepath which contains a .tif file
reader = scanreader.read_scan(str(htiffs[0]), join_contiguous=True, lbm=True, x_cut=(6,6))  # this should take < 2s, no actual data is being read yet

In [21]:
reader

In [22]:
data = reader[0]                                 # here, data is being read

### Data Storage: Zarr

[Zarr documentation](https://zarr.readthedocs.io/en/stable/tutorial.html)

Deciding how to save data on a host operating system is far from straight foreward. Read/write operations will vary widely between data saved in a **single file**
structure vs smaller chunks, e.g. one image per file, one image per epoch, etc. 
 
The former strategy is clean/consice and easy to handle but is *not* feasable with large (>10GB) datasets. 

The latter strategy of spreading files acrossed nested groups of directories, each with their own metadata/attributes has been widely adopted as the more sensible approach. HDF5 has 
been the frontrunner in scientific data I/O but suffers from widely inconsistent within academia.  

- Zarr, similar to H5, is a heirarchical data storage specification (or in non-alien speak: "rules of how data is stored on disk").
- Zarr nicely hides the complexities inherent in linking filesystem heirarchy with efficient data I/O.


In [None]:
data.shape

In [4]:
# Save raw data to persisten zarr store
save_str = savepath / "extracted.zarr"
save_str

store = zarr.DirectoryStore(str(save_str))  # save data to persistent disk storage
z = zarr.zeros(data.shape, chunks=(data.shape[0], data.shape[1], 1, data.shape[3]), dtype='int16', store=store, overwrite=True)
z[:] = data             # this will auto-chunk based on the specified chunks in 'open'

Our daw data can now be accessed through the filepath stored in `save_str`

In [9]:
save_str

And its parent directory organizes the types of data we're storeing

### Leverage zarr directory storage to compare pixel trimming

- `zarr.open` is a convenience method that handles chunking and compression for persistant storage.
- We want to keep this data as `16 bit` integers because no calculations should be done yet.

In general, we want this value to be `~1Mb` to optimize write speed. 

```python
name = '/path/to/folder'
chunksize=[300,300,1,1]
z1 = zarr.open(f'{name}', mode='w', shape=(data.shape),chunks=chunksize, dtype='int16')
```

In [None]:
px_save_dir = savepath / "pixel_trims"
px_save_dir.mkdir(exist_ok=True, parents=True)

# create a zarr dataset with a variable number of cut pixels
x_pixels_to_cut = range(1, 6)
for i in x_pixels_to_cut:
    # check for the existence of this file before any data is read
    px_save_str = px_save_dir / f'px_{i}.zarr'
    store = zarr.DirectoryStore(px_save_str)  # save data to persistent disk storage
    
    reader = scanreader.read_scan(str(htiffs[0]), join_contiguous=True, lbm=True, x_cut=(1,i))  # this should take < 2s, no actual data is being read yet
    data = reader[0]                                         # here, data is being read
    
    z = zarr.zeros(data.shape, chunks=(data.shape[0], data.shape[1], 1, data.shape[3]), dtype='int16', store=store, overwrite=True)
    z[:] = data             # this will auto-chunk based on the specified chunks in 'open'

In [None]:
px_save_dir

In [11]:
# helpful command to print the directory content. 
# place the output of the above cell, your savepath, without the quotes
!tree -L 1 /data2/fpo/data/extracted/raw.zarr

## Visualize Initial Pixel Cuts

We can see that there is a `name_N.zarr` with a variable number of X pixels trimmed on each ROI.

Using [HoloViews](https://holoviews.org/getting_started/), we can create an interactive plot to compare the different numbers of pixels cut on each ROI.

In [None]:
@pn.cache
def get_plot(i):
    arr = zarr.open(str(px_save_dir / f'px_{i}.zarr'), mode='r')
    return hv.Image(arr[:,:,5,400]).opts(
                width=600,
                height=600,
                title=f"pixels_cut: {i}",
                cmap='gray', 
                tools=['wheel_zoom'],
            )

# Widgets
image_widget = pn.widgets.IntSlider(name="Number of cut pixels: ", value=1, start=1, end=6)
bound_plot = pn.bind(get_plot, i=image_widget) 

# Layout of widgets and plot
layout = pn.Column(
    pn.Row(image_widget, sizing_mode="fixed", width=500),
    bound_plot
)

# Display the layout
layout.servable()

# Benchmark: Chunk Sizes

The below section demonstrates how to search for the optimal data chunking/partitioning scheme for our datasets.

- Chunking by **Image**:

  The smallest chunks we have. Each image is loaded in parallel, which requires many cores 

In [12]:
# directory to save our benchmarks
benchmarks_savedir = datapath / 'benchmarks'
benchmarks_savedir.mkdir(exist_ok=True, parents=True)

In [None]:
# Helper function to process a dataset
# Data can be any numpy or numpy-like array

def benchmark_chunk_sizes(data, chunk_shape, savepath='', overwrite=True):
    savepath = Path(savepath).with_suffix('.zarr')
    
    # benchmark write
    start = time.time()
    store = zarr.DirectoryStore(savepath)  # save data to persistent disk storage
    z = zarr.zeros(data.shape, chunks=chunk_shape, dtype='int16', store=store, overwrite=overwrite)

    if hasattr(data, 'compute'):
        z[:] = data.compute()             # this will auto-chunk based on the specified chunks in 'open'
    else:
        z[:] = data

    write = time.time() - start
    formatted_write = f"{write:.2f}"

    # benchmark read
    start = time.time()
    _ = z[:]
    read = time.time() - start
    formatted_read = f"{read:.2f}"

    chunksize_nbytes = np.prod(chunk_shape) * z.dtype.itemsize  # 2 bytes per int16
    return [
        str(data.shape),
        str(z.chunks),
        z.nbytes / 1e6,
        chunksize_nbytes / 1e6,
        z.dtype,
        z.order,
        formatted_read,
        formatted_write,
        z.store.path
    ]

Use our raw dataset to get image shapes, and read/write operation on the **same dataset** with different chunk sizes.

In [None]:
# we use dimensions from our initial raw data store
zinf = zarr.open(save_str)
zinf.chunks

In [None]:
labels = [
    'Array  [x,y,z,t]',
    'Chunks [x,y,z,t]',
    'Chunk Size (Mb)',
    'Array Size (Mb)',
    'Data Type',
    'Order',
    'Read Time (s)',
    'Write Time (s)',
    'Save Path'
]

# the chunk sizes we want to benchmark
chunksizes = [
    (zinf.shape[0], zinf.shape[1], 1, zinf.shape[3]),  # [300x300x1x1750]
    (zinf.shape[0], zinf.shape[1], zinf.shape[2], 1),  # [300x300x30x1  ]
    (zinf.shape[0], zinf.shape[1], 1, 1)               # [300x300x1x1   ]
]

# give our dataset a name. this will be the column header
names = [
    ('chunked_by_plane'),  # [300x300x1x1750]
    ('chunked_by_frame'),  # [300x300x30x1  ]
    ('chunked_by_image')   # [300x300x1x1   ]
]

benchmark_chunks_dir = benchmarks_savedir / "chunks"
benchmark_chunks_dir.mkdir(exist_ok=True, parents=True)

vals = {data_name: [] for data_name in names}
for i, (chunksize, dataset) in enumerate(zip(chunksizes, names)):
    # save the same data but with different chunk sizes
    chunks_save = benchmark_chunks_dir / f'{dataset}.zarr'
    # new store 
    store = zarr.DirectoryStore(save_str)  # save data to persistent disk storage
    z = zarr.zeros(zinf.shape, chunks=(zinf.shape[0], zinf.shape[1], 1, zinf.shape[3]), dtype='int16', store=store, overwrite=True)
    vals[dataset] = benchmark_chunk_sizes(zinf, chunksize, savepath=f"{chunks_save}", overwrite=True)

df = pd.DataFrame(index=labels, columns=names, data=None)
for k, v in vals.items():
    df[k] = v


In [None]:
df

### Dask
At this point, our zarr array is a "view" onto the actual data in memory. We can easily convert it to other data types that operate on the numpy ecosystem, like dask!

Dask is a library that will allow us to use numpy-like operations on zarr arrays with the added benefit of internally loading our data lazily, i.e. only when we need it for a computation.

In [None]:
chunks_save = benchmark_chunks_dir / 'chunked_by_plane.zarr'
data_da = da.from_zarr(chunks_save, chunks=zinf.chunks)
data_da

## Scan Phase Correction




### *Methods:*

**1) Linear interpolation**

**2) Phase - cross correlation**


### Phase correction via Linear Phase Interpolation 

In [None]:
# the 5 pixel-cut looked the best, lets get that one
files = sorted([x for x in savepath.glob("**/*.zarr")])
files[4]

In [None]:
dataset = files[4]
za = zarr.open(str(dataset))
za.info

In [None]:
array = da.from_array(za, chunks=za.chunks)
array

## Scan Phase Correction

Until now our data has been in `int16`. Now that we are performing correlations acrossed averaged pixels, we want a more accurate datatype.
`compute_raster_phase` will load in the data and convert it to a float intrinsically. 

In [None]:
# # TODO: debug why this takes 15min + (way too long)
# phase_angle = core.util.compute_raster_phase(array[:,:, 5, 400].compute(), reader.temporal_fill_fraction)
# corrected_li = core.util.correct_raster(array, phase_angle, reader.temporal_fill_fraction)
# phase_angle

In [None]:
corr = core.util.return_scan_offset(array[:,:,5,400].compute(), 1)
corrected_pc = core.util.fix_scan_phase(array, corr, 1)

## Motion Correction: CaImAn - NORMCorre

Load pre-processed data as a CaImAn `movie`

In [None]:
import caiman as cm
from caiman.motion_correction import MotionCorrect
from caiman.source_extraction.cnmf import cnmf, params
from caiman.utils.utils import download_demo
from caiman.utils.visualization import plot_contours, nb_view_patches, nb_plot_contour
from caiman.utils.visualization import view_quilt

Before motion correction, lets see what volumes look like interactively:

In [13]:
@pn.cache
def get_plot_plane(plane=1, frame=1, title=""):
    return hv.Image(data[:,:,plane,frame]).opts(
                width=600,
                height=600,
                title=f"{title}",
                tools=['wheel_zoom'],
                cmap='gray', 
            )

# Widgets
plane_slider = pn.widgets.IntSlider(name="plane: ", value=1, start=1, end=9)
frame_slider = pn.widgets.IntSlider(name="frame: ", value=1, start=1, end=9)

bound_plot = pn.bind(get_plot_plane, frame=frame_slider, plane=plane_slider) 

# Layout of widgets and plot
layout = pn.Column(
    pn.Row(plane_slider, width=200),
    pn.Row(frame_slider, width=200),
    bound_plot
)

# Display the layout
layout.servable()

In [None]:
help(MotionCorrect)

## Correlation metrics

Create a couple of summary images of the movie, including:
- maximum projection (the maximum value of each pixel) 
- correlation image (how correlated each pixel is with its neighbors)

If a pixel comes from an active neural component it will tend to be highly correlated with its neighbors.

In [None]:
plane_arr = data[:,:,5,5:1000]
max_projection = np.max(plane_arr, axis=0)  # 3D -> 2D

In [None]:
correlation_image = cm.local_correlations(plane_arr, swap_dim=False)
correlation_image[np.isnan(correlation_image)] = 0 

In [None]:
%matplotlib inline
f, (ax_max, ax_corr) = plt.subplots(1,2)
ax_max.imshow(max_projection.T, 
              cmap='viridis',
              vmin=np.percentile(np.ravel(max_projection),50), 
              vmax=np.percentile(np.ravel(max_projection),99.5));
ax_max.set_title("Max Projection Orig", fontsize=12)
ax_corr.imshow(correlation_image.T, 
               cmap='viridis', 
               vmin=np.percentile(np.ravel(correlation_image),50), 
               vmax=np.percentile(np.ravel(correlation_image),99.5))
ax_corr.set_title('Correlation Image Orig', fontsize=12)
plt.show()

### Parameter Selection

In [None]:
max_shifts = (6, 6)  # maximum allowed rigid shift in pixels (view the movie to get a sense of motion)
strides =  (48, 48)  # create a new patch every x pixels for pw-rigid correction
overlaps = (24, 24)  # overlap between patches (size of patch strides+overlaps)
max_deviation_rigid = 3   # maximum deviation allowed for patch with respect to rigid shifts
pw_rigid = False  # flag for performing rigid or piecewise rigid motion correction
shifts_opencv = True  # flag for correcting motion using bicubic interpolation (otherwise FFT interpolation is used)
border_nan = 'copy'  # replicate values along the boundary (if True, fill in with NaN)

In [None]:
file_in = [x for x in savepath.glob("*.zarr")]
file_in[0]

## Save as a tiff, for now 

TODO: why is caiman not recognizing the zarr group?

In [None]:
import tifffile
fpath = Path('/data2/fpo/data/raw/raw.tiff')
fpath.exists()
    
data_plane = data[:,:,5,2:1002]
data_plane = data_plane.reshape((data_plane.shape[-1], data_plane.shape[0], data_plane.shape[1]))
tifffile.imwrite(fpath, data_plane)

In [None]:
#%% start the cluster (if a cluster already exists terminate it)
if 'dview' in locals():
    cm.stop_server(dview=dview)
c, dview, n_processes = cm.cluster.setup_cluster(
    backend='multiprocessing', n_processes=None, single_thread=False)
# create a motion correction object

mc = MotionCorrect(str(fpath), dview=dview, max_shifts=max_shifts,
                  strides=strides, overlaps=overlaps,
                  max_deviation_rigid=max_deviation_rigid, 
                  shifts_opencv=shifts_opencv, nonneg_movie=True,
                  border_nan=border_nan)


In [None]:
%%capture
# correct for rigid motion correction and save the file (in memory mapped form)
mc.motion_correct(save_movie=True)

### View rigid template

In [None]:
# load motion corrected movie
m_rig = cm.load(mc.mmap_file)
bord_px_rig = np.ceil(np.max(mc.shifts_rig)).astype(int)
#%% visualize templates
plt.figure(figsize = (20,10))
plt.imshow(mc.total_template_rig, cmap = 'gray');

Rigid-corrected movie

In [None]:
#%% inspect movie
m_rig.resize(1, 1, downsample_ratio).play(
    q_max=99.5, fr=30, magnification=2, bord_px = 0*bord_px_rig) # press q to exit

Rigid Template shifts

In [None]:
#%% plot rigid shifts
plt.close()
plt.figure(figsize = (20,10))
plt.plot(mc.shifts_rig)
plt.legend(['x shifts','y shifts'])
plt.xlabel('frames')
plt.ylabel('pixels');


## Piecewise rigid registration

While rigid registration corrected for a lot of the movement, there is still non-uniform motion present in the registered file.

- To correct for that we can use piece-wise rigid registration directly in the original file by setting mc.pw_rigid=True.
- As before the registered file is saved in a memory mapped format in the location given by mc.mmap_file.


In [None]:
%%capture
#%% motion correct piecewise rigid
mc.pw_rigid = True  # turn the flag to True for pw-rigid motion correction
mc.template = mc.mmap_file  # use the template obtained before to save in computation (optional)

mc.motion_correct(save_movie=True, template=mc.total_template_rig)
m_els = cm.load(mc.fname_tot_els)
m_els.resize(1, 1, downsample_ratio).play(
    q_max=99.5, fr=30, magnification=2,bord_px = bord_px_rig)

visualize non-rigid shifts for the entire FOV

TODO: Interactively visualize rigid+non-rigid shifts independantly

In [None]:
plt.close()
plt.figure(figsize = (20,10))
plt.subplot(2, 1, 1)
plt.plot(mc.x_shifts_els)
plt.ylabel('x shifts (pixels)')
plt.subplot(2, 1, 2)
plt.plot(mc.y_shifts_els)
plt.ylabel('y_shifts (pixels)')
plt.xlabel('frames')
#%% compute borders to exclude
bord_px_els = np.ceil(np.maximum(np.max(np.abs(mc.x_shifts_els)),
                                 np.max(np.abs(mc.y_shifts_els)))).astype(int)

## Motion Corretion: Optical Flow

In [None]:
#%% plot the results of Residual Optical Flow
fls = [cm.paths.fname_derived_presuffix(mc.fname_tot_els[0], 'metrics', swapsuffix='npz'),
       cm.paths.fname_derived_presuffix(mc.fname_tot_rig[0], 'metrics', swapsuffix='npz'),
       cm.paths.fname_derived_presuffix(mc.fname[0],         'metrics', swapsuffix='npz'),
      ]

plt.figure(figsize = (20,10))
for cnt, fl, metr in zip(range(len(fls)), fls, ['pw_rigid','rigid','raw']):
    with np.load(fl) as ld:
        print(ld.keys())
        print(fl)
        print(str(np.mean(ld['norms'])) + '+/-' + str(np.std(ld['norms'])) +
              ' ; ' + str(ld['smoothness']) + ' ; ' + str(ld['smoothness_corr']))
        
        plt.subplot(len(fls), 3, 1 + 3 * cnt)
        plt.ylabel(metr)
        print(f"Loading data with base {fl[:-12]}")
        try:
            mean_img = np.mean(
            cm.load(fl[:-12] + '.mmap'), 0)[12:-12, 12:-12]
        except:
            try:
                mean_img = np.mean(
                    cm.load(fl[:-12] + '.tif'), 0)[12:-12, 12:-12]
            except:
                mean_img = np.mean(
                    cm.load(fl[:-12] + 'hdf5'), 0)[12:-12, 12:-12]
                    
        lq, hq = np.nanpercentile(mean_img, [.5, 99.5])
        plt.imshow(mean_img, vmin=lq, vmax=hq)
        plt.title('Mean')
        plt.subplot(len(fls), 3, 3 * cnt + 2)
        plt.imshow(ld['img_corr'], vmin=0, vmax=.35)
        plt.title('Corr image')
        plt.subplot(len(fls), 3, 3 * cnt + 3)
        flows = ld['flows']
        plt.imshow(np.mean(
        np.sqrt(flows[:, :, :, 0]**2 + flows[:, :, :, 1]**2), 0), vmin=0, vmax=0.3)
        plt.colorbar()
        plt.title('Mean optical flow');  

In [None]:
## Run CNMF on patches in parallel

# Cleanup

Make sure our parallel cluster is shut down.

In [None]:
if 'dview' in locals():
    cm.stop_server(dview=dview)
elif 'cluster' in locals():
    cm.stop_server(dview=cluster)