## AEON env

In [1]:
import os
import warnings

import aeon
import aeon.io.api as aeon_api
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from aeon.analysis.movies import gridframes
from aeon.io.video import frames
from aeon.schema.schemas import social02
from dotmap import DotMap
import glob
import pickle

In [3]:
aeon3_social_02 = {
    "subjects": {
        "BAA-1104047": {  # no tattoo
            "root": "/ceph/aeon/aeon/data/raw/AEON3/social0.2/",
            "start": pd.Timestamp("2024-02-05 15:43:11.581535816"),
            "end": pd.Timestamp("2024-02-08 14:49:41.552000046"),
        },
        "BAA-1104045": {  # tattooed
            "root": "/ceph/aeon/aeon/data/raw/AEON3/social0.2/",
            "start": pd.Timestamp("2024-01-31 11:28:45.543519974"),
            "end": pd.Timestamp("2024-02-03 16:28:29.139999866"),
        },
        "multi_animal": {
            "root": "/ceph/aeon/aeon/data/raw/AEON3/social0.2/",  # different root
            "start": pd.Timestamp("2024-02-09 16:26:07.579999924"),  # placeholder
            "end": pd.Timestamp("2024-02-23 12:07:24.863999844"),
        },
    },
    # patch 1 and 2 rfids are swapped in AEON3
    "patch_camera_order": [2, 1, 3],
    "session": "aeon3_social_02",
}

aeon4_social_02 = {
    "subjects": {
        "BAA-1104049": {  # no tattoo
            "root": "/ceph/aeon/aeon/data/raw/AEON4/social0.2/",
            "start": pd.Timestamp("2024-01-09 16:46:00"),
            "end": pd.Timestamp("2024-01-09 17:15:00"),
        },
        "BAA-1104048": {  # tattooed
            "root": "/ceph/aeon/aeon/data/raw/AEON4/social0.2/",
            "start": pd.Timestamp("2024-01-09 16:10:00"),
            "end": pd.Timestamp("2024-01-09 16:39:00"),
        },
        "multi_animal": {
            "root": "/ceph/aeon/aeon/data/raw/AEON4/model-training/", 
            "start": pd.Timestamp("2024-01-21 19:12:00"), 
            "end": pd.Timestamp("2024-01-21 19:43:00"),
        },
    },
    "session": "aeon4_social_02",
}

In [3]:
def get_experiment_times(
    root: str | os.PathLike, start_time: pd.Timestamp, end_time: pd.Timestamp
) -> DotMap:
    """
    Retrieve experiment start and stop times from environment states
    (i.e. times outside of maintenance mode) occurring within the
    given start and end times.

    Args:
        root (str or os.PathLike): The root path where epoch data is stored.
        start_time (pandas.Timestamp): Start time.
        end_time (pandas.Timestamp): End time.

    Returns:
        DotMap: A DotMap object containing two keys: 'start' and 'stop',
        corresponding to pairs of experiment start and stop times.

    Notes:
    This function uses the last 'Maintenance' event (if available, otherwise
    `end_time`) as the last 'Experiment' stop time. If the first retrieved state
    is 'Maintenance' (e.g. 'Experiment' mode entered before `start_time`),
    `start_time` is used as the first 'Experiment' start time.
    """

    experiment_times = DotMap()
    env_states = aeon.load(
        root,
        social02.Environment.EnvironmentState,
        # aeon.io.reader.Csv("Environment_EnvironmentState_*", ["state"]),
        start_time,
        end_time,
    )
    if env_states.empty:
        warnings.warn(
            "The environment state df is empty. "
            "Using input `start_time` and `end_time` as experiment times."
        )
        experiment_times.start = [start_time]
        experiment_times.stop = [end_time]
        return experiment_times
    if env_states["state"].iloc[-1] != "Maintenance":
        warnings.warn(
            "No 'Maintenance' event at the end of the search range. "
            "Using input `end_time` as last experiment stop time."
        )
        # Pad with a "Maintenance" event at the end
        env_states = pd.concat(
            [
                env_states,
                pd.DataFrame(
                    "Maintenance",
                    index=[end_time],
                    columns=env_states.columns,
                ),
            ]
        )
    # Use the last "Maintenance" event as end time
    end_time = (env_states[env_states.state == "Maintenance"]).index[-1]
    env_states = env_states[~env_states.index.duplicated(keep="first")]
    # Retain only events between visit start and stop times
    env_states = env_states.iloc[
        env_states.index.get_indexer([start_time], method="bfill")[
            0
        ]: env_states.index.get_indexer([end_time], method="ffill")[0]
        + 1
    ]
    # Retain only events where state changes (experiment-maintenance pairs)
    env_states = env_states[env_states["state"].ne(
        env_states["state"].shift())]
    if env_states["state"].iloc[0] == "Maintenance":
        warnings.warn(
            "No 'Experiment' event at the start of the search range. "
            "Using input `end_time` as last experiment stop time."
        )
        # Pad with an "Experiment" event at the start
        env_states = pd.concat(
            [
                pd.DataFrame(
                    "Experiment",
                    index=[start_time],
                    columns=env_states.columns,
                ),
                env_states,
            ]
        )
    experiment_times.start = env_states[
        env_states["state"] == "Experiment"
    ].index.values
    experiment_times.stop = env_states[
        env_states["state"] == "Maintenance"
    ].index.values

    return experiment_times


