In [None]:
# Social exp timeline

# Aeon3
#  2024-01-31 : 2024-02-03 - BAA-1104045 pre solo
#  2024-02-05 : 2024-02-08 - BAA-1104047 pre solo (dominant)
#  2024-02-09 : 2024-02-23 - BAA-1104045, BAA-1104047 social
#  2024-02-25 : 2024-02-28 - BAA-1104045 post solo
#  2024-02-28 : 2024-03-02 - BAA-1104047 post solo

# Aeon4
#  2024-01-31 : 2024-02-03 - BAA-1104048 pre solo (dominant)
#  2024-02-05 : 2024-02-08 - BAA-1104049 pre solo
#  2024-02-09 : 2024-02-23 - BAA-1104048, BAA-1104049 social
#  2024-02-25 : 2024-02-28 - BAA-1104048 post solo
#  2024-02-28 : 2024-03-02 - BAA-1104049 post solo

In [None]:
"""Imports and settings."""

%load_ext autoreload
%autoreload 2
# %flow mode reactive

from colorsys import hls_to_rgb, rgb_to_hls
from pathlib import Path

import aeon
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly
import plotly.express as px
import plotly.graph_objs as go
import seaborn as sns
from aeon.schema.schemas import social02
from numpy.lib.stride_tricks import as_strided

## Load and Clean Data

In [None]:
import pathlib
from contextlib import contextmanager


# .pkl files written on Windows are incompatible with Unix systems
@contextmanager
def set_windows_path_as_posix():
    windows_backup = pathlib.WindowsPath
    try:
        pathlib.WindowsPath = pathlib.PosixPath
        yield
    finally:
        pathlib.WindowsPath = windows_backup

In [None]:
"""Load data."""

with set_windows_path_as_posix():
    root_prefix = "/ceph/aeon"  # or "Z:"
    roots = [
        Path(f"{root_prefix}/aeon/data/raw/AEON3/social0.2"),
        Path(f"{root_prefix}/aeon/data/raw/AEON4/social0.2"),
    ]

    blocks_df_2024_01_31 = pd.read_pickle(
        Path(
            f"{root_prefix}/aeon/code/scratchpad/jai/social02/social02_2024-01-31_2024-02-14.pkl"
        )
    )
    blocks_df_2024_02_14 = pd.read_pickle(
        Path(
            f"{root_prefix}/aeon/code/scratchpad/jai/social02/social02_2024-02-14_2024-02-16.pkl"
        )
    )
    blocks_df_2024_02_16 = pd.read_pickle(
        Path(
            f"{root_prefix}/aeon/code/scratchpad/jai/social02/social02_2024-02-16_2024-02-23.pkl"
        )
    )
    blocks_df_2024_02_25 = pd.read_pickle(
        Path(
            f"{root_prefix}/aeon/code/scratchpad/jai/social02/social02_2024-02-25_2024-02-28.pkl"
        )
    )
    blocks_df_2024_02_29 = pd.read_pickle(
        Path(
            f"{root_prefix}/aeon/code/scratchpad/jai/social02/social02_2024-02-29_2024-03-02.pkl"
        )
    )

    skipped_blocks_2024_01_31 = np.load(
        Path(
            f"{root_prefix}/aeon/code/scratchpad/jai/social02/skipped_blocks_2024-01-31_2024-02-14.npy"
        )
    )
    skipped_blocks_2024_02_14 = np.load(
        Path(
            f"{root_prefix}/aeon/code/scratchpad/jai/social02/skipped_blocks_2024-02-14_2024-02-16.npy"
        )
    )
    skipped_blocks_2024_02_16 = np.load(
        Path(
            f"{root_prefix}/aeon/code/scratchpad/jai/social02/skipped_blocks_2024-02-16_2024-02-23.npy"
        )
    )
    skipped_blocks_2024_02_25 = np.load(
        Path(
            f"{root_prefix}/aeon/code/scratchpad/jai/social02/skipped_blocks_2024-02-25_2024-02-28.npy"
        )
    )
    skipped_blocks_2024_02_29 = np.load(
        Path(
            f"{root_prefix}/aeon/code/scratchpad/jai/social02/skipped_blocks_2024-02-29_2024-03-02.npy"
        )
    )

