In [2]:
%load_ext autoreload
%autoreload 2
!which python

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
/home/jonas/.cache/pypoetry/virtualenvs/data-quality-wearables-JR_qNb0v-py3.9/bin/python


In [3]:
import sys

import numpy as np
import pandas as pd
import plotly.graph_objects as go
from functional import seq
from IPython.display import display
from plotly.subplots import make_subplots
from plotly_resampler import FigureWidgetResampler
from plotly_resampler.aggregation import MinMaxLTTB, MedDiffGapHandler

pd.options.display.max_columns = None
pd.options.display.max_colwidth = None

sys.path.append("../..")
from code_utils.path_conf import mbrain_metadata_path, processed_mbrain_path

## Widgets

In [4]:
import ipywidgets as widgets; from datetime import datetime

# 1. File selection widgets
patient_widget = widgets.Dropdown(
    options=sorted(
        seq(mbrain_metadata_path.iterdir()).filter(lambda x: x.is_dir()).map(lambda x: x.name)
    )
)

t_start_widget = widgets.Dropdown(options=[])
t_end_widget = widgets.Dropdown(options=[])
wearable_sensor_widget = widgets.SelectMultiple(options=[])


def update_time_range_widget(change):
    min_time, max_time = (
        pd.to_datetime(
            pd.read_csv(mbrain_metadata_path.joinpath(patient_widget.value, "event_dump.csv"))[
                "time"
            ]
            .quantile([0, 1])
            .values,
            unit="ms",
        )
        .tz_localize("utc")
        .tz_convert("Europe/Brussels")
    )

    min_time, max_time = min_time.date(), max_time.date()

    t_end_widget.options = (
        seq(
            pd.date_range(
                min_time + pd.Timedelta(days=1),
                max_time + pd.Timedelta(days=2),
                freq="D",
            )
        )
        .map(lambda x: x.strftime("%Y_%m_%d"))
        .to_list()
    )
    t_end_widget.value = t_end_widget.options[-1]

    t_start_widget.options = (
        seq(pd.date_range(min_time - pd.Timedelta(days=1), max_time, freq="D"))
        .map(lambda x: x.strftime("%Y_%m_%d"))
        .to_list()
    )


patient_widget.observe(update_time_range_widget, "value")


def update_t_end_widget(change):
    if t_start_widget.value is None or t_end_widget.value is None:
        return
    t_start_date = datetime.strptime(t_start_widget.value, "%Y_%m_%d")
    t_end_date = datetime.strptime(t_end_widget.value, "%Y_%m_%d")

    if t_start_date >= t_end_date:
        t_end_date = t_start_date + pd.Timedelta(days=1)
        t_end_widget.value = t_end_date.strftime("%Y_%m_%d")


t_start_widget.observe(update_t_end_widget, "value")


Columns (49,51,52,53,54,55,56,57,59,61,63) have mixed types. Specify dtype option on import or set low_memory=False.



In [5]:
def _filter_glob_dates(glob, t_start: pd.Timestamp, t_end: pd.Timestamp):
    return (
        seq(glob)
        .filter(
            lambda x: t_start
            <= datetime.strptime("_".join(x.stem.split("_")[-3:]), "%Y_%m_%d")
            <= t_end
        )
        .to_list()
    )


def get_user_data_interval(user, t_start, t_end, modality, empatica=True):
    glob = processed_mbrain_path.glob(f"{user}*.E4*/{modality}*.parquet")
    if not empatica:
        glob = set(processed_mbrain_path.glob(f"{user}/{modality}*.parquet")) - set(
            glob
        )
    glob = _filter_glob_dates(glob, t_start, t_end)
    if len(glob):
        out = pd.concat([pd.read_parquet(f) for f in glob])
        if "timestamp" in out.columns:
            out = out.set_index("timestamp")
        out.sort_index(inplace=True)
        return out[~out.index.duplicated(keep="first")]


