In [1]:
%load_ext autoreload
%autoreload 2
# %flow mode reactive

from importlib import reload
from pathlib import Path

from dotmap import DotMap
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objs as go
import seaborn as sns

import aeon
from aeon.analysis.utils import visits, distancetravelled
from aeon.io import reader
from aeon.io.device import Device, register
from aeon.schema import core, foraging, social
from aeon.schema.schemas import exp02, social01, social02
import datajoint as dj

In [2]:
# Prettify pandas output display.

pd.set_option("display.max_columns", 20)
pd.set_option("display.max_rows", 50)

## Create VirtualModule to access `aeon_test_analysis` schema
Currently, the analysis is on `aeon_test_`, will move to `aeon_` soon (once ready for production)

In [None]:
analysis_vm = dj.create_virtual_module('aeon_test_analysis', 'aeon_test_analysis')

## Browse Block and BlockAnalysis

In [None]:
# View ERD around analysis

dj.Diagram(analysis_vm) - 1

In [None]:
# View Block table

analysis_vm.Block()

In [None]:
# Fetch BlockAnalysis table

blocks = analysis_vm.BlockAnalysis().fetch(format="frame")
blocks

In [None]:
# Spec block(s) of interest

# block_key = {
#     "experiment_name": "social0.1-aeon3",
#     "block_start": "2023-12-05 11:06:01.001984",
# } 
block_start = "block_start LIKE '2023-12-03%'"
experiment_name = "experiment_name LIKE 'social0.1-aeon3'"

In [None]:
# Given above keys, restrict BlockAnalysis table, and select info of first of returned blocks

block_df = (analysis_vm.BlockAnalysis & block_start & experiment_name).fetch(format="frame")
display(block_df)
block_key = {
    "experiment_name": block_df.index[0][0],
    "block_start": block_df.index[0][1]
}
print(block_key)

In [None]:
# Return BlockAnalysis' part tables

analysis_vm.BlockAnalysis.parts()

In [None]:
# For each part table, display and fetch info for the block of interest

display(analysis_vm.BlockAnalysis.Patch & block_key)
display(analysis_vm.BlockAnalysis.Subject & block_key)
block_patch_data = (analysis_vm.BlockAnalysis.Patch & block_key).fetch(as_dict=True)
block_subject_data = (analysis_vm.BlockAnalysis.Subject & block_key).fetch(as_dict=True)


## Corral some data

In [18]:
"""Create a blocks df."""

exp_start = pd.Timestamp("2024-01-31 00:00:00")
exp_end = pd.Timestamp("2024-02-17 00:00:00")
# roots = [
#     Path("/ceph/aeon/aeon/data/raw/AEON3/social0.2"),
#     Path("/ceph/aeon/aeon/data/raw/AEON4/social0.2"),
# ]
roots = [
    Path(r"Z:\aeon\data\raw\AEON3\social0.2"),
    Path(r"Z:\aeon\data\raw\AEON4\social0.2"),
]
social02.CameraTop.Pose._model_root = Path(r"Z:\aeon\data\processed")
arenas = ["AEON3", "AEON4"]
patches = ["Patch1", "Patch2", "Patch3"]
patch_locs = pd.DataFrame(index=arenas, columns=patches)
blocks_df = pd.DataFrame()
block_ts_tol = pd.Timedelta("2s")  # Tolerance for block start and end times
good_block_pel_ct = 4  # Min pellets for good block


for root in roots:
    # Pull out info from metadata
    for arena in arenas:
        if arena in str(root):
            break
    metadata = aeon.load(root, social02.Metadata, exp_start, exp_end).iloc[0].metadata
    patch_locs.loc[arena, patches] = (
        (metadata.Devices.Patch1Rfid.Location.X, metadata.Devices.Patch1Rfid.Location.Y),
        (metadata.Devices.Patch2Rfid.Location.X, metadata.Devices.Patch2Rfid.Location.Y),
        (metadata.Devices.Patch3Rfid.Location.X, metadata.Devices.Patch3Rfid.Location.Y)
    )
    block_info = aeon.load(root, social02.Environment.BlockState, exp_start, exp_end)

    # Block end if pellet_ct == 0 and preceding pellet_ct > 0
    # OR pellet_ct == 0, preceding pellet_ct == 0, and preceding pellet_ct time diff > 1s
    possible_block_end_indxs = np.where(block_info.pellet_ct == 0)[0]
    drop_indxs = []

    for i in possible_block_end_indxs:
        if block_info.pellet_ct[i - 1] > 0:
            continue
        elif block_info.pellet_ct[i - 1] == 0 and block_info.index[i] - block_info.index[i - 1] > pd.Timedelta("1s"):
            continue
        else:  # drop i from `possible_block_ends`
            drop_indxs.append(i)

    block_end_indxs = np.setdiff1d(possible_block_end_indxs, drop_indxs)
    # Start from first complete block to last complete block
    block_start_indxs = block_end_indxs[0:-1] 
    block_end_indxs = block_end_indxs[1:]
    block_start_times = block_info.index[block_start_indxs] - block_ts_tol
    block_end_times = block_info.index[block_end_indxs] + block_ts_tol
    sleap_model_dir = (
        Path(r"Z:/aeon/data/processed/test-node1/4310907/2024-01-12T19-00-00/topdown-multianimal-id-133")
        if "AEON3" in str(root) else 
        Path(r"Z:/aeon/data/processed/test-node1/4350621/2024-01-22T19-00-00/topdown-multianimal-id-133")
    )

    # Create a `blocks` df with columns 'start', 'end', and 'root'
    blocks_df = pd.concat(
        [
         blocks_df,
         pd.DataFrame(
             {   
                 "root": [root] * len(block_start_times),
                 "sleap_model_dir": [sleap_model_dir] * len(block_start_times),
                 "start": block_start_times, 
                 "end": block_end_times, 
             }
        )
       ], ignore_index=True
    )

# Add columns to `blocks_df`
new_cols = [
    "block_duration",
    "subjects",  # list of subjects in block
    "patch_info",  # df: index: patch; cols: rate, offset
    "pellet_info",  # df: index: del ts; cols: patch, thresh, id {for each pel del, get last thresh}
    "cum_wheel_dist",  # DotMap: patch: df
]
for col in new_cols:
    blocks_df[col] = None

blocks_df = blocks_df.sort_values(by="start").reset_index(drop=True)
display(blocks_df)

  if block_info.pellet_ct[i - 1] > 0:
  elif block_info.pellet_ct[i - 1] == 0 and block_info.index[i] - block_info.index[i - 1] > pd.Timedelta("1s"):
  data = pd.concat([reader.read(file) for _, file in files])
  data = pd.concat([reader.read(file) for _, file in files])
  if block_info.pellet_ct[i - 1] > 0:
  elif block_info.pellet_ct[i - 1] == 0 and block_info.index[i] - block_info.index[i - 1] > pd.Timedelta("1s"):


Unnamed: 0,root,sleap_model_dir,start,end,block_duration,subjects,patch_info,pellet_info,cum_wheel_dist
0,Z:\aeon\data\raw\AEON4\social0.2,Z:\aeon\data\processed\test-node1\4350621\2024...,2024-01-31 12:59:06.005983829,2024-01-31 14:58:11.045983791,,,,,
1,Z:\aeon\data\raw\AEON3\social0.2,Z:\aeon\data\processed\test-node1\4310907\2024...,2024-01-31 12:59:14.001984119,2024-01-31 14:45:59.000000000,,,,,
2,Z:\aeon\data\raw\AEON3\social0.2,Z:\aeon\data\processed\test-node1\4310907\2024...,2024-01-31 14:45:55.000000000,2024-01-31 16:18:11.001984119,,,,,
3,Z:\aeon\data\raw\AEON4\social0.2,Z:\aeon\data\processed\test-node1\4350621\2024...,2024-01-31 14:58:07.045983791,2024-01-31 17:49:26.000000000,,,,,
4,Z:\aeon\data\raw\AEON3\social0.2,Z:\aeon\data\processed\test-node1\4310907\2024...,2024-01-31 16:18:07.001984119,2024-01-31 17:56:23.000000000,,,,,
...,...,...,...,...,...,...,...,...,...
269,Z:\aeon\data\raw\AEON4\social0.2,Z:\aeon\data\processed\test-node1\4350621\2024...,2024-02-13 11:46:05.000000000,2024-02-13 13:02:51.049983978,,,,,
270,Z:\aeon\data\raw\AEON3\social0.2,Z:\aeon\data\processed\test-node1\4310907\2024...,2024-02-13 12:40:27.000000000,2024-02-13 15:17:20.001984119,,,,,
271,Z:\aeon\data\raw\AEON4\social0.2,Z:\aeon\data\processed\test-node1\4350621\2024...,2024-02-13 13:02:47.049983978,2024-02-13 15:32:22.001984119,,,,,
272,Z:\aeon\data\raw\AEON3\social0.2,Z:\aeon\data\processed\test-node1\4310907\2024...,2024-02-13 15:17:16.001984119,2024-02-13 18:03:15.001984119,,,,,


