In [None]:
# Test setup. Ignore warnings during production runs.

%run ./setup_tests.py

# Specify input data

* `data_dir` (`str`): Where the data is located. (change if data is not in the current directory, normally is)
* `data` (`str`): HDF5 file to use as input data.
* `dataset` (`str`): HDF5 dataset to use as input data.

</br>
* `num_workers` (`int`): Number of workers for iPython Cluster. (default all cores excepting one for client)

In [None]:
data_dir = ""
data = "reg.h5"
dataset = "images"

num_workers = None


import os

data_dir = os.path.abspath(data_dir)

data_sub = "_sub".join(os.path.splitext(data))
data_f_f0 = "_f_f0".join(os.path.splitext(data))
data_wt = "_wt".join(os.path.splitext(data))
data_norm = "_norm".join(os.path.splitext(data))
data_dict = "_dict".join(os.path.splitext(data))
data_post = "_post".join(os.path.splitext(data))
data_rois = "_rois".join(os.path.splitext(data))
data_traces = "_traces".join(os.path.splitext(data))

data_proj = "_proj".join(os.path.splitext(data))
data_proj_html = os.path.splitext(data_proj)[0] + os.path.extsep + "html"

# Configure and startup Cluster

In [None]:
from nanshe_workflow.par import cleanup_cluster_files, get_client, set_num_workers

num_workers = set_num_workers(num_workers)

cleanup_cluster_files("sge")

from sys import executable as PYTHON
!$PYTHON -m ipyparallel.apps.ipclusterapp start --daemon --profile=sge
del PYTHON

client = get_client("sge")

In [None]:
os.chdir(data_dir)
client[:].apply(os.chdir, os.getcwd()).get();

# Define functions for computation.

In [None]:
%matplotlib notebook

import matplotlib
import matplotlib.pyplot
import matplotlib.pyplot as plt

from mplview.core import MatplotlibViewer as MPLViewer

In [None]:
client[:].use_cloudpickle().get()

with client[:].sync_imports():
    import collections
    import contextlib
    import copy
    import functools
    import gc
    import inspect
    import itertools
    import logging
    import math
    import numbers
    import os
    import sys

    from contextlib import contextmanager

    from builtins import range

    import numpy
    import scipy
    import scipy.ndimage
    import h5py

    import numpy as np
    import scipy as sp
    import scipy.ndimage as spim
    import h5py as hp

    from toolz import sliding_window

    import imgroi
    import imgroi.core
    from imgroi.core import label_mask_stack

    import nanshe
    from nanshe.imp.segment import extract_f0, normalize_data, generate_dictionary
    from nanshe.imp.filters.wavelet import transform as wavelet_transform

    import nanshe_workflow
    from nanshe_workflow.data import DataBlocks, LazyDataset

logging.getLogger("nanshe").setLevel(logging.INFO)

In [None]:
from nanshe_workflow.par import halo_block_parallel

from nanshe_workflow.imp import extract_f0_halo
from nanshe_workflow.imp import wavelet_transform_halo
from nanshe_workflow.imp import normalize_data_halo
from nanshe_workflow.par import halo_block_generate_dictionary_parallel
from nanshe_workflow.imp import block_postprocess_data_parallel

par_extract_f0 = halo_block_parallel(client, extract_f0_halo)(extract_f0)
par_wavelet_transform = halo_block_parallel(client, wavelet_transform_halo)(wavelet_transform)
par_normalize_data = halo_block_parallel(client, normalize_data_halo)(normalize_data)
par_generate_dictionary = halo_block_generate_dictionary_parallel(client, None)(generate_dictionary)
par_postprocess_data = block_postprocess_data_parallel(client)

In [None]:
from nanshe_workflow.par import frame_stack_calculate_parallel
from nanshe_workflow.par import stack_compute_subtract_parallel

from nanshe_workflow.proj import stack_compute_traces_parallel

