In [1]:
import os
import aeon
import aeon.io.api as aeon_api
import pandas as pd
import warnings
import numpy as np
import plotly.graph_objects as go

from aeon.io.video import frames
from aeon.analysis.movies import gridframes
from aeon.schema.schemas import social02
from dotmap import DotMap

In [2]:
aeon3_social_02 = {
    "subjects": {
        "BAA-1104047": { # no tattoo
            "root": "/ceph/aeon/aeon/data/raw/AEON3/social0.2/",
            "start": pd.Timestamp("2024-02-05 15:43:11.581535816"),
            "end": pd.Timestamp("2024-02-08 14:49:41.552000046"),
        },
        "BAA-1104045": { # tattooed
            "root": "/ceph/aeon/aeon/data/raw/AEON3/social0.2/",
            "start": pd.Timestamp("2024-01-31 11:28:45.543519974"),
            "end": pd.Timestamp("2024-02-03 16:28:29.139999866"),
        },
        "multi_animal": {
            "root": "/ceph/aeon/aeon/data/raw/AEON3/model-training/", # different root
            "start": pd.Timestamp("2024-01-09 15:37:00"), # placeholder
            "end": pd.Timestamp("2024-01-09 16:09:00"),
        }
    },
    "patch_camera_order": [2,1,3], # patch 1 and 2 rfids are swapped in AEON3
    "session": "aeon3_social_02",
}

aeon4_social_02 = {
    "subjects": {
        "BAA-1104049": { # no tattoo
            "root": "/ceph/aeon/aeon/data/raw/AEON4/social0.2/",
            "start": pd.Timestamp("2024-01-09 16:46:00"),
            "end": pd.Timestamp("2024-01-09 17:15:00"),
        },
        "BAA-1104048": { # tattooed
            "root": "/ceph/aeon/aeon/data/raw/AEON4/social0.2/",
            "start": pd.Timestamp("2024-01-09 16:10:00"),
            "end": pd.Timestamp("2024-01-09 16:39:00"),
        },
        "multi_animal": {
            "root": "/ceph/aeon/aeon/data/raw/AEON4/model-training/", # different root
            "start": pd.Timestamp("2024-01-21 19:12:00"), # placeholder
            "end": pd.Timestamp("2024-01-21 19:43:00"),
        }
    },
    "session": "aeon4_social_02",
}

In [14]:
def get_experiment_times(
    root: str | os.PathLike, start_time: pd.Timestamp, end_time: pd.Timestamp
) -> DotMap:
    """
    Retrieve experiment start and stop times from environment states
    (i.e. times outside of maintenance mode) occurring within the
    given start and end times.

    Args:
        root (str or os.PathLike): The root path where epoch data is stored.
        start_time (pandas.Timestamp): Start time.
        end_time (pandas.Timestamp): End time.

    Returns:
        DotMap: A DotMap object containing two keys: 'start' and 'stop',
        corresponding to pairs of experiment start and stop times.

    Notes:
    This function uses the last 'Maintenance' event as the last 'Experiment'
    stop time. If the first retrieved state is 'Maintenance' (e.g.
    'Experiment' mode entered before `start_time`), `start_time` is used
    as the first 'Experiment' start time.
    """

    experiment_times = DotMap()
    env_states = aeon.load(
        root,
        social02.Environment.EnvironmentState,
        #aeon.io.reader.Csv("Environment_EnvironmentState_*", ["state"]),
        start_time,
        end_time,
    )
    if env_states.empty:
        warnings.warn("The environment state df is empty. "
                      "Using input `start_time` and `end_time` as experiment times.")
        experiment_times.start = [start_time]
        experiment_times.stop = [end_time]
        return experiment_times
    if "Maintenance" not in env_states["state"].values:
        warnings.warn("No 'Maintenance' events found. "
                      "Using input `end_time` as last experiment stop time.")
        # Pad with a "Maintenance" event at the end
        env_states = pd.concat(
            [
                env_states,
                pd.DataFrame(
                    "Maintenance",
                    index=[end_time],
                    columns=env_states.columns,
                ),
            ]
        )
    # Use the last "Maintenance" event as end time
    end_time = (env_states[env_states.state == "Maintenance"]).index[-1]
    env_states = env_states[~env_states.index.duplicated(keep="first")]
    # Retain only events between visit start and stop times
    env_states = env_states.iloc[
        env_states.index.get_indexer([start_time], method="bfill")[
            0
        ] : env_states.index.get_indexer([end_time], method="ffill")[0] + 1
    ]
    # Retain only events where state changes (experiment-maintenance pairs)
    env_states = env_states[env_states["state"].ne(env_states["state"].shift())]
    if env_states["state"].iloc[0] == "Maintenance":
        # Pad with an "Experiment" event at the start
        env_states = pd.concat(
            [
                pd.DataFrame(
                    "Experiment",
                    index=[start_time],
                    columns=env_states.columns,
                ),
                env_states,
            ]
        )
    experiment_times.start = env_states[
        env_states["state"] == "Experiment"
    ].index.values
    experiment_times.stop = env_states[
        env_states["state"] == "Maintenance"
    ].index.values

    return experiment_times


