In [1]:
# Jupyter settings and Imports

%load_ext autoreload
%autoreload 2
#%flow mode reactive

from datetime import date
from pathlib import Path

from dotmap import DotMap
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.graph_objs as go
import seaborn as sns

import aeon.io.api as api
from aeon.io import reader
from aeon.schema.dataset import exp02, exp01
from aeon.analysis.utils import visits, distancetravelled

In [2]:
def display_df(df):
    df.style.set_table_styles([
        {
            "selector": "thead", 
            "props": [("position", "sticky"), ("top", "0")]
        },
        #{
        #    "selector": "tbody tr:nth-child(-n+2)",
        #    "props": [('position', 'sticky'), ("top", "43px")]
        #},
        {
            "selector": "tbody td:nth-child(-n+3)",
            "props": [("position", "sticky"), ("left", "0")]
        }
    ])
    fig = go.Figure(
        data=[go.Table(
            header=dict(values=list(df.columns), align='left'),
            cells=dict(values=[df[i] for i in df.columns], align='left'),
        )
    ])
    fig.layout.template = "plotly_dark"
    return fig

In [3]:
# Get sessions
pd.set_option('display.max_columns', 100)
pd.set_option('display.max_rows', 100)

roots = [Path("/ceph/aeon/aeon/data/raw/AEON3/presocial0.1"), Path("/ceph/aeon/aeon/data/raw/AEON2/presocial0.1")]
if not np.all([path.exists() for path in roots]):
    print("Cannot find root paths. Check path names or connection.")
subject_events = api.load(roots, exp02.ExperimentalMetadata.SubjectState)
sessions = visits(subject_events[subject_events.id.str.startswith("BAA")])

In [4]:
# Prettify sessions

pd.options.mode.chained_assignment = None  # turn off "SettingWithCopy" warning for this cell

sessions = sessions[sessions.enter.dt.date == date(2023, 3, 10)]
sessions.loc[:, ("weight_enter")] = sessions["weight_enter"].astype(float).round(1)
sessions.loc[:, ("weight_exit")] = sessions["weight_exit"].astype(float).round(1)
sessions.loc[:, ("enter")] = sessions["enter"].dt.floor("1s")
sessions.loc[:, ("exit")] = sessions["exit"].dt.ceil("1s")
sessions.loc[:, ("duration")] = sessions["duration"].round("1s")
sessions = sessions[["id", "enter", "exit", "duration", "weight_enter", "weight_exit"]]
sessions = sessions.sort_values(by="enter")
sessions = sessions.reset_index()
sessions = sessions.drop(columns=["index"])
pd.options.mode.chained_assignment = "warn"
display(sessions)

Unnamed: 0,id,enter,exit,duration,weight_enter,weight_exit
0,BAA-1103050,2023-03-10 09:41:48,2023-03-10 12:55:19,0 days 03:13:30,23.2,23.9
1,BAA-1103045,2023-03-10 12:12:45,2023-03-10 15:22:14,0 days 03:09:28,23.0,23.7
2,BAA-1103048,2023-03-10 13:08:24,2023-03-10 16:16:31,0 days 03:08:06,22.5,24.6
3,BAA-1103047,2023-03-10 15:27:05,2023-03-10 19:10:44,0 days 03:43:38,19.8,21.0
4,BAA-1103049,2023-03-10 16:22:29,2023-03-10 19:21:50,0 days 02:59:20,20.9,22.6


In [5]:
# Get bad sessions based on Get 'DispenserBroken' and 'Annotation' messages
message_log_aeon3 = api.load(str(roots[0]), exp02.ExperimentalMetadata.MessageLog)
print(f"Aeon3 messages:\n")
display(message_log_aeon3[np.logical_or(message_log_aeon3.type == "DispenserBroken", message_log_aeon3.type == "Annotation")])
print(f"\n\n")
message_log_aeon2 = api.load(str(roots[0]), exp02.ExperimentalMetadata.MessageLog)
print(f"Aeon2 messages:\n")
display(message_log_aeon2[np.logical_or(message_log_aeon2.type == "DispenserBroken", message_log_aeon2.type == "Annotation")])