In [12]:
"""Get subject env visits."""

subject_env_visits = {}

for root, arena in zip(roots, arenas):
    subject_visits = aeon.load(root, social02.Environment.SubjectVisits, exp_start, exp_end)
    # Find all rows where:
    #  - 'id' column starts with "*AA"
    #  - 'type' column is either "Enter" or "Exit" or "Remain",
    #  - 'region' column is "Environment"
    subject_visits = subject_visits[
        (subject_visits.id.str.contains("^.*AA"))
        & (subject_visits.type.isin(["Enter", "Exit", "Remain"]))
        & (subject_visits.region == "Environment")
    ]
    subject_env_visits[arena] = subject_visits

  data = pd.concat([reader.read(file) for _, file in files])


In [13]:
subject_env_visits

{'AEON3':                                         id    type       region
 time                                                           
 2024-01-31 11:28:45.543519974  BAA-1104045  Remain  Environment
 2024-02-01 22:36:53.196512222  BAA-1104045  Remain  Environment
 2024-02-02 00:15:06.000000000  BAA-1104045  Remain  Environment
 2024-02-03 16:28:29.139999866  BAA-1104045    Exit  Environment
 2024-02-05 15:43:11.581535816  BAA-1104047  Remain  Environment
 2024-02-08 14:49:41.552000046  BAA-1104047    Exit  Environment
 2024-02-09 16:25:49.935999870  BAA-1104045   Enter  Environment
 2024-02-09 16:26:07.579999924  BAA-1104047   Enter  Environment,
 'AEON4':                                         id    type       region
 time                                                           
 2024-01-31 10:22:40.191999912  BAA-1104048   Enter  Environment
 2024-02-01 20:46:53.905536175  BAA-1104048  Remain  Environment
 2024-02-01 23:34:58.098527908  BAA-1104048  Remain  Environment
 2024-

In [None]:
"""Fill out blocks df."""

for i, block in enumerate(blocks_df.itertuples()):
    # Compute block duration
    blocks_df.at[block.Index, "block_duration"] = block.end - block.start
    # <s Get subjects within the block:
    # Get all unique subjects that visited the environment over the entire exp;
    # For each subject, see 'type' of visit most recent to start of block;
    # If "Exit", this animal was not in the block.
    for arena in arenas:
        if arena in str(block.root):
            cur_env_visits = subject_env_visits[arena]
            break
    possible_subjects = cur_env_visits.id.unique().tolist()
    subjects = possible_subjects.copy()
    for subject in possible_subjects:
        subj_visit_in_time = np.logical_and(
            cur_env_visits.id == subject, cur_env_visits.index < block.start
        )
        if not np.any(subj_visit_in_time):  # if no subject visits prior to block start, drop it
            subjects.remove(subject)
        else:  # if visits, get most recent visit type before block; if "Exit", drop it
            pre_block_visit = cur_env_visits[subj_visit_in_time].iloc[-1]
            # last_visit = cur_env_visits[cur_env_visits.id == subject].iloc[-1]
            if pre_block_visit.type == "Exit":
                subjects.remove(subject)
    blocks_df.at[block.Index, "subjects"] = subjects
    # /s>
    # <s See if we should continue with analyzing this block
    cum_pel_ct = 0
    for patch in patches:
        r = eval(f"social02.{patch}.DepletionState")
        patch_df = aeon.load(block.root, r, block.start, block.end)
        cum_pel_ct += sum(np.diff(patch_df.index) > pd.Timedelta("1s"))
    if cum_pel_ct < good_block_pel_ct:
        continue
    # /s>
    # <s Get pose-tracking info in order to do subject-specific assignments
    pose_df = aeon.load(block.root, social02.CameraTop.Pose, block.start, block.end)
    pose_df = reader.Pose.class_int2str(pose_df, block.sleap_model_dir)
    if len(subjects) == 1:  # fix mistaken sleap assignments for single-subject blocks
        pose_df["class"] = subjects[0]
    # /s>
    # <s Get per patch data (fill in `patch_info`, `cum_wheel_dist`, `pellet_info` cols of `blocks_df`)
    patch_stats_df = pd.DataFrame(index=patches, columns=["mean", "offset"])  # -> patch_info
    cum_wheel_dist_dm = DotMap()  # -> cum_wheel_dist
    pellets_stats_df = pd.DataFrame(columns=["time", "patch", "threshold", "id"])  # -> pellet_info
    for i, patch in enumerate(patches):
        # <ss Get wheel data
        r = eval(f"social02.{patch}.Encoder")
        wheel_df = aeon.load(block.root, r, block.start, block.end)[::50]
        cum_wheel_dist = -distancetravelled(wheel_df.angle)
        # /ss>
        # <ss Get pellets data
        r = eval(f"social02.{patch}.DepletionState")
        patch_df = aeon.load(block.root, r, block.start, block.end)
        rate, offset = patch_df[["rate", "offset"]].iloc[0]
        patch_stats_df.loc[patch, ["mean", "offset"]] = (1 / rate, offset)
        patch_df_good_indxs = np.concatenate((np.diff(patch_df.index) > pd.Timedelta("1s"), (True,)))
        patch_df_for_pellets_df = patch_df[patch_df_good_indxs].reset_index()[["time", "threshold"]]
        patch_df_for_pellets_df["patch"] = patch
        patch_df_for_pellets_df["id"] = None
        patch_df_for_pellets_df.dropna(subset=["threshold"], inplace=True)
        # drop 1st val as is from block start
        patch_df_for_pellets_df = patch_df_for_pellets_df.iloc[1:].reset_index(drop=True)
        # /ss>
        # <ss Assign data to subjects
        if len(subjects) == 1:
            cum_wheel_dist_dm[patch] = cum_wheel_dist.to_frame(name=subjects[0])
            patch_df_for_pellets_df["id"] = subjects[0]
        else:
            # <sss Assign id based on which subject was closest to patch at time of event
            # <ssss Get distance-to-patch at each pose data timestep
            patch_xy = np.array(patch_locs[patch][arena]).astype(np.uint32)
            subjects_xy = pose_df[pose_df["part"] == "centroid"][["x", "y"]].values
            dist_to_patch = np.sqrt(np.sum((subjects_xy - patch_xy) ** 2, axis=1))
            dist_to_patch_df = pose_df[["class"]].copy()
            dist_to_patch_df["dist_to_patch"] = dist_to_patch
            # /ssss>
            # <ssss Get distance-to-patch at each wheel ts and pel del ts, organized by subject
            dist_to_patch_wheel_ts_id_df = pd.DataFrame(index=cum_wheel_dist.index, columns=subjects)
            dist_to_patch_pel_ts_id_df = pd.DataFrame(
                index=patch_df_for_pellets_df["time"], columns=subjects
            )
            for subject in subjects:
                # Find closest match between pose_df indices and wheel indices
                dist_to_patch_wheel_ts_subj = pd.merge_asof(
                    left=dist_to_patch_wheel_ts_id_df,
                    right=dist_to_patch_df,
                    left_index=True,
                    right_index=True,
                    direction="forward",
                    tolerance=pd.Timedelta("200ms"),
                )
                dist_to_patch_wheel_ts_id_df[subject] = dist_to_patch_pel_ts_subj["dist_to_patch"]
                # Find closest match between pose_df indices and pel indices
                dist_to_patch_pel_ts_subj = pd.merge_asof(
                    left=dist_to_patch_pel_ts_id_df,
                    right=dist_to_patch_df,
                    left_index=True,
                    right_index=True,
                    direction="forward",
                    tolerance=pd.Timedelta("200ms"),
                )
                dist_to_patch_pel_ts_id_df[subject] = dist_to_patch_pel_ts_subj["dist_to_patch"]
            # /ssss>
            # <ssss Get closest subject to patch at each pel del timestep
            patch_df_for_pellets_df["id"] = dist_to_patch_pel_ts_id_df.idxmin(axis=1).values
            # /ssss>
            # <ssss Get closest subject to patch at each wheel timestep
            cum_wheel_dist_subj_df = pd.DataFrame(index=cum_wheel_dist.index, columns=subjects, data=0.)
            closest_subjects = dist_to_patch_wheel_ts_id_df.idxmin(axis=1)
            wheel_dist = cum_wheel_dist.diff().fillna(cum_wheel_dist.iloc[0])
            # Assign wheel dist to closest subject for each wheel timestep
            for subject in subjects:
                subj_idxs = cum_wheel_dist_subj_df[closest_subjects == subject].index
                cum_wheel_dist_subj_df.loc[subj_idxs, subject] = wheel_dist[closest_subjects == subject]
            cum_wheel_dist_dm[patch] = cum_wheel_dist_subj_df.cumsum(axis=0)
            # /ssss> #/sss> #/ss>
        pellets_stats_df = pd.concat([pellets_stats_df, patch_df_for_pellets_df], ignore_index=True)
        # /s>
    blocks_df.at[block.Index, "patch_info"] = patch_stats_df
    blocks_df.at[block.Index, "pellet_info"] = pellets_stats_df
    blocks_df.at[block.Index, "cum_wheel_dist"] = cum_wheel_dist_dm

display(blocks_df)

In [63]:
block = list(blocks_df.itertuples())[4]

In [82]:
    blocks_df.at[block.Index, "block_duration"] = block.end - block.start
    # /s>
    # <s Get subjects within the block:
    # Get all unique subjects that visited the environment over the entire exp;
    # For each subject, see 'type' of visit most recent to start of block;
    # If "Exit", this animal was not in the block.
    for arena in arenas:
        if arena in str(block.root):
            cur_env_visits = subject_env_visits[arena]
            break
    possible_subjects = cur_env_visits.id.unique().tolist()
    subjects = possible_subjects.copy()
    for subject in possible_subjects:
        subj_visit_in_time = np.logical_and(
            cur_env_visits.id == subject, cur_env_visits.index < block.start
        )
        if not np.any(subj_visit_in_time):  # if no subject visits prior to block start, drop it
            subjects.remove(subject)
        else:  # if visits, get most recent visit type before block; if "Exit", drop it
            pre_block_visit = cur_env_visits[subj_visit_in_time].iloc[-1]
            # last_visit = cur_env_visits[cur_env_visits.id == subject].iloc[-1]
            if pre_block_visit.type == "Exit":
                subjects.remove(subject)
    blocks_df.at[block.Index, "subjects"] = subjects
    # /s>
    # <s Get pose-tracking info
    pose_df = aeon.load(block.root, social02.CameraTop.Pose, block.start, block.end)
    pose_df = reader.Pose.class_int2str(pose_df, block.sleap_model_dir)
    pose_df_subjects = pose_df["class"].unique()
    # Fix mistaken sleap assignments for single-subject blocks
    if len(subjects) == 1:
        pose_df["class"] = subjects[0]
    # /s>
    # <s Get per patch data
    cum_wheel_dist_dm = DotMap()
    patch_stats_df = pd.DataFrame(index=patches, columns=["mean", "offset"])
    pellets_stats_df = pd.DataFrame(columns=["time", "patch", "threshold", "id"])
    for i, patch in enumerate(patches):
        # <ss Get wheel data
        r = eval(f"social02.{patch}.Encoder")
        wheel_df = aeon.load(block.root, r, block.start, block.end)[::50]
        cum_wheel_dist = -distancetravelled(wheel_df.angle)
        # /ss>
        # <ss Get pellets data
        r = eval(f"social02.{patch}.DepletionState")
        patch_df = aeon.load(block.root, r, block.start, block.end)
        rate, offset = patch_df[["rate", "offset"]].iloc[0]
        patch_stats_df.loc[patch, ["mean", "offset"]] = (1 / rate, offset)
        patch_df_good_indxs = np.concatenate((np.diff(patch_df.index) > pd.Timedelta("1s"), (True,)))
        patch_df_for_pellets_df = patch_df[patch_df_good_indxs].reset_index()[["time", "threshold"]]
        patch_df_for_pellets_df["patch"] = patch
        patch_df_for_pellets_df["id"] = None
        patch_df_for_pellets_df.dropna(subset=["threshold"], inplace=True)
        # drop 1st val as is from block start
        patch_df_for_pellets_df = patch_df_for_pellets_df.iloc[1:].reset_index(drop=True)
        # /ss>
        # <ss Assign data to subjects
        if len(subjects) == 1:  
            cum_wheel_dist_dm[patch] = cum_wheel_dist.to_frame(name=subjects[0])
            patch_df_for_pellets_df["id"] = subjects[0]
        else:
            # <ss Assign id based on which subject was closest to patch at time of delivery
            # <sss Get distance-to-patch at each pose data timestep
            patch_xy = np.array(patch_locs[patch][arena]).astype(np.uint32)
            subjects_xy = pose_df[pose_df["part"] == "centroid"][["x", "y"]].values
            dist_to_patch = np.sqrt(np.sum((subjects_xy - patch_xy) ** 2, axis=1))
            dist_to_patch_df = pose_df[["class"]].copy()
            dist_to_patch_df["dist_to_patch"] = dist_to_patch
            # /sss>
            # <sss Get distance-to-patch at each pel del ts and wheel ts, organized by subject
            dist_to_patch_wheel_ts_id_df = pd.DataFrame(index=cum_wheel_dist.index, columns=subjects)
            dist_to_patch_pel_ts_id_df = pd.DataFrame(index=patch_df_for_pellets_df["time"], columns=subjects)
            for subject in subjects:
                # Find closest match between pose_df indices and wheel / pel subj data indices
                dist_to_patch_wheel_ts_subj = pd.merge_asof(
                    left=dist_to_patch_wheel_ts_id_df, 
                    right=dist_to_patch_df,
                    left_index=True,
                    right_index=True, 
                    direction="forward", 
                    tolerance=pd.Timedelta("200ms")
                )
                dist_to_patch_wheel_ts_id_df[subject] = dist_to_patch_pel_ts_subj["dist_to_patch"]
                dist_to_patch_pel_ts_subj = pd.merge_asof(
                    left=dist_to_patch_pel_ts_id_df, 
                    right=dist_to_patch_df,
                    left_index=True,
                    right_index=True, 
                    direction="forward", 
                    tolerance=pd.Timedelta("200ms")
                )
                dist_to_patch_pel_ts_id_df[subject] = dist_to_patch_pel_ts_subj["dist_to_patch"]
            # /sss>
            # Get closest subject to patch at each pel del timestep
            patch_df_for_pellets_df["id"] = dist_to_patch_pel_ts_id_df.idxmin(axis=1).values
        pellets_stats_df = pd.concat([pellets_stats_df, patch_df_for_pellets_df], ignore_index=True)
        # /ss> /s>
    blocks_df.at[block.Index, "patch_info"] = patch_stats_df
    blocks_df.at[block.Index, "pellet_info"] = pellets_stats_df
    blocks_df.at[block.Index, "cum_wheel_dist"] = cum_wheel_dist_dm


  data.loc[data["class"] == i, "class"] = subj
  distance = distance - distance[0]
  pellets_stats_df = pd.concat([pellets_stats_df, patch_df_for_pellets_df], ignore_index=True)


KeyboardInterrupt: 

In [93]:
start, end = pd.Timestamp("2024-02-14 08:00:00"), pd.Timestamp("2024-02-14 09:00:00")
pose_df = aeon.load(block.root, social02.CameraTop.Pose, start, end)
pose_df = reader.Pose.class_int2str(pose_df, block.sleap_model_dir)
pose_df

  data.loc[data["class"] == i, "class"] = subj


Unnamed: 0_level_0,class,class_likelihood,part,x,y,part_likelihood
time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
2024-02-14 08:00:00.019999981,BAA-1104045,0.987112,centroid,1259.260010,536.543274,0.538822
2024-02-14 08:00:00.019999981,BAA-1104047,0.587085,centroid,1269.920776,555.007690,0.538822
2024-02-14 08:00:00.079999924,BAA-1104045,0.997696,centroid,1256.660889,536.355957,0.524928
2024-02-14 08:00:00.079999924,BAA-1104045,0.525753,centroid,1272.174805,555.194458,0.524928
2024-02-14 08:00:00.139999866,BAA-1104045,0.993752,centroid,1256.811279,539.052246,0.495777
...,...,...,...,...,...,...
2024-02-14 08:59:59.840000153,BAA-1104045,0.764098,centroid,595.052551,368.538788,0.951113
2024-02-14 08:59:59.840000153,BAA-1104045,0.847785,centroid,1299.120605,498.892639,0.951113
2024-02-14 08:59:59.900000095,BAA-1104045,0.668274,centroid,595.155579,368.437073,0.983769
2024-02-14 08:59:59.900000095,BAA-1104047,0.853135,centroid,1299.102173,496.867737,0.983769


In [105]:
cum_wheel_dist_dm = DotMap()
patch_stats_df = pd.DataFrame(index=patches, columns=["mean", "offset"])
pellets_stats_df = pd.DataFrame(columns=["time", "patch", "threshold", "id"])

r = eval(f"social02.{patch}.Encoder")
wheel_df = aeon.load(block.root, r, start, end)[::50].round(1).astype(np.float32)
cum_wheel_dist = -distancetravelled(wheel_df.angle)
# /ss>
# <ss Get pellets data
r = eval(f"social02.{patch}.DepletionState")
patch_df = aeon.load(block.root, r, start, end)
rate, offset = patch_df[["rate", "offset"]].iloc[0]
patch_stats_df.loc[patch, ["mean", "offset"]] = (1 / rate, offset)
patch_df_good_indxs = np.concatenate((np.diff(patch_df.index) > pd.Timedelta("1s"), (True,)))
patch_df_for_pellets_df = patch_df[patch_df_good_indxs].reset_index()[["time", "threshold"]]
patch_df_for_pellets_df["patch"] = patch
patch_df_for_pellets_df["id"] = None
patch_df_for_pellets_df.dropna(subset=["threshold"], inplace=True)
# drop 1st val as is from block start
patch_df_for_pellets_df = patch_df_for_pellets_df.iloc[1:].reset_index(drop=True)
# /ss>
# <ss Assign data to subjects
if len(subjects) == 1:  
    cum_wheel_dist_dm[subjects[0]][patch] = cum_wheel_dist.values
    patch_df_for_pellets_df["id"] = subjects[0]
else:
    # <ss Assign id based on which subject was closest to patch at time of delivery
    # <sss Get distance-to-patch at each pose data timestep
    patch_xy = np.array(patch_locs[patch][arena]).astype(np.uint32)
    subjects_xy = pose_df[pose_df["part"] == "centroid"][["x", "y"]].values
    dist_to_patch = np.sqrt(np.sum((subjects_xy - patch_xy) ** 2, axis=1))
    dist_to_patch_df = pose_df[["class"]].copy()
    dist_to_patch_df["dist_to_patch"] = dist_to_patch
    # /sss>
    # <sss Get distance-to-patch at each pel del ts and wheel ts, organized by subject
    dist_to_patch_wheel_ts_id_df = pd.DataFrame(index=cum_wheel_dist.index, columns=subjects)
    dist_to_patch_pel_ts_id_df = pd.DataFrame(index=patch_df_for_pellets_df["time"], columns=subjects)


  distance = distance - distance[0]


In [110]:
for subject in subjects:
    # Find closest match between pose_df indices and wheel / pel subj data indices
    dist_to_patch_wheel_ts_subj = pd.merge_asof(
        left=dist_to_patch_wheel_ts_id_df, 
        right=dist_to_patch_df[dist_to_patch_df["class"] == subject],
        left_index=True,
        right_index=True, 
        direction="forward", 
        tolerance=pd.Timedelta("100ms")
    )
    dist_to_patch_wheel_ts_id_df[subject] = dist_to_patch_wheel_ts_subj["dist_to_patch"]
    dist_to_patch_pel_ts_subj = pd.merge_asof(
        left=dist_to_patch_pel_ts_id_df,
        right=dist_to_patch_df[dist_to_patch_df["class"] == subject],
        left_index=True,
        right_index=True,
        direction="forward",
        tolerance=pd.Timedelta("200ms"),
    )
    dist_to_patch_pel_ts_id_df[subject] = dist_to_patch_pel_ts_subj["dist_to_patch"]
# Get closest subject to patch at each wheel / pel del ts
patch_df_for_pellets_df["id"] = dist_to_patch_pel_ts_id_df.idxmin(axis=1).values


cum_wheel_dist_subj_df = pd.DataFrame(index=cum_wheel_dist.index, columns=subjects, data=0.)
closest_subjects = dist_to_patch_wheel_ts_id_df.idxmin(axis=1)
wheel_dist = cum_wheel_dist.diff().fillna(cum_wheel_dist.iloc[0])
for subject in subjects:
    subj_idxs = cum_wheel_dist_subj_df[closest_subjects == subject].index
    cum_wheel_dist_subj_df.loc[subj_idxs, subject] = wheel_dist[closest_subjects == subject]
cum_wheel_dist_dm[patch] = cum_wheel_dist_subj_df.cumsum(axis=0)

# Iterate back through subjects to see

In [163]:
# Convert series to df with column name "subject"
cum_wheel_dist.to_frame(name=subjects[0])

Unnamed: 0_level_0,BAA-1104045
time,Unnamed: 1_level_1
2024-02-14 08:00:00.000000000,-0.000000
2024-02-14 08:00:00.099999905,0.001534
2024-02-14 08:00:00.199999809,-0.006136
2024-02-14 08:00:00.300000191,-0.001534
2024-02-14 08:00:00.400000095,-0.003068
...,...
2024-02-14 08:59:59.500000000,4394.620038
2024-02-14 08:59:59.599999905,4394.618504
2024-02-14 08:59:59.699999809,4394.615436
2024-02-14 08:59:59.800000191,4394.615436


In [155]:
cum_wheel_dist_dm[subjects[0]][patch]

Unnamed: 0_level_0,BAA-1104045,BAA-1104047
time,Unnamed: 1_level_1,Unnamed: 2_level_1
2024-02-14 08:00:00.000000000,-0.000000,0.000000
2024-02-14 08:00:00.099999905,0.001534,0.000000
2024-02-14 08:00:00.199999809,-0.006136,0.000000
2024-02-14 08:00:00.300000191,-0.001534,0.000000
2024-02-14 08:00:00.400000095,-0.003068,0.000000
...,...,...
2024-02-14 08:59:59.500000000,1247.826872,3146.785496
2024-02-14 08:59:59.599999905,1247.825338,3146.785496
2024-02-14 08:59:59.699999809,1247.825338,3146.782427
2024-02-14 08:59:59.800000191,1247.825338,3146.782427


In [150]:
cum_wheel_dist_subj_df = pd.DataFrame(index=cum_wheel_dist.index, columns=subjects, data=0.0)
closest_subjects = dist_to_patch_wheel_ts_id_df.idxmin(axis=1)
wheel_dist = cum_wheel_dist.diff().fillna(cum_wheel_dist.iloc[0])
for subject in subjects:
    subj_idxs = cum_wheel_dist_subj_df[closest_subjects == subject].index
    cum_wheel_dist_subj_df.loc[subj_idxs, subject] = wheel_dist[closest_subjects == subject]
cum_wheel_dist_subj_df.cumsum(axis=1)

  closest_subjects = dist_to_patch_wheel_ts_id_df.idxmin(axis=1)


Unnamed: 0_level_0,BAA-1104045,BAA-1104047
time,Unnamed: 1_level_1,Unnamed: 2_level_1
2024-02-14 08:00:00.000000000,-0.000000,0.000000
2024-02-14 08:00:00.099999905,0.001534,0.001534
2024-02-14 08:00:00.199999809,-0.007670,-0.007670
2024-02-14 08:00:00.300000191,0.004602,0.004602
2024-02-14 08:00:00.400000095,-0.001534,-0.001534
...,...,...
2024-02-14 08:59:59.500000000,-0.001534,-0.001534
2024-02-14 08:59:59.599999905,-0.001534,-0.001534
2024-02-14 08:59:59.699999809,0.000000,-0.003068
2024-02-14 08:59:59.800000191,0.000000,0.000000


In [136]:
cum_wheel_dist_subj_df[subject][closest_subjects == subject]

AttributeError: 'DataFrame' object has no attribute 'subject'

In [144]:
subj_idxs = cum_wheel_dist_subj_df[closest_subjects == subject].index
cum_wheel_dist_subj_df.loc[subj_idxs, subject] = wheel_dist[closest_subjects == subject]

 -0.00306815]' has dtype incompatible with int64, please explicitly cast to a compatible dtype first.
  cum_wheel_dist_subj_df.loc[subj_idxs, subject] = wheel_dist[closest_subjects == subject]


