In [None]:
%matplotlib widget

import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
import ipywidgets as widgets
from functools import partial
import networkx as nx
import json
import numpy as np
import h5py
import multiprocessing

from ophys_etl.modules.segmentation.graph_utils import plotting as guplot
from ophys_etl.modules.segmentation.qc_utils.roi_utils import add_roi_boundaries_to_img, add_labels_to_axes, convert_keys

In [None]:
basedir = Path("/allen/programs/braintv/workgroups/nc-ophys/danielk/deepinterpolation/experiments")
experiment_ids = [785569470, 1048483611, 1048483613, 1048483616, 785569447,
                  788422859, 795901850, 788422825, 795897800, 850517348,
                  951980473, 951980484, 795901895, 806862946, 803965468,
                  806928824]
metadata_files = [basedir / f"ophys_experiment_{i}" / "metadata.csv" for i in experiment_ids]

### Available Experiments

In [None]:
frames = pd.concat([pd.read_csv(file, header=None).set_index(0).T for file in metadata_files]).set_index("ophys_experiment_id")
frames["depth"] = frames["depth"].astype(int)
frames = frames.sort_values(by=["code", "rig", "depth"]).reset_index()
display(frames)

In [None]:
def get_max_ave_projections(result_dir: Path):
    nmax = result_dir / "noised_maxp.png"
    navg = result_dir / "noised_avgp.png"
    dnmax = result_dir / "denoised_maxp.png"
    dnavg = result_dir / "denoised_avgp.png"

    nmaxim = np.array(Image.open(nmax))
    navgim = np.array(Image.open(navg))
    dnmaxim = np.array(Image.open(dnmax))
    dnavgim = np.array(Image.open(dnavg))

    return nmaxim, dnmaxim, navgim, dnavgim

In [None]:
def new_background_selector(nrows, ncols):
    available = []
    for ext in ["png"]:
        available.extend(
            sorted(list((experiment_selector.value / "backgrounds").rglob(f"*.{ext}"))))
    background_selector = [
        widgets.Dropdown(
        options = [(None, None)] + [(p.name, p) for p in available],
        description = f"({i}, {j})")
    for i in range(nrows)
    for j in range(ncols)]
    return background_selector

def new_foreground_selector(nrows, ncols):
    files = sorted(list((experiment_selector.value / "rois").rglob("*_rois.json")))
    foreground_selector = [
        widgets.Dropdown(
        options = [(None, None)] + [(f.name, f) for f in files],
        description = f"({i}, {j})")
    for i in range(nrows)
    for j in range(ncols)
    ]
    return foreground_selector

def new_plot_update_buttons(nrows, ncols):
    buttons = [
        widgets.Button(description="Update")
    for i in range(nrows)
    for j in range(ncols)
    ]
    return buttons

def update_plot(widget, axes, background_widget, foreground_widget):
    background_path = background_widget.value
    if background_path is None:
        im = np.ones((512, 512, 3), dtype="uint8") * 255
    else:
        im = plt.imread(background_path)
        if im.ndim == 2:
            im = np.dstack([im, im, im])
    
    foreground_path = foreground_widget.value
    if foreground_path is not None:
        with open(foreground_path, "r") as f:
            rois = json.load(f)
        im = add_roi_boundaries_to_img(im, rois)
    axes.cla()
    axes.imshow(im)
    title = ""
    if background_path is not None:
        title += f"{background_path.name}"
    if foreground_path is not None:
        if title != "":
            title += "\n"
        title += f"{foreground_path.name}"
    axes.set_title(title, fontsize=10)

In [None]:
df_per_row = [pd.DataFrame(i[1]).T for i in frames.iterrows()]
col_widths = [max([len(f"{i}") for i in frames[c].values]) for c in frames.columns]
options = ["".join([f"{i:{w + 4}}" for i, w in zip(k.values[0], col_widths)]) for k in df_per_row]
eids = [int(i["ophys_experiment_id"].values[0]) for i in df_per_row]
values = [basedir / f"ophys_experiment_{i}" for i in eids]
experiment_selector = widgets.Dropdown(
    options=[(o, v) for o, v in zip(options, values)],
    layout={'width': 'max-content'}, # If the items' names are long
    description='Select experiment:',
    disabled=False
)

def on_change_experiment(change):
    if change['type'] == 'change' and change['name'] == 'value':
        pass

experiment_selector.observe(on_change_experiment)
display(experiment_selector)