In [14]:
label_dict = {
    "sleep-period": {
        "type": "x-range",
        "plt_kwargs": {"line_width": 0.5, "fillcolor": "green", "opacity": 0.15},
    },
    "incomplete-sleep-period": {
        "type": "x-range",
        "plt_kwargs": {"line_width": 0.5, "fillcolor": "darkgreen", "opacity": 0.15},
    },
    "wake-up": {
        "type": "x-range",
        "plt_kwargs": {"line_width": 0.5, "fillcolor": "purple", "opacity": 0.15},
    },
    "off-wrist": {
        "type": "x-range",
        "plt_kwargs": {"line_width": 0.5, "fillcolor": "red", "opacity": 0.15},
    },
    "snooze": {
        "type": "x-range",
        "plt_kwargs": {"line_width": 0.5, "fillcolor": "gray", "opacity": 0.15},
    },
}


fw_fig = FigureWidgetResampler(default_downsampler=MinMaxLTTB(parallel=True))
meta_cols = ["label", "start", "end", "patient_id"]
try:
    len(meta_list)
    df_meta = pd.DataFrame(meta_list, columns=meta_cols)
except NameError:
    df_meta = pd.read_csv("meta.csv")
    meta_list = list(df_meta.values)

prev_x = []
point_list = []

# Create a label selector
label_selector = widgets.Dropdown()
label_selector.options = list(label_dict.keys())


def update_point(trace, points, selector):
    if not len(points.xs):
        return

    # print(trace, points, selector)
    global prev_x, point_list, meta_list, fw_fig
    config = label_dict[label_selector.value]

    if config.get("type", "") == "x-range":
        prev_x.append(points.xs[0])
        if len(prev_x) == 2:
            fw_fig.add_vrect(prev_x[0], prev_x[1], **config.get("plt_kwargs", {}))
            # add a row to df_meta
            meta_list.append(
                [
                    label_selector.value,
                    prev_x[0],
                    prev_x[1],
                    patient_widget.value,
                    pd.Timestamp(datetime.utcnow(), tz="UTC"),
                ]
            )
            prev_x = []


# ignore jupyter_client warnings
import warnings

warnings.filterwarnings("ignore", module="jupyter_client")
warnings.filterwarnings("ignore", module="pandas")


