In [None]:
%matplotlib widget

import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
import ipywidgets as widgets
from functools import partial
import networkx as nx
import json
import numpy as np
import h5py
import multiprocessing
import tempfile
import sqlite3

from IPython.display import Video

from ophys_etl.types import ExtractROI
from ophys_etl.modules.segmentation.qc_utils.roi_utils import add_list_of_roi_boundaries_to_img, add_labels_to_axes, convert_roi_keys
from ophys_etl.modules.segmentation.qc_utils.graph_plotting import draw_graph_edges
from ophys_etl.modules.segmentation.qc_utils.video_generator import VideoGenerator
from ophys_etl.modules.segmentation.qc_utils.video_display_generator import VideoDisplayGenerator
from ophys_etl.modules.segmentation.processing_log import SegmentationProcessingLog

### Sqlite DB interface for creating notebook inspection manifest

to use, install the READ requirements for :
```
pip install 'evaldb[READ] @ git+https://github.com/AllenInstitute/ophys_segmentation_eval_db'
```

In [None]:
from evaldb.reader import EvalDBReader
sqlite_path = Path("/allen/aibs/informatics/danielk/segmentation_tracking.db")
dbreader = EvalDBReader(sqlite_path)

In [None]:
all_metadata = dbreader.get_all_metadata()
all_metadata

In [None]:
ophys_experiment_id = 788422825
inspection_manifest = dbreader.get_inspection_manifest(ophys_experiment_id)
pd.DataFrame.from_records([inspection_manifest["metadata"]])

### inspection manifest specification

the above sqlite interface is designed to feed this inspection notebook with an inspection manifest.

But, if one wants to manually specify inspection data sources, the format is:

```
inspection_manifest = {
    "metadata": dictionary (not required) 
    "videos": list of strings that are paths to videos
    "backgrounds": list of strings that are paths to png or pkl background images/graphs
    "processing_logs": list of strings that are paths to hdf5 processing_logs
```
an empty one:
```
inspection_manifest = {
    "videos": [],
    "backgrounds": [],
    "processing_logs": []}
```

In [None]:
def new_background_selector(nrows, ncols, background_paths):
    background_selector = [
        widgets.Dropdown(
            options = [(None, None)] + [(p.name, p) for p in background_paths],
            description = f"({i}, {j})",
            layout=widgets.Layout(display='flex', align_items='flex-start')
        )
        for i in range(nrows)
        for j in range(ncols)]
    return background_selector

def new_processing_log_selector(nrows, ncols, processing_logs):
    foreground_selector = [
        widgets.Dropdown(
            options = [(None, None)] + [(f.name, f) for f in processing_logs],
            description = f"({i}, {j})")
        for i in range(nrows)
        for j in range(ncols)]
    dataset_selector = [
        widgets.Dropdown(options=[], layout=widgets.Layout(width='150px')) 
        for i in range(nrows)
        for j in range(ncols)]
    return foreground_selector, dataset_selector

def new_plot_update_buttons(nrows, ncols):
    buttons = [
        widgets.Button(description="Update")
        for i in range(nrows)
        for j in range(ncols)]
    return buttons

def update_plot(widget, fig, axes, background_widget, log_widget, dataset_widget, label_widget):
    background_path = background_widget.value
    if background_path is None:
        im = np.ones((512, 512, 3), dtype="uint8") * 255
    else:
        if background_path.suffix == ".pkl":
            graph = nx.read_gpickle(background_path)
            edge = list(graph.edges(data=True))[0]
            attribute_name = list(edge[2].keys())[0]
            axes.cla()
            draw_graph_edges(fig, axes, graph, attribute_name, colorbar=False)
            title = f"{background_path.name}"
            axes.set_title(title, fontsize=10)
            fig.tight_layout()
            return
        
        im = plt.imread(background_path)
        if im.ndim == 2:
            im = np.dstack([im, im, im])
    
    log_path = log_widget.value
    dataset = dataset_widget.value
    if (log_path is not None) & (dataset is not None):
        processing_log = SegmentationProcessingLog(log_path)
        rois = processing_log.get_rois_from_group(dataset)
        im = add_list_of_roi_boundaries_to_img(im, rois)
    axes.cla()
    axes.imshow(im)
    title = ""
    if background_path is not None:
        title += f"{background_path.name}"
    if log_path is not None:
        if title != "":
            title += "\n"
        title += f"{log_path.name} - {dataset}"
    axes.set_title(title, fontsize=10)
    
    if label_widget.value:
        add_labels_to_axes(axes, rois, (255, 0, 0), fontsize=6)
    fig.tight_layout()