In [None]:
"""Clean data."""

# Concatenate loaded data
blocks_df = pd.concat(
    [
        blocks_df_2024_01_31,
        blocks_df_2024_02_14,
        blocks_df_2024_02_16,
        blocks_df_2024_02_25,
        blocks_df_2024_02_29,
    ]
)
skipped_blocks = np.concatenate(
    [
        skipped_blocks_2024_01_31,
        skipped_blocks_2024_02_14,
        skipped_blocks_2024_02_16,
        skipped_blocks_2024_02_25,
        skipped_blocks_2024_02_29,
    ]
)
# Clean indices
blocks_df.reset_index(inplace=True, drop=True)
# Remove skipped blocks (blocks with few pellets)
good_blocks_df = blocks_df[~skipped_blocks].reset_index(drop=True)
# Make paths consistent with my ceph network map
good_blocks_df["root"] = (
    good_blocks_df["root"]
    .apply(lambda path: Path(str(path).replace("S:", "Z:")))
    .apply(lambda path: Path(str(path).replace("Z:\\", root_prefix)))
)
good_blocks_df["sleap_model_dir"] = (
    good_blocks_df["sleap_model_dir"]
    .apply(lambda path: Path(str(path).replace("S:", "Z:")))
    .apply(lambda path: Path(str(path).replace("Z:\\", root_prefix)))
)
# Round timestamps to nearest ms
good_blocks_df["start"] = good_blocks_df["start"].apply(lambda x: x.round("ms"))
good_blocks_df["end"] = good_blocks_df["end"].apply(lambda x: x.round("ms"))
good_blocks_df["block_duration"].apply(lambda x: x.round("ms"))
# Create df we can iterate through (pandas doesn't like format of "cum_wheel_dist" col for some reason)
good_blocks_df_cp = good_blocks_df.drop(columns=["cum_wheel_dist"])

In [None]:
"""See percentage of 'good' blocks (> 3 pellets)."""

pct_blocks_foraging = len(good_blocks_df_cp) / len(blocks_df)
print(f"{len(good_blocks_df_cp)=}\n{len(good_blocks_df)=}\n{pct_blocks_foraging=:.3f}")

display(good_blocks_df_cp)

In [None]:
"""View a particular block."""

b_i = 250
block = good_blocks_df_cp.iloc[b_i]
print(block)
# Example for getting cum_wheel_dist
# good_blocks_df["cum_wheel_dist"].iloc[b_i]  # have to use `iloc`

## Block Plots

In [None]:
"""Standardize subject colors, patch colors, and markers."""

subject_colors = plotly.colors.qualitative.Plotly
subject_colors_dict = {
    "BAA-1104045": subject_colors[0],
    "BAA-1104047": subject_colors[1],
    "BAA-1104048": subject_colors[2],
    "BAA-1104049": subject_colors[3],
}
patch_colors = plotly.colors.qualitative.Pastel2
patch_markers = [
    "circle",
    "bowtie",
    "square",
    "hourglass",
    "diamond",
    "cross",
    "x",
    "triangle",
    "star",
]
patch_markers_symbols = ["●", "⧓", "■", "⧗", "♦", "✖", "×", "▲", "★"]
patch_markers_dict = {
    marker: symbol for marker, symbol in zip(patch_markers, patch_markers_symbols)
}
# patch_markers = {}
# for patch, symbol in zip(cum_pel_ct["patch"].unique(), symbols):
#     patch_markers[patch] = symbol
# patch_markers

### 1. x,y animal location, over time, per subject

In [None]:
"""Get pose data."""