In [148]:
cum_wheel_dist

time
2024-02-14 08:00:00.000000000      -0.000000
2024-02-14 08:00:00.099999905       0.001534
2024-02-14 08:00:00.199999809      -0.006136
2024-02-14 08:00:00.300000191      -0.001534
2024-02-14 08:00:00.400000095      -0.003068
                                    ...     
2024-02-14 08:59:59.500000000    4394.620038
2024-02-14 08:59:59.599999905    4394.618504
2024-02-14 08:59:59.699999809    4394.615436
2024-02-14 08:59:59.800000191    4394.615436
2024-02-14 08:59:59.900000095    4394.618504
Name: angle, Length: 36000, dtype: float64

In [146]:
wheel_dist[closest_subjects == subject].cumsum()

time
2024-02-14 08:00:03.699999809       0.001534
2024-02-14 08:00:03.800000191       0.001534
2024-02-14 08:00:03.900000095      -0.003068
2024-02-14 08:00:04.000000000       0.001534
2024-02-14 08:00:04.099999905       0.000000
                                    ...     
2024-02-14 08:59:58.699999809    3146.783961
2024-02-14 08:59:58.900000095    3146.787030
2024-02-14 08:59:59.000000000    3146.785496
2024-02-14 08:59:59.300000191    3146.785496
2024-02-14 08:59:59.699999809    3146.782427
Name: angle, Length: 21992, dtype: float64