Aeon3 messages:



Unnamed: 0_level_0,priority,type,message
time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
2023-02-27 12:37:32.836480141,Alert,DispenserBroken,Patch2
2023-02-27 12:38:31.825503826,Alert,DispenserBroken,Patch2
2023-02-27 12:43:57.997504234,Alert,DispenserBroken,Patch2
2023-02-27 12:51:15.807487965,Alert,DispenserBroken,Patch2
2023-02-27 12:55:49.355487823,Alert,DispenserBroken,Patch1
2023-02-27 13:01:22.381504059,Alert,DispenserBroken,Patch2
2023-02-27 13:21:45.391488075,Alert,DispenserBroken,Patch1
2023-02-27 13:29:47.673503876,Alert,DispenserBroken,Patch2
2023-02-27 13:51:05.439487934,Alert,DispenserBroken,Patch2
2023-02-27 13:51:28.847487926,Alert,DispenserBroken,Patch1





Aeon2 messages:



Unnamed: 0_level_0,priority,type,message
time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
2023-02-27 12:37:32.836480141,Alert,DispenserBroken,Patch2
2023-02-27 12:38:31.825503826,Alert,DispenserBroken,Patch2
2023-02-27 12:43:57.997504234,Alert,DispenserBroken,Patch2
2023-02-27 12:51:15.807487965,Alert,DispenserBroken,Patch2
2023-02-27 12:55:49.355487823,Alert,DispenserBroken,Patch1
2023-02-27 13:01:22.381504059,Alert,DispenserBroken,Patch2
2023-02-27 13:21:45.391488075,Alert,DispenserBroken,Patch1
2023-02-27 13:29:47.673503876,Alert,DispenserBroken,Patch2
2023-02-27 13:51:05.439487934,Alert,DispenserBroken,Patch2
2023-02-27 13:51:28.847487926,Alert,DispenserBroken,Patch1


In [6]:
# Based on above, manually decide which are bad sessions, and drop these from `sessions`

bad_sessions = DotMap()
bad_sessions.ids = ("BAA-1103048",)
bad_sessions.dates = (date(2023, 3, 15),)

for i in range(len(bad_sessions.ids)):
    i_bad_sesh = np.where(np.logical_and(
        sessions.id == bad_sessions.ids[i], sessions.enter.dt.date == bad_sessions.dates[i]))[0]
    sessions.drop(index=sessions.iloc[i_bad_sesh].index, inplace=True)

In [7]:
# Declare some set-up variables to help with analysis

# Specify which animals in which room
in_b2_210 = ("48", "49", "50")
in_465 = ("45", "47")

# Columns to add to table
new_cols = (
    "post_thresh_dur", #"post_thresh_both_p_sampled_dur",
    "pre_sampling_both_p_dur", "easy_patch", "hard_patch", 
    "post_easy_rate", "post_hard_rate", "pre_easy_n_pel", "pre_hard_n_pel", 
    "post_easy_n_pel", "post_hard_n_pel", "pre_easy_wheel_dist", "pre_hard_wheel_dist",
    "post_easy_wheel_dist", "post_hard_wheel_dist", "pre_easy_pref", "post_easy_pref",
    "post_pre_easy_pref", "post_easy_pel_thresh", "post_hard_pel_thresh", "init_pref_by_pel_ct",
)
for col in new_cols:
    sessions[col] = np.nan
sessions["post_easy_pel_thresh"] = sessions["post_easy_pel_thresh"].astype(object)
sessions["post_hard_pel_thresh"] = sessions["post_hard_pel_thresh"].astype(object)
sessions["init_pref_by_pel_ct"] = sessions["init_pref_by_pel_ct"].astype(object)
display(sessions)