# Change model root path to my network mapped ceph path
social02.CameraTop.Pose._model_root = Path(f"{root_prefix}/aeon/data/processed")
pose_df = aeon.load(block.root, social02.CameraTop.Pose, block.start, block.end)
pose_df.resample("100ms").first()
pose_df.index = pose_df.index.round("100ms")
pose_df = social02.CameraTop.Pose.class_int2str(pose_df, block.sleap_model_dir)
# Simplify to centroid only.
centroid_df = pose_df[pose_df["part"] == "centroid"].drop(
    columns=["part", "part_likelihood"]
)
centroid_df["x"], centroid_df["y"] = (
    centroid_df["x"].astype(np.int32),
    centroid_df["y"].astype(np.int32),
)
# For each time point and class, keep the row with the highest likelihood
centroid_df = centroid_df.iloc[
    centroid_df.reset_index().groupby(["time", "class"])["class_likelihood"].idxmax()
]
# Compute instantaneous speed
centroid_df["speed"] = (
    centroid_df.groupby("class")[["x", "y"]].diff().apply(np.linalg.norm, axis=1)
    / centroid_df.reset_index()
    .groupby("class")["time"]
    .diff()
    .dt.total_seconds()
    .values
)
display(centroid_df)

In [None]:
"""Create function for generating arrays of hex color values from a single initial hex color."""


def gen_hex_grad(hex_col, vals, min_l=0.3):
    """Generates an array of hex color values based on a gradient defined by unit-normalized values."""
    # Convert hex to rgb to hls
    h, l, s = rgb_to_hls(
        *[int(hex_col.lstrip("#")[i : i + 2], 16) / 255 for i in (0, 2, 4)]
    )
    grad = np.empty(shape=(len(vals),), dtype="<U10")  # init grad
    for i, val in enumerate(vals):
        cur_l = (l * val) + (
            min_l * (1 - val)
        )  # get cur lightness relative to `hex_col`
        cur_l = max(min(cur_l, l), min_l)  # set min, max bounds
        cur_rgb_col = hls_to_rgb(h, cur_l, s)  # convert to rgb
        cur_hex_col = "#%02x%02x%02x" % tuple(
            int(c * 255) for c in cur_rgb_col
        )  # convert to hex
        grad[i] = cur_hex_col

    return grad

In [None]:
"""Create position over time scatter plot."""

fig = go.Figure()
for id_i, (id_val, id_grp) in enumerate(centroid_df.groupby("class")):
    norm_time = (
        (id_grp.index - id_grp.index[0]) / (id_grp.index[-1] - id_grp.index[0])
    ).values.round(3)
    colors = gen_hex_grad(subject_colors[id_i], norm_time)
    fig.add_trace(
        go.Scatter(
            x=id_grp["x"],
            y=id_grp["y"],
            mode="markers",
            name=id_val,
            marker={
                # "opacity": norm_time,
                "color": colors
            },
        )
    )
fig.update_layout(
    title="Position Tracking over Time",
    xaxis_title="X Coordinate",
    yaxis_title="Y Coordinate",
)
fig.show()

### 2. x,y animal location heatmap, per subject

In [None]:
def conv2d(arr, kernel):
    """Performs "valid" 2d convolution using numpy `as_strided` and `einsum`"""
    out_shape = tuple(np.subtract(arr.shape, kernel.shape) + 1)
    sub_mat_shape = kernel.shape + out_shape
    # Create "new view" of `arr` as submatrices at which kernel will be applied
    sub_mats = as_strided(arr, shape=sub_mat_shape, strides=(arr.strides * 2))
    out = np.einsum("ij, ijkl -> kl", kernel, sub_mats)
    return out

In [None]:
"""Create position heatmaps per subject."""

max_x, max_y = centroid_df["x"].max(), centroid_df["y"].max()
for id_i, (id_val, id_grp) in enumerate(centroid_df.groupby("class")):
    # <s Add counts of x,y points to a grid that will be used for heatmap
    img_grid = np.zeros((max_x + 1, max_y + 1))
    points, counts = np.unique(id_grp[["x", "y"]].values, return_counts=True, axis=0)
    for point, count in zip(points, counts):
        img_grid[point[0], point[1]] = count
    img_grid /= img_grid.max()  # normalize
    # /s>
    # <s Smooth `img_grid`
    # Mice can go ~450 cm/s, we've downsampled to 10 frames/s, we have 200 px / 1000 cm,
    # so 45 cm/frame ~= 9 px/frame
    win_sz = 9  # in pixels  (ensure odd for centering)
    kernel = np.ones((win_sz, win_sz)) / win_sz**2  # moving avg kernel
    img_grid_p = np.pad(
        img_grid, win_sz // 2, mode="edge"
    )  # pad for full output from convolution
    img_grid_smooth = conv2d(img_grid_p, kernel)
    # /s>
    fig = px.imshow(
        img_grid_smooth.T,
        zmin=0,
        zmax=(img_grid_smooth.max() / 1000),
        x=np.arange(img_grid.shape[0]),
        y=np.arange(img_grid.shape[1]),
        labels=dict(x="X", y="Y", color="Norm Freq / 1e3"),
        aspect="auto",
    )
    fig.update_layout(title=f"Position Heatmap ({id_val})")
    fig.show()

