In [1]:
# Jupyter settings and Imports

# %load_ext autoreload
# %autoreload 2
# %flow mode reactive

from datetime import date
import ipdb
from itertools import product
from pathlib import Path

from dotmap import DotMap
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objs as go
import seaborn as sns

import aeon.io.api as api
from aeon.io import reader
from aeon.schema.dataset import exp02, exp01
from aeon.analysis.utils import visits, distancetravelled

In [2]:
# Get sessions
pd.set_option('display.max_columns', 100)
pd.set_option('display.max_rows', 100)

roots = [Path("/ceph/aeon/aeon/data/raw/AEON3/presocial0.1"), Path("/ceph/aeon/aeon/data/raw/AEON2/presocial0.1")]
if not np.all([path.exists() for path in roots]):
    print("Cannot find root paths. Check path names or connection.")
subject_events = api.load(roots, exp02.ExperimentalMetadata.SubjectState)
sessions = visits(subject_events[subject_events.id.str.startswith("BAA")])

In [3]:
# Prettify sessions

pd.options.mode.chained_assignment = None  # turn off "SettingWithCopy" warning for this cell

sessions = sessions[sessions.enter.dt.date >= date(2023, 3, 10)]
sessions.loc[:, ("weight_enter")] = sessions["weight_enter"].astype(float).round(1)
sessions.loc[:, ("weight_exit")] = sessions["weight_exit"].astype(float).round(1)
sessions.loc[:, ("enter")] = sessions["enter"].dt.floor("1s")
sessions.loc[:, ("exit")] = sessions["exit"].dt.ceil("1s")
sessions.loc[:, ("duration")] = sessions["duration"].round("1s")
sessions = sessions[["id", "enter", "exit", "duration", "weight_enter", "weight_exit"]]
sessions = sessions.sort_values(by="enter")
sessions = sessions.reset_index()
sessions = sessions.drop(columns=["index"])
pd.options.mode.chained_assignment = "warn"
display(sessions)

Unnamed: 0,id,enter,exit,duration,weight_enter,weight_exit
0,BAA-1103050,2023-03-10 09:41:48,2023-03-10 12:55:19,0 days 03:13:30,23.2,23.9
1,BAA-1103045,2023-03-10 12:12:45,2023-03-10 15:22:14,0 days 03:09:28,23.0,23.7
2,BAA-1103048,2023-03-10 13:08:24,2023-03-10 16:16:31,0 days 03:08:06,22.5,24.6
3,BAA-1103047,2023-03-10 15:27:05,2023-03-10 19:10:44,0 days 03:43:38,19.8,21.0
4,BAA-1103049,2023-03-10 16:22:29,2023-03-10 19:21:50,0 days 02:59:20,20.9,22.6
5,BAA-1103048,2023-03-15 09:26:49,2023-03-15 10:59:51,0 days 01:33:01,26.4,25.7
6,BAA-1103044,2023-03-17 14:44:00,2023-03-17 19:15:43,0 days 04:31:42,25.0,23.0
7,BAA-1103048,2023-03-23 07:56:44,2023-03-23 11:08:34,0 days 03:11:49,24.8,24.3
8,BAA-1103045,2023-03-23 10:16:38,2023-03-23 13:19:43,0 days 03:03:03,23.6,25.1
9,BAA-1103049,2023-03-23 11:15:29,2023-03-23 14:23:41,0 days 03:08:11,22.0,24.4


In [None]:
# Get bad sessions basedmin Get 'DispenserBroken' and 'Annotation' messages
message_log_aeon3 = api.load(str(roots[0]), exp02.ExperimentalMetadata.MessageLog)
print(f"Aeon3 messages:\n")
display(message_log_aeon3[np.logical_or(message_log_aeon3.type == "DispenserBroken", message_log_aeon3.type == "Annotation")])
print(f"\n\n")
message_log_aeon2 = api.load(str(roots[0]), exp02.ExperimentalMetadata.MessageLog)
print(f"Aeon2 messages:\n")
display(message_log_aeon2[np.logical_or(message_log_aeon2.type == "DispenserBroken", message_log_aeon2.type == "Annotation")])