In [None]:
@widgets.interact(nrows=[1, 2, 3], ncols=[1, 2, 3])
def update(nrows=1, ncols=1):
    # erase old figure
    fig = plt.figure(1)
    plt.close(fig)
    
    # make new figure
    fig, axes = plt.subplots(nrows, ncols, clear=True, sharex=True, sharey=True, num=1, squeeze=False)
    plt.show()
    fig.tight_layout()
    
    # make selectors for each axis and attach to callbacks
    backgrounds = new_background_selector(nrows, ncols, inspection_manifest["backgrounds"])
    processing_logs, datasets = new_processing_log_selector(nrows, ncols, inspection_manifest["processing_logs"])
    
    def on_change_logs(index):
        def on_change(change):
            """open the log and see what groups have ROIs in them.
            display the available groups in the datasets widget.
            """
            if change['type'] == 'change' and change['name'] == 'value':
                options = []
                with h5py.File(processing_logs[index].value, "r") as f:
                    for key in f.keys():
                        if isinstance(f[key], h5py.Group):
                            if "rois" in f[key]:
                                options.append(key)
                datasets[index].options = options
        return on_change
    
    for i in range(len(processing_logs)):
        processing_logs[i].observe(on_change_logs(i))
    
    label_checks = [widgets.Checkbox(description="include labels") for i in range(nrows*ncols)]
    partials = []
    for ax, bgw, logw, dataw, lw in zip(axes.flat, backgrounds, processing_logs, datasets, label_checks):
        partials.append(partial(update_plot,
                                fig=fig,
                                axes=ax,
                                log_widget=logw,
                                dataset_widget=dataw,
                                background_widget=bgw,
                                label_widget=lw))
    update_buttons = new_plot_update_buttons(nrows, ncols)
    for partial_fun, button in zip(partials, update_buttons):
        button.on_click(partial_fun)
    
    # group the selectors and display
    background_box = widgets.VBox([widgets.Label("backgrounds")] + backgrounds)
    log_selection = widgets.VBox([widgets.Label("processing logs")] + processing_logs)
    dataset_selection = widgets.VBox([widgets.Label("datasets")] + datasets)
    button_box = widgets.VBox([widgets.Label("update buttons")] + update_buttons)
    label_box = widgets.VBox([widgets.Label("include labels")] + label_checks)
    selector_box = widgets.HBox([background_box, log_selection, dataset_selection, label_box, button_box])
    display(selector_box)

In [None]:
def all_roi_dicts():
    results = dict()
    for log in inspection_manifest["processing_logs"]:
        groups = []
        with h5py.File(log, "r") as f:
            for key in f.keys():
                if isinstance(f[key], h5py.Group):
                    if "rois" in f[key]:
                        groups.append(key)
        splog = SegmentationProcessingLog(log)
        for group in groups:
            results[f"{log.name}-{group}"] = splog.get_rois_from_group(group)
    return results
    
movie_widget_list = [
    widgets.Checkbox(
        value=True,
        description=f.name,
        description_tooltip=str(f),
        layout={'width': 'max-content'}
    )
    for f in inspection_manifest["videos"]
]

rois_dict = all_roi_dicts()

movie_list = widgets.VBox(movie_widget_list)
roi_drops = [
    widgets.Dropdown(
        options=np.sort([-1] + [i["id"] for i in v]),
        description=k,
        layout={'width': 'max-content'},
        style={'description_width': 'initial'}
    )
    for k, v in rois_dict.items()]
roi_list = widgets.VBox(roi_drops)
trace_grouping = widgets.Dropdown(
    options=[
        ("group traces by ROI", 0),
        ("group traces by movie", 1)
    ])

In [None]:
def extents_from_roi(roi):
    xmin = roi["x"]
    xmax = xmin + roi["width"]
    ymin = roi["y"]
    ymax = ymin + roi["height"]
    return xmin, xmax, ymin, ymax


def get_trace(movie_path, roi):
    xmin, xmax, ymin, ymax = extents_from_roi(roi)
    with h5py.File(movie_path, "r") as f:
        data = f["data"][:, ymin: ymax, xmin: xmax]
    data = data.reshape(data.shape[0], -1)
    mask = np.array(roi["mask_matrix"]).reshape(data.shape[1])
    npix = np.count_nonzero(mask)
    trace = data[:, mask].sum(axis=1) / npix
    return trace