### 3. Patch mean, next to boxplots of each pellet threshold per patch.

In [None]:
n_patches = len(block.patch_info)
pellet_info = block.pellet_info.sort_values("time").iloc[
    :-n_patches
]  # drop updates for each patch at new block
display(pellet_info)
display(block.patch_info)

In [None]:
mean_pellet_info = pd.DataFrame(index=np.arange(n_patches), columns=pellet_info.columns)
mean_pellet_info["time"] = block.start
mean_pellet_info["id"] = "mean"
for i, patch in enumerate(block.patch_info.index):
    mean_pellet_info.loc[i, "patch"] = patch
    mean_pellet_info.loc[i, "threshold"] = (
        block.patch_info.loc[patch, "mean"] + block.patch_info.loc[patch, "offset"]
    )
    mean_pellet_info.loc[i, "patch"] = patch
pellet_info_plus = pd.concat((mean_pellet_info, pellet_info)).reset_index(drop=True)
pellet_info_plus["norm_time"] = (
    (pellet_info_plus["time"] - pellet_info_plus["time"].iloc[0])
    / (pellet_info_plus["time"].iloc[-1] - pellet_info_plus["time"].iloc[0])
).round(3)

In [None]:
pellet_info_plus

In [None]:
box_colors = ["#0A0A0A"] + subject_colors[
    0 : len(block.subjects)
]  # subject colors + mean color

fig = px.box(
    pellet_info_plus.sort_values("patch"),
    x="patch",
    y="threshold",
    color="id",
    hover_data=["norm_time"],
    color_discrete_sequence=box_colors,
    # notched=True,
    points="all",
)

fig.update_layout(
    title="Patch Stats", xaxis_title="Patch", yaxis_title="Threshold (cm)"
)

fig.show()

### 4a. Cumulative pellet count over time, per patch, per subject

In [None]:
def cumsum_helper(group):
    group["counter"] = np.arange(len(group)) + 1
    return group


cum_pel_ct = (
    pellet_info.groupby("id", group_keys=False)
    .apply(cumsum_helper)
    .reset_index(drop=True)
)
cum_pel_ct = cum_pel_ct.merge(
    mean_pellet_info[["patch", "threshold"]].rename(
        columns={"threshold": "mean_thresh"}
    ),
    on="patch",
    how="left",
)
cum_pel_ct["patch_label"] = cum_pel_ct.apply(
    lambda row: f"{row['patch']} μ: {row['mean_thresh']}", axis=1
)
display(cum_pel_ct)

In [None]:
cum_pel_ct["norm_thresh_val"] = (
    (cum_pel_ct["threshold"] - cum_pel_ct["threshold"].min())
    / (cum_pel_ct["threshold"].max() - cum_pel_ct["threshold"].min())
).round(3)

fig = go.Figure()
for id_val, id_grp in cum_pel_ct.groupby("id"):
    # Add lines by subject
    fig.add_trace(
        go.Scatter(
            x=id_grp["time"],
            y=id_grp["counter"],
            mode="lines",
            line=dict(width=2),
            name=id_val,
        )
    )
for patch_i, (patch_val, patch_grp) in enumerate(cum_pel_ct.groupby("patch_label")):
    # Add markers by patch
    fig.add_trace(
        go.Scatter(
            x=patch_grp["time"],
            y=patch_grp["counter"],
            mode="markers",
            marker={
                "symbol": patch_markers[patch_i],
                "color": gen_hex_grad(subject_colors[-1], patch_grp["norm_thresh_val"]),
                "size": 8,
            },
            name=patch_val,
            customdata=np.stack((patch_grp["threshold"],), axis=-1),
            hovertemplate="Threshold: %{customdata[0]:.2f} cm",
        )
    )