Unnamed: 0,id,enter,exit,duration,weight_enter,weight_exit,post_thresh_dur,pre_sampling_both_p_dur,easy_patch,hard_patch,post_easy_rate,post_hard_rate,pre_easy_n_pel,pre_hard_n_pel,post_easy_n_pel,post_hard_n_pel,pre_easy_wheel_dist,pre_hard_wheel_dist,post_easy_wheel_dist,post_hard_wheel_dist,pre_easy_pref,post_easy_pref,post_pre_easy_pref,post_easy_pel_thresh,post_hard_pel_thresh,init_pref_by_pel_ct
0,BAA-1103050,2023-03-10 09:41:48,2023-03-10 12:55:19,0 days 03:13:30,23.2,23.9,,,,,,,,,,,,,,,,,,,,
1,BAA-1103045,2023-03-10 12:12:45,2023-03-10 15:22:14,0 days 03:09:28,23.0,23.7,,,,,,,,,,,,,,,,,,,,
2,BAA-1103048,2023-03-10 13:08:24,2023-03-10 16:16:31,0 days 03:08:06,22.5,24.6,,,,,,,,,,,,,,,,,,,,
3,BAA-1103047,2023-03-10 15:27:05,2023-03-10 19:10:44,0 days 03:43:38,19.8,21.0,,,,,,,,,,,,,,,,,,,,
4,BAA-1103049,2023-03-10 16:22:29,2023-03-10 19:21:50,0 days 02:59:20,20.9,22.6,,,,,,,,,,,,,,,,,,,,