In [4]:
# Based on above, manually decide which are bad sessions, and drop these from `sessions`

bad_sessions = DotMap()
# Bad sessions reasons:
# 0: bugs in workflow
# 1: rfid session
# 2: stayed on only one patch since session start
# 3: poop stuck on wheel made it hard to turn
bad_sessions.ids = (
    "BAA-1103048", 
    "BAA-1103044",
    "BAA-1103050",
    "BAA-1103048",
)
bad_sessions.dates = (
    date(2023, 3, 15),  # bugs in workflow
    date(2023, 3, 17),  # rfid session
    date(2023, 3, 24),  # only stayed on one patch from beginning
    date(2023, 3, 24),  # poop stuck on wheel
)

for i in range(len(bad_sessions.ids)):
    i_bad_sesh = np.where(np.logical_and(
        sessions.id == bad_sessions.ids[i], sessions.enter.dt.date == bad_sessions.dates[i]))[0]
    sessions.drop(index=sessions.iloc[i_bad_sesh].index, inplace=True)
sessions = sessions.sort_values(by="enter")
sessions = sessions.reset_index()
sessions = sessions.drop(columns=["index"])

In [10]:
# Declare some set-up variables to help with analysis

# Specify which animals in which room
in_b2_210 = ("48", "49", "50")
in_465 = ("45", "47")

# Columns to add to table
new_cols = (
    "post_thresh_dur", "post_thresh_both_p_sampled_dur",
    "pre_sampling_both_p_dur", "easy_patch", "hard_patch", 
    "post_easy_rate", "post_hard_rate", "pre_easy_n_pel", "pre_hard_n_pel", 
    "post_easy_n_pel", "post_hard_n_pel", "pre_easy_wheel_dist", "pre_hard_wheel_dist",
    "post_easy_wheel_dist", "post_hard_wheel_dist", "pre_easy_pref", "post_easy_pref",
    "pre_hard_pref", "post_hard_pref", "post_pre_easy_pref", "post_easy_pel_thresh", 
    "post_easy_pel_thresh_idx", "post_hard_pel_thresh", "post_hard_pel_thresh_idx", "init_pref_by_pel_ct",
)
for col in new_cols:
    sessions[col] = np.nan
sessions["post_easy_pel_thresh"] = sessions["post_easy_pel_thresh"].astype(object)
sessions["post_hard_pel_thresh"] = sessions["post_hard_pel_thresh"].astype(object)
sessions["post_easy_pel_thresh_idx"] = sessions["post_easy_pel_thresh_idx"].astype(object)
sessions["post_hard_pel_thresh_idx"] = sessions["post_hard_pel_thresh_idx"].astype(object)
sessions["init_pref_by_pel_ct"] = sessions["init_pref_by_pel_ct"].astype(object)
display(sessions)