from nanshe_workflow.proj import stack_compute_harmonic_mean_projection_parallel
from nanshe_workflow.proj import stack_compute_adj_harmonic_mean_projection_parallel
from nanshe_workflow.proj import stack_compute_quantile_projection_parallel
from nanshe_workflow.proj import stack_compute_min_projection_parallel
from nanshe_workflow.proj import stack_compute_max_projection_parallel

from nanshe_workflow.proj import stack_compute_moment_projections_parallel

from nanshe_workflow.proj import stack_norm_layer_parallel



par_compute_subtract = frame_stack_calculate_parallel(client, stack_compute_subtract_parallel)

par_compute_harmonic_mean_projection = frame_stack_calculate_parallel(client, stack_compute_harmonic_mean_projection_parallel)
par_compute_adj_harmonic_mean_projection = frame_stack_calculate_parallel(client, stack_compute_adj_harmonic_mean_projection_parallel)

par_compute_traces = frame_stack_calculate_parallel(client, stack_compute_traces_parallel)

par_compute_quantile_projection = frame_stack_calculate_parallel(client, stack_compute_quantile_projection_parallel)
par_compute_min_projection = frame_stack_calculate_parallel(client, stack_compute_min_projection_parallel)
par_compute_max_projection = frame_stack_calculate_parallel(client, stack_compute_max_projection_parallel)

par_compute_moment_projections = frame_stack_calculate_parallel(client, stack_compute_moment_projections_parallel)

par_norm_layer = frame_stack_calculate_parallel(client, stack_norm_layer_parallel)

# Begin workflow. Set parameters and run each cell.

### View Input Data

* `norm_frames` (`int`): number of frames for use during normalization of each full frame block (run in parallel).

In [None]:
norm_frames = 100

if __IPYTHON__:
    result_image_stack = LazyDataset(data, dataset)

    mplsv = plt.figure(FigureClass=MPLViewer)
    mplsv.set_images(
        result_image_stack,
        vmin=par_compute_min_projection(num_frames=norm_frames)(result_image_stack).min(),
        vmax=par_compute_max_projection(num_frames=norm_frames)(result_image_stack).max()
    )

### Projections

* `block_frames` (`int`): number of frames to work with in each block (run in parallel).

In [None]:
%%time


block_frames = 100


# Somehow we can't overwrite the file in the container so this is needed.
if os.path.exists(data_proj):
    os.remove(data_proj)

images = LazyDataset(data, dataset)
with images.astype(numpy.float32) as images:
    with h5py.File(data_proj, "w") as f:
        imgproj_hmean = par_compute_adj_harmonic_mean_projection(num_frames=block_frames)(images)

        imgproj_max = par_compute_max_projection(num_frames=block_frames)(images)

        imgproj_mean, imgproj_std = par_compute_moment_projections(num_frames=block_frames)(images, 3)[1:]
        imgproj_std -= imgproj_mean**2
        numpy.sqrt(imgproj_std, out=imgproj_std)

        f["hmean"] = imgproj_hmean
        f["mean"] = imgproj_mean
        f["max"] = imgproj_max
        f["std"] = imgproj_std

        imgproj_hmean = None
        imgproj_mean = None
        imgproj_max = None
        imgproj_std = None

        del imgproj_hmean
        del imgproj_mean
        del imgproj_max
        del imgproj_std

### Subtract Projection

* `block_frames` (`int`): number of frames to work with in each block (run in parallel).
* `norm_frames` (`int`): number of frames for use during normalization of each full frame block (run in parallel).

In [None]:
%%time


block_frames = 100
norm_frames = 100


# Somehow we can't overwrite the file in the container so this is needed.
if os.path.exists(data_sub):
    os.remove(data_sub)

images = LazyDataset(data, dataset)
image_back = LazyDataset(data_proj, "hmean")
with images.astype(numpy.float32) as images:
    with h5py.File(data_sub, "w") as f2:
        result = f2.create_dataset("images", shape=images.shape, dtype=images.dtype, chunks=True)
        par_compute_subtract(num_frames=block_frames)(images, image_back, out=result)

        result_j = f2.create_dataset("images_j", shape=images.shape, dtype=numpy.uint16, chunks=True)
        par_norm_layer(num_frames=norm_frames)(result, out=result_j)


