(target-dj-compute-behav-bouts)=
# DataJoint pipeline: Computing behavioural bouts

:::{important}
This guide assumes you have a [DataJoint pipeline deployed](target-dj-pipeline-deployment) with [data already ingested](target-dj-data-ingestion-processing).
:::

Using position data from the [Aeon DataJoint pipeline](target-aeon-dj-pipeline), this guide walks through computing foraging, drinking, and sleeping bouts for each subject.

You can also run this notebook online at [`works.datajoint.com`](https://works.datajoint.com/) using the following credentials:
 - Username: aeondemo
 - Password: aeon_djworks 

To access it, go to the Notebook tab at the top and in the File Browser on the left, navigate to `ucl-swc_aeon > docs > examples`, where this notebook `dj_compute_bouts.ipynb` is located.

:::{note}
The examples here use the [social0.2-aeon4](target-full-datasets) dataset. 
If you are using a different dataset, be sure to replace the experiment name and  parameters in the code below accordingly.
:::

## Import libraries and define variables and helper functions

In [1]:
import datetime
import sys
import os
import warnings
from pathlib import Path
from typing import Any, Tuple, List, Dict

import numpy as np
import pandas as pd
import plotly
import plotly.express as px
import plotly.graph_objs as go
import statsmodels.api as sm
from plotly.subplots import make_subplots
from tqdm.notebook import tqdm

import datajoint as dj
from aeon.dj_pipeline.analysis.block_analysis import *
from aeon.dj_pipeline import acquisition, streams, subject
from swc.aeon.io import api as aeon_api
from aeon.schema.schemas import social02

[2025-08-28 15:29:59,542][INFO]: DataJoint 0.14.6 connected to chuan@aeon-db2:3306


In [None]:
def correct_swaps(pos_df: pd.DataFrame, max_speed: float) -> pd.DataFrame:
    """Detect and correct swaps in the position data.

    Args:
        pos_df (pd.DataFrame): DataFrame containing position data of a single subject.
        max_speed (float): Maximum speed (px/s) threshold over which we assume a swap.

    Returns:
        pd.DataFrame: DataFrame with swaps corrected.
    """
    dt = pos_df.index.diff().total_seconds()
    dx = pos_df["x"].diff()
    dy = pos_df["y"].diff()
    pos_df["inst_speed"] = np.sqrt(dx**2 + dy**2) / dt

    # Identify jumps
    jumps = (pos_df["inst_speed"] > max_speed)
    shift_down = jumps.shift(1)
    shift_down.iloc[0] = False
    shift_up = jumps.shift(-1)
    shift_up.iloc[len(jumps) - 1] = False
    jump_starts = jumps & ~shift_down
    jump_ends = jumps & ~shift_up
    jump_start_indices = np.where(jump_starts)[0]
    jump_end_indices = np.where(jump_ends)[0]

    if np.any(jumps):
        # Ensure the lengths match
        if len(jump_start_indices) > len(jump_end_indices):  # jump-in-progress at start
            jump_end_indices = np.append(jump_end_indices, len(pos_df) - 1)
        elif len(jump_start_indices) < len(jump_end_indices):  # jump-in-progress at end
            jump_start_indices = np.insert(jump_start_indices, 0, 0)
        # Remove jumps by setting speed to nan in jump regions and dropping nans
        for start, end in zip(jump_start_indices, jump_end_indices, strict=True):
            pos_df.loc[pos_df.index[start]:pos_df.index[end], "inst_speed"] = np.nan
        pos_df.dropna(subset=["inst_speed"], inplace=True)

    return pos_df


def ensure_ts_arr_datetime(array):
    """Ensure array is a numpy array of datetime64[ns] type."""
    if len(array) == 0:
        return np.array([], dtype="datetime64[ns]")
    else:
        return np.array(array, dtype="datetime64[ns]")


def concat_and_reorder(dfs):
    """Concatenate dataframes and reorder columns."""
    df = pd.concat(dfs, ignore_index=True)
    first_cols = ["experiment_name", "period"]
    return df[first_cols + [col for col in df.columns if col not in first_cols]]


In [None]:
cm2px = 5.2  # 1 cm = 5.2 px roughly for top camera
exp = {
    "name": "social0.2-aeon4",
    "presocial_start": "2024-01-31 11:00:00",
    "presocial_end": "2024-02-08 15:00:00",
    "social_start": "2024-02-09 17:00:00",
    "social_end": "2024-02-23 12:00:00",
    "postsocial_start": "2024-02-25 18:00:00",
    "postsocial_end": "2024-03-02 13:00:00",
}
key = {"experiment_name": exp["name"]}

# Define periods
periods = {
    "presocial": (exp["presocial_start"], exp["presocial_end"]),
    "social": (exp["social_start"], exp["social_end"]),
    "postsocial": (exp["postsocial_start"], exp["postsocial_end"]),
}

## Fetch position data

In [None]:
def load_position_data(
    key: Dict[str, str], period_start: str, period_end: str
) -> pd.DataFrame:
    """Loads position data (centroid tracking) for a specified time period.

    Args:
        key (dict): Key to identify experiment data (e.g., {"experiment_name": "Exp1"}).
        period_start (str): Start datetime of the time period.
        period_end (str): End datetime of the time period.

    Returns:
        pd.DataFrame: DataFrame containing position data for the specified period.
                     Returns an empty DataFrame if no data found.
    """
    try:
        print(f"  Querying data from {period_start} to {period_end}...")

        # Create chunk restriction for the time period
        chunk_restriction = acquisition.create_chunk_restriction(
            key["experiment_name"], period_start, period_end
        )

        # Fetch centroid tracking data for the specified period
        centroid_df = (
            streams.SpinnakerVideoSource * tracking.DenoisedTracking.Subject
            & key
            & {"spinnaker_video_source_name": "CameraTop"}
            & chunk_restriction
        ).fetch(format="frame")

        centroid_df = centroid_df.reset_index()
        centroid_df = centroid_df.rename(
            columns={
                "subject_name": "identity_name",
                "timestamps": "time",
                "subject_likelihood": "identity_likelihood",
            }
        )
        centroid_df = centroid_df.explode(
            ["time", "identity_likelihood", "x", "y", "likelihood"]
        )
        centroid_df = centroid_df[
            [
                "time",
                "experiment_name",
                "identity_name",
                "identity_likelihood",
                "x",
                "y",
                "likelihood",
            ]
        ].set_index("time")

        # Clean up the dataframe
        if isinstance(centroid_df, pd.DataFrame) and not centroid_df.empty:
            if "spinnaker_video_source_name" in centroid_df.columns:
                centroid_df.drop(columns=["spinnaker_video_source_name"], inplace=True)
            print(f"  Retrieved {len(centroid_df)} rows of position data")
        else:
            print("  No data found for the specified period")

        return centroid_df

    except Exception as e:
        print(
            f"  Error loading position data for {key['experiment_name']} ({period_start} "
            f"to {period_end}): {e}"
        )
        return pd.DataFrame()


position_data_list = []
for period_name, (period_start, period_end) in periods.items():
    # Load position data for this period
    df = load_position_data(key, period_start, period_end)
    df.reset_index(inplace=True)
    df = df.assign(
        experiment_name=exp["name"], 
        period=period_name
    )
    position_data_list.append(df)

position_df = concat_and_reorder(position_data_list)

## Sleep bouts

In [None]:
def sleep_bouts(
    pos_df: pd.DataFrame,
    subject: str,
    move_thresh: float = 4 * cm2px,  # cm -> px
    max_speed: float = 100 * cm2px,  # cm/s -> px/s
) -> pd.DataFrame:
    """Returns sleep bouts for a given animal within the specified position data time period.

    Args:
        pos_df (pd.DataFrame): DataFrame containing position data.
        subject (str): Name of the animal to filter by.
        move_thresh (float): Movement (in px) threshold to define sleep bouts.
        max_speed (float): Maximum speed threshold for excising swaps.

    Returns:
        pd.DataFrame: DataFrame containing sleep bouts for the specified animal.
    """
    animal_data = pos_df[pos_df["identity_name"] == subject].copy()
    if animal_data.empty or not isinstance(animal_data, pd.DataFrame):
        print(f"No position data found for {subject}")
        return pd.DataFrame()

    # Set some constants and placeholder `windows_df` which will be combined into `bouts_df`
    sleep_win = pd.Timedelta("1m")
    sleep_windows_df = pd.DataFrame(
        columns=["subject", "start", "end", "duration", "period"]
    )

    # Create time windows based on start and end time
    data_start_time = animal_data.index.min()
    data_end_time = animal_data.index.max()
    window_starts = pd.date_range(
        start=data_start_time, end=data_end_time, freq=sleep_win
    )

    # <s> Process each time window
    period = animal_data["period"].iloc[0]
    pbar = tqdm(window_starts, desc=f"Processing sleep bouts for {subject} in {period}")
    for win_start in pbar:
        win_end = win_start + sleep_win
        win_data = animal_data[
            (animal_data.index >= win_start) & (animal_data.index < win_end)
        ].copy()
        if len(win_data) < 100:  # skip windows with too little data
            continue

        # Excise id swaps (based on pos / speed jumps)
        # win_data = correct_swaps(win_data, max_speed)

        # Calculate the displacement - maximum distance between any two points in the window
        dx = win_data["x"].max() - win_data["x"].min()
        dy = win_data["y"].max() - win_data["y"].min()
        displacement = np.sqrt(dx**2 + dy**2)

        # If displacement is less than threshold, consider it a sleep bout
        if displacement < move_thresh:
            new_bout = {
                "subject": subject,
                "start": win_start,
                "end": win_end,
                "duration": sleep_win,
                "period": win_data["period"].iloc[0],
            }
            sleep_windows_df = pd.concat(
                [sleep_windows_df, pd.DataFrame([new_bout])], ignore_index=True
            )
    # </s>

    # <s> Now merge consecutive sleep windows into continuous bouts
    if sleep_windows_df.empty or not isinstance(sleep_windows_df, pd.DataFrame):
        return pd.DataFrame(columns=["subject", "start", "end", "duration", "period"])
    # Initialize the merged bouts dataframe with the first window
    sleep_bouts_df = pd.DataFrame(
        [
            {
                "subject": subject,
                "start": sleep_windows_df.iloc[0]["start"],
                "end": sleep_windows_df.iloc[0]["end"],
                "duration": sleep_windows_df.iloc[0]["duration"],
                "period": sleep_windows_df.iloc[0]["period"],
            }
        ]
    )
    # Iterate through remaining windows and merge consecutive ones
    for i in range(1, len(sleep_windows_df)):
        current_window = sleep_windows_df.iloc[i]
        last_bout = sleep_bouts_df.iloc[-1]

        if current_window["start"] == last_bout["end"]:  # continue bout
            sleep_bouts_df.at[len(sleep_bouts_df) - 1, "end"] = current_window["end"]
            sleep_bouts_df.at[len(sleep_bouts_df) - 1, "duration"] = (
                sleep_bouts_df.iloc[-1]["end"] - sleep_bouts_df.iloc[-1]["start"]
            )
        else:  # start a new bout
            new_bout = {
                "subject": subject,
                "start": current_window["start"],
                "end": current_window["end"],
                "duration": current_window["duration"],
                "period": current_window["period"],
            }
            sleep_bouts_df = pd.concat(
                [sleep_bouts_df, pd.DataFrame([new_bout])], ignore_index=True
            )
    # </s>

    # Set min bout time
    min_bout_time = pd.Timedelta("2m")
    sleep_bouts_df = sleep_bouts_df[sleep_bouts_df["duration"] >= min_bout_time]

    return sleep_bouts_df

In [None]:
"""Save sleep bouts to parquet files for all experiments and periods."""

# For each experiment, for each period, load pos data, get sleep bouts, save to parquet

pbar_exp = tqdm(experiments, desc="Processing experiments")
for exp in pbar_exp:
    sleep_bouts_data_dict = {}
    key = {"experiment_name": exp["name"]}
    sleep_bouts_data_dict[exp["name"]] = {}
    periods = {
        "presocial": (exp["presocial_start"], exp["presocial_end"]),
        "social": (exp["social_start"], exp["social_end"]),
        "postsocial": (exp["postsocial_start"], exp["postsocial_end"]),
    }
    pbar_period = tqdm(periods.items(), desc="Processing periods", leave=False)
    for period_name, (period_start, period_end) in pbar_period:
        print(f"  Loading {period_name} period...")
        period_start = datetime.strptime(period_start, "%Y-%m-%d %H:%M:%S")
        period_end = datetime.strptime(period_end, "%Y-%m-%d %H:%M:%S")

        # load pos data for this period
        pos_df = load_data_from_parquet(
            experiment_name=exp["name"],
            period=period_name,
            data_type="position",
            data_dir=data_dir,
            set_time_index=True,
        )

        # get sleep bouts for each subject
        subjects = pos_df["identity_name"].unique()
        sleep_bouts_df = pd.DataFrame(
            columns=["subject", "start", "end", "duration", "period"]
        )
        for subject in subjects:
            subject_bouts = sleep_bouts(pos_df, subject)
            if isinstance(subject_bouts, pd.DataFrame) and not subject_bouts.empty:
                sleep_bouts_df = pd.concat(
                    [sleep_bouts_df, subject_bouts], ignore_index=True
                )

        # save data dict
        sleep_bouts_data_dict[exp["name"]][period_name] = sleep_bouts_df
        save_all_experiment_data(
            experiments=[exp],
            periods=[period_name],
            data_dict=sleep_bouts_data_dict,
            data_type="sleep",
            data_dir=data_dir,
        )
        print(f"  Saved sleep bouts for {exp['name']} during {period_name} period.")

## Drink bouts

In [None]:
def drink_bouts(
    pos_df: pd.DataFrame,
    subject: str,
    spout_loc: tuple[float, float],  # x,y spout location in px
    start_radius: float = 4 * 5.2,  # must be within X cm of spout, in px
    move_thresh: float = 2.5 * 5.2,  # during bout must move less than X cm, in px
    min_dur: float = 6,  # min duration of bout in seconds
    max_dur: float = 90,  # max duration of bout in seconds
) -> pd.DataFrame:  # cols: subject, start, end, duration, period
    """Returns drink bouts for a given animal within the specified position data time period."""

    animal_data = pos_df[pos_df["identity_name"] == subject].copy()
    if animal_data.empty or not isinstance(animal_data, pd.DataFrame):
        print(f"No position data found for {subject}")
        return pd.DataFrame(columns=["subject", "start", "end", "duration", "period"])

    # Smooth position data to 100ms intervals - only numeric columns
    numeric_cols = animal_data.select_dtypes(include=[np.number]).columns
    animal_data = animal_data[numeric_cols].resample("100ms").mean().interpolate()
    animal_data = animal_data.dropna()

    # Add non-numeric columns back
    animal_data["identity_name"] = subject
    animal_data["experiment_name"] = pos_df["experiment_name"].iloc[0]
    animal_data["period"] = pos_df["period"].iloc[0]

    # Calculate distance from spout
    spout_x, spout_y = spout_loc
    animal_data["dist_to_spout"] = np.sqrt(
        (animal_data["x"] - spout_x) ** 2 + (animal_data["y"] - spout_y) ** 2
    )

    # Find potential bout starts (within start_radius of spout)
    near_spout = animal_data["dist_to_spout"] <= start_radius

    # Get period info
    period = animal_data["period"].iloc[0]

    drink_bouts_df = pd.DataFrame(
        columns=["subject", "start", "end", "duration", "period"]
    )

    pbar = tqdm(
        total=len(animal_data), desc=f"Processing drink bouts for {subject} in {period}"
    )
    i = 0
    while i < len(animal_data):
        pbar.update(i - (i - 1))
        # Skip if not near spout
        if not near_spout.iloc[i]:
            i += 1
            continue

        # Found potential bout start
        bout_start_time = animal_data.index[i]
        bout_start_idx = i

        # Track movement during potential bout
        start_x = animal_data["x"].iloc[i]
        start_y = animal_data["y"].iloc[i]

        j = i
        max_displacement = 0

        # Continue while near spout and not moving too much
        while j < len(animal_data):
            current_time = animal_data.index[j]
            elapsed_time = (current_time - bout_start_time).total_seconds()

            # Calculate displacement from bout start position
            current_x = animal_data["x"].iloc[j]
            current_y = animal_data["y"].iloc[j]
            displacement = np.sqrt(
                (current_x - start_x) ** 2 + (current_y - start_y) ** 2
            )
            max_displacement = max(max_displacement, displacement)

            # Check if bout should end
            if max_displacement > move_thresh:
                break

            if elapsed_time > max_dur:
                break

            j += 1

        # Determine bout end
        bout_end_time = (
            animal_data.index[j - 1] if j > bout_start_idx else bout_start_time
        )
        bout_duration = (bout_end_time - bout_start_time).total_seconds()

        # Check if bout meets duration criteria
        if min_dur < bout_duration < max_dur:
            new_bout = {
                "subject": subject,
                "start": bout_start_time,
                "end": bout_end_time,
                "duration": pd.Timedelta(seconds=bout_duration),
                "period": period,
            }
            drink_bouts_df = pd.concat(
                [drink_bouts_df, pd.DataFrame([new_bout])], ignore_index=True
            )

        # Move to next potential bout (skip past current bout end)
        i = max(j, i + 1)

    pbar.close()
    return drink_bouts_df

In [None]:
"""Save drink bouts to parquet files for all experiments and periods."""

# For each experiment, for each period, load pos data, get drink bouts, save to parquet

pbar_exp = tqdm(experiments, desc="Processing experiments")
for exp in pbar_exp:
    drink_bouts_data_dict = {}
    key = {"experiment_name": exp["name"]}
    drink_bouts_data_dict[exp["name"]] = {}
    pbar_period = tqdm(periods, desc="Processing periods", leave=False)
    for period_name in pbar_period:
        print(f"  Loading {period_name} period...")

        # load pos data for this period
        pos_df = load_data_from_parquet(
            experiment_name=exp["name"],
            period=period_name,
            data_type="position",
            data_dir=data_dir,
            set_time_index=True,
        )

        # get drink bouts for each subject
        subjects = pos_df["identity_name"].unique()
        drink_bouts_df = pd.DataFrame(
            columns=["subject", "start", "end", "duration", "period"]
        )
        for subject in subjects:
            spout_loc = (1280, 500) if "aeon3" in exp["name"] else (1245, 535)
            subject_bouts = drink_bouts(pos_df, subject, spout_loc)
            if isinstance(subject_bouts, pd.DataFrame) and not subject_bouts.empty:
                drink_bouts_df = pd.concat(
                    [drink_bouts_df, subject_bouts], ignore_index=True
                )

        # save data dict
        drink_bouts_data_dict[exp["name"]][period_name] = drink_bouts_df
        save_all_experiment_data(
            experiments=[exp],
            periods=[period_name],
            data_dict=drink_bouts_data_dict,
            data_type="drink",
            data_dir=data_dir,
        )
        print(f"  Saved drink bouts for {exp['name']} during {period_name} period.")

In [None]:
drink_bouts_df

## Explore bouts

In [None]:
# Given pos_df, animal name, nest xy, reutrn all explore bouts in df

nest_center = np.array((1215, 530))
cm2px = 5.2
nest_radius = 14 * cm2px  # 14 cm, in px


def explore_bouts(
    pos_df: pd.DataFrame,
    subject: str,
    nest_center: np.ndarray,
    nest_radius: float = 14 * 5.2,  # 14 cm, in px
    max_speed: float = 100 * 5.2,  # 100 cm/s, in px/s
) -> pd.DataFrame:
    """Returns exploration bouts for a given animal within the specified position data time period.

    Args:
        pos_df (pd.DataFrame): DataFrame containing position data.
        subject (str): Name of the animal to filter by.
        nest_center (np.ndarray): Coordinates of the nest center.
        nest_radius (float): Radius of the nest area (default: 14 cm in px).
        max_speed (float): Maximum speed threshold for excising swaps (default: 100 cm/s in px/s).

    Returns:
        pd.DataFrame: DataFrame containing exploration bouts for the specified animal.
    """
    animal_data = pos_df[pos_df["identity_name"] == subject].copy()
    if animal_data.empty or not isinstance(animal_data, pd.DataFrame):
        print(f"No position data found for {subject}")
        return pd.DataFrame()

    # Set some constants and placeholder `windows_df` which will be combined into `bouts_df`
    explore_win = pd.Timedelta("1m")
    explore_windows_df = pd.DataFrame(
        columns=["subject", "start", "end", "duration", "period"]
    )

    # Create time windows based on start and end time
    data_start_time = animal_data.index.min()
    data_end_time = animal_data.index.max()
    window_starts = pd.date_range(
        start=data_start_time, end=data_end_time, freq=explore_win
    )

    # <s> Process each time window (use tqdm for progress bar)
    period = animal_data["period"].iloc[0]
    pbar = tqdm(window_starts, desc=f"Processing explore bouts for {subject} in {period}")
    for win_start in pbar:
        win_end = win_start + explore_win
        win_data = animal_data[
            (animal_data.index >= win_start) & (animal_data.index < win_end)
        ].copy()
        if len(win_data) < 100:  # skip windows with too little data
            continue

        # Excise id swaps (based on pos / speed jumps)
        win_data = correct_swaps(win_data, max_speed)

        # If majority of time in a window is outside nest, consider it an explore bout
        dx = win_data["x"] - nest_center[0]
        dy = win_data["y"] - nest_center[1]
        distance_from_nest = np.sqrt(dx**2 + dy**2)
        frac_out_nest = (distance_from_nest > nest_radius).sum() / len(win_data)
        if frac_out_nest > 0.5:
            new_bout = {
                "subject": subject,
                "start": win_start,
                "end": win_end,
                "duration": explore_win,
                "period": win_data["period"].iloc[0],
            }
            explore_windows_df = pd.concat(
                [explore_windows_df, pd.DataFrame([new_bout])], ignore_index=True
            )
    # </s>

    # <s> Now merge consecutive explore windows into continuous bouts
    if explore_windows_df.empty or not isinstance(explore_windows_df, pd.DataFrame):
        return pd.DataFrame(columns=["subject", "start", "end", "duration", "period"])
    # Initialize the merged bouts dataframe with the first window
    explore_bouts_df = pd.DataFrame(
        [
            {
                "subject": subject,
                "start": explore_windows_df.iloc[0]["start"],
                "end": explore_windows_df.iloc[0]["end"],
                "duration": explore_windows_df.iloc[0]["duration"],
                "period": explore_windows_df.iloc[0]["period"],
            }
        ]
    )
    # Iterate through remaining windows and merge consecutive ones
    for i in range(1, len(explore_windows_df)):
        current_window = explore_windows_df.iloc[i]
        last_bout = explore_bouts_df.iloc[-1]

        if current_window["start"] == last_bout["end"]:  # continue bout
            explore_bouts_df.at[len(explore_bouts_df) - 1, "end"] = current_window["end"]
            explore_bouts_df.at[len(explore_bouts_df) - 1, "duration"] = (
                explore_bouts_df.iloc[-1]["end"] - explore_bouts_df.iloc[-1]["start"]
            )
        else:  # start a new bout
            new_bout = {
                "subject": subject,
                "start": current_window["start"],
                "end": current_window["end"],
                "duration": current_window["duration"],
                "period": current_window["period"],
            }
            explore_bouts_df = pd.concat(
                [explore_bouts_df, pd.DataFrame([new_bout])], ignore_index=True
            )
    # </s>

    return explore_bouts_df

In [None]:
"""Save explore bouts to parquet files for each experiment and period"""

# For each experiment, for each period, load pos data, get explore bouts, save to parquet

pbar_exp = tqdm(experiments, desc="Processing experiments")
for exp in pbar_exp:
    sleep_bouts_data_dict = {}
    key = {"experiment_name": exp["name"]}

    # get nest center for this exp
    epoch_query = acquisition.Epoch & (acquisition.Chunk & key).proj("epoch_start")
    active_region_query = acquisition.EpochConfig.ActiveRegion & epoch_query
    roi_locs = dict(
        zip(*active_region_query.fetch("region_name", "region_data"), strict=True)
    )
    points = roi_locs["NestRegion"]["ArrayOfPoint"]
    vertices = np.array([[float(point["X"]), float(point["Y"])] for point in points])
    nest_center = np.mean(vertices, axis=0)

    sleep_bouts_data_dict[exp["name"]] = {}
    periods = {
        "presocial": (exp["presocial_start"], exp["presocial_end"]),
        "social": (exp["social_start"], exp["social_end"]),
        "postsocial": (exp["postsocial_start"], exp["postsocial_end"]),
    }
    pbar_period = tqdm(periods.items(), desc="Processing periods", leave=False)
    for period_name, (period_start, period_end) in pbar_period:
        print(f"  Loading {period_name} period...")
        period_start = datetime.strptime(period_start, "%Y-%m-%d %H:%M:%S")
        period_end = datetime.strptime(period_end, "%Y-%m-%d %H:%M:%S")

        # load pos data for this period
        pos_df = load_data_from_parquet(
            experiment_name=exp["name"],
            period=period_name,
            data_type="position",
            data_dir=data_dir,
            set_time_index=True,
        )

        # get explore bouts for each subject
        subjects = pos_df["identity_name"].unique()
        sleep_bouts_df = pd.DataFrame(
            columns=["subject", "start", "end", "duration", "period"]
        )
        for subject in subjects:
            subject_bouts = explore_bouts(pos_df, subject, nest_center)
            if isinstance(subject_bouts, pd.DataFrame) and not subject_bouts.empty:
                sleep_bouts_df = pd.concat(
                    [sleep_bouts_df, subject_bouts], ignore_index=True
                )

        # save data dict
        sleep_bouts_data_dict[exp["name"]][period_name] = sleep_bouts_df
        save_all_experiment_data(
            experiments=[exp],
            periods=[period_name],
            data_dict=sleep_bouts_data_dict,
            data_type="explore",
            data_dir=data_dir,
        )
        print(f"  Saved explore bouts for {exp['name']} during {period_name} period.")


key = {"experiment_name": "social0.2-aeon3"}
epoch_query = acquisition.Epoch & (acquisition.Chunk & key).proj("epoch_start")
active_region_query = acquisition.EpochConfig.ActiveRegion & epoch_query
roi_locs = dict(
    zip(*active_region_query.fetch("region_name", "region_data"), strict=True)
)

In [None]:
"""Example usage:"""

explore_df = load_data_from_parquet(
    experiment_name="social0.2-aeon3",
    period="presocial",
    data_type="explore",
    data_dir=data_dir,
    set_time_index=True,
)
display(explore_df)