def plot_callback():
    # determine which ROIs are selected
    rois_lookup = dict()
    for roi_select in roi_drops:
        if roi_select.value != -1:
            rois_lookup[roi_select.description] = int(roi_select.value)
    rois = []
    for k, v in list(rois_lookup.items()):
        j = rois_dict[k]
        j = convert_roi_keys(j)
        for i in j:
            if i["id"] == v:
                rois_lookup[k] = i
                
    # determine which movie paths are selected
    movie_paths = []
    for movie_widget in movie_widget_list:
        if movie_widget.value:
            movie_paths.append(Path(movie_widget.description_tooltip))
    
    # get all combinations of ROIs and movie paths
    trace_list = []
    for roi_source, roi in rois_lookup.items():
        for movie_path in movie_paths:
            trace_list.append(
                {
                    "roi_source": roi_source,
                    "roi": roi,
                    "roi_id": roi["id"],
                    "movie_path": movie_path,
                    "movie_label": movie_path.name,
                    "roi_label": f"{roi_source}_{roi['id']}"
                }
            )
    
    # load traces in parallel
    args = [(i["movie_path"], i["roi"]) for i in trace_list]
    with multiprocessing.Pool(4) as pool:
        results = pool.starmap(get_trace, args)
    for i, result in enumerate(results):
        trace_list[i]["trace"] = result
    
    # group according to selected method
    df = pd.DataFrame.from_records(trace_list)
    if trace_grouping.value == 0:
        groups = df.groupby(["roi_source", "roi_id"])
        label = "movie_label"
    elif trace_grouping.value == 1:
        groups = df.groupby(["movie_label"])
        label = "roi_label"

    fig2, axes2 = plt.subplots(len(groups), 1, clear=True, sharex=True, sharey=False, squeeze=False)
    for group, ax in zip(groups, axes2.flat):
        if isinstance(group[0], tuple):
            ylab = "\n".join([f"{i}" for i in group[0]])
        else:
            ylab = group[0]
        ax.set_ylabel(ylab, fontsize=6)
        for entry in group[1].iterrows():
            ax.plot(entry[1]["trace"], linewidth=0.4, label=entry[1][label])

    axes2.flat[0].legend(fontsize=6)
    fig2.tight_layout()
    plt.show()

display(widgets.HBox(
    [widgets.VBox([widgets.HTML(value="<b>available movies</b>"),
                   movie_list]),
     widgets.VBox([widgets.HTML(value="<b>available ROIs</b>"),
                   roi_list])],
    layout={'width': 'max-content'}))
display(trace_grouping)
b = widgets.interact_manual(plot_callback, description="plot traces")
b.widget.children[0].description = "plot traces"

In [None]:
movie_list = inspection_manifest["videos"]

In [None]:
display_generator = VideoDisplayGenerator()

In [None]:
%%time
video_generator = VideoGenerator(movie_list[0])

Display full field of view

In [None]:
%%time
full_fov = video_generator.get_thumbnail_video(origin=(0,0), frame_shape=None, quality=4)

In [None]:
%%time
Video(**display_generator.display_video(full_fov, width=512, height=512))

Get and display a random thumbnail by hand

In [None]:
%%time
by_hand_thumbnail = video_generator.get_thumbnail_video(origin=(100, 200), frame_shape=(64, 64),
                                                        quality=5)

In [None]:
%%time
Video(**display_generator.display_video(by_hand_thumbnail))

Get and display a thumbnail containing an ROI

In [None]:
roi_file_list = list_roi_files(experiment_selector.value)
roi_file_path = roi_file_list[-1]
with open(roi_file_path, 'rb') as in_file:
    roi_list = json.load(in_file)

In [None]:
keys = list(rois_dict.keys())
raw_roi = rois_dict[keys[0]][6]

In [None]:
convert_roi_keys([raw_roi])[0].keys()

In [None]:
keys = list(rois_dict.keys())
raw_roi = convert_roi_keys([rois_dict[keys[0]][6]])[0]
roi = ExtractROI(x=raw_roi['x'], y=raw_roi['y'], width=raw_roi['width'], height=raw_roi['height'],
                 mask=raw_roi['mask_matrix'])

In [None]:
%%time
roi_thumbnail = video_generator.get_thumbnail_video_from_roi(roi, roi_color=(255,0,0), quality=9)

In [None]:
%%time
Video(**display_generator.display_video(roi_thumbnail, width=512, height=512))

Use padding kwarg to increase number of pixels on either side of the ROI

In [None]:
%%time
padded_roi_thumbnail = video_generator.get_thumbnail_video_from_roi(roi, padding=20, roi_color=(255,0,0), quality=7)

In [None]:
%%time
Video(**display_generator.display_video(padded_roi_thumbnail, width=512, height=512))

Display without the ROI's border

In [None]:
%%time
no_border_thumbnail = video_generator.get_thumbnail_video_from_roi(roi, quality=9)

In [None]:
%%time
Video(**display_generator.display_video(no_border_thumbnail, width=512, height=512))

Focus on timesteps where we know (from above) there is activity

In [None]:
%%time
t0 = 4*60+26
t1 = 4*60+38
timesteps = np.arange(t0*31,t1*31)
active_thumbnail = video_generator.get_thumbnail_video_from_roi(roi, roi_color=(255, 0, 0),
                                                                quality=9, timesteps=timesteps)

In [None]:
%%time
Video(**display_generator.display_video(active_thumbnail, width=512, height=512))