if __IPYTHON__:
    result_image_stack = LazyDataset(data_sub, "images")

    mplsv = plt.figure(FigureClass=MPLViewer)
    mplsv.set_images(
        result_image_stack,
        vmin=par_compute_min_projection(num_frames=block_frames)(result_image_stack).min(),
        vmax=par_compute_max_projection(num_frames=block_frames)(result_image_stack).max()
    )

### Background Subtraction

* `half_window_size` (`int`): the rank filter window size is `2*half_window_size+1`.
* `which_quantile` (`float`): which quantile to return from the rank filter.
* `temporal_smoothing_gaussian_filter_stdev` (`float`): stdev for gaussian filter to convolve over time.
* `temporal_smoothing_gaussian_filter_window_size` (`float`): window for gaussian filter to convolve over time. (Measured in standard deviations)
* `spatial_smoothing_gaussian_filter_stdev` (`float`): stdev for gaussian filter to convolve over space.
* `spatial_smoothing_gaussian_filter_window_size` (`float`): window for gaussian filter to convolve over space. (Measured in standard deviations)

<br>
* `block_frames` (`int`): number of frames to work with in each block (run in parallel).
* `block_space` (`int`): extent of each spatial dimension for each block (run in parallel).
* `norm_frames` (`int`): number of frames for use during normalization of each full frame block (run in parallel).

In [None]:
%%time


half_window_size = 100
which_quantile = 0.5
temporal_smoothing_gaussian_filter_stdev = 0.0
temporal_smoothing_gaussian_filter_window_size = 0
spatial_smoothing_gaussian_filter_stdev = 0.0
spatial_smoothing_gaussian_filter_window_size = 0

block_frames = 1000
block_space = 100
norm_frames = 100


# Somehow we can't overwrite the file in the container so this is needed.
if os.path.exists(data_f_f0):
    os.remove(data_f_f0)

image_stack = LazyDataset(data_sub, "images")
block_shape = (block_frames,) + (block_space,) * (len(image_stack.shape) - 1)

bias = 1 - par_compute_min_projection(num_frames=norm_frames)(image_stack).min()

with h5py.File(data_f_f0, "w") as f2:
    result = f2.create_dataset("images", shape=image_stack.shape, dtype=image_stack.dtype, chunks=True)

    par_extract_f0(block_shape)(
        image_stack,
        half_window_size=half_window_size,
        which_quantile=which_quantile,
        temporal_smoothing_gaussian_filter_stdev=temporal_smoothing_gaussian_filter_stdev,
        temporal_smoothing_gaussian_filter_window_size=temporal_smoothing_gaussian_filter_window_size,
        spatial_smoothing_gaussian_filter_stdev=spatial_smoothing_gaussian_filter_stdev,
        spatial_smoothing_gaussian_filter_window_size=spatial_smoothing_gaussian_filter_window_size,
        bias=bias,
        out=result
    )

    result_j = f2.create_dataset("images_j", shape=image_stack.shape, dtype=numpy.uint16, chunks=True)
    par_norm_layer(num_frames=norm_frames)(result, out=result_j)


if __IPYTHON__:
    result_image_stack = LazyDataset(data_f_f0, "images")

    mplsv = plt.figure(FigureClass=MPLViewer)
    mplsv.set_images(
        result_image_stack,
        vmin=par_compute_min_projection(num_frames=norm_frames)(result_image_stack).min(),
        vmax=par_compute_max_projection(num_frames=norm_frames)(result_image_stack).max()
    )

### Wavelet Transform

* `scale` (`int`): the scale of wavelet transform to apply.

<br>
* `block_frames` (`int`): number of frames to work with in each block (run in parallel).
* `block_space` (`int`): extent of each spatial dimension for each block (run in parallel).
* `norm_frames` (`int`): number of frames for use during normalization of each full frame block (run in parallel).

In [None]:
%%time


scale = 3

block_frames = 250
block_space = 250
norm_frames = 100