In [145]:
cum_wheel_dist_subj_df

Unnamed: 0_level_0,BAA-1104045,BAA-1104047
time,Unnamed: 1_level_1,Unnamed: 2_level_1
2024-02-14 08:00:00.000000000,0,0.000000
2024-02-14 08:00:00.099999905,0,0.000000
2024-02-14 08:00:00.199999809,0,0.000000
2024-02-14 08:00:00.300000191,0,0.000000
2024-02-14 08:00:00.400000095,0,0.000000
...,...,...
2024-02-14 08:59:59.500000000,0,0.000000
2024-02-14 08:59:59.599999905,0,0.000000
2024-02-14 08:59:59.699999809,0,-0.003068
2024-02-14 08:59:59.800000191,0,0.000000


In [138]:
subject

'BAA-1104047'

In [139]:
cum_wheel_dist_subj_df[subject]

time
2024-02-14 08:00:00.000000000    0
2024-02-14 08:00:00.099999905    0
2024-02-14 08:00:00.199999809    0
2024-02-14 08:00:00.300000191    0
2024-02-14 08:00:00.400000095    0
                                ..
2024-02-14 08:59:59.500000000    0
2024-02-14 08:59:59.599999905    0
2024-02-14 08:59:59.699999809    0
2024-02-14 08:59:59.800000191    0
2024-02-14 08:59:59.900000095    0
Name: BAA-1104047, Length: 36000, dtype: int64