def exclude_maintenance_data(
    data: pd.DataFrame, experiment_times: DotMap
) -> pd.DataFrame:
    """
    Exclude rows not in experiment times (i.e., corresponding to maintenance times)
    from the given dataframe.

    Args:
        data (pandas.DataFrame): The data to filter. Expected to have a DateTimeIndex.
        experiment_times (DotMap): A DotMap object containing experiment start and stop times.

    Returns:
        pandas.DataFrame: The filtered data.
    """
    filtered_data = pd.concat(
        [
            data.loc[start:stop]
            for start, stop in zip(experiment_times.start, experiment_times.stop)
        ]
    )
    return filtered_data

def get_single_frame(
    root: str | os.PathLike,
    video_reader: aeon.io.reader.Video,
    time: pd.Timestamp,
) -> np.ndarray:
    """
    Retrieve a single frame from the given root directory,
    Video reader, and time.

    Args:
        root (str or os.PathLike): The root path where epoch data
            is stored.
        video_reader (aeon.io.reader.Video): The Video reader.
        time (pd.Timestamp): The timestamp of the frame to retrieve.

    Returns:
        numpy.ndarray: The raw frame.
    """
    vdata = aeon.load(
        root, video_reader, start=time, end=time + pd.Timedelta(seconds=1)
    )
    vframe = frames(vdata.iloc[:1])
    return np.squeeze(list(vframe))


def show_frame(raw_frame: np.ndarray, width: int = 1440, height: int = 1080):
    """
    Display raw input frame(s).

    Args:
        raw_frame (numpy.ndarray): The raw frame.
        width (int): The width of the display layout.
        height (int): The height of the display layout.

    Returns:
        None
    """
    width = width
    height = height
    fig = go.Figure(
        data=[go.Image(z=raw_frame)],
        layout=go.Layout(
            width=width,
            height=height,
            xaxis=dict(
                visible=False,
            ),
            yaxis=dict(
                visible=False,
                scaleanchor="x",
            ),
            margin=dict(l=0, r=0, t=0, b=0),
        ),
    )
    fig.show()

In [4]:
# load subject states to retrieve enter, exit times
root = "/ceph/aeon/aeon/data/raw/AEON3/social0.2/"
aeon_api.load(root, social02.Environment.SubjectState) #aeon.io.reader.Subject("Environment_SubjectState_*"))


Unnamed: 0_level_0,id,weight,type
time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
2024-01-31 11:28:45.543519974,BAA-1104045,29.1,Remain
2024-02-01 22:36:53.196512222,BAA-1104045,29.1,Remain
2024-02-02 00:15:06.000000000,BAA-1104045,29.1,Remain
2024-02-03 16:28:29.139999866,BAA-1104045,29.1,Exit
2024-02-05 15:43:11.581535816,BAA-1104047,33.200001,Remain
2024-02-08 14:49:41.552000046,BAA-1104047,30.4,Exit
2024-02-08 14:56:42.000000000,test,740000.0,Enter
2024-02-08 15:03:59.465536118,test,740000.0,Remain
2024-02-08 15:07:03.751999855,test,740000.0,Exit
2024-02-08 15:07:19.808000088,Test2,74.0,Enter