In [8]:
for s in sessions.itertuples():
    root = str(roots[0]) if np.any([s.id.endswith(sid) for sid in in_b2_210]) else str(roots[1])  # get root for current session
    harp_reader = reader.Harp(pattern="Patch1_35", columns=["TriggerPellet"])
    new_pellet_trig_bitmask = api.load(root, harp_reader, start=s.enter, end=s.exit).iloc[0, 0]
    new_pellet_trig_reader_p1 = reader.BitmaskEvent("Patch1_35", new_pellet_trig_bitmask, "TriggerPellet")
    new_pellet_trig_reader_p2 = reader.BitmaskEvent("Patch2_35", new_pellet_trig_bitmask, "TriggerPellet")
    p1 = api.load(root, new_pellet_trig_reader_p1, start=s.enter, end=s.exit)
    p2 = api.load(root, new_pellet_trig_reader_p2, start=s.enter, end=s.exit)
    pstate1 = api.load(root, exp02.Patch1.DepletionState, start=s.enter, end=s.exit)
    pstate2 = api.load(root, exp02.Patch2.DepletionState, start=s.enter, end=s.exit)
    encoder1 = api.load(root, exp02.Patch1.Encoder, start=s.enter, end=s.exit)
    w1 = -distancetravelled(encoder1.angle)
    encoder2 = api.load(root, exp02.Patch2.Encoder, start=s.enter, end=s.exit)
    w2 = -distancetravelled(encoder2.angle)
    # PelletTrig cleaning: remove repeated deliveries (events <1.5 s apart) and manual deliveries (201)
    p1 = p1.drop(p1.index[np.where(np.diff(p1.index).astype("float64") < 1.5e9)[0]])
    p2 = p2.drop(p2.index[np.where(np.diff(p2.index).astype("float64") < 1.5e9)[0]])
    harp_reader = reader.Harp(pattern="Patch1_201", columns=["ExperimenterDeliveries"])
    user_p1 = api.load(root, harp_reader, start=s.enter, end=s.exit)
    harp_reader = reader.Harp(pattern="Patch2_201", columns=["ExperimenterDeliveries"])
    user_p2 = api.load(root, harp_reader, start=s.enter, end=s.exit)
    if not user_p1.empty:
        user_p1_idxs = np.abs(np.subtract.outer(user_p1.index, p1.index)).argmin(axis=1)
        p1.drop(p1.index[user_p1_idxs])
    if not user_p2.empty:
        user_p2_idxs = np.abs(np.subtract.outer(user_p2.index, p2.index)).argmin(axis=1)
        p2.drop(p2.index[user_p2_idxs])
    both_pellet_data = pd.concat([p1, p2]).sort_index()
    # PatchState cleaning: remove NaNs; remove updates <1.5s apart (bug updates)
    pstate1.dropna(inplace=True)
    good_indxs = np.concatenate((np.diff(pstate1.index).astype("float64") > 1.5e9, [True]))
    pstate1 = pstate1[good_indxs]
    pstate2.dropna(inplace=True)
    good_indxs = np.concatenate((np.diff(pstate2.index).astype("float64") > 1.5e9, [True]))
    pstate2 = pstate2[good_indxs]
    # Clean known issues in particular sessions
    if s.enter == pd.Timestamp('2023-03-10 13:08:24'):
        p2 = p2.drop(p2.index[-1])
    # Check lengths of PelletTrigger and PatchState events
    if ((len(pstate1) - len(p1)) not in (1, 2)) or ((len(pstate2) - len(p2)) not in (1, 2)):
        raise Exception(
            f"PelletTrigger-PatchState mismatch: \n"
            f"len(p1) = {len(p1)} \n"
            f"len(p2) = {len(p2)} \n"
            f"len(pstate1) = {len(pstate1)} \n"
            f"len(pstate2) = {len(pstate2)} \n"
        )
    both_state_data = pd.concat([pstate1, pstate2]).sort_index()
    # Find threshold-change ts
    safe_change_ts = change_ts = both_state_data[both_state_data.threshold > 76].index[0]
    sessions.loc[s.Index, "post_thresh_dur"] = post_thresh_dur = (s.exit - change_ts).round("1s")
    # if (len(p2[p2.index > change_ts]) > 0) and (len(p1[p1.index > change_ts]) > 0):
    #     safe_change_ts = pd.Series((p1[p1.index > change_ts].index[0], p2[p2.index > change_ts].index[0])).max()
    #     sessions.loc[s.Index, "post_thresh_both_p_sampled_dur"] = post_thresh_both_p_sampled_dur = (s.exit - safe_change_ts).round("1s")
    # else:
    #     safe_change_ts = change_ts
    # Find both-patches-sampled ts
    both_patches_sampled_ts = pd.Series((p1.index[0], p2.index[0])).max()
    sessions.loc[s.Index, "pre_sampling_both_p_dur"] = pre_sampling_b_patches_dur = (both_patches_sampled_ts - s.enter).round("1s")
    sessions.loc[s.Index, "hard_patch"] = hard_patch = 1 if (pstate1["delta"][-1] < pstate2["delta"][-1]) else 2
    sessions.loc[s.Index, "easy_patch"] = easy_patch = 1 if (hard_patch == 2) else 2
    sessions.loc[s.Index, "post_hard_rate"] = post_hard_rate = pstate1["delta"][-1] if (hard_patch == 1) else pstate2["delta"][-1]
    sessions.loc[s.Index, "post_easy_rate"] = post_easy_rate = pstate1["delta"][-1] if (hard_patch == 2) else pstate2["delta"][-1]
    whard = w1 if (hard_patch == 1) else w2
    weasy = w1 if (easy_patch == 1) else w2
    p1_pre_n_pel = len(p1[p1.index <= (safe_change_ts + pd.Timedelta("1s"))])  # ensure we don't count last pellet in pre as first pellet in post
    p1_post_n_pel = len(p1[p1.index > (safe_change_ts + pd.Timedelta("1s"))])
    p2_pre_n_pel = len(p2[p2.index <= (safe_change_ts + pd.Timedelta("1s"))])
    p2_post_n_pel = len(p2[p2.index > (safe_change_ts + pd.Timedelta("1s"))])
    sessions.loc[s.Index, "pre_easy_n_pel"] = pre_easy_n_pel = p1_pre_n_pel if (easy_patch == 1) else p2_pre_n_pel
    sessions.loc[s.Index, "pre_hard_n_pel"] = pre_hard_n_pel = p1_pre_n_pel if (hard_patch == 1) else p2_pre_n_pel
    p1_pre_wheel_dist = w1[w1.index > safe_change_ts][0] - w1[0]
    p2_pre_wheel_dist = w2[w2.index > safe_change_ts][0] - w2[0]
    sessions.loc[s.Index, "pre_easy_wheel_dist"] = pre_easy_wheel_dist = p1_pre_wheel_dist if (easy_patch == 1) else p2_pre_wheel_dist
    sessions.loc[s.Index, "pre_hard_wheel_dist"] = pre_hard_wheel_dist = p1_pre_wheel_dist if (hard_patch == 1) else p2_pre_wheel_dist
    sessions.loc[s.Index, "post_easy_n_pel"] = post_easy_n_pel = p1_post_n_pel if (easy_patch == 1) else p2_post_n_pel
    sessions.loc[s.Index, "post_hard_n_pel"] = post_hard_n_pel = p1_post_n_pel if (hard_patch == 1) else p2_post_n_pel
    p1_post_wheel_dist = w1[-1] - p1_pre_wheel_dist
    p2_post_wheel_dist = w2[-1] - p2_pre_wheel_dist
    sessions.loc[s.Index, "post_easy_wheel_dist"] = post_easy_wheel_dist = p1_post_wheel_dist if (easy_patch == 1) else p2_post_wheel_dist
    sessions.loc[s.Index, "post_hard_wheel_dist"] = post_hard_wheel_dist = p1_post_wheel_dist if (hard_patch == 1) else p2_post_wheel_dist
    sessions.loc[s.Index, "pre_easy_pref"] = pre_easy_pref = pre_easy_wheel_dist / (pre_easy_wheel_dist + pre_hard_wheel_dist)
    sessions.loc[s.Index, "post_easy_pref"] = post_easy_pref = post_easy_wheel_dist / (post_easy_wheel_dist + post_hard_wheel_dist)
    sessions.loc[s.Index, "post_hard_pref"] = post_hard_pref = 1 - post_easy_pref
    sessions.loc[s.Index, "post_pre_easy_pref"] = post_pre_easy_pref = post_easy_pref / pre_easy_pref
    # Find each pstate update prior to each pellet threshold crossing
    p1_post_pel_thresh = pstate1[pstate1.index >= safe_change_ts].threshold[:-1]
    #p1_post_pel_thresh = np.nan if p1_post_pel_thresh.empty else p1_post_pel_thresh
    p2_post_pel_thresh = pstate2[pstate2.index >= safe_change_ts].threshold[:-1]
    #p2_post_pel_thresh = np.nan if p2_post_pel_thresh.empty else p2_post_pel_thresh
    post_easy_pel_thresh = p1_post_pel_thresh.values if (easy_patch == 1) else p2_post_pel_thresh.values
    post_hard_pel_thresh = p1_post_pel_thresh.values if (hard_patch == 1) else p2_post_pel_thresh.values
    sessions.at[s.Index, "post_easy_pel_thresh"] = post_easy_pel_thresh 
    sessions.at[s.Index, "post_hard_pel_thresh"] = post_hard_pel_thresh
    whard = w1 if (hard_patch == 1) else w2
    weasy = w1 if (easy_patch == 1) else w2
    init_pref_by_pel_ct = np.ones((10,)) * np.nan
    for i, pel_ct in enumerate(range(8,18)):
        cur_pel_ct_ts = both_pellet_data.index[pel_ct]
        if cur_pel_ct_ts > (safe_change_ts + pd.Timedelta("1s")):
            break
        cur_whard_dist = whard[whard.index > cur_pel_ct_ts][0] - whard[0]
        cur_weasy_dist = weasy[weasy.index > cur_pel_ct_ts][0] - weasy[0] 
        init_pref_by_pel_ct[i] = cur_whard_dist / (cur_whard_dist + cur_weasy_dist)
    sessions.at[s.Index, "init_pref_by_pel_ct"] = init_pref_by_pel_ct