def exclude_maintenance_data(
    data: pd.DataFrame, experiment_times: DotMap
) -> pd.DataFrame:
    """
    Exclude rows not in experiment times (i.e., corresponding to maintenance times)
    from the given dataframe.

    Args:
        data (pandas.DataFrame): The data to filter. Expected to have a DateTimeIndex.
        experiment_times (DotMap): A DotMap object containing experiment start and stop times.

    Returns:
        pandas.DataFrame: The filtered data.
    """
    filtered_data = pd.concat(
        [
            data.loc[start:stop]
            for start, stop in zip(experiment_times.start, experiment_times.stop)
        ]
    )
    return filtered_data


def get_single_frame(
    root: str | os.PathLike,
    video_reader: aeon.io.reader.Video,
    time: pd.Timestamp,
) -> np.ndarray:
    """
    Retrieve a single frame from the given root directory,
    Video reader, and time.

    Args:
        root (str or os.PathLike): The root path where epoch data
            is stored.
        video_reader (aeon.io.reader.Video): The Video reader.
        time (pd.Timestamp): The timestamp of the frame to retrieve.

    Returns:
        numpy.ndarray: The raw frame.
    """
    vdata = aeon.load(
        root, video_reader, start=time, end=time + pd.Timedelta(seconds=1)
    )
    vframe = frames(vdata.iloc[:1])
    return np.squeeze(list(vframe))


def show_frame(raw_frame: np.ndarray, width: int = 1440, height: int = 1080):
    """
    Display raw input frame(s).

    Args:
        raw_frame (numpy.ndarray): The raw frame.
        width (int): The width of the display layout.
        height (int): The height of the display layout.

    Returns:
        None
    """
    width = width
    height = height
    fig = go.Figure(
        data=[go.Image(z=raw_frame)],
        layout=go.Layout(
            width=width,
            height=height,
            xaxis=dict(
                visible=False,
            ),
            yaxis=dict(
                visible=False,
                scaleanchor="x",
            ),
            margin=dict(l=0, r=0, t=0, b=0),
        ),
    )
    fig.show()

### Identify start, end times to be added to subject dicts

In [4]:
# load subject states to retrieve enter, exit times
root = "/ceph/aeon/aeon/data/raw/AEON3/social0.2/"
# aeon.io.reader.Subject("Environment_SubjectState_*"))
aeon_api.load(root, social02.Environment.SubjectState)

Unnamed: 0_level_0,id,weight,type
time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
2024-01-31 11:28:45.543519974,BAA-1104045,29.1,Remain
2024-02-01 22:36:53.196512222,BAA-1104045,29.1,Remain
2024-02-02 00:15:06.000000000,BAA-1104045,29.1,Remain
2024-02-03 16:28:29.139999866,BAA-1104045,29.1,Exit
2024-02-05 15:43:11.581535816,BAA-1104047,33.200001,Remain
2024-02-08 14:49:41.552000046,BAA-1104047,30.4,Exit
2024-02-08 14:56:42.000000000,test,740000.0,Enter
2024-02-08 15:03:59.465536118,test,740000.0,Remain
2024-02-08 15:07:03.751999855,test,740000.0,Exit
2024-02-08 15:07:19.808000088,Test2,74.0,Enter


In [None]:
# load subject region visits if subject states are not available
# aeon.io.reader.Csv("Environment_SubjectVisits_*", columns=["id", "state", "region"]))
subj_visits = aeon_api.load(root, social02.Environment.SubjectVisits)
subj_visits[(subj_visits["region"] == "Environment")]

### Curate frames and export as csv

In [5]:
most_reads = []
subj_dict = aeon3_social_02["subjects"]["multi_animal"]
root = aeon3_social_02.get("root", subj_dict["root"])
# patch 1 and 2 rfids are swapped in AEON3
camera_idx = aeon3_social_02.get("patch_camera_order", [1, 2, 3])
start_time = subj_dict["start"]
end_time = subj_dict["end"]
experiment_times = get_experiment_times(root, start_time, end_time)
print(
    f"multi_animal"
    f"\nexp start times: {experiment_times.start}"
    f"\nexp stop times: {experiment_times.stop}"
)
for i in range(1, 4):
    rfid_reads = aeon_api.load(
        root,
        aeon.io.reader.Harp(
            pattern=f"Patch{i}Rfid_32_*", columns=["Rfid"]),
        start=start_time,
        end=end_time,
    )
    rfid_reads = exclude_maintenance_data(rfid_reads, experiment_times)
    # most_reads.append(rfid_reads.groupby(
    #     pd.Grouper(freq="h")).size().nlargest(1).index)
    most_reads_datetime = rfid_reads.groupby(
        pd.Grouper(freq="h")).size().nlargest(1).index[0].strftime('%Y-%m-%dT%H-%M-%S')
    most_reads_filename = f"CameraPatch{i}_{most_reads_datetime}.avi"
    most_reads_file_path = next(glob.iglob(root + '/**/' + most_reads_filename, recursive=True), None)
    most_reads.append(most_reads_file_path)
    