In [13]:
env_states = aeon_api.load(root, social02.Environment.EnvironmentState,pd.Timestamp("2024-01-31 11:28:45.543519974"), pd.Timestamp("2024-02-03 16:22:52.943999767") ) #aeon.io.reader.Subject("Environment_SubjectState_*"))
env_states

Unnamed: 0_level_0,state
time,Unnamed: 1_level_1
2024-01-31 11:28:45.543519974,Experiment
2024-02-01 22:36:53.196512222,Experiment
2024-02-02 00:15:06.000000000,Experiment


In [9]:
# load subject region visits if subject states are not available
subj_visits = aeon_api.load(root, social02.Environment.SubjectVisits) #aeon.io.reader.Csv("Environment_SubjectVisits_*", columns=["id", "state", "region"]))
subj_visits[(subj_visits["region"] == "Environment")]

Unnamed: 0_level_0,id,type,region
time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
2024-01-31 11:28:45.543519974,BAA-1104045,Remain,Environment
2024-02-01 22:36:53.196512222,BAA-1104045,Remain,Environment
2024-02-02 00:15:06.000000000,BAA-1104045,Remain,Environment
2024-02-03 16:28:29.139999866,BAA-1104045,Exit,Environment
2024-02-05 15:43:11.581535816,BAA-1104047,Remain,Environment
2024-02-08 14:49:41.552000046,BAA-1104047,Exit,Environment
2024-02-08 14:56:42.000000000,test,Enter,Environment
2024-02-08 15:03:59.465536118,test,Remain,Environment
2024-02-08 15:07:03.751999855,test,Exit,Environment
2024-02-08 15:07:19.808000088,Test2,Enter,Environment


In [9]:
get_experiment_times(root, pd.Timestamp("2024-01-31 11:28:45.543519974"), pd.Timestamp("2024-02-03 16:28:29.139999866"),)

DotMap(start=array(['2024-01-31T11:28:45.543519974'], dtype='datetime64[ns]'), stop=array(['2024-02-03T16:23:52.943999767'], dtype='datetime64[ns]'), _ipython_display_=DotMap(), _repr_mimebundle_=DotMap())

In [11]:
# get experiment times for each subject within a session
subj_ids = [subj for subj in aeon3_social_02["subjects"] if "multi_" not in subj]
cameras = pd.DataFrame()
for subj in subj_ids:
    subj_dict = aeon3_social_02["subjects"][subj]
    root = aeon3_social_02.get("root", subj_dict["root"])
    camera_idx = aeon3_social_02.get("patch_camera_order", [1,2,3]) # patch 1 and 2 rfids are swapped in AEON3
    start_time = subj_dict["start"]
    end_time = subj_dict["end"]
    experiment_times = get_experiment_times(root, start_time, end_time)
    print(f"{subj} \nexp start times: {experiment_times.start} \nexp stop times: {experiment_times.stop}")
    for i in range(1, 4):
        rfid_reads = aeon_api.load(
            root,
            aeon.io.reader.Harp(pattern=f"Patch{i}Rfid_32_*", columns=["Rfid"]),
            start=start_time,
            end=end_time,
        )
        rfid_reads = exclude_maintenance_data(
            rfid_reads, 
            experiment_times
        ).sample(n=100) # randomly sample 100 rows per patch
        camera = aeon_api.load(root, aeon.io.reader.Video(f"CameraPatch{camera_idx[i-1]}_*"), time=rfid_reads.index)
        camera["id"] = subj
        cameras = pd.concat([cameras, camera])