In [133]:
subject

'BAA-1104047'

In [131]:
wheel_dist[closest_subjects == subject]

time
2024-02-14 08:00:03.699999809    0.001534
2024-02-14 08:00:03.800000191    0.000000
2024-02-14 08:00:03.900000095   -0.004602
2024-02-14 08:00:04.000000000    0.004602
2024-02-14 08:00:04.099999905   -0.001534
                                   ...   
2024-02-14 08:59:58.699999809    0.003068
2024-02-14 08:59:58.900000095    0.003068
2024-02-14 08:59:59.000000000   -0.001534
2024-02-14 08:59:59.300000191    0.000000
2024-02-14 08:59:59.699999809   -0.003068
Name: angle, Length: 21992, dtype: float64

In [127]:
cum_wheel_dist_subj_df.subject[closest_subjects == subject]

AttributeError: 'DataFrame' object has no attribute 'subject'

In [125]:
cum_wheel_dist_subj_df = pd.DataFrame(index=cum_wheel_dist.index, columns=subjects, data=0)
closest_subjects = dist_to_patch_wheel_ts_id_df.idxmin(axis=1)
wheel_dist = cum_wheel_dist.diff().fillna(cum_wheel_dist.iloc[0])

  closest_subjects = dist_to_patch_wheel_ts_id_df.idxmin(axis=1)


In [126]:
wheel_dist[closest_subjects == subject]

time
2024-02-14 08:00:03.699999809    0.001534
2024-02-14 08:00:03.800000191    0.000000
2024-02-14 08:00:03.900000095   -0.004602
2024-02-14 08:00:04.000000000    0.004602
2024-02-14 08:00:04.099999905   -0.001534
                                   ...   
2024-02-14 08:59:58.699999809    0.003068
2024-02-14 08:59:58.900000095    0.003068
2024-02-14 08:59:59.000000000   -0.001534
2024-02-14 08:59:59.300000191    0.000000
2024-02-14 08:59:59.699999809   -0.003068
Name: angle, Length: 21992, dtype: float64

In [122]:
cum_wheel_dist

time
2024-02-14 08:00:00.000000000      -0.000000
2024-02-14 08:00:00.099999905       0.001534
2024-02-14 08:00:00.199999809      -0.006136
2024-02-14 08:00:00.300000191      -0.001534
2024-02-14 08:00:00.400000095      -0.003068
                                    ...     
2024-02-14 08:59:59.500000000    4394.620038
2024-02-14 08:59:59.599999905    4394.618504
2024-02-14 08:59:59.699999809    4394.615436
2024-02-14 08:59:59.800000191    4394.615436
2024-02-14 08:59:59.900000095    4394.618504
Name: angle, Length: 36000, dtype: float64

time
2024-02-14 08:00:00.000000000         NaN
2024-02-14 08:00:00.099999905    0.001534
2024-02-14 08:00:00.199999809   -0.007670
2024-02-14 08:00:00.300000191    0.004602
2024-02-14 08:00:00.400000095   -0.001534
                                   ...   
2024-02-14 08:59:59.500000000   -0.001534
2024-02-14 08:59:59.599999905   -0.001534
2024-02-14 08:59:59.699999809   -0.003068
2024-02-14 08:59:59.800000191    0.000000
2024-02-14 08:59:59.900000095    0.003068
Name: angle, Length: 36000, dtype: float64

In [118]:
dist_to_patch_wheel_ts_id_df

Unnamed: 0_level_0,BAA-1104045,BAA-1104047
time,Unnamed: 1_level_1,Unnamed: 2_level_1
2024-02-14 08:00:00.000000000,693.885635,698.566247
2024-02-14 08:00:00.099999905,690.778328,698.661071
2024-02-14 08:00:00.199999809,690.418638,698.643935
2024-02-14 08:00:00.300000191,687.212969,700.757699
2024-02-14 08:00:00.400000095,683.109510,700.926455
...,...,...
2024-02-14 08:59:59.500000000,384.701512,741.401770
2024-02-14 08:59:59.599999905,386.913748,743.527414
2024-02-14 08:59:59.699999809,743.851848,384.906781
2024-02-14 08:59:59.800000191,384.493044,744.545220


In [117]:
dist_to_patch_pel_ts_id_df

Unnamed: 0_level_0,BAA-1104045,BAA-1104047
time,Unnamed: 1_level_1,Unnamed: 2_level_1
2024-02-14 08:02:19.219999790,746.210657,
2024-02-14 08:02:43.841983795,21.425895,749.594314
2024-02-14 08:05:23.169983864,21.432737,380.245593
2024-02-14 08:11:36.375999928,633.40308,23.403707
2024-02-14 08:13:55.437983990,567.547967,23.516178
2024-02-14 08:14:07.480000019,342.014591,23.583132
2024-02-14 08:16:26.765984058,276.627255,23.568197
2024-02-14 08:19:22.361983776,562.156015,23.203993
2024-02-14 08:19:33.853983879,539.382565,24.981456
2024-02-14 08:22:26.412000179,302.219721,23.412274


In [116]:
patch_df_for_pellets_df

Unnamed: 0,time,threshold,patch,id
0,2024-02-14 08:02:19.219999790,361.785049,Patch2,BAA-1104045
1,2024-02-14 08:02:43.841983795,194.950509,Patch2,BAA-1104045
2,2024-02-14 08:05:23.169983864,88.046672,Patch2,BAA-1104045
3,2024-02-14 08:11:36.375999928,85.945754,Patch2,BAA-1104047
4,2024-02-14 08:13:55.437983990,76.270649,Patch2,BAA-1104047
5,2024-02-14 08:14:07.480000019,205.54352,Patch2,BAA-1104047
6,2024-02-14 08:16:26.765984058,88.423802,Patch2,BAA-1104047
7,2024-02-14 08:19:22.361983776,96.462577,Patch2,BAA-1104047
8,2024-02-14 08:19:33.853983879,132.625352,Patch2,BAA-1104047
9,2024-02-14 08:22:26.412000179,372.380471,Patch2,BAA-1104047


In [107]:
dist_to_patch_df

Unnamed: 0_level_0,class,dist_to_patch
time,Unnamed: 1_level_1,Unnamed: 2_level_1
2024-02-14 08:00:00.019999981,BAA-1104045,693.885635
2024-02-14 08:00:00.019999981,BAA-1104047,698.566247
2024-02-14 08:00:00.079999924,BAA-1104045,691.475353
2024-02-14 08:00:00.079999924,BAA-1104045,700.675389
2024-02-14 08:00:00.139999866,BAA-1104045,690.778328
...,...,...
2024-02-14 08:59:59.840000153,BAA-1104045,384.493044
2024-02-14 08:59:59.840000153,BAA-1104045,743.868383
2024-02-14 08:59:59.900000095,BAA-1104045,384.593439
2024-02-14 08:59:59.900000095,BAA-1104047,744.545220


In [103]:
dist_to_patch_wheel_ts_id_df

Unnamed: 0_level_0,BAA-1104045,BAA-1104047
time,Unnamed: 1_level_1,Unnamed: 2_level_1
2024-02-14 08:00:00.000000000,693.885635,693.885635
2024-02-14 08:00:00.099999905,690.778328,690.778328
2024-02-14 08:00:00.199999809,690.418638,690.418638
2024-02-14 08:00:00.300000191,687.212969,687.212969
2024-02-14 08:00:00.400000095,683.109510,683.109510
...,...,...
2024-02-14 08:59:59.500000000,384.701512,384.701512
2024-02-14 08:59:59.599999905,386.913748,386.913748
2024-02-14 08:59:59.699999809,384.906781,384.906781
2024-02-14 08:59:59.800000191,384.493044,384.493044


In [102]:
dist_to_patch_wheel_ts_subj