@widgets.interact_manual
def test(
    patient_id=patient_widget,
    t_start=t_start_widget,
    t_end=t_end_widget,
    label=label_selector,
):
    global fw_fig, df_meta
    t_start = pd.Timestamp(datetime.strptime(t_start, "%Y_%m_%d"))
    t_end = pd.Timestamp(datetime.strptime(t_end, "%Y_%m_%d"))

    # fmt: off
    e4_acc = get_user_data_interval(patient_id, t_start, t_end, "acc").add_prefix( "e4_")
    if not len(e4_acc):
        print("data is too short, returning")
        print(f"{e4_acc.shape[0]:,}")
        return


    n_rows = 3
    fw_fig = FigureWidgetResampler(
        make_subplots(rows=n_rows, cols=1, shared_xaxes=True, vertical_spacing=0.08,
                      specs=np.array([[{"secondary_y": True}] * n_rows]).reshape(-1, 1).tolist(),
                      subplot_titles=['E4: skin temperature & skin conductance', 'E4: ACC', 'Phone: accelerometer-x & E4: BVP + <b>Annotations</b>']
                      ), default_n_shown_samples=2000)

    # Sort the series by name
    # proc_series = sorted(proc_series, key=lambda x: x.name)
    e4_tmp = get_user_data_interval(patient_id, t_start, t_end, "tmp").add_prefix( "e4_")
    # e4_ibi = get_user_data_interval(patient_id, t_start, t_end, "ibi").add_prefix( "e4_")
    e4_eda = get_user_data_interval(patient_id, t_start, t_end, "gsr").add_prefix( "e4_")
    e4_bvp = get_user_data_interval(patient_id, t_start, t_end, "bvp").add_prefix( "e4_")
    for c in e4_eda.columns:
        fw_fig.add_trace(go.Scattergl(name=c, opacity=.5, line_width=0, fill='tozeroy', line_shape='vh', legend='legend1'), 
        gap_handler=MedDiffGapHandler(fill_value=0), secondary_y=True, hf_x=e4_eda.index, hf_y=e4_eda[c], row=1, col=1)
    for c in e4_tmp.columns:
        fw_fig.add_trace(go.Scattergl(name=c, legend='legend1'), hf_x=e4_tmp.index, hf_y=e4_tmp[c], row=1, col=1)
    # for c in e4_ibi.columns:
    #     fw_fig.add_trace(go.Scattergl(name=c, legend='legend1'), secondary_y=True, hf_x=e4_ibi.index, hf_y=e4_ibi[c], row=1, col=1)

    ph_acc = get_user_data_interval(patient_id, t_start, t_end, "acc",empatica=False).drop(columns=['y', 'z'], errors='ignore')
    for c in ph_acc.columns:
                fw_fig.add_trace(go.Scattergl(name=c, legend='legend3'), hf_x=ph_acc.index, hf_y=ph_acc[c], row=3, col=1)

    for c in e4_bvp.columns:
        fw_fig.add_trace(go.Scattergl(name=c, visible='legendonly', legend='legend3'), secondary_y=True, hf_x=e4_bvp.index, hf_y=e4_bvp[c], row=3, col=1)

    fw_fig.update_yaxes(range=[0, 1.5], row=1, col=1, secondary_y=True)
    for c in list(set(e4_acc.columns)): #.difference({"e4_ACC_y", "e4_ACC_z"})):
        fw_fig.add_trace(go.Scattergl(name=c, legend='legend2', opacity=0.7), hf_x=e4_acc.index, hf_y=e4_acc[c] / 64, row=2, col=1)

    fw_fig.update_layout(template="plotly_white", height=650, hovermode="x unified")
    fw_fig.update_traces(xaxis='x3')

    for date in pd.date_range(t_start, t_end, freq='D'):
        if date.dayofweek == 5:
            fw_fig.add_vrect(date, date + pd.Timedelta(days=2), fillcolor='gray', opacity=0.15, line_color='black', row=n_rows, col=1)


    for trace in fw_fig.data:
        if 'e4_acc' in trace.name.lower():
            trace.on_click(update_point)
            break

    if len(meta_list):
        df_meta = pd.DataFrame(meta_list, columns=meta_cols[:])
    for _, r in df_meta[(df_meta.patient_id == patient_id)].iterrows():
        fw_fig.add_vrect(r.start, r.end, **label_dict[r.label]["plt_kwargs"], row=n_rows, col=1)
    # fmt: on

    fw_fig.update_layout(
        margin=dict(l=0, r=0, t=20, b=0),
        # legend1={
        #     "x": 0.0,
        #     "y": 1.03,
        #     **{"xref": "paper", "yref": "paper"},
        #     **{"xanchor": "left", "yanchor": "top", "orientation": "h"}
        # },
        # legend2={
        #     "x": 0.0,
        #     "y": 0.68,
        #     **{"xref": "paper", "yref": "paper"},
        #     **{"xanchor": "left", "yanchor": "top", "orientation": "h"}
        # },
        # legend3={
        #     "x": 0.0, "y": 0.34,
        #     **{"xref": "paper", "yref": "paper"},
        #     **{"xanchor": "left", "yanchor": "top", "orientation": "h"}
        # },
    )
    # hide the y axis ticks labels
    fw_fig.update_yaxes(showticklabels=False, row=3, col=1, secondary_y=True)
    fw_fig.update_yaxes(showticklabels=False, row=3, col=1, secondary_y=False)
    # hide the mode bar
    # return fw_fig.show(config={"displayModeBar": False})
    display(fw_fig)


interactive(children=(Dropdown(description='patient_id', index=45, options=('Antds', 'COPIMAC001', 'COPIMAC003…