# Somehow we can't overwrite the file in the container so this is needed.
if os.path.exists(data_wt):
    os.remove(data_wt)

result = LazyDataset(data_f_f0, "images")
block_shape = (block_frames,) + (block_space,) * (len(result.shape) - 1)
with h5py.File(data_wt, "w") as f2:
    new_result = f2.create_dataset("images", shape=result.shape, dtype=result.dtype, chunks=True)

    par_wavelet_transform(block_shape)(
        result,
        scale=scale,
        out=new_result
    )

    result_j = f2.create_dataset("images_j", shape=new_result.shape, dtype=numpy.uint16, chunks=True)
    par_norm_layer(num_frames=norm_frames)(result, out=result_j)


if __IPYTHON__:
    result_image_stack = LazyDataset(data_wt, "images")

    mplsv = plt.figure(FigureClass=MPLViewer)
    mplsv.set_images(
        result_image_stack,
        vmin=par_compute_min_projection(num_frames=norm_frames)(result_image_stack).min(),
        vmax=par_compute_max_projection(num_frames=norm_frames)(result_image_stack).max()
    )

### Normalize Data
* `block_frames` (`int`): number of frames to work with in each full frame block (run in parallel).
* `norm_frames` (`int`): number of frames for use during normalization of each full frame block (run in parallel).

In [None]:
%%time


block_frames = 40
norm_frames = 100


# Somehow we can't overwrite the file in the container so this is needed.
if os.path.exists(data_norm):
    os.remove(data_norm)

result = LazyDataset(data_wt, "images")
block_shape = (block_frames,) + result.shape[1:]
with h5py.File(data_norm, "w") as f2:
    new_result = f2.create_dataset("images", shape=result.shape, dtype=result.dtype, chunks=True)

    result = par_normalize_data(block_shape)(
        result,
        out=new_result
    )

    result_j = f2.create_dataset("images_j", shape=new_result.shape, dtype=numpy.uint16, chunks=True)
    par_norm_layer(num_frames=norm_frames)(result, out=result_j)


if __IPYTHON__:
    result_image_stack = LazyDataset(data_norm, "images")

    mplsv = plt.figure(FigureClass=MPLViewer)
    mplsv.set_images(
        result_image_stack,
        vmin=par_compute_min_projection(num_frames=norm_frames)(result_image_stack).min(),
        vmax=par_compute_max_projection(num_frames=norm_frames)(result_image_stack).max()
    )

### Dictionary Learning

* `n_components` (`int`): number of basis images in the dictionary.
* `batchsize` (`int`): minibatch size to use.
* `iters` (`int`): number of iterations to run before getting dictionary.
* `lambda1` (`float`): weight for L<sup>1</sup> sparisty enforcement on sparse code.
* `lambda2` (`float`): weight for L<sup>2</sup> sparisty enforcement on sparse code.

<br>
* `block_frames` (`int`): number of frames to work with in each full frame block (run in parallel).
* `norm_frames` (`int`): number of frames for use during normalization of each full frame block (run in parallel).

In [None]:
%%time


n_components = 50
batchsize = 256
iters = 100
lambda1 = 0.2
lambda2 = 0.0

block_frames = 51
norm_frames = 100


# Somehow we can't overwrite the file in the container so this is needed.
if os.path.exists(data_dict):
    os.remove(data_dict)

result = LazyDataset(data_norm, "images")
block_shape = (block_frames,) + result.shape[1:]
with h5py.File(data_dict, "w") as f2:
    new_result = f2.create_dataset("images", shape=(n_components,) + result.shape[1:], dtype=result.dtype, chunks=True)

    result = par_generate_dictionary(block_shape)(
        result,
        n_components=n_components,
        out=new_result,
        **{"sklearn.decomposition.dict_learning_online" : {
                "n_jobs" : 1,
                "n_iter" : iters,
                "batch_size" : batchsize,
                "alpha" : lambda1
            }
        }
    )

    result_j = f2.create_dataset("images_j", shape=new_result.shape, dtype=numpy.uint16, chunks=True)
    par_norm_layer(num_frames=norm_frames)(result, out=result_j)