Unnamed: 0,id,enter,exit,duration,weight_enter,weight_exit,post_thresh_dur,post_thresh_both_p_sampled_dur,pre_sampling_both_p_dur,easy_patch,hard_patch,post_easy_rate,post_hard_rate,pre_easy_n_pel,pre_hard_n_pel,post_easy_n_pel,post_hard_n_pel,pre_easy_wheel_dist,pre_hard_wheel_dist,post_easy_wheel_dist,post_hard_wheel_dist,pre_easy_pref,post_easy_pref,pre_hard_pref,post_hard_pref,post_pre_easy_pref,post_easy_pel_thresh,post_easy_pel_thresh_idx,post_hard_pel_thresh,post_hard_pel_thresh_idx,init_pref_by_pel_ct
0,BAA-1103050,2023-03-10 09:41:48,2023-03-10 12:55:19,0 days 03:13:30,23.2,23.9,,,,,,,,,,,,,,,,,,,,,,,,,
1,BAA-1103045,2023-03-10 12:12:45,2023-03-10 15:22:14,0 days 03:09:28,23.0,23.7,,,,,,,,,,,,,,,,,,,,,,,,,
2,BAA-1103048,2023-03-10 13:08:24,2023-03-10 16:16:31,0 days 03:08:06,22.5,24.6,,,,,,,,,,,,,,,,,,,,,,,,,
3,BAA-1103047,2023-03-10 15:27:05,2023-03-10 19:10:44,0 days 03:43:38,19.8,21.0,,,,,,,,,,,,,,,,,,,,,,,,,
4,BAA-1103049,2023-03-10 16:22:29,2023-03-10 19:21:50,0 days 02:59:20,20.9,22.6,,,,,,,,,,,,,,,,,,,,,,,,,
5,BAA-1103048,2023-03-23 07:56:44,2023-03-23 11:08:34,0 days 03:11:49,24.8,24.3,,,,,,,,,,,,,,,,,,,,,,,,,
6,BAA-1103045,2023-03-23 10:16:38,2023-03-23 13:19:43,0 days 03:03:03,23.6,25.1,,,,,,,,,,,,,,,,,,,,,,,,,
7,BAA-1103049,2023-03-23 11:15:29,2023-03-23 14:23:41,0 days 03:08:11,22.0,24.4,,,,,,,,,,,,,,,,,,,,,,,,,
8,BAA-1103047,2023-03-23 13:29:36,2023-03-23 16:32:51,0 days 03:03:14,23.2,22.4,,,,,,,,,,,,,,,,,,,,,,,,,
9,BAA-1103050,2023-03-23 14:30:26,2023-03-23 17:29:18,0 days 02:58:51,23.5,24.9,,,,,,,,,,,,,,,,,,,,,,,,,