In [None]:
@widgets.interact(nrows=[1, 2, 3], ncols=[1, 2, 3])
def update(nrows=1, ncols=1):
    # erase old figure
    fig = plt.figure(1)
    plt.close(fig)
    
    # make new figure
    fig, axes = plt.subplots(nrows, ncols, clear=True, sharex=True, sharey=True, num=1, squeeze=False)
    plt.show()
    fig.tight_layout()
    
    # make selectors for each axis and attach to callbacks
    backgrounds = new_background_selector(nrows, ncols)
    foregrounds = new_foreground_selector(nrows, ncols)
    partials = []
    for ax, bgw, fgw in zip(axes.flat, backgrounds, foregrounds):
        partials.append(partial(update_plot,
                                axes=ax,
                                foreground_widget=fgw,
                                background_widget=bgw))
    update_buttons = new_plot_update_buttons(nrows, ncols)
    for partial_fun, button in zip(partials, update_buttons):
        button.on_click(partial_fun)
    
    # group the selectors and display
    background_box = widgets.VBox([widgets.Label("backgrounds")] + backgrounds)
    foreground_box = widgets.VBox([widgets.Label("foregrounds")] + foregrounds)
    button_box = widgets.VBox([widgets.Label("update buttons")] + update_buttons)
    selector_box = widgets.HBox([background_box, foreground_box, button_box])
    display(selector_box)

In [None]:
movie_widget_list = [
    widgets.Checkbox(
        value=True,
        description=f.name,
        layout={'width': 'max-content'}
    )
    for f in sorted(list((experiment_selector.value / "videos").rglob("*.h5")))
]
rois_lists = sorted(list((experiment_selector.value / "rois").rglob("*_rois.json")))
rois_dict = dict()
for r in rois_lists:
    with open(r, "r") as f:
        rois_dict[r.name] = json.load(f)
movie_list = widgets.VBox(movie_widget_list)
roi_drops = [
    widgets.Dropdown(
        options=np.sort([-1] + [i["id"] for i in v]),
        description=k,
        layout={'width': 'max-content'},
        style={'description_width': 'initial'}
    )
    for k, v in rois_dict.items()]
roi_list = widgets.VBox(roi_drops)
trace_grouping = widgets.Dropdown(
    options=[
        ("group traces by ROI", 0),
        ("group traces by movie", 1)
    ])

In [None]:
def extents_from_roi(roi):
    xmin = roi["x"]
    xmax = xmin + roi["width"]
    ymin = roi["y"]
    ymax = ymin + roi["height"]
    return xmin, xmax, ymin, ymax


def get_trace(movie_path, roi):
    xmin, xmax, ymin, ymax = extents_from_roi(roi)
    with h5py.File(movie_path, "r") as f:
        data = f["data"][:, ymin: ymax, xmin: xmax]
    data = data.reshape(data.shape[0], -1)
    mask = np.array(roi["mask_matrix"]).reshape(data.shape[1])
    npix = np.count_nonzero(mask)
    trace = data[:, mask].sum(axis=1) / npix
    return trace


def plot_callback():
    # determine which ROIs are selected
    rois_lookup = dict()
    for roi_select in roi_drops:
        if roi_select.value != -1:
            rois_lookup[roi_select.description] = int(roi_select.value)
    rois = []
    for k, v in list(rois_lookup.items()):
        with open(experiment_selector.value / "rois" / k, "r") as f:
            j = json.load(f)
        j = convert_keys(j)
        for i in j:
            if i["id"] == v:
                rois_lookup[k] = i
                
    # determine which movie paths are selected
    movie_paths = []
    for movie_widget in movie_widget_list:
        if movie_widget.value:
            movie_paths.append(experiment_selector.value / "videos" / movie_widget.description)
    
    # get all combinations of ROIs and movie paths
    trace_list = []
    for roi_source, roi in rois_lookup.items():
        for movie_path in movie_paths:
            trace_list.append(
                {
                    "roi_source": roi_source,
                    "roi": roi,
                    "roi_id": roi["id"],
                    "movie_path": movie_path,
                    "movie_label": movie_path.name,
                    "roi_label": f"{roi_source}_{roi['id']}"
                }
            )
    
    # load traces in parallel
    args = [(i["movie_path"], i["roi"]) for i in trace_list]
    with multiprocessing.Pool(4) as pool:
        results = pool.starmap(get_trace, args)
    for i, result in enumerate(results):
        trace_list[i]["trace"] = result
    
    # group according to selected method
    df = pd.DataFrame.from_records(trace_list)
    if trace_grouping.value == 0:
        groups = df.groupby(["roi_source", "roi_id"])
        label = "movie_label"
    elif trace_grouping.value == 1:
        groups = df.groupby(["movie_label"])
        label = "roi_label"

    fig2, axes2 = plt.subplots(len(groups), 1, clear=True, sharex=True, sharey=True, squeeze=False)
    for group, ax in zip(groups, axes2.flat):
        if isinstance(group[0], tuple):
            ylab = "\n".join([f"{i}" for i in group[0]])
        else:
            ylab = group[0]
        ax.set_ylabel(ylab, fontsize=6)
        for entry in group[1].iterrows():
            ax.plot(entry[1]["trace"], linewidth=0.4, label=entry[1][label])

    axes2.flat[0].legend(fontsize=6)
    fig2.tight_layout()
    plt.show()

display(widgets.HBox(
    [widgets.VBox([widgets.HTML(value="<b>available movies</b>"),
                   movie_list]),
     widgets.VBox([widgets.HTML(value="<b>available ROIs</b>"),
                   roi_list])],
    layout={'width': 'max-content'}))
display(trace_grouping)
b = widgets.interact_manual(plot_callback, description="plot traces")
b.widget.children[0].description = "plot traces"