Skip to content

Commit

Permalink
Optimalisations, unit testing, travis (#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
zsinnema authored Jun 12, 2020
1 parent f3501c8 commit 086ed66
Show file tree
Hide file tree
Showing 12 changed files with 237 additions and 113 deletions.
55 changes: 55 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
language: bash

os:
- windows
- linux

# jobs:
# include:
# - os: linux
# python: 3.8.1

branches:
only:
- master
- dev
- "/^pull.$"
- "/^hotfix-.+$/"

before_install:
- . scripts/travis/before_install_nix_win.sh

install:
# Install miniconda
- . scripts/travis/install_nix_win.sh

- source $MINICONDA_PATH/etc/profile.d/conda.sh;
- hash -r

# Setting up conda env and install deps
- conda env create -q -n minian -f environment.yml
- conda activate minian
- conda list
- conda install -c conda-forge -y jupyterlab
- jupyter labextension install @pyviz/jupyterlab_pyviz
- conda env export
- conda install -y pytest-cov
- conda install -c anaconda -y black

script:
# The test/check scripts go here
- travis_fold start "Black-check code quality"
- black --check minian
- travis_fold end "Black-check code quality"

- travis_fold start "pytest"
- pytest -v --color=yes --cov=minian --pyargs minian
- travis_fold end "pytest"

# - travis_fold start "Jupyter notebook pipeline"
# - papermill pipeline.ipynb pipeline_output.ipynb
# - travis_fold end "Jupyter notebook pipeline"

after_success:
- echo conda install codecov
- echo codecov
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
[![Build Status](https://img.shields.io/travis/MetaCell/minian/master.svg?style=flat&label=master)](https://travis-ci.org/MetaCell/minian)
[![Build Status](https://img.shields.io/travis/MetaCell/minian/dev.svg?style=flat&label=dev)](https://travis-ci.org/MetaCell/minian)


# About MiniAn

MiniAn is an analysis pipeline and visualization tool inspired by both [CaImAn](https://github.com/flatironinstitute/CaImAn) and [MIN1PIPE](https://github.com/JinghaoLu/MIN1PIPE) package specifically for [Miniscope](http://miniscope.org/index.php/Main_Page) data.
Expand Down
3 changes: 2 additions & 1 deletion cross-registration.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
"import holoviews as hv\n",
"import pandas as pd\n",
"from holoviews.operation.datashader import datashade, regrid\n",
"from dask.diagnostics import ProgressBar\n",
"from minian.cross_registration import (calculate_centroids, calculate_centroid_distance, calculate_mapping,\n",
" group_by_session, resolve_mapping, fill_mapping)\n",
"from minian.motion_correction import estimate_shifts, apply_shifts\n",
Expand Down Expand Up @@ -445,7 +446,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.7"
"version": "3.8.1"
}
},
"nbformat": 4,
Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,6 @@ dependencies:
- pip:
- ffmpeg-python==0.2.0
- medpy==0.4.0
- pytest==5.4.3
- simpleitk==1.2.4
- sk-video==1.1.10
9 changes: 9 additions & 0 deletions minian/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Minian

## Running unit tests
Minian uses pytest for unit testing.

To run the unit tests run the following command from the root of the Minian project:
```
pytest --pyargs minian
```
50 changes: 0 additions & 50 deletions minian/cross_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from scipy.stats import pearsonr
from scipy.spatial.distance import cdist
from .preprocessing import remove_background
from .motion_correction import estimate_shift_fft, apply_shifts
from .utilities import xrconcat_recursive
from .visualization import centroid
from IPython.core.debugger import set_trace
Expand Down Expand Up @@ -62,55 +61,6 @@ def get_minian_list(path, pattern=r'^minian.nc$'):
mnlist += mn_paths
return mnlist

def estimate_shifts(minian_df, by='session', to='first', temp_var='org', template=None, rm_background=False):
if template is not None:
minian_df['template'] = template

def get_temp(row):
ds, temp = row['minian'], row['template']
try:
return ds.isel(frame=temp).drop('frame')
except TypeError:
func_dict = {
'mean': lambda v: v.mean('frame'),
'max': lambda v: v.max('frame')}
try:
return func_dict[temp](ds)
except KeyError:
raise NotImplementedError(
"template {} not understood".format(temp))

minian_df['template'] = minian_df.apply(get_temp, axis='columns')
grp_dims = list(minian_df.index.names)
grp_dims.remove(by)
temp_dict, shift_dict, corr_dict, tempsh_dict = [dict() for _ in range(4)]
for idxs, df in minian_df.groupby(level=grp_dims):
try:
temp_ls = [t[temp_var] for t in df['template']]
except KeyError:
raise KeyError(
"variable {} not found in dataset".format(temp_var))
temps = (xr.concat(temp_ls, dim=by).expand_dims(grp_dims)
.reset_coords(drop=True))
res = estimate_shift_fft(temps, dim=by, on=to)
shifts = res.sel(variable=['height', 'width'])
corrs = res.sel(variable='corr')
temps_sh = apply_shifts(temps, shifts)
temp_dict[idxs] = temps
shift_dict[idxs] = shifts
corr_dict[idxs] = corrs
tempsh_dict[idxs] = temps_sh
temps = xrconcat_recursive(temp_dict, grp_dims).rename('temps')
shifts = xrconcat_recursive(shift_dict, grp_dims).rename('shifts')
corrs = xrconcat_recursive(corr_dict, grp_dims).rename('corrs')
temps_sh = xrconcat_recursive(tempsh_dict, grp_dims).rename('temps_shifted')
with ProgressBar():
temps = temps.compute()
shifts = shifts.compute()
corrs = corrs.compute()
temps_sh = temps_sh.compute()
return xr.merge([temps, shifts, corrs, temps_sh])


def estimate_shifts_old(mn_list,
temp_list,
Expand Down
Empty file added minian/test/__init__.py
Empty file.
42 changes: 42 additions & 0 deletions minian/test/test_pre_processing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import pytest
import numpy as np
import holoviews as hv

from ..utilities import load_videos
from ..preprocessing import denoise

dpath = "./demo_movies"

param_load_videos = {
'pattern': 'msCam[0-9]+\.avi$',
'dtype': np.uint8,
'downsample': dict(frame=2,height=1,width=1),
'downsample_strategy': 'subset'
}

param_denoise = {
'method': 'median',
'ksize': 7
}

@pytest.fixture
def varr():
return load_videos(dpath, **param_load_videos)

def test_can_load_videos(varr):
assert varr.shape[0] == 1000 # frames
assert varr.shape[1] == 480 # height
assert varr.shape[2] == 752 # width
return varr

def test_can_init_holoviews():
hv.notebook_extension('bokeh')

def test_subset_part_video(varr):
subset = dict(frame=slice(0,None))
varr_ref = varr.sel(subset)
assert varr_ref.all() == varr.all()

def test_denoise(varr):
varr_ref = denoise(varr, **param_denoise)
assert varr_ref.all() != varr.all() # when both are equal the denoise didn't do anything --> fail
66 changes: 34 additions & 32 deletions pipeline.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
"from holoviews.operation.datashader import datashade, regrid, dynspread\n",
"from datashader.colors import Sets1to3\n",
"from dask.diagnostics import ProgressBar\n",
"from IPython.core.display import display, HTML"
"from IPython.core.display import display, HTML\n"
]
},
{
Expand Down Expand Up @@ -164,7 +164,7 @@
" 'max_iters': 500,\n",
" 'use_smooth': True,\n",
" 'scs_fallback': False,\n",
" 'post_scal': True}"
" 'post_scal': True}\n"
]
},
{
Expand Down Expand Up @@ -216,12 +216,10 @@
"outputs": [],
"source": [
"dpath = os.path.abspath(dpath)\n",
"hv.notebook_extension('bokeh')\n",
"if interactive:\n",
" hv.notebook_extension('bokeh')\n",
" pbar = ProgressBar(minimum=2)\n",
" pbar.register()\n",
"else:\n",
" hv.notebook_extension('matplotlib')"
" pbar.register()"
]
},
{
Expand Down Expand Up @@ -973,10 +971,10 @@
"hv.output(size=output_size)\n",
"opts_im = dict(plot=dict(height=b_init.sizes['height'], width=b_init.sizes['width'], colorbar=True), style=dict(cmap='Viridis'))\n",
"opts_cr = dict(plot=dict(height=b_init.sizes['height'], width=b_init.sizes['height'] * 2))\n",
"(regrid(hv.Image(b_init, kdims=['width', 'height'])).opts(**opts_im).relabel('Background Spatial Initial')\n",
" + datashade(hv.Curve(f_init, kdims=['frame'])).opts(**opts_cr).relabel('Background Temporal Initial')\n",
" + regrid(hv.Image(b_spatial, kdims=['width', 'height'])).opts(**opts_im).relabel('Background Spatial First Update')\n",
" + datashade(hv.Curve(f_spatial, kdims=['frame'])).opts(**opts_cr).relabel('Background Temporal First Update')\n",
"(regrid(hv.Image(b_init.compute(), kdims=['width', 'height'])).opts(**opts_im).relabel('Background Spatial Initial')\n",
" + datashade(hv.Curve(f_init.compute(), kdims=['frame'])).opts(**opts_cr).relabel('Background Temporal Initial')\n",
" + regrid(hv.Image(b_spatial.compute(), kdims=['width', 'height'])).opts(**opts_im).relabel('Background Spatial First Update')\n",
" + datashade(hv.Curve(f_spatial.compute(), kdims=['frame'])).opts(**opts_cr).relabel('Background Temporal First Update')\n",
").cols(2)"
]
},
Expand Down Expand Up @@ -1023,7 +1021,7 @@
" sparse_penal=cur_sprs, p=cur_p, use_spatial=False, use_smooth=True,\n",
" add_lag = cur_add, noise_freq=cur_noise)\n",
" YA_dict[ks], C_dict[ks], S_dict[ks], g_dict[ks], sig_dict[ks], A_dict[ks] = (\n",
" YrA, cur_C, cur_S, cur_g, cur_sig, A_sub)\n",
" YrA.compute(), cur_C.compute(), cur_S.compute(), cur_g.compute(), cur_sig.compute(), A_sub.compute())\n",
" hv_res = visualize_temporal_update(\n",
" YA_dict, C_dict, S_dict, g_dict, sig_dict, A_dict,\n",
" kdims=['p', 'sparse penalty', 'additional lag', 'noise frequency'])"
Expand Down Expand Up @@ -1067,10 +1065,10 @@
"source": [
"hv.output(size=output_size)\n",
"opts_im = dict(frame_width=500, aspect=2, colorbar=True, cmap='Viridis', logz=True)\n",
"(regrid(hv.Image(C_init.rename('ci'), kdims=['frame', 'unit_id'])).opts(**opts_im).relabel(\"Temporal Trace Initial\")\n",
"(regrid(hv.Image(C_init.compute().rename('ci'), kdims=['frame', 'unit_id'])).opts(**opts_im).relabel(\"Temporal Trace Initial\")\n",
" + hv.Div('')\n",
" + regrid(hv.Image(C_temporal.rename('c1'), kdims=['frame', 'unit_id'])).opts(**opts_im).relabel(\"Temporal Trace First Update\")\n",
" + regrid(hv.Image(S_temporal.rename('s1'), kdims=['frame', 'unit_id'])).opts(**opts_im).relabel(\"Spikes First Update\")\n",
" + regrid(hv.Image(C_temporal.compute().rename('c1'), kdims=['frame', 'unit_id'])).opts(**opts_im).relabel(\"Temporal Trace First Update\")\n",
" + regrid(hv.Image(S_temporal.compute().rename('s1'), kdims=['frame', 'unit_id'])).opts(**opts_im).relabel(\"Spikes First Update\")\n",
").cols(2)"
]
},
Expand All @@ -1089,7 +1087,7 @@
" bad_units.sort()\n",
" if len(bad_units)>0:\n",
" hv_res = (hv.NdLayout({\n",
" \"Spatial Footprin\": regrid(hv.Dataset(A_spatial.sel(unit_id=bad_units).rename('A'))\n",
" \"Spatial Footprin\": regrid(hv.Dataset(A_spatial.sel(unit_id=bad_units).compute().rename('A'))\n",
" .to(hv.Image, kdims=['width', 'height'])).opts(**im_opts),\n",
" \"Spatial Footprints of Accepted Units\": regrid(hv.Image(A_temporal.sum('unit_id').compute().rename('A'), kdims=['width', 'height'])).opts(**im_opts)\n",
" })\n",
Expand All @@ -1108,7 +1106,8 @@
"source": [
"hv.output(size=output_size)\n",
"if interactive:\n",
" display(visualize_temporal_update(YrA, C_temporal, S_temporal, g_temporal, sig_temporal, A_temporal))"
" display(visualize_temporal_update(YrA.compute(), C_temporal.compute(), S_temporal.compute(), \n",
" g_temporal.compute(), sig_temporal.compute(), A_temporal.compute()))"
]
},
{
Expand Down Expand Up @@ -1137,8 +1136,8 @@
"source": [
"hv.output(size=output_size)\n",
"opts_im = dict(frame_width=500, aspect=2, colorbar=True, cmap='Viridis', logz=True)\n",
"(regrid(hv.Image(sig_temporal.rename('c1'), kdims=['frame', 'unit_id'])).relabel(\"Temporal Signals Before Merge\").opts(**opts_im) +\n",
"regrid(hv.Image(sig_mrg.rename('c2'), kdims=['frame', 'unit_id'])).relabel(\"Temporal Signals After Merge\").opts(**opts_im))"
"(regrid(hv.Image(sig_temporal.compute().rename('c1'), kdims=['frame', 'unit_id'])).relabel(\"Temporal Signals Before Merge\").opts(**opts_im) +\n",
"regrid(hv.Image(sig_mrg.compute().rename('c2'), kdims=['frame', 'unit_id'])).relabel(\"Temporal Signals After Merge\").opts(**opts_im))"
]
},
{
Expand Down Expand Up @@ -1177,8 +1176,8 @@
" Y, A_sub, b_init, sig_sub, f_init,\n",
" sn_spatial, dl_wnd=param_second_spatial['dl_wnd'], sparse_penal=cur_sprs)\n",
" if cur_A.sizes['unit_id']:\n",
" A_dict[cur_sprs] = cur_A\n",
" C_dict[cur_sprs] = cur_C\n",
" A_dict[cur_sprs] = cur_A.compute()\n",
" C_dict[cur_sprs] = cur_C.compute()\n",
" hv_res = visualize_spatial_update(A_dict, C_dict, kdims=['sparse penalty'])"
]
},
Expand Down Expand Up @@ -1234,10 +1233,10 @@
"hv.output(size=output_size)\n",
"opts_im = dict(aspect=b_spatial_it2.sizes['width'] / b_spatial_it2.sizes['height'], frame_width=500, colorbar=True, cmap='Viridis')\n",
"opts_cr = dict(aspect=2, frame_height=int(500 * b_spatial_it2.sizes['height'] / b_spatial_it2.sizes['width']))\n",
"(regrid(hv.Image(b_spatial, kdims=['width', 'height'])).opts(**opts_im).relabel('Background Spatial First Update')\n",
" + datashade(hv.Curve(f_spatial, kdims=['frame'])).opts(**opts_cr).relabel('Background Temporal First Update')\n",
" + regrid(hv.Image(b_spatial_it2, kdims=['width', 'height'])).opts(**opts_im).relabel('Background Spatial Second Update')\n",
" + datashade(hv.Curve(f_spatial_it2, kdims=['frame'])).opts(**opts_cr).relabel('Background Temporal Second Update')\n",
"(regrid(hv.Image(b_spatial.compute(), kdims=['width', 'height'])).opts(**opts_im).relabel('Background Spatial First Update')\n",
" + datashade(hv.Curve(f_spatial.compute(), kdims=['frame'])).opts(**opts_cr).relabel('Background Temporal First Update')\n",
" + regrid(hv.Image(b_spatial_it2.compute(), kdims=['width', 'height'])).opts(**opts_im).relabel('Background Spatial Second Update')\n",
" + datashade(hv.Curve(f_spatial_it2.compute(), kdims=['frame'])).opts(**opts_cr).relabel('Background Temporal Second Update')\n",
").cols(2)"
]
},
Expand Down Expand Up @@ -1284,7 +1283,8 @@
" sparse_penal=cur_sprs, p=cur_p, use_spatial=False, use_smooth=True,\n",
" add_lag = cur_add, noise_freq=cur_noise)\n",
" YA_dict[ks], C_dict[ks], S_dict[ks], g_dict[ks], sig_dict[ks], A_dict[ks] = (\n",
" YrA, cur_C, cur_S, cur_g, cur_sig, A_sub)\n",
" YrA.compute(), cur_C.compute(), cur_S.compute(), cur_g.compute(), cur_sig.compute(), \n",
" A_sub.compute())\n",
" hv_res = visualize_temporal_update(\n",
" YA_dict, C_dict, S_dict, g_dict, sig_dict, A_dict,\n",
" kdims=['p', 'sparse penalty', 'additional lag', 'noise frequency'])"
Expand Down Expand Up @@ -1331,10 +1331,10 @@
"source": [
"hv.output(size=output_size)\n",
"opts_im = dict(frame_width=500, aspect=2, colorbar=True, cmap='Viridis', logz=True)\n",
"(regrid(hv.Image(C_mrg.rename('c1'), kdims=['frame', 'unit_id'])).opts(**opts_im).relabel(\"Temporal Trace First Update\")\n",
" + regrid(hv.Image(S_mrg.rename('s1'), kdims=['frame', 'unit_id'])).opts(**opts_im).relabel(\"Spikes First Update\")\n",
" + regrid(hv.Image(C_temporal_it2.rename('c2').rename(unit_id='unit_id_it2'), kdims=['frame', 'unit_id_it2'])).opts(**opts_im).relabel(\"Temporal Trace Second Update\")\n",
" + regrid(hv.Image(S_temporal_it2.rename('s2').rename(unit_id='unit_id_it2'), kdims=['frame', 'unit_id_it2'])).opts(**opts_im).relabel(\"Spikes Second Update\")).cols(2)"
"(regrid(hv.Image(C_mrg.compute().rename('c1'), kdims=['frame', 'unit_id'])).opts(**opts_im).relabel(\"Temporal Trace First Update\")\n",
" + regrid(hv.Image(S_mrg.compute().rename('s1'), kdims=['frame', 'unit_id'])).opts(**opts_im).relabel(\"Spikes First Update\")\n",
" + regrid(hv.Image(C_temporal_it2.compute().rename('c2').rename(unit_id='unit_id_it2'), kdims=['frame', 'unit_id_it2'])).opts(**opts_im).relabel(\"Temporal Trace Second Update\")\n",
" + regrid(hv.Image(S_temporal_it2.compute().rename('s2').rename(unit_id='unit_id_it2'), kdims=['frame', 'unit_id_it2'])).opts(**opts_im).relabel(\"Spikes Second Update\")).cols(2)"
]
},
{
Expand All @@ -1352,11 +1352,11 @@
" bad_units.sort()\n",
" if len(bad_units)>0:\n",
" hv_res = (hv.NdLayout({\n",
" \"Spatial Footprin\": regrid(hv.Dataset(A_spatial_it2.sel(unit_id=bad_units).rename('A'))\n",
" \"Spatial Footprin\": regrid(hv.Dataset(A_spatial_it2.sel(unit_id=bad_units).compute().rename('A'))\n",
" .to(hv.Image, kdims=['width', 'height'])).opts(**im_opts),\n",
" \"Spatial Footprints of Accepted Units\": regrid(hv.Image(A_temporal_it2.sum('unit_id').compute().rename('A'), kdims=['width', 'height'])).opts(**im_opts)\n",
" })\n",
" + datashade(hv.Dataset(YrA.sel(unit_id=bad_units).rename('raw'))\n",
" + datashade(hv.Dataset(YrA.sel(unit_id=bad_units).compute().rename('raw'))\n",
" .to(hv.Curve, kdims=['frame'])).opts(**cr_opts).relabel(\"Temporal Trace\")).cols(1)\n",
" display(hv_res)\n",
" else:\n",
Expand All @@ -1371,7 +1371,9 @@
"source": [
"hv.output(size=output_size)\n",
"if interactive:\n",
" display(visualize_temporal_update(YrA, C_temporal_it2, S_temporal_it2, g_temporal_it2, sig_temporal_it2, A_temporal_it2))"
" display(visualize_temporal_update(YrA.compute(), C_temporal_it2.compute(), S_temporal_it2.compute(), \n",
" g_temporal_it2.compute(), sig_temporal_it2.compute(), \n",
" A_temporal_it2.compute()))"
]
},
{
Expand Down
Loading

0 comments on commit 086ed66

Please sign in to comment.