BAA-1104047 
exp start times: ['2024-02-05T15:43:11.581535816' '2024-02-08T14:31:46.447999954'] 
exp stop times: ['2024-02-08T14:23:15.659999847' '2024-02-08T14:41:38.703999996']
BAA-1104045 
exp start times: ['2024-01-31T11:28:45.543519974'] 
exp stop times: ['2024-02-03T16:23:52.943999767']


In [11]:
cameras.to_csv(f'/ceph/aeon/aeon/code/scratchpad/sleap/social0.2/{aeon3_social_02["session"]}_rfid_patch_frames.csv')

In [None]:
show_frame(gridframes(list(frames(cameras.sample(n=25))), width=1440, height=1080, shape=25))

In [2]:
import pandas as pd
import sleap

from sleap.gui.suggestions import SuggestionFrame

In [5]:
cameras = pd.read_csv(f'/ceph/aeon/aeon/code/scratchpad/sleap/social0.2/{aeon3_social_02["session"]}_rfid_patch_frames.csv')
cameras

Unnamed: 0,time,hw_counter,hw_timestamp,_frame,_path,_epoch,id
0,2024-02-05 18:51:27.980959892,1432926,11626032384296,385997,/ceph/aeon/aeon/data/raw/AEON3/social0.2/2024-...,2024-02-05T15-43-07,BAA-1104047
1,2024-02-05 19:36:28.116576195,1770443,14326126636184,273514,/ceph/aeon/aeon/data/raw/AEON3/social0.2/2024-...,2024-02-05T15-43-07,BAA-1104047
2,2024-02-06 12:22:05.946847916,9312672,74663028331296,165743,/ceph/aeon/aeon/data/raw/AEON3/social0.2/2024-...,2024-02-05T15-43-07,BAA-1104047
3,2024-02-06 13:18:01.029983997,9732057,78018056524320,135128,/ceph/aeon/aeon/data/raw/AEON3/social0.2/2024-...,2024-02-05T15-43-07,BAA-1104047
4,2024-02-06 13:18:35.917119980,9736418,78052943978048,139489,/ceph/aeon/aeon/data/raw/AEON3/social0.2/2024-...,2024-02-05T15-43-07,BAA-1104047
...,...,...,...,...,...,...,...
595,2024-02-03 14:42:48.592991829,17306575,139328963531320,321074,/ceph/aeon/aeon/data/raw/AEON3/social0.2/2024-...,2024-02-02T00-15-00,BAA-1104045
596,2024-02-03 14:38:13.809375763,17272227,139054182581296,286726,/ceph/aeon/aeon/data/raw/AEON3/social0.2/2024-...,2024-02-02T00-15-00,BAA-1104045
597,2024-02-03 14:44:46.485727787,17321311,139446850242024,335810,/ceph/aeon/aeon/data/raw/AEON3/social0.2/2024-...,2024-02-02T00-15-00,BAA-1104045
598,2024-02-03 14:43:59.825600147,17315479,139400194739832,329978,/ceph/aeon/aeon/data/raw/AEON3/social0.2/2024-...,2024-02-02T00-15-00,BAA-1104045


In [23]:
cameras = pd.read_csv(f'/ceph/aeon/aeon/code/scratchpad/sleap/social0.2/{aeon3_social_02["session"]}_rfid_patch_frames.csv')
videos_dict = {
    video: sleap.Video.from_filename(video, grayscale=True)
    for video in cameras._path.unique()
}

for _, row in cameras.drop_duplicates(subset=["_path", "_frame"]).iterrows():
    sfs = []
    # create a new SuggestionFrame for each row
    sfs.append(
        SuggestionFrame(
            video=videos_dict[row._path],
            frame_idx=row._frame,
        )
    )
labels_suggest = sleap.Labels(videos=list(videos_dict.values()), suggestions=sfs)
labels_suggest

In [24]:
slp_path = f'/ceph/aeon/aeon/code/scratchpad/sleap/social0.2/{aeon3_social_02["session"]}_rfid_patch_frames.slp'
sleap.Labels.save_file(labels_suggest, slp_path)