if __IPYTHON__:
    result_image_stack = LazyDataset(data_dict, "images")

    mplsv = plt.figure(FigureClass=MPLViewer)
    mplsv.set_images(
        result_image_stack,
        vmin=par_compute_min_projection(num_frames=norm_frames)(result_image_stack).min(),
        vmax=par_compute_max_projection(num_frames=norm_frames)(result_image_stack).max()
    )
    mplsv.time_nav.stime.label.set_text("Basis Image")

### Postprocessing

* `significance_threshold` (`float`): number of standard deviations below which to include in "noise" estimate
* `wavelet_scale` (`int`): scale of wavelet transform to apply (should be the same as the one used above)
* `noise_threshold` (`float`): number of units of "noise" above which something needs to be to be significant
* `accepted_region_shape_constraints` (`dict`): if ROIs don't match this, reduce the `wavelet_scale` once.
* `percentage_pixels_below_max` (`float`): upper bound on ratio of ROI pixels not at max intensity vs. all ROI pixels
* `min_local_max_distance` (`float`): minimum allowable euclidean distance between two ROIs maximum intensities
* `accepted_neuron_shape_constraints` (`dict`): shape constraints for ROI to be kept.

* `alignment_min_threshold` (`float`): similarity measure of the intensity of two ROIs images used for merging.
* `overlap_min_threshold` (`float`): similarity measure of the masks of two ROIs used for merging.

In [None]:
%%time


significance_threshold = 3.0
wavelet_scale = 3
noise_threshold = 2.0
percentage_pixels_below_max = 0.8
min_local_max_distance = 16.0

alignment_min_threshold = 0.6
overlap_min_threshold = 0.6


# Somehow we can't overwrite the file in the container so this is needed.
if os.path.exists(data_post):
    os.remove(data_post)

result = LazyDataset(data_dict, "images")
with h5py.File(data_post, "w") as f2:
    result = par_postprocess_data(result,
                                  **{
                                        "wavelet_denoising" : {
                                            "estimate_noise" : {
                                                "significance_threshold" : significance_threshold
                                            },
                                            "wavelet.transform" : {
                                                "scale" : wavelet_scale
                                            },
                                            "significant_mask" : {
                                                "noise_threshold" : noise_threshold
                                            },
                                            "accepted_region_shape_constraints" : {
                                                "major_axis_length" : {
                                                    "min" : 0.0,
                                                    "max" : 25.0
                                                }
                                            },
                                            "remove_low_intensity_local_maxima" : {
                                                "percentage_pixels_below_max" : percentage_pixels_below_max
                                            },
                                            "remove_too_close_local_maxima" : {
                                                "min_local_max_distance" : min_local_max_distance
                                            },
                                            "accepted_neuron_shape_constraints" : {
                                                "area" : {
                                                    "min" : 25,
                                                    "max" : 600
                                                },
                                                "eccentricity" : {
                                                    "min" : 0.0,
                                                    "max" : 0.9
                                                }
                                            }
                                        },
                                        "merge_neuron_sets" : {
                                            "alignment_min_threshold" : alignment_min_threshold,
                                            "overlap_min_threshold" : overlap_min_threshold,
                                            "fuse_neurons" : {
                                                "fraction_mean_neuron_max_threshold" : 0.01
                                            }
                                        }
                                  }
    )

    result = f2.create_dataset("rois", shape=result.shape, dtype=result.dtype, data=result, chunks=True)

### ROI and trace extraction

* `block_frames` (`int`): number of frames to work with in each block (run in parallel).

In [None]:
%%time


block_frames = 100


# Somehow we can't overwrite the file in the container so this is needed.
if os.path.exists(data_rois):
    os.remove(data_rois)