Unnamed: 0_level_0,BAA-1104045,BAA-1104047,class,dist_to_patch
time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
2024-02-14 08:00:00.000000000,693.885635,,BAA-1104045,693.885635
2024-02-14 08:00:00.099999905,690.778328,,BAA-1104045,690.778328
2024-02-14 08:00:00.199999809,690.418638,,BAA-1104045,690.418638
2024-02-14 08:00:00.300000191,687.212969,,BAA-1104045,687.212969
2024-02-14 08:00:00.400000095,683.109510,,BAA-1104045,683.109510
...,...,...,...,...
2024-02-14 08:59:59.500000000,384.701512,,BAA-1104045,384.701512
2024-02-14 08:59:59.599999905,386.913748,,BAA-1104045,386.913748
2024-02-14 08:59:59.699999809,384.906781,,BAA-1104047,384.906781
2024-02-14 08:59:59.800000191,384.493044,,BAA-1104045,384.493044


In [100]:
dist_to_patch_pel_ts_subj

Unnamed: 0_level_0,BAA-1104045,BAA-1104047,class,dist_to_patch
time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
2024-02-14 08:02:19.219999790,746.210657,,BAA-1104045,746.210657
2024-02-14 08:02:43.841983795,749.594314,,BAA-1104047,749.594314
2024-02-14 08:05:23.169983864,380.245593,,BAA-1104047,380.245593
2024-02-14 08:11:36.375999928,633.40308,,BAA-1104045,633.40308
2024-02-14 08:13:55.437983990,23.516178,,BAA-1104047,23.516178
2024-02-14 08:14:07.480000019,23.583132,,BAA-1104047,23.583132
2024-02-14 08:16:26.765984058,23.568197,,BAA-1104047,23.568197
2024-02-14 08:19:22.361983776,562.156015,,BAA-1104045,562.156015
2024-02-14 08:19:33.853983879,539.382565,,BAA-1104045,539.382565
2024-02-14 08:22:26.412000179,23.412274,,BAA-1104047,23.412274


In [78]:
subjects_xy = pose_df[pose_df["part"] == "centroid"][["x", "y"]].values
subjects_xy

array([[1259.26   ,  536.5433 ],
       [1269.9208 ,  555.0077 ],
       [1256.6609 ,  536.35596],
       ...,
       [ 595.1556 ,  368.43707],
       [1299.1022 ,  496.86774],
       [ 595.17505,  368.38034]], dtype=float32)

In [79]:
subjects_xy.shape

(110635, 2)

In [80]:
dist_to_patch = np.sqrt(np.sum((subjects_xy - patch_xy) ** 2, axis=1))
dist_to_patch_df = pose_df[["class"]].copy()
dist_to_patch_df["dist_to_patch"] = dist_to_patch

In [81]:
dist_to_patch_df

Unnamed: 0_level_0,class,dist_to_patch
time,Unnamed: 1_level_1,Unnamed: 2_level_1
2024-02-14 08:00:00.019999981,BAA-1104045,696.273687
2024-02-14 08:00:00.019999981,BAA-1104047,711.691849
2024-02-14 08:00:00.079999924,BAA-1104045,693.721146
2024-02-14 08:00:00.079999924,BAA-1104045,713.902905
2024-02-14 08:00:00.139999866,BAA-1104045,694.602667
...,...,...
2024-02-14 08:59:59.840000153,BAA-1104045,21.412034
2024-02-14 08:59:59.840000153,BAA-1104045,725.975112
2024-02-14 08:59:59.900000095,BAA-1104045,21.343971
2024-02-14 08:59:59.900000095,BAA-1104047,725.538902


In [83]:
dist_to_patch_id_df

Unnamed: 0_level_0,BAA-1104045
time,Unnamed: 1_level_1
2024-01-31 17:46:01.929984093,19.174071
2024-01-31 17:46:18.987999916,15.551528
2024-01-31 17:46:37.323999882,19.905172
2024-01-31 17:46:52.657983780,19.984235
2024-01-31 17:47:07.561984062,21.056611
2024-01-31 17:47:18.127999783,17.622595
2024-01-31 17:47:29.808000088,19.765585
2024-01-31 17:47:44.273983955,17.342467
2024-01-31 17:47:58.581984043,21.922041
2024-01-31 17:48:20.599999905,18.808438


In [69]:
patch_df_for_pellets_df

Unnamed: 0,time,threshold,patch,id
0,2024-01-31 17:46:01.929984093,166.161345,Patch3,BAA-1104045
1,2024-01-31 17:46:18.987999916,201.563446,Patch3,BAA-1104045
2,2024-01-31 17:46:37.323999882,148.860696,Patch3,BAA-1104045
3,2024-01-31 17:46:52.657983780,155.151245,Patch3,BAA-1104045
4,2024-01-31 17:47:07.561984062,85.592153,Patch3,BAA-1104045
5,2024-01-31 17:47:18.127999783,132.567869,Patch3,BAA-1104045
6,2024-01-31 17:47:29.808000088,148.151722,Patch3,BAA-1104045
7,2024-01-31 17:47:44.273983955,118.466936,Patch3,BAA-1104045
8,2024-01-31 17:47:58.581984043,104.584493,Patch3,BAA-1104045
9,2024-01-31 17:48:20.599999905,88.933069,Patch3,BAA-1104045


In [68]:
dist_to_patch_id_df

Unnamed: 0_level_0,BAA-1104045
time,Unnamed: 1_level_1
2024-01-31 17:46:01.929984093,19.174071
2024-01-31 17:46:18.987999916,15.551528
2024-01-31 17:46:37.323999882,19.905172
2024-01-31 17:46:52.657983780,19.984235
2024-01-31 17:47:07.561984062,21.056611
2024-01-31 17:47:18.127999783,17.622595
2024-01-31 17:47:29.808000088,19.765585
2024-01-31 17:47:44.273983955,17.342467
2024-01-31 17:47:58.581984043,21.922041
2024-01-31 17:48:20.599999905,18.808438


In [65]:
block.cum_wheel_dist.pprint()