print(most_reads)



multi_animal
exp start times: ['2024-02-09T16:26:12.840000153' '2024-02-12T08:08:43.440000057'
 '2024-02-12T11:38:51.380000114' '2024-02-12T12:56:08.984000206'
 '2024-02-13T13:57:48.820000172' '2024-02-13T14:12:49.311999798'
 '2024-02-13T15:18:59.231999874' '2024-02-15T17:03:51.800000191'
 '2024-02-19T10:31:41.872000217' '2024-02-20T20:19:07.311999798'
 '2024-02-22T17:16:44.967999935']
exp stop times: ['2024-02-12T08:06:26.831999779' '2024-02-12T11:38:12.688000202'
 '2024-02-12T12:51:14.599999905' '2024-02-13T13:57:14.980000019'
 '2024-02-13T14:10:56.624000072' '2024-02-13T15:14:58.535999774'
 '2024-02-15T17:02:02.432000160' '2024-02-19T10:26:01.576000214'
 '2024-02-20T20:17:21.607999802' '2024-02-22T17:15:01.743999958'
 '2024-02-23T12:07:24.863999844']
['/ceph/aeon/aeon/data/raw/AEON3/social0.2/2024-02-15T17-03-12/CameraPatch1/CameraPatch1_2024-02-21T09-00-00.avi', '/ceph/aeon/aeon/data/raw/AEON3/social0.2/2024-02-09T16-07-32/CameraPatch2/CameraPatch2_2024-02-14T15-00-00.avi', '/ceph/

In [8]:
# Save list so we can change kernel
with open(f'/ceph/aeon/aeon/code/scratchpad/sleap/social0.2/AEON3/{aeon3_social_02["session"]}_rfid_patch_chunks.pkl', 'wb') as f:
    pickle.dump(most_reads, f)

## SLEAP env

In [5]:
import pandas as pd
import sleap
from sleap.gui.suggestions import SuggestionFrame
import pickle

In [6]:
with open(f'/ceph/aeon/aeon/code/scratchpad/sleap/social0.2/AEON3/{aeon3_social_02["session"]}_rfid_patch_chunks.pkl', 'rb') as f:
    most_reads = pickle.load(f)
most_reads

['/ceph/aeon/aeon/data/raw/AEON3/social0.2/2024-02-15T17-03-12/CameraPatch1/CameraPatch1_2024-02-21T09-00-00.avi',
 '/ceph/aeon/aeon/data/raw/AEON3/social0.2/2024-02-09T16-07-32/CameraPatch2/CameraPatch2_2024-02-14T15-00-00.avi',
 '/ceph/aeon/aeon/data/raw/AEON3/social0.2/2024-02-09T16-07-32/CameraPatch3/CameraPatch3_2024-02-11T09-00-00.avi']

In [10]:
videos_dict = {
    video: sleap.Video.from_filename(video, grayscale=True)
    for video in most_reads
}
videos_dict

{'/ceph/aeon/aeon/data/raw/AEON3/social0.2/2024-02-15T17-03-12/CameraPatch1/CameraPatch1_2024-02-21T09-00-00.avi': Video(backend=MediaVideo(filename='/ceph/aeon/aeon/data/raw/AEON3/social0.2/2024-02-15T17-03-12/CameraPatch1/CameraPatch1_2024-02-21T09-00-00.avi', grayscale=True, bgr=True, dataset='', input_format='')),
 '/ceph/aeon/aeon/data/raw/AEON3/social0.2/2024-02-09T16-07-32/CameraPatch2/CameraPatch2_2024-02-14T15-00-00.avi': Video(backend=MediaVideo(filename='/ceph/aeon/aeon/data/raw/AEON3/social0.2/2024-02-09T16-07-32/CameraPatch2/CameraPatch2_2024-02-14T15-00-00.avi', grayscale=True, bgr=True, dataset='', input_format='')),
 '/ceph/aeon/aeon/data/raw/AEON3/social0.2/2024-02-09T16-07-32/CameraPatch3/CameraPatch3_2024-02-11T09-00-00.avi': Video(backend=MediaVideo(filename='/ceph/aeon/aeon/data/raw/AEON3/social0.2/2024-02-09T16-07-32/CameraPatch3/CameraPatch3_2024-02-11T09-00-00.avi', grayscale=True, bgr=True, dataset='', input_format=''))}

### Extract a sample of frames where both mice are present and convert to user labelled

In [None]:
slp_pr_path = (
    "/ceph/aeon/aeon/code/scratchpad/sleap/social0.2/AEON3/"
    f"{aeon3_social_02['session']}_rfid_patch_chunks_pr.slp"
)
labels_pr = sleap.load_file(slp_pr_path)

In [None]:
# create skeleton with just 'tail_base' and 'tattoo'
skeleton = sleap.Skeleton()
skeleton.add_nodes(["tail_base", "tattoo"])
skeleton.add_edge("tail_base", "tattoo")