with h5py.File(data_rois, "w") as f2:
    with h5py.File(data_post, "r") as f1:
        f2.create_dataset(
            "masks",
            shape=f1["rois"].shape + f1["rois"].dtype["mask"].shape,
            dtype=f1["rois"].dtype["mask"].subdtype[0],
            chunks=True
        )
        for i, j in sliding_window(2, range(0, len(f1["rois"]) + block_frames, block_frames)):
            f2["masks"][i:j] = f1["rois"][i:j, "mask"]

    mskimg = f2["masks"]
    mskimg_j = f2.create_dataset("masks_j", shape=mskimg.shape, dtype=numpy.uint8, chunks=True)
    par_norm_layer(num_frames=block_frames)(mskimg, out=mskimg_j)

    lblimg = label_mask_stack(mskimg, np.uint64)
    f2["labels"] = lblimg
    f2["labels_j"] = lblimg.astype(np.uint16)
    lblimg = f2["labels"]

# Somehow we can't overwrite the file in the container so this is needed.
if os.path.exists(data_traces):
    os.remove(data_traces)

images = LazyDataset(data_f_f0, "images")
mskimg = LazyDataset(data_rois, "masks")
with h5py.File(data_traces, "w") as f2:
    traces = f2.create_dataset("traces", shape=(len(mskimg), len(images)), dtype=images.dtype, chunks=True)
    par_compute_traces(num_frames=block_frames)(images, mskimg, out=traces)
    traces_j = f2.create_dataset("traces_j", shape=traces.shape, dtype=numpy.uint16, chunks=True)
    par_norm_layer(num_frames=block_frames)(traces, out=traces_j)


if __IPYTHON__:
    result_image_stack = LazyDataset(data_f_f0, "images")
    lblimg = LazyDataset(data_rois, "labels")

    mplsv = plt.figure(FigureClass=MPLViewer)
    mplsv.set_images(
        result_image_stack,
        vmin=par_compute_min_projection(num_frames=block_frames)(result_image_stack).min(),
        vmax=par_compute_max_projection(num_frames=block_frames)(result_image_stack).max()
    )

    lblimg_msk = numpy.ma.masked_array(lblimg[...][...], mask=(lblimg==0))

    mplsv.viewer.matshow(lblimg_msk, alpha=0.3)


mskimg = None
mskimg_j = None
lblimg = None
traces = None
traces_j = None

del mskimg
del mskimg_j
del lblimg
del traces
del traces_j

# End of workflow. Shutdown cluster.

In [None]:
from nanshe_workflow.par import cleanup_cluster_files

from sys import executable as PYTHON
!$PYTHON -m ipyparallel.apps.ipclusterapp stop --profile=sge
del PYTHON

cleanup_cluster_files("sge")

# Prepare interactive projection graph

In [None]:
import io
import os
import textwrap

import numpy
import numpy as np

import scipy
import scipy as sp

import scipy.ndimage
import scipy.ndimage as spim

import h5py
import h5py as hp

import bokeh.plotting
import bokeh.plotting as bp

import bokeh.io
import bokeh.io as bio

import bokeh.embed
import bokeh.embed as be

from bokeh.models.mappers import LinearColorMapper

import matplotlib
import matplotlib.cm

from matplotlib.colors import ColorConverter
from matplotlib.cm import gist_rainbow

import webcolors

from bokeh.models import CustomJS, ColumnDataSource, HoverTool
from bokeh.models.layouts import HBox

from builtins import (
    map as imap,
    range as irange
)

from past.builtins import basestring

import nanshe

import xnumpy
import xnumpy.core
from xnumpy.core import expand

import nanshe_workflow
from nanshe_workflow.vis import get_rgb_array, get_rgba_array, get_all_greys, masks_to_contours_2d

In [None]:
with h5py.File(data_rois, "r") as f:
    mskimg = f["masks"][...]

with h5py.File(data_traces, "r") as f:
    traces = f["traces"][...]

with h5py.File(data_proj, "r") as f:
    imgproj_mean = f["mean"][...]
    imgproj_max = f["max"][...]
    imgproj_std = f["std"][...]

### Result visualization
* `proj_img` (`str` or `list` of `str`): which projection or projections to plot (e.g. "max", "mean", "std").
* `block_size` (`int`): size of each point on any dimension in the image in terms of pixels.
* `roi_alpha` (`float`): transparency of the ROIs in a range of [0.0, 1.0].
* `roi_border_width` (`int`): width of the line border on each ROI.