In [9]:
display(sessions)

Unnamed: 0,id,enter,exit,duration,weight_enter,weight_exit,post_thresh_dur,pre_sampling_both_p_dur,easy_patch,hard_patch,post_easy_rate,post_hard_rate,pre_easy_n_pel,pre_hard_n_pel,post_easy_n_pel,post_hard_n_pel,pre_easy_wheel_dist,pre_hard_wheel_dist,post_easy_wheel_dist,post_hard_wheel_dist,pre_easy_pref,post_easy_pref,post_pre_easy_pref,post_easy_pel_thresh,post_hard_pel_thresh,init_pref_by_pel_ct,post_hard_pref
0,BAA-1103050,2023-03-10 09:41:48,2023-03-10 12:55:19,0 days 03:13:30,23.2,23.9,0 days 02:22:53,0 days 00:17:40,2.0,1.0,0.01,0.0025,1.0,17.0,21.0,18.0,89.731081,1275.490836,3207.803306,8580.92044,0.065726,0.272108,4.140009,"[121.823136965621, 239.6461061848392, 214.0557...","[787.955404447864, 453.6910572464389, 131.5334...","[0.8815044790108325, 0.8833988643958598, 0.893...",0.727892
1,BAA-1103045,2023-03-10 12:12:45,2023-03-10 15:22:14,0 days 03:09:28,23.0,23.7,0 days 02:15:22,0 days 00:07:38,2.0,1.0,0.01,0.0025,9.0,9.0,24.0,14.0,637.378774,671.407613,3490.806286,7168.452099,0.487,0.327491,0.672465,"[129.1453623352006, 79.02768659469507, 191.036...","[1549.340270100893, 194.008876654322, 117.7250...","[0.40413871418162683, 0.3611177037993481, 0.32...",0.672509
2,BAA-1103048,2023-03-10 13:08:24,2023-03-10 16:16:31,0 days 03:08:06,22.5,24.6,0 days 01:28:08,0 days 01:39:59,2.0,1.0,0.01,0.0025,1.0,68.0,1.0,33.0,89.789376,5115.579789,271.89168,13906.200532,0.017249,0.019177,1.111744,[193.26450565881615],"[522.3701548297602, 274.4432472474792, 115.277...","[0.9945817950119978, 0.9947954694754996, 0.995...",0.980823
3,BAA-1103047,2023-03-10 15:27:05,2023-03-10 19:10:44,0 days 03:43:38,19.8,21.0,0 days 03:17:44,0 days 00:06:35,1.0,2.0,0.01,0.0025,7.0,11.0,0.0,51.0,526.377751,825.28755,20.805117,24885.276616,0.389429,0.000835,0.002145,[],"[137.33419585742638, 514.5993127234236, 162.64...","[0.22612623407040333, 0.2995549748907851, 0.36...",0.999165
4,BAA-1103049,2023-03-10 16:22:29,2023-03-10 19:21:50,0 days 02:59:20,20.9,22.6,0 days 02:34:42,0 days 00:03:20,1.0,2.0,0.01,0.0025,8.0,10.0,67.0,1.0,600.211219,756.873967,11424.822884,2429.829679,0.44228,0.82462,1.864476,"[93.1972822496444, 114.71862034176128, 123.426...",[1578.32854895351],"[0.6702724079895558, 0.6987034819223695, 0.722...",0.17538


In [10]:
sessions.to_pickle(Path("/nfs/nhome/live/jbhagat/ProjectAeon/aeon_analysis/aeon_analysis/presocial/data/presocial_data.pkl"))