{'BAA-1104045': {'Patch1': array([-0.00000000e+00, -3.06814884e-03, -1.53407442e-03, ...,
        3.28233171e+03,  3.28233631e+03,  3.28232864e+03]),
                 'Patch2': array([-0.00000000e+00,  1.53407442e-03,  4.60222326e-03, ...,
        5.53662799e+01,  5.53662799e+01,  5.53708821e+01]),
                 'Patch3': array([-0.00000000e+00, -1.53407442e-03, -1.53407442e-03, ...,
        2.38928409e+03,  2.38928256e+03,  2.38927796e+03])},
 '_ipython_display_': {},
 '_repr_html_': {},
 '_repr_javascript_': {},
 '_repr_jpeg_': {},
 '_repr_json_': {},
 '_repr_latex_': {},
 '_repr_markdown_': {},
 '_repr_mimebundle_': {},
 '_repr_pdf_': {},
 '_repr_png_': {},
 '_repr_svg_': {},
 'time': DatetimeIndex(['2024-01-31 16:18:07.001984119',
               '2024-01-31 16:18:07.101984024',
               '2024-01-31 16:18:07.201983929',
               '2024-01-31 16:18:07.301983833',
               '2024-01-31 16:18:07.401984215',
               '2024-01-31 16:18:07.501984118',
            

In [60]:
np.sqrt(np.sum((patch_xy - (patch_xy + 10))**2))

14.142135623730951

In [212]:
# Get part = centroid, return x,y as np.array
pose_df[pose_df["part"] == "centroid"][["x", "y"]]

Unnamed: 0_level_0,x,y
time,Unnamed: 1_level_1,Unnamed: 2_level_1
2024-01-31 16:18:07.059999943,922.973083,538.867432
2024-01-31 16:18:07.119999886,923.023926,538.851929
2024-01-31 16:18:07.179999828,923.019226,538.852417
2024-01-31 16:18:07.239999771,922.998718,538.848206
2024-01-31 16:18:07.300000191,923.012939,538.838379
...,...,...
2024-01-31 17:56:22.659999847,1109.742920,872.181152
2024-01-31 17:56:22.719999790,1104.848389,875.274902
2024-01-31 17:56:22.780000210,1096.728882,885.990417
2024-01-31 17:56:22.840000153,1086.006714,896.699646


In [221]:
patch_df_for_pellets_df

Unnamed: 0,time,threshold,patch,id
0,2024-01-31 17:46:01.929984093,166.161345,Patch3,BAA-1104045
1,2024-01-31 17:46:18.987999916,201.563446,Patch3,BAA-1104045
2,2024-01-31 17:46:37.323999882,148.860696,Patch3,BAA-1104045
3,2024-01-31 17:46:52.657983780,155.151245,Patch3,BAA-1104045
4,2024-01-31 17:47:07.561984062,85.592153,Patch3,BAA-1104045
5,2024-01-31 17:47:18.127999783,132.567869,Patch3,BAA-1104045
6,2024-01-31 17:47:29.808000088,148.151722,Patch3,BAA-1104045
7,2024-01-31 17:47:44.273983955,118.466936,Patch3,BAA-1104045
8,2024-01-31 17:47:58.581984043,104.584493,Patch3,BAA-1104045
9,2024-01-31 17:48:20.599999905,88.933069,Patch3,BAA-1104045


In [223]:
dist_to_patch_id_df = pd.DataFrame(index=patch_df_for_pellets_df["time"], columns=[subjects])

In [224]:
dist_to_patch_id_df

Unnamed: 0_level_0,BAA-1104045
time,Unnamed: 1_level_1
2024-01-31 17:46:01.929984093,
2024-01-31 17:46:18.987999916,
2024-01-31 17:46:37.323999882,
2024-01-31 17:46:52.657983780,
2024-01-31 17:47:07.561984062,
2024-01-31 17:47:18.127999783,
2024-01-31 17:47:29.808000088,
2024-01-31 17:47:44.273983955,
2024-01-31 17:47:58.581984043,
2024-01-31 17:48:20.599999905,


In [231]:
patch_xy = np.array(patch_locs[patch][arena]).astype(np.uint32)
subjects_xy = pose_df[pose_df["part"] == "centroid"][["x", "y"]].values
dist_to_patch = np.sqrt(np.sum((subjects_xy - patch_xy) ** 2, axis=1))
dist_to_patch_df = pose_df[["class"]].copy()
dist_to_patch_df["dist_to_patch"] = dist_to_patch
dist_to_patch_id_df = pd.DataFrame(index=patch_df_for_pellets_df["time"], columns=[subjects])
display(dist_to_patch_df)
display(dist_to_patch_id_df)

Unnamed: 0_level_0,class,dist_to_patch
time,Unnamed: 1_level_1,Unnamed: 2_level_1
2024-01-31 16:18:07.059999943,BAA-1104045,384.666605
2024-01-31 16:18:07.119999886,BAA-1104045,384.703057
2024-01-31 16:18:07.179999828,BAA-1104045,384.699218
2024-01-31 16:18:07.239999771,BAA-1104045,384.679323
2024-01-31 16:18:07.300000191,BAA-1104045,384.686795
...,...,...
2024-01-31 17:56:22.659999847,BAA-1104045,738.876897
2024-01-31 17:56:22.719999790,BAA-1104045,737.643805
2024-01-31 17:56:22.780000210,BAA-1104045,739.744758
2024-01-31 17:56:22.840000153,BAA-1104045,740.328964


Unnamed: 0_level_0,BAA-1104045
time,Unnamed: 1_level_1
2024-01-31 17:46:01.929984093,
2024-01-31 17:46:18.987999916,
2024-01-31 17:46:37.323999882,
2024-01-31 17:46:52.657983780,
2024-01-31 17:47:07.561984062,
2024-01-31 17:47:18.127999783,
2024-01-31 17:47:29.808000088,
2024-01-31 17:47:44.273983955,
2024-01-31 17:47:58.581984043,
2024-01-31 17:48:20.599999905,


In [238]:
pd.merge_asof(dist_to_patch_id_df, dist_to_patch_df, left_index=True, right_index=True, direction="forward", tolerance=pd.Timedelta("200ms"))

Unnamed: 0_level_0,"(BAA-1104045,)",class,dist_to_patch
time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
2024-01-31 17:46:01.929984093,,BAA-1104045,19.174071
2024-01-31 17:46:18.987999916,,BAA-1104045,15.551528
2024-01-31 17:46:37.323999882,,BAA-1104045,19.905172
2024-01-31 17:46:52.657983780,,BAA-1104045,19.984235
2024-01-31 17:47:07.561984062,,BAA-1104045,21.056611
2024-01-31 17:47:18.127999783,,BAA-1104045,17.622595
2024-01-31 17:47:29.808000088,,BAA-1104045,19.765585
2024-01-31 17:47:44.273983955,,BAA-1104045,17.342467
2024-01-31 17:47:58.581984043,,BAA-1104045,21.922041
2024-01-31 17:48:20.599999905,,BAA-1104045,18.808438


In [240]:
            patch_xy = np.array(patch_locs[patch][arena]).astype(np.uint32)
            subjects_xy = pose_df[pose_df["part"] == "centroid"][["x", "y"]].values
            dist_to_patch = np.sqrt(np.sum((subjects_xy - patch_xy) ** 2, axis=1))
            dist_to_patch_df = pose_df[["class"]].copy()
            dist_to_patch_df["dist_to_patch"] = dist_to_patch
            dist_to_patch_id_df = pd.DataFrame(index=patch_df_for_pellets_df["time"], columns=subjects)
            for subject in subjects:
                # Find closest match between dist_to_patch_id_df indices and pose_df indices
                dist_to_patch_subj = pd.merge_asof(
                    left=dist_to_patch_id_df, 
                    right=dist_to_patch_df,
                    left_index=True,
                    right_index=True, 
                    direction="forward", 
                    tolerance=pd.Timedelta("200ms")
                )
                dist_to_patch_id_df[subject] = dist_to_patch_subj["dist_to_patch"]

In [248]:
dist_to_patch_id_df["BAA-1104047"] = dist_to_patch_id_df["BAA-1104045"]
dist_to_patch_id_df["BAA-1104047"] -= 1
dist_to_patch_id_df["BAA-1104047"].iloc[::2] += 2
dist_to_patch_id_df

Unnamed: 0_level_0,BAA-1104045,BAA-1104047
time,Unnamed: 1_level_1,Unnamed: 2_level_1
2024-01-31 17:46:01.929984093,19.174071,20.174071
2024-01-31 17:46:18.987999916,15.551528,14.551528
2024-01-31 17:46:37.323999882,19.905172,20.905172
2024-01-31 17:46:52.657983780,19.984235,18.984235
2024-01-31 17:47:07.561984062,21.056611,22.056611
2024-01-31 17:47:18.127999783,17.622595,16.622595
2024-01-31 17:47:29.808000088,19.765585,20.765585
2024-01-31 17:47:44.273983955,17.342467,16.342467
2024-01-31 17:47:58.581984043,21.922041,22.922041
2024-01-31 17:48:20.599999905,18.808438,17.808438


In [251]:
dist_to_patch_id_df.idxmin(axis=1).values

array(['BAA-1104045', 'BAA-1104047', 'BAA-1104045', 'BAA-1104047',
       'BAA-1104045', 'BAA-1104047', 'BAA-1104045', 'BAA-1104047',
       'BAA-1104045', 'BAA-1104047', 'BAA-1104045', 'BAA-1104047',
       'BAA-1104045', 'BAA-1104047', 'BAA-1104045'], dtype=object)

In [252]:
patch_df_for_pellets_df["id"] = dist_to_patch_id_df.idxmin(axis=1).values

In [253]:
patch_df_for_pellets_df

Unnamed: 0,time,threshold,patch,id
0,2024-01-31 17:46:01.929984093,166.161345,Patch3,BAA-1104045
1,2024-01-31 17:46:18.987999916,201.563446,Patch3,BAA-1104047
2,2024-01-31 17:46:37.323999882,148.860696,Patch3,BAA-1104045
3,2024-01-31 17:46:52.657983780,155.151245,Patch3,BAA-1104047
4,2024-01-31 17:47:07.561984062,85.592153,Patch3,BAA-1104045
5,2024-01-31 17:47:18.127999783,132.567869,Patch3,BAA-1104047
6,2024-01-31 17:47:29.808000088,148.151722,Patch3,BAA-1104045
7,2024-01-31 17:47:44.273983955,118.466936,Patch3,BAA-1104047
8,2024-01-31 17:47:58.581984043,104.584493,Patch3,BAA-1104045
9,2024-01-31 17:48:20.599999905,88.933069,Patch3,BAA-1104047


In [166]:
patch_df

Unnamed: 0_level_0,threshold,offset,rate
time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
2024-02-14 08:02:06.921984196,115.006748,75,0.01
2024-02-14 08:02:19.219999790,361.785049,75,0.01
2024-02-14 08:02:43.841983795,194.950509,75,0.01
2024-02-14 08:05:23.169983864,88.046672,75,0.01
2024-02-14 08:11:36.375999928,85.945754,75,0.01
2024-02-14 08:13:55.437983990,76.270649,75,0.01
2024-02-14 08:14:07.480000019,205.54352,75,0.01
2024-02-14 08:16:26.765984058,88.423802,75,0.01
2024-02-14 08:19:22.361983776,96.462577,75,0.01
2024-02-14 08:19:33.853983879,132.625352,75,0.01


In [167]:
sum(np.diff(patch_df.index) > pd.Timedelta("1s"))

23

In [218]:
        patch_xy = np.array(patch_locs[patch][arena]).astype(np.uint32)
        subjects_xy = pose_df[pose_df["part"] == "centroid"][["x", "y"]].values
        dist_to_patch = np.sqrt(np.sum((subjects_xy - patch_xy) ** 2, axis=1))

In [219]:
dist_to_patch

array([384.66660495, 384.70305651, 384.69921817, ..., 739.74475784,
       740.32896414, 743.07779698])

In [None]:
np.convolve([1, 2, 3, 4, 5], [1, 1, 1])

In [None]:
# NOPE

In [None]:
# <s Assigning patch to subject

# <ss Sanity check sleap-tracking:
# If subject in sleap-tracked data does not match one of the subjects actually in env in the block...
# Of the subjects in env, find the one that has less tracking data in a time-period around the non-matching
# sleap-tracked id, and assign this subject to this id.

In [None]:
"""Pose info."""

pose_df = aeon.load(block.root, social02.CameraTop.Pose, block.start, block.end)
pose_df = reader.Pose.class_int2str(pose_df, block.sleap_model_dir)
pose_df_subjects = pose_df["class"].unique()
# Fix mistaken sleap assignments for single-subject blocks
if len(subjects) == 1:  
    for pose_df_subject in pose_df_subjects:
        if pose_df_subject != subjects[0]:
            pose_df[pose_df["class"] == pose_df_subject] = subjects[0]

In [None]:
"""Get strange blocks."""

# Get blocks with no subjects
np.where(blocks_df.subjects.apply(len) == 0)

# Get blocks with < 3 pellets

## Plots

In [None]:
# 1 / patch_rate, next to boxplots of each pellet threshold per patch

In [None]:
# Cumulative pellet count over time, per patch, per subject (0, 0)

In [None]:
# Running cumulative patch preference, per subject: each patch as a line (0, 1)

In [None]:
# Pellet threshold vals over time, per patch, per subject (1, 0)

In [None]:
# Null distribution 2.5th and 97.5th percentiles with per-patch preference vals, per subject (1, 1)


In [None]:
# Pairwise Null distribution 2.5th and 97.5th percentiles with per-patch preference vals, per subject (1, 1)

In [None]:
block_subject_data

In [None]:
block_patch_data

In [None]:
block_wheel_data = [block["wheel_cumsum_distance_travelled"] for block in block_patch_data]

In [None]:
block_wheel_data

In [None]:
# <ss Get patch preference by wheel distance within each session div
wboth = w1 + w2
wboth_quantized = np.linspace(0, wboth[-1], session_divs + 1)
easy_pref_epoch_cum = np.zeros((session_divs,))
easy_pref_epoch = np.zeros((session_divs,))
epoch_thresh_change_idx = 0
epoch_ts_pre = wboth.index[0]
for i in range(1, session_divs):
    epoch_ts_post = wboth[wboth > wboth_quantized[i]].index[0] - pd.Timedelta("1s")
    if (epoch_ts_post > safe_change_ts) and not epoch_thresh_change_idx:
        epoch_thresh_change_idx = i
    weasy_post = weasy[weasy.index > epoch_ts_post][0]
    whard_post = whard[whard.index > epoch_ts_post][0]
    weasy_pre = weasy[weasy.index > epoch_ts_pre][0]
    whard_pre = whard[whard.index > epoch_ts_pre][0]
    weasy_diff = weasy_post - weasy_pre
    whard_diff = whard_post - whard_pre
    easy_pref_epoch_cum[i] = weasy_post / (weasy_post + whard_post)
    easy_pref_epoch[i] = weasy_diff / (weasy_diff + whard_diff)
    epoch_ts_pre = epoch_ts_post
sessions.at[s.Index, "easy_pref_epoch_cum"] = easy_pref_epoch_cum
sessions.at[s.Index, "easy_pref_epoch"] = easy_pref_epoch
sessions.loc[s.Index, "epoch_thresh_change_idx"] = epoch_thresh_change_idx
# /ss>
# <ss Get chunked patch pref compared to synthetic data
# <sss Chunk (downsample) wheel data
weasy_chnkd = np.abs((weasy[(w_chunk_t - 1)::w_chunk_t]).values - (weasy[::w_chunk_t][:-1]).values)
weasy_chnkd_cumsum = weasy_chnkd.cumsum()
whard_chnkd = np.abs((whard[(w_chunk_t - 1)::w_chunk_t]).values - (whard[::w_chunk_t][:-1]).values)
whard_chnkd_cumsum = whard_chnkd.cumsum()
w_all_chnkd_cumsum = weasy_chnkd_cumsum + whard_chnkd_cumsum
n_samples = len(weasy_chnkd)
pref_first_idx = np.where(w_all_chnkd_cumsum > w_chunk_dist)[0][0]
end_idxs = np.arange(pref_first_idx, n_samples, 1).astype(int)
start_idxs = np.zeros((len(end_idxs),)).astype(int)
for i, idx in enumerate(end_idxs):
    start_idxs[i] = np.where((w_all_chnkd_cumsum[0:idx] + w_chunk_dist) > w_all_chnkd_cumsum[idx])[0][0]
# /sss>
# <sss Get true chunked patch pref
weasy_diff = weasy_chnkd_cumsum[end_idxs] - weasy_chnkd_cumsum[start_idxs]
whard_diff = whard_chnkd_cumsum[end_idxs] - whard_chnkd_cumsum[start_idxs]
weasy_pref = weasy_diff / (weasy_diff + whard_diff)
# /sss>
# <sss Generate individual wheel null distributions
w_all_chnkd = np.concatenate((weasy_chnkd, whard_chnkd))
syn_chunk_pref_dists = np.zeros((n_distris, len(weasy_pref)))
for distri_n in range(n_distris):
    # Create synthetic distributions
    weasy_chnkd_gen = np.random.choice(w_all_chnkd, size=n_samples, replace=False)
    whard_chnkd_gen = np.random.choice(w_all_chnkd, size=n_samples, replace=False)
    impossible_idxs = np.where(np.logical_and(weasy_chnkd_gen > 0.1, whard_chnkd_gen > 0.1))[0]
    for ii in impossible_idxs:
        if weasy_chnkd_gen[ii] > whard_chnkd_gen[ii]:
            whard_chnkd_gen[ii] = 0
        else:
            weasy_chnkd_gen[ii] = 0
    weasy_chnkd_gen_cumsum = weasy_chnkd_gen.cumsum()
    whard_chnkd_gen_cumsum = whard_chnkd_gen.cumsum()
    w_all_chnkd_gen_cumsum = weasy_chnkd_gen_cumsum + whard_chnkd_gen_cumsum
    # Get synthetic patch pref
    end_idxs = np.arange(pref_first_idx, n_samples, 1).astype(int)
    start_idxs = np.zeros((len(end_idxs),)).astype(int)
    for i, idx in enumerate(end_idxs):
        start_idxs[i] = np.where(
            (w_all_chnkd_gen_cumsum[0:idx] + w_chunk_dist) 
            > w_all_chnkd_gen_cumsum[idx]
        )[0][0]
    weasy_diff_gen = weasy_chnkd_gen_cumsum[end_idxs] - weasy_chnkd_gen_cumsum[start_idxs]
    whard_diff_gen = whard_chnkd_gen_cumsum[end_idxs] - whard_chnkd_gen_cumsum[start_idxs]
    weasy_pref_gen = weasy_diff_gen / (weasy_diff_gen + whard_diff_gen)
    syn_chunk_pref_dists[distri_n, :] = weasy_pref_gen
# /sss>
# <sss Get the 2.5th and 97.5th percentiles of the null distributions
syn_chunk_pref_dists = np.sort(syn_chunk_pref_dists, axis=0)
low_bound = syn_chunk_pref_dists[3, :]
high_bound = syn_chunk_pref_dists[96, :]
# /sss>
# <sss Check if learning criteria is met
learned_start_idx = None
learned_end_idx = None
pref_idxs = np.where(weasy_pref > high_bound)[0]
# For each pref_idx, find the first earlier idx with `pref_window` less
# cum distance, then see if pref over this window is > `pref_thresh`
for pref_start_idx in pref_idxs:
    pref_end_idx = np.where(
        w_all_chnkd_cumsum[pref_start_idx:] 
        > (w_all_chnkd_cumsum[pref_start_idx] + pref_window)
    )[0]
    if pref_end_idx.size > 0:
        pref_end_idx = pref_end_idx[0] + pref_start_idx
        pref_p = np.sum(
            weasy_pref[pref_start_idx : pref_end_idx] 
            > high_bound[pref_start_idx : pref_end_idx]
        ) / (pref_end_idx - pref_start_idx)
        if pref_p > pref_thresh:
            learned_start_idx = pref_start_idx
            learned_end_idx = pref_end_idx
            break
# /sss>
cont_patch_pref = DotMap(
    w_all_chnkd_cumsum=w_all_chnkd_cumsum.astype('float32'),
    weasy_pref=weasy_pref.astype('float32'),
    low_bound=low_bound.astype('float32'),
    high_bound=high_bound.astype('float32'),
    learned_start_idx=learned_start_idx,
    learned_end_idx=learned_end_idx,
    thresh_change_idx=(safe_change_ts - s.enter).seconds
)
sessions.at[s.Index, "cont_patch_pref"] = cont_patch_pref
if learned_start_idx:
    print(f"Learned: {s.id} {s.enter} ... {post_easy_rate} {post_hard_rate}")
# /ss> /s>

## Questions

- In a block, what percentage of time do they end on easy block?