<br>
* `trace_plot_width` (`int`): width of the trace plot.

In [None]:
proj_img = "std"
block_size = 1
roi_alpha = 0.3
roi_border_width = 3
trace_plot_width = 500


bio.curdoc().clear()

grey_range = get_all_greys()
grey_cm = LinearColorMapper(grey_range)

colors_rgb = get_rgb_array(len(mskimg))
colors_rgb = colors_rgb.tolist()
colors_rgb = list(imap(webcolors.rgb_to_hex, colors_rgb))

mskctr_pts_y, mskctr_pts_x = masks_to_contours_2d(mskimg)

mskctr_pts_dtype = np.min_scalar_type(max(mskimg.shape[1:]) - 1)
mskctr_pts_y = [np.array(_, dtype=mskctr_pts_dtype) for _ in mskctr_pts_y]
mskctr_pts_x = [np.array(_, dtype=mskctr_pts_dtype) for _ in mskctr_pts_x]

mskctr_srcs = ColumnDataSource(data=dict(x=mskctr_pts_x, y=mskctr_pts_y, color=colors_rgb))


if isinstance(proj_img, basestring):
    proj_img = [proj_img]
else:
    proj_img = list(proj_img)


proj_plot_width = block_size*mskimg.shape[2]
proj_plot_height = block_size*mskimg.shape[1]
plot_projs = []

if "max" in proj_img:
    plot_max = bp.Figure(plot_width=proj_plot_width, plot_height=proj_plot_height,
                         x_range=[0, mskimg.shape[2]], y_range=[mskimg.shape[1], 0],
                         tools=["tap", "pan", "box_zoom", "resize", "wheel_zoom", "save", "reset"],
                         title="Max Projection with ROIs", border_fill_color="black")
    plot_max.image(image=[numpy.flipud(imgproj_max)], x=[0], y=[mskimg.shape[1]],
                   dw=[imgproj_max.shape[1]], dh=[imgproj_max.shape[0]], color_mapper=grey_cm)
    plot_max.patches('x', 'y', source=mskctr_srcs, alpha=roi_alpha, line_width=roi_border_width, color="color")

    plot_max.outline_line_color = "white"
    for i in irange(len(plot_max.axis)):
        plot_max.axis[i].axis_line_color = "white"

    plot_projs.append(plot_max)


if "mean" in proj_img:
    plot_mean = bp.Figure(plot_width=proj_plot_width, plot_height=proj_plot_height,
                         x_range=[0, mskimg.shape[2]], y_range=[mskimg.shape[1], 0],
                         tools=["tap", "pan", "box_zoom", "resize", "wheel_zoom", "save", "reset"],
                         title="Mean Projection with ROIs", border_fill_color="black")
    plot_mean.image(image=[numpy.flipud(imgproj_mean)], x=[0], y=[mskimg.shape[1]],
                   dw=[mskimg.shape[2]], dh=[mskimg.shape[1]], color_mapper=grey_cm)
    plot_mean.patches('x', 'y', source=mskctr_srcs, alpha=roi_alpha, line_width=roi_border_width, color="color")

    plot_mean.outline_line_color = "white"
    for i in irange(len(plot_mean.axis)):
        plot_mean.axis[i].axis_line_color = "white"

    plot_projs.append(plot_mean)


if "std" in proj_img:
    plot_std = bp.Figure(plot_width=proj_plot_width, plot_height=proj_plot_height,
                         x_range=[0, mskimg.shape[2]], y_range=[mskimg.shape[1], 0],
                         tools=["tap", "pan", "box_zoom", "resize", "wheel_zoom", "save", "reset"],
                         title="Std Dev Projection with ROIs", border_fill_color="black")
    plot_std.image(image=[numpy.flipud(imgproj_std)], x=[0], y=[mskimg.shape[1]],
                   dw=[mskimg.shape[2]], dh=[mskimg.shape[1]], color_mapper=grey_cm)
    plot_std.patches('x', 'y', source=mskctr_srcs, alpha=roi_alpha, line_width=roi_border_width, color="color")

    plot_std.outline_line_color = "white"
    for i in irange(len(plot_std.axis)):
        plot_std.axis[i].axis_line_color = "white"

    plot_projs.append(plot_std)