In [None]:
for s in sessions.itertuples():
    root = str(roots[0]) if np.any([s.id.endswith(sid) for sid in in_b2_210]) else str(roots[1])  # get root for current session
    harp_reader = reader.Harp(pattern="Patch1_35", columns=["TriggerPellet"])
    new_pellet_trig_bitmask = api.load(root, harp_reader, start=s.enter, end=s.exit).iloc[0, 0]
    new_pellet_trig_reader_p1 = reader.BitmaskEvent("Patch1_35", new_pellet_trig_bitmask, "TriggerPellet")
    new_pellet_trig_reader_p2 = reader.BitmaskEvent("Patch2_35", new_pellet_trig_bitmask, "TriggerPellet")
    p1 = api.load(root, new_pellet_trig_reader_p1, start=s.enter, end=s.exit)
    p2 = api.load(root, new_pellet_trig_reader_p2, start=s.enter, end=s.exit)
    pstate1 = api.load(root, exp02.Patch1.DepletionState, start=s.enter, end=s.exit)
    pstate2 = api.load(root, exp02.Patch2.DepletionState, start=s.enter, end=s.exit)
    encoder1 = api.load(root, exp02.Patch1.Encoder, start=s.enter, end=s.exit)
    w1 = -distancetravelled(encoder1.angle)
    encoder2 = api.load(root, exp02.Patch2.Encoder, start=s.enter, end=s.exit)
    w2 = -distancetravelled(encoder2.angle)
    # PelletTrig cleaning: remove repeated deliveries (events <1.5 s apart) and manual deliveries (201)
    p1 = p1.drop(p1.index[np.where(np.diff(p1.index).astype("float64") < 1.5e9)[0]])
    p2 = p2.drop(p2.index[np.where(np.diff(p2.index).astype("float64") < 1.5e9)[0]])
    harp_reader = reader.Harp(pattern="Patch1_201", columns=["ExperimenterDeliveries"])
    user_p1 = api.load(root, harp_reader, start=s.enter, end=s.exit)
    harp_reader = reader.Harp(pattern="Patch2_201", columns=["ExperimenterDeliveries"])
    user_p2 = api.load(root, harp_reader, start=s.enter, end=s.exit)
    if not user_p1.empty:
        user_p1_idxs = np.abs(np.subtract.outer(user_p1.index, p1.index)).argmin(axis=1)
        p1.drop(p1.index[user_p1_idxs])
    if not user_p2.empty:
        user_p2_idxs = np.abs(np.subtract.outer(user_p2.index, p2.index)).argmin(axis=1)
        p2.drop(p2.index[user_p2_idxs])
    both_pellet_data = pd.concat([p1, p2]).sort_index()
    # PatchState cleaning: remove NaNs; remove updates <1.5s apart (bug updates)
    pstate1.dropna(inplace=True)
    good_indxs = np.concatenate((np.diff(pstate1.index).astype("float64") > 1.5e9, [True]))
    pstate1 = pstate1[good_indxs]
    pstate2.dropna(inplace=True)
    good_indxs = np.concatenate((np.diff(pstate2.index).astype("float64") > 1.5e9, [True]))
    pstate2 = pstate2[good_indxs]
    # Clean known issues in particular sessions
    if s.enter == pd.Timestamp("2023-03-24 14:22:48"):  # last threshold update of 75 for some reason
        pstate1 = pstate1.drop(pstate1.index[-1])
        pstate2 = pstate2.drop(pstate2.index[-1])
    if s.enter == pd.Timestamp("2023-03-10 13:08:24"):  # TriggerPellet at very end of session for some reason
        p2 = p2.drop(p2.index[-1])
    if s.enter == pd.Timestamp("2023-03-24 15:19:57"):  # some really weird bug around 18:00 in both arenas this day
        pstate1 = pstate1[pstate1.index < pd.Timestamp("2023-03-24 18:00:00")]
        pstate2 = pstate2[pstate2.index < pd.Timestamp("2023-03-24 18:00:00")]
       
    # Check lengths of PelletTrigger and PatchState events
    if ((len(pstate1) - len(p1)) not in (1, 2)) or ((len(pstate2) - len(p2)) not in (1, 2)):
        raise Exception(
            f"PelletTrigger-PatchState mismatch: \n"
            f"len(p1) = {len(p1)} \n"
            f"len(p2) = {len(p2)} \n"
            f"len(pstate1) = {len(pstate1)} \n"
            f"len(pstate2) = {len(pstate2)} \n"
        )
    both_state_data = pd.concat([pstate1, pstate2]).sort_index()
    # Find threshold-change ts
    thresh_change_idx = np.where(np.diff(both_state_data.threshold) > 1)[0][1]
    safe_change_ts = change_ts = both_state_data.index[thresh_change_idx]
    sessions.loc[s.Index, "post_thresh_dur"] = post_thresh_dur = (s.exit - change_ts).round("1s")
    # if (len(p2[p2.index > change_ts]) > 0) and (len(p1[p1.index > change_ts]) > 0):
    #     safe_change_ts = pd.Series((p1[p1.index > change_ts].index[0], p2[p2.index > change_ts].index[0])).max()
    #     sessions.loc[s.Index, "post_thresh_both_p_sampled_dur"] = post_thresh_both_p_sampled_dur = (s.exit - safe_change_ts).round("1s")
    # else:
    #     safe_change_ts = change_ts
    # Find both-patches-sampled ts
    both_patches_sampled_ts = pd.Series((p1.index[0], p2.index[0])).max()
    sessions.loc[s.Index, "pre_sampling_both_p_dur"] = pre_sampling_b_patches_dur = (both_patches_sampled_ts - s.enter).round("1s")
    if (np.any(p1.index > safe_change_ts) and np.any(p2.index > safe_change_ts)):
        both_patches_sampled_ts_post = (
            pd.Series((p1.index[p1.index > safe_change_ts][0], 
                       p2.index[p2.index > safe_change_ts][0])).max()
        )
        sessions.loc[s.Index, "post_thresh_both_p_sampled_dur"] = (
            (both_patches_sampled_ts_post - safe_change_ts).round("1s")
        )
    sessions.loc[s.Index, "hard_patch"] = hard_patch = 1 if (pstate1["delta"][-1] < pstate2["delta"][-1]) else 2
    sessions.loc[s.Index, "easy_patch"] = easy_patch = 1 if (hard_patch == 2) else 2
    sessions.loc[s.Index, "post_hard_rate"] = post_hard_rate = pstate1["delta"][-1] if (hard_patch == 1) else pstate2["delta"][-1]
    sessions.loc[s.Index, "post_easy_rate"] = post_easy_rate = pstate1["delta"][-1] if (hard_patch == 2) else pstate2["delta"][-1]
    whard = w1 if (hard_patch == 1) else w2
    weasy = w1 if (easy_patch == 1) else w2
    p1_pre_n_pel = len(p1[p1.index <= (safe_change_ts + pd.Timedelta("1s"))])  # ensure we don't count last pellet in pre as first pellet in post
    p1_post_n_pel = len(p1[p1.index > (safe_change_ts + pd.Timedelta("1s"))])
    p2_pre_n_pel = len(p2[p2.index <= (safe_change_ts + pd.Timedelta("1s"))])
    p2_post_n_pel = len(p2[p2.index > (safe_change_ts + pd.Timedelta("1s"))])
    sessions.loc[s.Index, "pre_easy_n_pel"] = pre_easy_n_pel = p1_pre_n_pel if (easy_patch == 1) else p2_pre_n_pel
    sessions.loc[s.Index, "pre_hard_n_pel"] = pre_hard_n_pel = p1_pre_n_pel if (hard_patch == 1) else p2_pre_n_pel
    p1_pre_wheel_dist = w1[w1.index > safe_change_ts][0] - w1[0]
    p2_pre_wheel_dist = w2[w2.index > safe_change_ts][0] - w2[0]
    sessions.loc[s.Index, "pre_easy_wheel_dist"] = pre_easy_wheel_dist = p1_pre_wheel_dist if (easy_patch == 1) else p2_pre_wheel_dist
    sessions.loc[s.Index, "pre_hard_wheel_dist"] = pre_hard_wheel_dist = p1_pre_wheel_dist if (hard_patch == 1) else p2_pre_wheel_dist
    sessions.loc[s.Index, "post_easy_n_pel"] = post_easy_n_pel = p1_post_n_pel if (easy_patch == 1) else p2_post_n_pel
    sessions.loc[s.Index, "post_hard_n_pel"] = post_hard_n_pel = p1_post_n_pel if (hard_patch == 1) else p2_post_n_pel
    p1_post_wheel_dist = w1[-1] - p1_pre_wheel_dist
    p2_post_wheel_dist = w2[-1] - p2_pre_wheel_dist
    sessions.loc[s.Index, "post_easy_wheel_dist"] = post_easy_wheel_dist = p1_post_wheel_dist if (easy_patch == 1) else p2_post_wheel_dist
    sessions.loc[s.Index, "post_hard_wheel_dist"] = post_hard_wheel_dist = p1_post_wheel_dist if (hard_patch == 1) else p2_post_wheel_dist
    sessions.loc[s.Index, "pre_easy_pref"] = pre_easy_pref = pre_easy_wheel_dist / (pre_easy_wheel_dist + pre_hard_wheel_dist)
    sessions.loc[s.Index, "post_easy_pref"] = post_easy_pref = post_easy_wheel_dist / (post_easy_wheel_dist + post_hard_wheel_dist)
    sessions.loc[s.Index, "pre_hard_pref"] = post_hard_pref = 1 - pre_easy_pref
    sessions.loc[s.Index, "post_hard_pref"] = post_hard_pref = 1 - post_easy_pref
    sessions.loc[s.Index, "post_pre_easy_pref"] = post_pre_easy_pref = post_easy_pref / pre_easy_pref
    # Find each pstate update prior to each pellet threshold crossing
    p1_post_pel_thresh = pstate1[pstate1.index >= safe_change_ts].threshold[:-1]
    #p1_post_pel_thresh = np.nan if p1_post_pel_thresh.empty else p1_post_pel_thresh
    p2_post_pel_thresh = pstate2[pstate2.index >= safe_change_ts].threshold[:-1]
    #p2_post_pel_thresh = np.nan if p2_post_pel_thresh.empty else p2_post_pel_thresh
    post_easy_pel_thresh = p1_post_pel_thresh if (easy_patch == 1) else p2_post_pel_thresh
    post_hard_pel_thresh = p1_post_pel_thresh if (hard_patch == 1) else p2_post_pel_thresh
    sessions.at[s.Index, "post_easy_pel_thresh"] = post_easy_pel_thresh.values.round(3)
    sessions.at[s.Index, "post_hard_pel_thresh"] = post_hard_pel_thresh.values.round(3)
    sessions.at[s.Index, "post_easy_pel_thresh_idx"] = np.array(post_easy_pel_thresh.index.round("1s"))
    sessions.at[s.Index, "post_hard_pel_thresh_idx"] = np.array(post_hard_pel_thresh.index.round("1s"))
    whard = w1 if (hard_patch == 1) else w2
    weasy = w1 if (easy_patch == 1) else w2
    init_pref_by_pel_ct = np.ones((10,)) * np.nan
    for i, pel_ct in enumerate(range(8,18)):
        cur_pel_ct_ts = both_pellet_data.index[pel_ct]
        if cur_pel_ct_ts > (safe_change_ts + pd.Timedelta("1s")):
            break
        cur_whard_dist = whard[whard.index > cur_pel_ct_ts][0] - whard[0]
        cur_weasy_dist = weasy[weasy.index > cur_pel_ct_ts][0] - weasy[0] 
        init_pref_by_pel_ct[i] = cur_whard_dist / (cur_whard_dist + cur_weasy_dist)
    sessions.at[s.Index, "init_pref_by_pel_ct"] = init_pref_by_pel_ct