fig.update_layout(
    title="Cumulative Pellet Count", xaxis_title="Time", yaxis_title="Count"
)
fig.show()

### 4b. Pellet delivery over time, per patch, per subject

In [None]:
fig = go.Figure()
for id_i, (id_val, id_grp) in enumerate(cum_pel_ct.groupby("id")):
    # Add lines by subject
    fig.add_trace(
        go.Scatter(
            x=id_grp["time"],
            y=id_grp["patch_label"],
            mode="lines+markers",
            line=dict(width=2),
            marker={
                "symbol": patch_markers[id_i],
                "color": gen_hex_grad(subject_colors[id_i], id_grp["norm_thresh_val"]),
                "size": 8,
            },
            name=id_val,
            customdata=np.stack((id_grp["threshold"],), axis=-1),
            hovertemplate="Threshold: %{customdata[0]:.2f} cm",
        )
    )


fig.update_layout(
    title="Pellet Delivery Over Time",
    xaxis_title="Time",
    yaxis_title="Patch",
    yaxis={
        "categoryorder": "array",
        "categoryarray": cum_pel_ct.sort_values("mean_thresh")[
            "patch_label"
        ].unique(),  # sort y-axis by patch threshold mean
    },
)
fig.show()

### 5. Pellet threshold vals over time, per patch, per subject

In [None]:
fig = go.Figure()
for id_val, id_grp in cum_pel_ct.groupby("id"):
    # Add lines by subject
    fig.add_trace(
        go.Scatter(
            x=id_grp["time"],
            y=id_grp["threshold"],
            mode="lines",
            name=id_val,
        )
    )
for patch_i, (patch_val, patch_grp) in enumerate(cum_pel_ct.groupby("patch_label")):
    # Add markers by patch
    fig.add_trace(
        go.Scatter(
            x=patch_grp["time"],
            y=patch_grp["threshold"],
            mode="markers",
            marker={
                "symbol": patch_markers[patch_i],
                "color": "black",
            },
            name=patch_val,
        )
    )

fig.update_layout(
    title="Pellet Thresholds", xaxis_title="Time", yaxis_title="Threshold (cm)"
)
fig.show()

### 6. Cumulative wheel distance over time, per patch, per subject

In [None]:
cum_wheel_dist = good_blocks_df["cum_wheel_dist"].iloc[b_i]
patches = block.patch_info.index

fig = go.Figure()
for subj_i, subj in enumerate(block.subjects):
    for patch_i, p in enumerate(patches):
        cur_cum_wheel_dist = cum_wheel_dist[p][subj]
        fig.add_trace(
            go.Scatter(
                x=cur_cum_wheel_dist.index,
                y=cur_cum_wheel_dist,
                mode="lines",  # +  markers",
                line={"width": 2, "color": subject_colors[subj_i]},
                name=f"{p} {subj}",
                legendgroup=subj,
                showlegend=False,
            )
        )
        # Add markers for each pellet
        cur_cum_pel_ct = pd.merge_asof(
            cum_pel_ct[(cum_pel_ct["id"] == subj) & (cum_pel_ct["patch"] == p)],
            cur_cum_wheel_dist.reset_index(name="cum_wheel_dist"),
            on="time",
            direction="forward",
            tolerance=pd.Timedelta("0.1s"),
        )
        if not cur_cum_pel_ct.empty:
            fig.add_trace(
                go.Scatter(
                    x=cur_cum_pel_ct["time"],
                    y=cur_cum_pel_ct["cum_wheel_dist"],
                    mode="markers",
                    marker={
                        "symbol": patch_markers[patch_i],
                        "color": subject_colors[subj_i],
                        "size": 7,
                    },
                    name=cur_cum_pel_ct["patch_label"].iloc[0],
                    legendgroup=subj,
                    legendgrouptitle_text=subj,
                    customdata=np.stack((cur_cum_pel_ct["threshold"],), axis=-1),
                    hovertemplate="Threshold: %{customdata[0]:.2f} cm",
                )
            )