all_tr_shape_srcs = ColumnDataSource(data=dict(traces_shape=traces.shape))
all_tr_srcs = ColumnDataSource(data=dict(traces=traces.flatten()))
tr_srcs = ColumnDataSource(data=dict(times_sel=[], traces_sel=[], colors_sel=[]))
plot_tr = bp.Figure(plot_width=trace_plot_width, plot_height=proj_plot_height,
                    x_range=(0.0, float(traces.shape[1])), y_range=(float(traces.min()), float(traces.max())),
                    tools=["pan", "box_zoom", "resize", "wheel_zoom", "save", "reset"], title="ROI traces",
                    background_fill_color="black", border_fill_color="black")
plot_tr.multi_line("times_sel", "traces_sel", source=tr_srcs, color="colors_sel")

plot_tr.outline_line_color = "white"
for i in irange(len(plot_tr.axis)):
    plot_tr.axis[i].axis_line_color = "white"

plot_projs.append(plot_tr)


mskctr_srcs.callback = CustomJS(
    args=dict(
        all_tr_shape_srcs=all_tr_shape_srcs,
        all_tr_srcs=all_tr_srcs,
        tr_srcs=tr_srcs
    ), code="""
    var range = function(n){ return Array.from(Array(n).keys()); };

    var inds = cb_obj.get('selected')['1d'].indices;
    var traces = all_tr_srcs.get('data')['traces'];
    var traces_shape = all_tr_shape_srcs.get('data')['traces_shape'];
    var trace_len = traces_shape[1];
    var colors = cb_obj.get('data')['color'];
    var selected = tr_srcs.get('data');

    var times = range(trace_len);

    selected['times_sel'] = [];
    selected['traces_sel'] = [];
    selected['colors_sel'] = [];

    for (i = 0; i < inds.length; i++) {
        var inds_i = inds[i];
        var trace_i = traces.slice(trace_len*inds_i, trace_len*(inds_i+1));
        var color_i = colors[inds_i];

        selected['times_sel'].push(times);
        selected['traces_sel'].push(trace_i);
        selected['colors_sel'].push(color_i);
    }

    tr_srcs.trigger('change');
""")


plot_group = HBox(*plot_projs)


# Clear out the old HTML file before writing a new one.
if os.path.exists(data_proj_html):
    os.remove(data_proj_html)


def indent(text, spaces):
    spaces = " " * int(spaces)
    return "\n".join(imap(lambda l: spaces + l, text.splitlines()))

def write_html(filename, title, div, script):
    indent(bokeh.resources.CDN.render(), 8)
    html_tmplt = textwrap.dedent(u"""\
        <html lang="en">
            <head>
                <meta charset="utf-8">
                <title>{title}</title>
                {cdn}
                <style>
                  html {{
                    width: 100%;
                    height: 100%;
                  }}
                  body {{
                    width: 90%;
                    height: 100%;
                    margin: auto;
                    background-color: black;
                  }}
                </style>
            </head>
            <body>
                {div}
                {script}
            </body>
        </html>
    """)

    html_cont = html_tmplt.format(
        title=title,
        div=indent(div, 8),
        script=indent(script, 8),
        cdn=indent(bokeh.resources.CDN.render(), 8),
    )

    with io.open(filename, "w") as fh:
        fh.write(html_cont)

script, div = be.components(plot_group)
write_html(data_proj_html, os.path.splitext(data_proj_html)[0], div, script)


if __IPYTHON__:
    from IPython.display import display, IFrame
    display(IFrame(data_proj_html, "100%", 1.05*proj_plot_height))

In [None]:
# Test teardown. Ignore warnings during production runs.

%run ./teardown_tests.py