In [None]:
display(sessions)

In [None]:
cols_to_round = [
    "pre_easy_wheel_dist",
    "pre_hard_wheel_dist",
    "post_easy_wheel_dist",
    "post_hard_wheel_dist",
    "pre_easy_pref",
    "post_easy_pref",
    "pre_hard_pref",
    "post_hard_pref",
]
for col in cols_to_round:
    sessions[col] = sessions[col].round(3)

In [None]:
sessions.to_pickle(Path(
    "/nfs/nhome/live/jbhagat/ProjectAeon/aeon_analysis/aeon_analysis/presocial/data"
    "/presocial_data.pkl"
    )
)

In [2]:
df = pd.read_pickle(
    Path(
        "/nfs/nhome/live/jbhagat/ProjectAeon/aeon_analysis/aeon_analysis/presocial/data"
        "/presocial_data.pkl"
    )
)

In [None]:
df[df["id"] == "BAA-1103045"]["pre_easy_wheel_dist"].tolist()

In [None]:
np.concatenate(df[df["id"] == uid][col].tolist())

In [4]:
import dash
from dash import Dash, dash_table, dcc, html

bg_col = "#050505"
txt_col = "#f2f2f2"
plt_bg_col = "#0d0d0d"
tab_bg_col = "#003399"
tab_txt_col = "#f2f2f2"
table_max_height = "400px"
table_min_width = "1200px"
mrkr_sz = 14
color_dict = {
    "BAA-1103045": "rgb(31, 119, 180)",
    "BAA-1103047": "rgb(255, 127, 14)",
    "BAA-1103048": "rgb(44, 160, 44)",
    "BAA-1103049": "rgb(214, 39, 40)",
    "BAA-1103050": "rgb(148, 103, 189)",
}
# Set all relevant app.layout children names (for future color theme updates)
fig_names = [
    "weight_enter_session",
    "weight_diff_session",
    "weight_enter_subject",
    "weight_diff_subject",
    "duration_session",
    "post_thresh_dur_session",
    "pre_sampling_both_p_dur_session",
    "duration_subject",
    "post_thresh_dur_subject",
    "pre_sampling_both_p_dur_subject",
    "hard_patch_session",
    "hard_patch_subject",
    "wheel_session_abs",
    "wheel_session_norm",
    "wheel_subject_abs",
    "wheel_subject_norm",
    "pellet_session_abs",
    "pellet_session_norm",
    "pellet_subject_abs",
    "pellet_subject_norm",
    "prob_pels_session",
    "prob_pels_subject",
]
tab_names = []