fig.update_layout(
    title="Cumulative Wheel Distance", xaxis_title="Time", yaxis_title="Distance (cm)"
)
fig.show()

## Overall Plots (per experiment and room)

### 1. Threshold vals per block per subject

In [None]:
# Consider only blocks with more than 15 pellets

pel_thresh = 15
foraging_blocks = good_blocks_df_cp[
    good_blocks_df_cp["pellet_info"].apply(len) > pel_thresh
].reset_index(drop=True)

In [None]:
# Create column "block_type" with values ["presocial", "social", "postsocial"]

mask_social = foraging_blocks["subjects"].apply(len) > 1
subj_str = foraging_blocks["subjects"].apply(str)
presocial_condition = foraging_blocks[mask_social].groupby(subj_str)[
    "start"].min()
postsocial_condition = foraging_blocks[mask_social].groupby(subj_str)[
    "end"].max()


def assign_block_type(row):
    if len(row["subjects"]) > 1:
        return "social"
    else:
        subj = row["subjects"][0]
        idx = next(
            iter([idx for idx in presocial_condition.index if subj in idx]), None
        )
        if idx is not None:
            if row["start"] < presocial_condition[idx]:
                return "presocial"
            elif row["start"] > postsocial_condition[idx]:
                return "postsocial"


foraging_blocks["block_type"] = foraging_blocks.apply(
    assign_block_type, axis=1)

In [None]:
# Threshold vals obtained from halfway thru block (lower better, obvs)

