(target-dj-social-analysis)=
# DataJoint pipeline: Social experiment analysis 

:::{important}
This guide assumes you have a [DataJoint pipeline deployed](target-dj-pipeline-deployment) with [data already ingested](target-dj-data-ingestion-processing).
:::

TODO: UPDATE DESCRIPTION
This guide builds upon the [Querying data](target-dj-querying-data) guide and provides examples of how to query data from the [Aeon DataJoint pipeline](target-aeon-dj-pipeline) ... 

You can also run this notebook online at [`works.datajoint.com`](https://works.datajoint.com/) using the following credentials:
 - Username: aeondemo
 - Password: aeon_djworks 

TODO: RENAME NOTEBOOK
To access it, go to the Notebook tab at the top and in the File Browser on the left, navigate to `ucl-swc_aeon > docs > examples`, where this notebook `dj_data_loading.ipynb` is located.

:::{note}
The examples here use the [social0.2-aeon4](target-full-datasets) dataset. 
If you are using a different dataset, be sure to replace the experiment name and  parameters in the code below accordingly.
:::

## Import libraries and define variables and helper functions

In [1]:
import datetime
import sys
import os
import warnings
from pathlib import Path
from typing import Any, Tuple, List, Dict

import numpy as np
import pandas as pd
import plotly
import plotly.express as px
import plotly.graph_objs as go
import statsmodels.api as sm
from plotly.subplots import make_subplots
from tqdm.notebook import tqdm

import datajoint as dj
from aeon.dj_pipeline.analysis.block_analysis import *
from aeon.dj_pipeline import acquisition, streams, subject
from swc.aeon.io import api as aeon_api
from aeon.schema.schemas import social02

[2025-08-28 15:32:58,460][INFO]: DataJoint 0.14.6 connected to chuan@aeon-db2:3306


In [27]:
def ensure_ts_arr_datetime(array):
    """Ensure array is a numpy array of datetime64[ns] type."""
    if len(array) == 0:
        return np.array([], dtype="datetime64[ns]")
    else:
        return np.array(array, dtype="datetime64[ns]")


def concat_and_reorder(dfs):
    """Concatenate dataframes and reorder columns."""

    df = pd.concat(dfs, ignore_index=True)
    first_cols = ["experiment_name", "period"]
    return df[first_cols + [col for col in df.columns if col not in first_cols]]


In [10]:
cm2px = 5.2  # 1 cm = 5.2 px roughly for top camera
exp = {
    "name": "social0.2-aeon4",
    "presocial_start": "2024-01-31 11:00:00",
    "presocial_end": "2024-02-08 15:00:00",
    "social_start": "2024-02-09 17:00:00",
    "social_end": "2024-02-23 12:00:00",
    "postsocial_start": "2024-02-25 18:00:00",
    "postsocial_end": "2024-03-02 13:00:00",
}
key = {"experiment_name": exp["name"]}

# Define periods
periods = {
    "presocial": (exp["presocial_start"], exp["presocial_end"]),
    "social": (exp["social_start"], exp["social_end"]),
    "postsocial": (exp["postsocial_start"], exp["postsocial_end"]),
}


## Fetch Patch data

In [11]:

def load_subject_patch_data(
    key: dict[str, str], period_start: str, period_end: str
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    """Loads subject patch data for a specified time period.

    Args:
        key (dict): The key to filter the subject patch data.
        period_start (str): The start time for the period.
        period_end (str): The end time for the period.

    Returns:
        tuple: A tuple containing:
            - patch_info (pd.DataFrame): Information about patches.
            - block_subject_patch_data (pd.DataFrame): Data for the specified period.
            - block_subject_patch_pref (pd.DataFrame): Preference data for the specified period.
    """
    patch_info = (
        BlockAnalysis.Patch()
        & key
        & f"block_start >= '{period_start}'"
        & f"block_start <= '{period_end}'"
    ).fetch(
        "block_start",
        "patch_name",
        "patch_rate",
        "patch_offset",
        "wheel_timestamps",
        as_dict=True,
    )

    block_subject_patch_data = (
        BlockSubjectAnalysis.Patch()
        & key
        & f"block_start >= '{period_start}'"
        & f"block_start <= '{period_end}'"
    ).fetch(format="frame")

    block_subject_patch_pref = (
        BlockSubjectAnalysis.Preference()
        & key
        & f"block_start >= '{period_start}'"
        & f"block_start <= '{period_end}'"
    ).fetch(format="frame")

    if patch_info:
        patch_info = pd.DataFrame(patch_info)

    if isinstance(block_subject_patch_data, pd.DataFrame) and not block_subject_patch_data.empty:
        block_subject_patch_data.reset_index(inplace=True)

    if isinstance(block_subject_patch_pref, pd.DataFrame) and not block_subject_patch_pref.empty:
        block_subject_patch_pref.reset_index(inplace=True)

    return patch_info, block_subject_patch_data, block_subject_patch_pref


patch_info_list = []
block_subject_patch_data_list = []
block_subject_patch_pref_list = []

for period_name, (start, end) in periods.items():
    start_dt = datetime.strptime(start, "%Y-%m-%d %H:%M:%S")
    end_dt = datetime.strptime(end, "%Y-%m-%d %H:%M:%S")

    patch_info, patch_data, patch_pref = load_subject_patch_data(key, start_dt, end_dt)

    # Drop NaNs in preference columns
    patch_pref = patch_pref.dropna(subset=["final_preference_by_time", "final_preference_by_wheel"])

    # Add metadata using assign
    patch_info = patch_info.assign(experiment_name=exp["name"], period=period_name)
    patch_data = patch_data.assign(experiment_name=exp["name"], period=period_name)
    patch_pref = patch_pref.assign(experiment_name=exp["name"], period=period_name)

    # Validate subject count for pre/post-social blocks
    if period_name in ["presocial", "postsocial"] and not patch_data.empty:
        n_subjects = patch_data.groupby("block_start")["subject_name"].nunique()
        if (n_subjects != 1).any():
            warnings.warn(
                f"{exp['name']} {period_name} blocks have >1 subject. Data may need cleaning."
            )

    # Ensure timestamp arrays are datetime64[ns]
    for col in ["pellet_timestamps", "in_patch_rfid_timestamps", "in_patch_timestamps"]:
        if col in patch_data.columns:
            patch_data[col] = patch_data[col].apply(ensure_ts_arr_datetime)

    patch_info_list.append(patch_info)
    block_subject_patch_data_list.append(patch_data)
    block_subject_patch_pref_list.append(patch_pref)

patch_info = concat_and_reorder(patch_info_list)
block_subject_patch_data = concat_and_reorder(block_subject_patch_data_list)
block_subject_patch_pref = concat_and_reorder(block_subject_patch_pref_list)

In [29]:
patch_info.head()

Unnamed: 0,experiment_name,period,block_start,patch_name,wheel_timestamps,patch_rate,patch_offset
0,social0.2-aeon4,presocial,2024-01-31 12:59:08.005984,Patch1,"[2024-01-31T12:59:08.020000000, 2024-01-31T12:...",0.01,75.0
1,social0.2-aeon4,presocial,2024-01-31 12:59:08.005984,Patch2,"[2024-01-31T12:59:08.020000000, 2024-01-31T12:...",0.002,75.0
2,social0.2-aeon4,presocial,2024-01-31 12:59:08.005984,Patch3,"[2024-01-31T12:59:08.020000000, 2024-01-31T12:...",0.0033,75.0
3,social0.2-aeon4,presocial,2024-01-31 14:58:09.045984,Patch1,"[2024-01-31T14:58:09.060000000, 2024-01-31T14:...",0.01,75.0
4,social0.2-aeon4,presocial,2024-01-31 14:58:09.045984,Patch2,"[2024-01-31T14:58:09.060000000, 2024-01-31T14:...",0.01,75.0


In [34]:
block_subject_patch_data.head()

Unnamed: 0,experiment_name,period,block_start,patch_name,subject_name,in_patch_timestamps,in_patch_time,in_patch_rfid_timestamps,pellet_count,pellet_timestamps,patch_threshold,wheel_cumsum_distance_travelled
0,social0.2-aeon4,presocial,2024-01-31 12:59:08.005984,Patch1,BAA-1104048,"[2024-01-31T12:59:15.540000000, 2024-01-31T12:...",1780.5,"[2024-01-31T13:01:56.638368000, 2024-01-31T13:...",30,"[2024-01-31T13:42:05.553504000, 2024-01-31T13:...","[519.6922767677252, 103.3127392096241, 151.577...","[-0.0, 0.0076703721017885584, 0.00460222326107..."
1,social0.2-aeon4,presocial,2024-01-31 12:59:08.005984,Patch2,BAA-1104048,"[2024-01-31T12:59:12.160000000, 2024-01-31T12:...",346.94,"[2024-01-31T13:04:43.674048000, 2024-01-31T13:...",0,[],[],"[-0.0, 0.006136297681430314, 0.001534074420357..."
2,social0.2-aeon4,presocial,2024-01-31 12:59:08.005984,Patch3,BAA-1104048,"[2024-01-31T12:59:08.020000000, 2024-01-31T12:...",1326.94,"[2024-01-31T12:59:08.040480000, 2024-01-31T12:...",9,"[2024-01-31T13:06:05.353504000, 2024-01-31T13:...","[109.21084511590703, 426.49168400760897, 435.9...","[-0.0, -0.00767037210178767, 0.001534074420357..."
3,social0.2-aeon4,presocial,2024-01-31 14:58:09.045984,Patch1,BAA-1104048,"[2024-01-31T14:58:16.260000000, 2024-01-31T14:...",802.24,"[2024-01-31T15:17:52.416960000, 2024-01-31T15:...",11,"[2024-01-31T15:57:04.845504000, 2024-01-31T15:...","[110.68885894736736, 76.14063915657457, 142.50...","[-0.0, -0.001534074420357634, -0.0046022232610..."
4,social0.2-aeon4,presocial,2024-01-31 14:58:09.045984,Patch2,BAA-1104048,"[2024-01-31T14:59:43.600000000, 2024-01-31T14:...",33.1,"[2024-01-31T15:48:42.393728000, 2024-01-31T15:...",0,[],[],"[-0.0, 0.003068148840715157, 0.001534074420357..."


In [35]:
block_subject_patch_pref.head()

Unnamed: 0,experiment_name,period,block_start,patch_name,subject_name,cumulative_preference_by_wheel,cumulative_preference_by_time,running_preference_by_time,running_preference_by_wheel,final_preference_by_wheel,final_preference_by_time
0,social0.2-aeon4,presocial,2024-01-31 12:59:08.005984,Patch1,BAA-1104048,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0.60653,0.515433
1,social0.2-aeon4,presocial,2024-01-31 12:59:08.005984,Patch2,BAA-1104048,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0.0,0.100435
2,social0.2-aeon4,presocial,2024-01-31 12:59:08.005984,Patch3,BAA-1104048,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[5.78975098281023e-06, 1.157950196562046e-05, ...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0.393466,0.384133
3,social0.2-aeon4,presocial,2024-01-31 14:58:09.045984,Patch1,BAA-1104048,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0.356369,0.322589
4,social0.2-aeon4,presocial,2024-01-31 14:58:09.045984,Patch2,BAA-1104048,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0.0,0.01331


## Fetch foraging bouts

In [None]:
def load_foraging_bouts(
    key: Dict[str, str], period_start: str, period_end: str
) -> pd.DataFrame:
    """Loads foraging bout data for blocks falling within a specified time period.

    Args:
        key (dict): Key to identify experiment data (e.g., {"experiment_name": "Exp1"}).
        period_start (str): Start datetime of the time period (format: '%Y-%m-%d %H:%M:%S').
        period_end (str): End datetime of the time period (format: '%Y-%m-%d %H:%M:%S').

    Returns:
        pd.DataFrame: Concatenated dataframe of foraging bouts for all matching blocks.
                      Returns an empty dataframe with predefined columns if no data found.
    """
    # Fetch block start times within the specified period
    blocks = (
        Block & key & f"block_start >= '{period_start}'" & f"block_end <= '{period_end}'"
    ).fetch("block_start")

    # Retrieve foraging bouts for each block
    bouts = []
    for block_start in blocks:
        block_key = key | {"block_start": str(block_start)}
        bouts.append(get_foraging_bouts(block_key, min_pellets=1))

    # Return concatenated DataFrame or empty fallback
    if bouts:
        return pd.concat(bouts, ignore_index=True)
    else:
        return pd.DataFrame(
            columns=["start", "end", "n_pellets", "cum_wheel_dist", "subject"]
        )

# Load foraging bouts for each period and concatenate into a single DataFrame
foraging_df = concat_and_reorder(
    [
        load_foraging_bouts(key,
                            datetime.strptime(start, "%Y-%m-%d %H:%M:%S"),
                            datetime.strptime(end, "%Y-%m-%d %H:%M:%S")
        ).assign(
            experiment_name=exp["name"],
            period=period_name
        )
        for period_name, (start, end) in periods.items()
    ]
)


In [37]:
foraging_df.head()


Unnamed: 0,experiment_name,period,start,end,n_pellets,cum_wheel_dist,subject
0,social0.2-aeon4,presocial,2024-01-31 13:05:07.440,2024-01-31 13:09:21.700,1,491.692329,BAA-1104048
1,social0.2-aeon4,presocial,2024-01-31 13:27:50.840,2024-01-31 13:35:41.940,4,1509.667296,BAA-1104048
2,social0.2-aeon4,presocial,2024-01-31 13:35:47.100,2024-01-31 13:40:15.720,3,557.384464,BAA-1104048
3,social0.2-aeon4,presocial,2024-01-31 13:41:56.460,2024-01-31 13:52:04.340,9,1538.124377,BAA-1104048
4,social0.2-aeon4,presocial,2024-01-31 14:02:01.400,2024-01-31 14:06:14.680,4,651.983163,BAA-1104048


## Fetch RFID data

In [41]:
def load_rfid_events(
    key: Dict[str, str], period_start: str, period_end: str
) -> pd.DataFrame:
    """Loads RFID events data for chunks falling within a specified time period.

    Args:
        key (dict): Key to identify experiment data (e.g., {"experiment_name": "Exp1"}).
        period_start (str): Start datetime of the time period (format: '%Y-%m-%d %H:%M:%S').
        period_end (str): End datetime of the time period (format: '%Y-%m-%d %H:%M:%S').

    Returns:
        pd.DataFrame: DataFrame containing RFID events for the specified period.
                      Returns an empty dataframe with predefined columns if no data found.
    """
    # Fetch RFID events within the specified period
    rfid_events_df = (
        streams.RfidReader * streams.RfidReaderRfidEvents
        & key
        & f'chunk_start >= "{period_start}"'
        & f'chunk_start <= "{period_end}"'
    ).fetch(format="frame")

    if rfid_events_df.empty or not isinstance(rfid_events_df, pd.DataFrame):
        # Return empty DataFrame with expected columns if no data found
        return pd.DataFrame(
            columns=[
                "experiment_name",
                "chunk_start",
                "rfid_reader_name",
                "sample_count",
                "timestamps",
                "rfid",
            ]
        )

    # Get subject details for RFID mapping
    subject_detail = subject.SubjectDetail.fetch(format="frame")
    subject_detail.reset_index(inplace=True)

    # Create mapping from RFID to subject ID
    rfid_to_lab_id = dict(zip(subject_detail["lab_id"], subject_detail["subject"]))

    rfid_events_df["rfid"] = [
        [rfid_to_lab_id.get(str(rfid)) for rfid in rfid_array]
        for rfid_array in rfid_events_df["rfid"]
    ]

    # Extract experiment_name and chunk_start from the index before resetting
    rfid_events_df["experiment_name"] = [idx[0] for idx in rfid_events_df.index]
    rfid_events_df["chunk_start"] = [
        idx[3] for idx in rfid_events_df.index
    ]  # Assuming chunk_start is at index 3

    # Reset the index and drop the index column
    rfid_events_df = rfid_events_df.reset_index(drop=True)

    # Reorder columns to put experiment_name first and chunk_start second
    cols = ["experiment_name", "chunk_start"] + [
        col
        for col in rfid_events_df.columns
        if col not in ["experiment_name", "chunk_start"]
    ]
    rfid_events_df = rfid_events_df[cols]

    return rfid_events_df

# Load RFID events for each period and concatenate into a single DataFrame
rfid_data_list = []
for period_name, (start, end) in periods.items():
    period_start_str = start
    period_end_str = end

    # Load RFID data for this period
    rfid_df = load_rfid_events(key, period_start_str, period_end_str)
    rfid_df = rfid_df.assign(
        experiment_name=exp["name"],
        period=period_name
    )
    rfid_data_list.append(rfid_df)

rfid_events_df = concat_and_reorder(rfid_data_list)

In [43]:
rfid_events_df.head()

Unnamed: 0,experiment_name,period,chunk_start,rfid_reader_name,sample_count,timestamps,rfid
0,social0.2-aeon4,presocial,2024-01-31 11:00:00,Patch1Rfid,50,"[2024-01-31T11:30:35.493951797, 2024-01-31T11:...","[BAA-1104048, BAA-1104048, BAA-1104048, BAA-11..."
1,social0.2-aeon4,presocial,2024-01-31 12:00:00,Patch1Rfid,25,"[2024-01-31T12:34:34.396575928, 2024-01-31T12:...","[BAA-1104048, BAA-1104048, BAA-1104048, BAA-11..."
2,social0.2-aeon4,presocial,2024-01-31 13:00:00,Patch1Rfid,606,"[2024-01-31T13:01:56.638368130, 2024-01-31T13:...","[BAA-1104048, BAA-1104048, BAA-1104048, BAA-11..."
3,social0.2-aeon4,presocial,2024-01-31 14:00:00,Patch1Rfid,1948,"[2024-01-31T14:02:40.840640068, 2024-01-31T14:...","[BAA-1104048, BAA-1104048, BAA-1104048, BAA-11..."
4,social0.2-aeon4,presocial,2024-01-31 15:00:00,Patch1Rfid,357,"[2024-01-31T15:17:52.416959763, 2024-01-31T15:...","[BAA-1104048, BAA-1104048, BAA-1104048, BAA-11..."


In [45]:
def load_position_data(
    key: Dict[str, str], period_start: str, period_end: str
) -> pd.DataFrame:
    """Loads position data (centroid tracking) for a specified time period.

    Args:
        key (dict): Key to identify experiment data (e.g., {"experiment_name": "Exp1"}).
        period_start (str): Start datetime of the time period.
        period_end (str): End datetime of the time period.

    Returns:
        pd.DataFrame: DataFrame containing position data for the specified period.
                     Returns an empty DataFrame if no data found.
    """
    try:
        print(f"  Querying data from {period_start} to {period_end}...")

        # Create chunk restriction for the time period
        chunk_restriction = acquisition.create_chunk_restriction(
            key["experiment_name"], period_start, period_end
        )

        # Fetch centroid tracking data for the specified period
        centroid_df = (
            streams.SpinnakerVideoSource * tracking.DenoisedTracking.Subject
            & key
            & {"spinnaker_video_source_name": "CameraTop"}
            & chunk_restriction
        ).fetch(format="frame")

        centroid_df = centroid_df.reset_index()
        centroid_df = centroid_df.rename(
            columns={
                "subject_name": "identity_name",
                "timestamps": "time",
                "subject_likelihood": "identity_likelihood",
            }
        )
        centroid_df = centroid_df.explode(
            ["time", "identity_likelihood", "x", "y", "likelihood"]
        )
        centroid_df = centroid_df[
            [
                "time",
                "experiment_name",
                "identity_name",
                "identity_likelihood",
                "x",
                "y",
                "likelihood",
            ]
        ].set_index("time")

        # Clean up the dataframe
        if isinstance(centroid_df, pd.DataFrame) and not centroid_df.empty:
            if "spinnaker_video_source_name" in centroid_df.columns:
                centroid_df.drop(columns=["spinnaker_video_source_name"], inplace=True)
            print(f"  Retrieved {len(centroid_df)} rows of position data")
        else:
            print("  No data found for the specified period")

        return centroid_df

    except Exception as e:
        print(
            f"  Error loading position data for {key['experiment_name']} ({period_start} "
            f"to {period_end}): {e}"
        )
        return pd.DataFrame()


position_data_list = []
for period_name, (period_start, period_end) in periods.items():
    # Load position data for this period
    df = load_position_data(key, period_start, period_end)
    df.reset_index(inplace=True)
    df = df.assign(
        experiment_name=exp["name"], 
        period=period_name
    )
    position_data_list.append(df)

position_df = concat_and_reorder(position_data_list)
    

  Querying data from 2024-01-31 11:00:00 to 2024-02-08 15:00:00...
  Retrieved 26951914 rows of position data
  Querying data from 2024-02-09 17:00:00 to 2024-02-23 12:00:00...
  Retrieved 94316441 rows of position data
  Querying data from 2024-02-25 18:00:00 to 2024-03-02 13:00:00...
  Retrieved 25101935 rows of position data


In [46]:
position_df.head()

Unnamed: 0,experiment_name,period,time,identity_name,identity_likelihood,x,y,likelihood
0,social0.2-aeon4,presocial,2024-01-31 10:15:55.500,BAA-1104048,,1440.638428,229.691986,0.20518
1,social0.2-aeon4,presocial,2024-01-31 10:15:55.520,BAA-1104048,,1440.613281,229.785126,0.215189
2,social0.2-aeon4,presocial,2024-01-31 10:15:55.540,BAA-1104048,,1440.594604,229.80069,0.233931
3,social0.2-aeon4,presocial,2024-01-31 10:15:55.560,BAA-1104048,,1440.530151,229.937073,0.211704
4,social0.2-aeon4,presocial,2024-01-31 10:15:55.580,BAA-1104048,,1440.455933,232.457001,0.202572


In [47]:
def load_weight_data(
    key: Dict[str, str], period_start: str, period_end: str
) -> pd.DataFrame:
    """Loads weight data for a specified time period.

    Args:
        key (dict): Key to identify experiment data (e.g., {"experiment_name": "Exp1"}).
        period_start (str): Start datetime of the time period (format: '%Y-%m-%d %H:%M:%S').
        period_end (str): End datetime of the time period (format: '%Y-%m-%d %H:%M:%S').

    Returns:
        pd.DataFrame: Weight data for the specified period.
                      Returns an empty dataframe if no data found.
    """
    try:
        weight_df = (
            acquisition.Environment.SubjectWeight
            & key
            & f"chunk_start >= '{period_start}'"
            & f"chunk_start <= '{period_end}'"
        ).fetch(format="frame")

        return weight_df if not weight_df.empty and isinstance(weight_df, pd.DataFrame) else pd.DataFrame()
    except Exception as e:
        print(
            f"Error loading weight data for {key} from {period_start} to {period_end}: {e}"
        )
        return pd.DataFrame()

weight_data_list = []
for period_name, (period_start, period_end) in periods.items():
    # Convert to datetime if needed (assuming they're already strings in the right format)
    if isinstance(period_start, str):
        period_start_dt = datetime.strptime(period_start, "%Y-%m-%d %H:%M:%S")
        period_end_dt = datetime.strptime(period_end, "%Y-%m-%d %H:%M:%S")
    else:
        period_start_dt = period_start
        period_end_dt = period_end

    # Load weight data for this period
    weight_df = load_weight_data(key, str(period_start_dt), str(period_end_dt))
    weight_df = weight_df.assign(
        experiment_name=exp["name"],
        period=period_name
    )
    weight_data_list.append(weight_df)
weight_df = concat_and_reorder(weight_data_list)

In [48]:
weight_df.head()

Unnamed: 0,experiment_name,period,sample_count,timestamps,weight,confidence,subject_id,int_id
0,social0.2-aeon4,presocial,0,[],[],[],[],[]
1,social0.2-aeon4,presocial,0,[],[],[],[],[]
2,social0.2-aeon4,presocial,0,[],[],[],[],[]
3,social0.2-aeon4,presocial,0,[],[],[],[],[]
4,social0.2-aeon4,presocial,0,[],[],[],[],[]