In [5]:


data_table = dash_table.DataTable(
    id="data_table",
    data=df.to_dict("records"),
    columns=[{"name": i, "id": i} for i in df.columns],
    style_table={
        "overflowX": "auto",
        "overflowY": "auto",
        "maxHeight": table_max_height,
        "minWidth": table_min_width,
    },
    fixed_columns={"headers": True, "data": 2},
    fixed_rows={"headers": True},
    style_header={"fontWeight": "bold", "backgroundColor": plt_bg_col},
    style_cell={
        "backgroundColor": plt_bg_col,
        "color": txt_col,
        "textAlign": "left",
        "whiteSpace": "normal",
        "height": "auto",
        "minWidth": 60,
    },
)

In [11]:
display(data_table)

DataTable(data=[{'id': 'BAA-1103050', 'enter': Timestamp('2023-03-10 09:41:48'), 'exit': Timestamp('2023-03-10 12:55:19'), 'duration': Timedelta('0 days 03:13:30'), 'weight_enter': 23.2, 'weight_exit': 23.9, 'post_thresh_dur': Timedelta('0 days 02:22:53'), 'post_thresh_both_p_sampled_dur': Timedelta('0 days 00:12:53'), 'pre_sampling_both_p_dur': Timedelta('0 days 00:17:40'), 'easy_patch': 2.0, 'hard_patch': 1.0, 'post_easy_rate': 0.01, 'post_hard_rate': 0.0025, 'pre_easy_n_pel': 1.0, 'pre_hard_n_pel': 17.0, 'post_easy_n_pel': 21.0, 'post_hard_n_pel': 18.0, 'pre_easy_wheel_dist': 89.731, 'pre_hard_wheel_dist': 1275.491, 'post_easy_wheel_dist': 3207.803, 'post_hard_wheel_dist': 8580.92, 'pre_easy_pref': 0.066, 'post_easy_pref': 0.272, 'pre_hard_pref': 0.934, 'post_hard_pref': 0.728, 'post_pre_easy_pref': 4.140009063854826, 'post_easy_pel_thresh': array([121.823, 239.646, 214.056, 109.924, 160.179,  84.302,  82.351,
       223.876,  84.995, 249.337, 101.766, 142.513, 247.148, 229.395,
   

In [None]:
import dash

In [None]:
dash

In [None]:
dash.dash

In [None]:
from dash.dependencies import Input, Output, State, ClientsideFunction


In [None]:
import seaborn as sns


In [None]:
clear dash

In [6]:
s = list(sessions.itertuples())[-1]

In [7]:
s

Pandas(Index=12, id='BAA-1103045', enter=Timestamp('2023-03-24 15:19:57'), exit=Timestamp('2023-03-24 18:42:44'), duration=Timedelta('0 days 03:22:45'), weight_enter=22.7, weight_exit=26.2)

In [11]:
    root = str(roots[0]) if np.any([s.id.endswith(sid) for sid in in_b2_210]) else str(roots[1])  # get root for current session
    harp_reader = reader.Harp(pattern="Patch1_35", columns=["TriggerPellet"])
    new_pellet_trig_bitmask = api.load(root, harp_reader, start=s.enter, end=s.exit).iloc[0, 0]
    new_pellet_trig_reader_p1 = reader.BitmaskEvent("Patch1_35", new_pellet_trig_bitmask, "TriggerPellet")
    new_pellet_trig_reader_p2 = reader.BitmaskEvent("Patch2_35", new_pellet_trig_bitmask, "TriggerPellet")
    p1 = api.load(root, new_pellet_trig_reader_p1, start=s.enter, end=s.exit)
    p2 = api.load(root, new_pellet_trig_reader_p2, start=s.enter, end=s.exit)