fig = go.Figure()
legend_added = set()  # legend only for unique subjs
learning_threshold_vals = pd.DataFrame(columns=["threshold_vals", "id"])
block_type_markers = {
    "presocial": patch_markers[0],
    "social": patch_markers[1],
    "postsocial": patch_markers[2],
}
for block_i, block in enumerate(foraging_blocks.itertuples()):
    n_pellets = len(block.pellet_info)
    pellet_info = block.pellet_info.sort_values("time")
    pellet_info_post = pellet_info.iloc[
        (n_pellets // 2):
    ]  # halfway point (in terms of pellets obtained)
    # Get threshold vals, per subject.
    for subj_val, subj_grp in pellet_info_post.groupby("id"):
        # Add threshold_vals, id as row to `learning_threshold_vals`.
        new_row = pd.DataFrame(
            columns=["threshold_vals", "id"], index=[block.Index])
        new_row["threshold_vals"], new_row["id"] = (
            [subj_grp["threshold"].values],
            subj_val,
        )
        learning_threshold_vals = pd.concat((learning_threshold_vals, new_row))
        # subj_val
        grouping_factor = f"{patch_markers_dict.get(block_type_markers[block.block_type])} {subj_val} {block.block_type}"
        show_in_legend = grouping_factor not in legend_added
        # If not, add it and set the flag to False for subsequent rows with the same ID
        if show_in_legend:
            legend_added.add(grouping_factor)
        fig.add_trace(
            go.Box(
                y=new_row.loc[block.Index, "threshold_vals"],
                x=[block.Index] *
                len(new_row.loc[block.Index, "threshold_vals"]),
                name=grouping_factor,
                legendgroup=grouping_factor,
                boxpoints="all",
                # jitter=0.5,
                marker={
                    "color": subject_colors_dict[subj_val],
                    "symbol": block_type_markers[block.block_type],
                },
                showlegend=show_in_legend,  # legend only for unique subjs
            )
        )

fig.update_layout(
    title="Pellet Thresholds",
    xaxis_title="Block Index",
    yaxis_title="Threshold (cm)",
)
fig.update_layout(legend_tracegroupgap=0)
fig.show()

### 2. Patch preference per block per subject

In [None]:
# Patch preference by wheel dist spun from halfway through block

pp_overall = pd.DataFrame(
    columns=["one_best_patch", "two_best_patches", "id", "block_type"]
)  # e.g. easy vs. (medium + hard); (easy + medium) vs. hard

for block_i, block in enumerate(foraging_blocks.itertuples()):
    easy_p, med_p, hard_p = block.patch_info.sort_values("mean").index
    n_pellets = len(block.pellet_info)
    pellet_info = block.pellet_info.sort_values("time")
    halfway_ts = pellet_info["time"].iloc[n_pellets // 2]
    cum_wheel_dist = good_blocks_df["cum_wheel_dist"].loc[block.Index]
    for subj in block.subjects:
        easy_dist = cum_wheel_dist[easy_p][subj]
        easy_dist = (
            easy_dist[easy_dist.index > halfway_ts].iloc[-1]
            - easy_dist[easy_dist.index > halfway_ts].iloc[0]
        )
        easy_dist = 1 if easy_dist < 0 else easy_dist
        med_dist = cum_wheel_dist[med_p][subj]
        med_dist = (
            med_dist[med_dist.index > halfway_ts].iloc[-1]
            - med_dist[med_dist.index > halfway_ts].iloc[0]
        )
        med_dist = 1 if med_dist < 0 else med_dist
        hard_dist = cum_wheel_dist[hard_p][subj]
        hard_dist = (
            hard_dist[hard_dist.index > halfway_ts].iloc[-1]
            - hard_dist[hard_dist.index > halfway_ts].iloc[0]
        )
        hard_dist = 1 if hard_dist < 0 else hard_dist
        new_row = pd.DataFrame(
            columns=["one_best_patch", "two_best_patches"], index=[block.Index]
        )
        new_row["one_best_patch"] = np.round(
            (easy_dist / (easy_dist + med_dist + hard_dist)), 3
        )
        new_row["two_best_patches"] = np.round(
            ((easy_dist + med_dist) / (easy_dist + med_dist + hard_dist)), 3
        )
        new_row["id"] = subj
        new_row["block_type"] = block.block_type
        pp_overall = pd.concat((pp_overall, new_row))

In [None]:
fig = go.Figure()
for block_type in block_type_markers.keys():
    for id, group in pp_overall[pp_overall["block_type"] == block_type].groupby("id"):
        fig.add_trace(
            go.Scatter(
                x=group.index,
                y=group["one_best_patch"],
                mode="markers",
                name=block_type,
                legendgroup=id,
                legendgrouptitle_text=id,
                marker={"color": subject_colors_dict[id]},
                marker_symbol=block_type_markers[block_type],
            )
        )
fig.update_layout(
    title="Patch Preference: Easy Only",
    xaxis_title="Block Index",
    yaxis_title="Preference",
    legend_tracegroupgap=0,
)
fig.update_traces(
    legendgrouptitle_font_size=12,
)
fig.show()

In [None]:
fig = px.scatter(
    pp_overall,
    x=pp_overall.index,
    y=["one_best_patch"],
    color="id",
    symbol="block_type",
    title="Patch Preference: Easy Only",
    labels={"index": "Block Index", "value": "Preference"},
)
fig.update_layout(legend_title_text="")
fig.for_each_trace(lambda trace: setattr(trace, "name", trace.name.replace(",", "")))
fig.show()

In [None]:
fig = px.scatter(
    pp_overall,
    x=pp_overall.index,
    y=["two_best_patches"],
    color="id",
    symbol="block_type",
    title="Patch Preference: Easy + Medium",
    labels={"index": "Block Index", "value": "Preference"},
)
fig.update_layout(legend_title_text="")
fig.for_each_trace(lambda trace: setattr(trace, "name", trace.name.replace(",", "")))
fig.show()

### 2b. Patch Preference by Block Type (presocial, social, postsocial)

In [None]:
fig = px.box(
    pp_overall,
    x="block_type",
    y="one_best_patch",
    color="id",
    points="all",
    title="Patch Preference: Easy Only by Block Type",
    labels={"block_type": "Block Type", "one_best_patch": "Preference"},
)
fig.update_layout(legend_title_text="ID")
fig.show()

In [None]:
fig = px.box(
    pp_overall,
    x="block_type",
    y="two_best_patches",
    color="id",
    points="all",
    title="Patch Preference: Easy + Medium by Block Type",
    labels={"block_type": "Block Type", "two_best_patches": "Preference"},
)
fig.update_layout(legend_title_text="ID")
fig.show()

### 2c. Patch Preference Difference between First- and Second-Half of Block

### 3a. Cumulative Wheel Distance by Patch Difficulty

### 3b. Cumulative Wheel Distance by Patch ID

### 4. Social Distancing =p

In [None]:
# Get position for each individual animal for first day of solo
# Sample 100k points from each of these
# Create "synthetic" distribution of euclidean distances away from each other

baa45_solo_pos = aeon.load(
    roots[0],
    social02.CameraTop.Pose,
    pd.Timestamp("2024-02-01 00:00:00"),
    pd.Timestamp("2024-02-02 00:00:00"),
)
baa47_solo_pos = aeon.load(
    roots[0],
    social02.CameraTop.Pose,
    pd.Timestamp("2024-02-06 00:00:00"),
    pd.Timestamp("2024-02-07 00:00:00"),
)

baa48_solo_pos = aeon.load(
    roots[1],
    social02.CameraTop.Pose,
    pd.Timestamp("2024-02-01 00:00:00"),
    pd.Timestamp("2024-02-02 00:00:00"),
)
baa49_solo_pos = aeon.load(
    roots[1],
    social02.CameraTop.Pose,
    pd.Timestamp("2024-02-06 00:00:00"),
    pd.Timestamp("2024-02-07 00:00:00"),
)

baa45_47_joint_pos = aeon.load(
    roots[0],
    social02.CameraTop.Pose,
    pd.Timestamp("2024-02-15 00:00:00"),
    pd.Timestamp("2024-02-16 00:00:00"),
)
baa48_49_joint_pos = aeon.load(
    roots[1],
    social02.CameraTop.Pose,
    pd.Timestamp("2024-02-15 00:00:00"),
    pd.Timestamp("2024-02-16 00:00:00"),
)

baa45_from_joint = baa45_47_joint_pos[baa45_47_joint_pos["class"] == 0.0].sample(100000)
baa47_from_joint = baa45_47_joint_pos[baa45_47_joint_pos["class"] == 1.0].sample(100000)
distances2_45_47 = np.sqrt(
    ((baa45_from_joint["x"].values - baa47_from_joint["x"].values) ** 2)
    + ((baa45_from_joint["y"].values - baa47_from_joint["y"].values) ** 2)
)

baa48_from_joint = baa48_49_joint_pos[baa48_49_joint_pos["class"] == 0.0].sample(100000)
baa49_from_joint = baa48_49_joint_pos[baa48_49_joint_pos["class"] == 1.0].sample(100000)
distances2_48_49 = np.sqrt(
    ((baa48_from_joint["x"].values - baa49_from_joint["x"].values) ** 2)
    + ((baa48_from_joint["y"].values - baa49_from_joint["y"].values) ** 2)
)

sampled_df1 = baa48_solo_pos.sample(n=100000, random_state=1)
sampled_df2 = baa49_solo_pos.sample(n=100000, random_state=1)
distances0_48_49 = np.sqrt(
    ((sampled_df1["x"].values - sampled_df2["x"].values) ** 2)
    + ((sampled_df1["y"].values - sampled_df2["y"].values) ** 2)
)

sampled_df1 = baa45_solo_pos.sample(n=100000, random_state=1)
sampled_df2 = baa47_solo_pos.sample(n=100000, random_state=1)
distances0_45_47 = np.sqrt(
    ((sampled_df1["x"].values - sampled_df2["x"].values) ** 2)
    + ((sampled_df1["y"].values - sampled_df2["y"].values) ** 2)
)

In [None]:
sns.set_theme(style="whitegrid")
fig, ax = plt.subplots(figsize=(10, 6))
sns.histplot(distances0_45_47, kde=True, bins=1000, ax=ax)
sns.histplot(distances2_45_47, kde=True, bins=1000, ax=ax)
ax.set_title("Distances between subjects 45-47")
ax.set_xlabel("Distance (px)")
ax.legend(["Synthetic", "True"])

In [None]:
sns.set_theme(style="whitegrid")
fig, ax = plt.subplots(figsize=(10, 6))
sns.histplot(distances0_48_49, kde=True, bins=1000, ax=ax)
sns.histplot(distances2_48_49, kde=True, bins=1000, ax=ax)
ax.set_title("Distances between subjects 48-49")
ax.set_xlabel("Distance (px)")
ax.legend(["Synthetic", "True"])