## AEON env

In [1]:
import os
import warnings

import aeon
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px
from aeon.analysis.movies import gridframes
from aeon.io.video import frames
from aeon.schema.schemas import social02
from dotmap import DotMap

In [2]:
aeon3_social_02 = {
    "subjects": {
        "BAA-1104045": {  # tattooed
            "root": "/ceph/aeon/aeon/data/raw/AEON3/social0.2/",
            "start": pd.Timestamp("2024-01-31 11:28:45.543519974"),
            "end": pd.Timestamp("2024-02-03 16:28:29.139999866"),
            "rfid": 977200010377711,
        },
        "BAA-1104047": {  # no tattoo
            "root": "/ceph/aeon/aeon/data/raw/AEON3/social0.2/",
            "start": pd.Timestamp("2024-02-05 15:43:11.581535816"),
            "end": pd.Timestamp("2024-02-08 14:49:41.552000046"),
            "rfid": 977200010164158,
        },
        "multi_animal": {
            "root": "/ceph/aeon/aeon/data/raw/AEON3/social0.2/",
            "start": pd.Timestamp("2024-02-09 16:27:00"),
            "end": pd.Timestamp("2024-02-15 17:03:00"),
        },
    },
    # patch 1 and 2 rfids are swapped in AEON3
    "patch_camera_order": [2, 1, 3],
    "session": "aeon3_social_02",
}

aeon4_social_02 = {
    "root": "/ceph/aeon/aeon/data/raw/AEON4/social0.2/",
    "subjects": {
        "BAA-1104048": {  # tattooed
            "start": pd.Timestamp("2024-01-31 10:23:00"),
            "end": pd.Timestamp("2024-02-03 16:34:00"),
            "rfid": 977200010379293,
        },
        "BAA-1104049": {  # no tattoo
            "start": pd.Timestamp("2024-02-05 15:00:00"),
            "end": pd.Timestamp("2024-02-08 14:54:00"),
            "rfid": 977200010378675,
        },
        "multi_animal": {
            "start": pd.Timestamp("2024-02-09 16:49:00"),
            "end": pd.Timestamp("2024-02-13 15:06:00"),
        },
    },
    "session": "aeon4_social_02",
}

In [3]:
def get_experiment_times(
    root: str | os.PathLike, start_time: pd.Timestamp, end_time: pd.Timestamp
) -> DotMap:
    """
    Retrieve experiment start and stop times from environment states
    (i.e. times outside of maintenance mode) occurring within the
    given start and end times.

    Args:
        root (str or os.PathLike): The root path where epoch data is stored.
        start_time (pandas.Timestamp): Start time.
        end_time (pandas.Timestamp): End time.

    Returns:
        DotMap: A DotMap object containing two keys: 'start' and 'stop',
        corresponding to pairs of experiment start and stop times.

    Notes:
    This function uses the last 'Maintenance' event (if available, otherwise
    `end_time`) as the last 'Experiment' stop time. If the first retrieved state
    is 'Maintenance' (e.g. 'Experiment' mode entered before `start_time`),
    `start_time` is used as the first 'Experiment' start time.
    """

    experiment_times = DotMap()
    env_states = aeon.load(
        root,
        social02.Environment.EnvironmentState,
        # aeon.io.reader.Csv("Environment_EnvironmentState_*", ["state"]),
        start_time,
        end_time,
    )
    if env_states.empty:
        warnings.warn(
            "The environment state df is empty. "
            "Using input `start_time` and `end_time` as experiment times."
        )
        experiment_times.start = [start_time]
        experiment_times.stop = [end_time]
        return experiment_times
    if env_states["state"].iloc[-1] != "Maintenance":
        warnings.warn(
            "No 'Maintenance' event at the end of the search range. "
            "Using input `end_time` as last experiment stop time."
        )
        # Pad with a "Maintenance" event at the end
        env_states = pd.concat(
            [
                env_states,
                pd.DataFrame(
                    "Maintenance",
                    index=[end_time],
                    columns=env_states.columns,
                ),
            ]
        )
    # Use the last "Maintenance" event as end time
    end_time = (env_states[env_states.state == "Maintenance"]).index[-1]
    env_states = env_states[~env_states.index.duplicated(keep="first")]
    # Retain only events between visit start and stop times
    env_states = env_states.iloc[
        env_states.index.get_indexer([start_time], method="bfill")[
            0
        ] : env_states.index.get_indexer([end_time], method="ffill")[0]
        + 1
    ]
    # Retain only events where state changes (experiment-maintenance pairs)
    env_states = env_states[env_states["state"].ne(env_states["state"].shift())]
    if env_states["state"].iloc[0] == "Maintenance":
        warnings.warn(
            "No 'Experiment' event at the start of the search range. "
            "Using input `end_time` as last experiment stop time."
        )
        # Pad with an "Experiment" event at the start
        env_states = pd.concat(
            [
                pd.DataFrame(
                    "Experiment",
                    index=[start_time],
                    columns=env_states.columns,
                ),
                env_states,
            ]
        )
    experiment_times.start = env_states[
        env_states["state"] == "Experiment"
    ].index.values
    experiment_times.stop = env_states[
        env_states["state"] == "Maintenance"
    ].index.values

    return experiment_times


def exclude_maintenance_data(
    data: pd.DataFrame, experiment_times: DotMap
) -> pd.DataFrame:
    """
    Exclude rows not in experiment times (i.e., corresponding to maintenance times)
    from the given dataframe.

    Args:
        data (pandas.DataFrame): The data to filter. Expected to have a DateTimeIndex.
        experiment_times (DotMap): A DotMap object containing experiment start and stop times.

    Returns:
        pandas.DataFrame: The filtered data.
    """
    filtered_data = pd.concat(
        [
            data.loc[start:stop]
            for start, stop in zip(experiment_times.start, experiment_times.stop)
        ]
    )
    return filtered_data


def get_single_frame(
    root: str | os.PathLike,
    video_reader: aeon.io.reader.Video,
    time: pd.Timestamp,
) -> np.ndarray:
    """
    Retrieve a single frame from the given root directory,
    Video reader, and time.

    Args:
        root (str or os.PathLike): The root path where epoch data
            is stored.
        video_reader (aeon.io.reader.Video): The Video reader.
        time (pd.Timestamp): The timestamp of the frame to retrieve.

    Returns:
        numpy.ndarray: The raw frame.
    """
    vdata = aeon.load(
        root, video_reader, start=time, end=time + pd.Timedelta(seconds=1)
    )
    vframe = frames(vdata.iloc[:1])
    return np.squeeze(list(vframe))


def plot_overlay(
    img: np.ndarray,
    pose: pd.DataFrame = None,
    roi_coords: "list[list[tuple[int, int]]]" = None,
) -> None:
    """
    Overlay any pose data and/or ROI points on a frame (image).

    Args:
        img (numpy.ndarray): The raw frame.
        pose (pandas.DataFrame): The pose data with "id" as a grouping column.
        roi_coords (list[list[tuple[int, int]]]): A list of lists of tuples
            representing the ROI x, y coordinates.

    Returns:
        None
    """
    fig = px.imshow(img)
    if pose is not None:
        scatter = px.scatter(
            pose,
            x="x",
            y="y",
            color="id",
            opacity=0.2,
        )
        for trace in scatter.data:
            fig.add_trace(trace)
    if roi_coords is not None:
        for roi_coord in roi_coords:
            pts = roi_coord + [roi_coord[0]]
            xs, ys = zip(*pts)
            fig.add_scatter(x=xs, y=ys, fill="toself")
    fig.show()

### Identify start, end times to be added to subject dicts

In [None]:
# load subject states to retrieve enter, exit times
root = "/ceph/aeon/aeon/data/raw/AEON4/social0.2/"
# aeon.io.reader.Subject("Environment_SubjectState_*"))
aeon.load(root, social02.Environment.SubjectState)

In [None]:
# load subject region visits if subject states are not available
# aeon.io.reader.Csv("Environment_SubjectVisits_*", columns=["id", "state", "region"]))
subj_visits = aeon.load(root, social02.Environment.SubjectVisits)
subj_visits[(subj_visits["region"] == "Environment")]

### Curate frames from single animal videos and export as csv

In [5]:
# get experiment times for each subject within a session
subj_ids = [subj for subj in aeon3_social_02["subjects"] if "multi_" not in subj]
cameras = pd.DataFrame()
for subj in subj_ids:
    subj_dict = aeon3_social_02["subjects"][subj]
    root = aeon3_social_02.get("root", subj_dict["root"])
    # patch 1 and 2 rfids are swapped in AEON3
    camera_idx = aeon3_social_02.get("patch_camera_order", [1, 2, 3])
    start_time = subj_dict["start"]
    end_time = subj_dict["end"]
    experiment_times = get_experiment_times(root, start_time, end_time)
    print(
        f"{subj}"
        f"\nexp start times: {experiment_times.start}"
        f"\nexp stop times: {experiment_times.stop}"
    )
    for i in range(1, 4):
        rfid_reads = aeon.load(
            root,
            aeon.io.reader.Harp(pattern=f"Patch{i}Rfid_32_*", columns=["Rfid"]),
            start=start_time,
            end=end_time,
        )
        rfid_reads = exclude_maintenance_data(rfid_reads, experiment_times)
        # get the 3 hours with the most RFID reads
        most_reads = rfid_reads.groupby(pd.Grouper(freq="h")).size().nlargest(3).index
        rfid_most_reads = pd.DataFrame()
        for hour in most_reads:
            rfid_most_reads = pd.concat(
                [rfid_most_reads, rfid_reads.loc[hour : hour + pd.Timedelta(hours=1)]]
            )
        n_samples = min(100, len(rfid_most_reads))
        # randomly sample max 100 rows per patch
        rfid_most_reads = rfid_most_reads.sample(n=n_samples)
        camera = aeon.load(
            root,
            aeon.io.reader.Video(f"CameraPatch{camera_idx[i-1]}_*"),
            time=rfid_most_reads.index,
        )
        camera["id"] = subj
        cameras = pd.concat([cameras, camera])

BAA-1104045
exp start times: ['2024-01-31T11:28:45.543519974']
exp stop times: ['2024-02-03T16:23:52.943999767']
BAA-1104047
exp start times: ['2024-02-05T15:43:11.581535816' '2024-02-08T14:31:46.447999954']
exp stop times: ['2024-02-08T14:23:15.659999847' '2024-02-08T14:41:38.703999996']


In [None]:
cameras.groupby("id").size()

In [7]:
cameras.to_csv(
    "/ceph/aeon/aeon/code/scratchpad/sleap/social0.2/"
    f"AEON3/{aeon3_social_02['session']}_rfid_patch_frames_tail2.csv"
)

In [None]:
# Optional sanity check to make sure frames contain animals
fig = px.imshow(
    gridframes(list(frames(cameras.sample(n=25))),
               width=1440, height=1080, shape=25)
)
fig.show()

### Curate multi-animal frames

In [7]:
root = aeon3_social_02.get("root", aeon3_social_02["subjects"]["BAA-1104047"]["root"])
start_time = aeon3_social_02["subjects"]["multi_animal"]["start"]
end_time = aeon3_social_02["subjects"]["multi_animal"]["end"]
experiment_times = get_experiment_times(root, start_time, end_time)


No 'Experiment' event at the start of the search range. Using input `end_time` as last experiment stop time.



In [8]:
rfid_times = pd.concat(
    [
        aeon.load(
            root,
            aeon.io.reader.Harp(pattern=f"Patch{i}Rfid_32_*", columns=["Rfid"]),
            start=start_time,
            end=end_time,
        ).assign(Patch=i)
        for i in range(1, 4)
    ]
)
rfid_times.sort_index(inplace=True)
rfid_hour_chunks = rfid_times.index.round("h").unique()

In [None]:
# Get patch coordinates
metadata = aeon.load(
    root, social02.Metadata, start=pd.Timestamp(start_time.date()), end=end_time
)["metadata"].iloc[0]
patch_coords = [
    [
        (int(pt.X), int(pt.Y))
        for pt in eval(f"metadata.ActiveRegion.Patch{i}Region.ArrayOfPoint")
    ]
    for i in range(1, 4)
]
# Visualise patch coordinates
img = get_single_frame(root, social02.CameraTop.Video, start_time)
plot_overlay(img, roi_coords=patch_coords)
# Get patch midpoints
patch_midpoints = [
    (sum(x for x, y in patch_coords[i]) / 4, sum(y for x, y in patch_coords[i]) / 4)
    for i in range(3)
]

In [12]:
# Filter pose data to only include frames where all animals are within 100 pixels of the patch midpoint
def within_patch_radius(row, patch_midpoints, radius):
    for num, midpoint in enumerate(patch_midpoints):
        if (
            np.sqrt((row["x"] - midpoint[0]) ** 2 + (row["y"] - midpoint[1]) ** 2)
            <= radius
        ):
            return f"Patch{str(num + 1)}"
    return np.nan


pose = (
    aeon.load(
        root,
        social02.CameraTop.Pose,
        start=start_time,
        end=end_time,
    )
    # exclude maintenance data
    .pipe(exclude_maintenance_data, experiment_times)
    # exclude hour chunks without RFID reads
    .loc[lambda df: df.index.round("h").isin(rfid_hour_chunks)]
    .assign(
        in_patch=lambda df: df.apply(
            within_patch_radius, axis=1, args=(patch_midpoints, 100)
        )
    )
    # include only rows with animals in patch
    .dropna(subset=["in_patch"])
    # include only timestamps with multiple animals
    .loc[lambda df: df.index.duplicated(keep=False)]
    # group by timestamp
    .groupby(level=0)
    # include only rows with all animals in the same patch
    .filter(lambda x: x["in_patch"].nunique() == 1)
    .assign(
        id=lambda df: df["class"].apply(
            lambda x: "BAA-1104047" if x == 1.0 else "BAA-1104045"
        )
    )
)
plot_overlay(img, pose=pose, roi_coords=patch_coords)

In [54]:
# Sample for each patch 50 frames from the 3 hours with the most co-occurrences
most_cooccurrences = (
    pose.groupby(["in_patch", pd.Grouper(freq="h")])
    .size()
    .groupby(level=0, group_keys=False)
    .nlargest(3)
    .index
)
cameras = pd.DataFrame()
for patch, time in most_cooccurrences:
    group = pose.loc[(pose["in_patch"] == patch) &
                     (pose.index.floor("h") == time)]
    camera = aeon.load(
        root,
        aeon.io.reader.Video(f"Camera{patch}_*"),
        time=group.index,
    ).sample(n=50)
    cameras = pd.concat([cameras, camera])

cameras.to_csv(
    "/ceph/aeon/aeon/code/scratchpad/sleap/social0.2/"
    f"AEON3/{aeon3_social_02['session']}_rfid_patch_frames_tail_ma.csv"
)

In [None]:
# Optional sanity check to make sure frames contain animals
fig = px.imshow(
    gridframes(list(frames(cameras.sample(n=25))), width=1440, height=1080, shape=25)
)
fig.show()

## SLEAP env

In [3]:
import pandas as pd
import sleap
from sleap.gui.suggestions import SuggestionFrame
from sleap.io.pathutils import fix_path_separator


def update_slp_video_paths(
    labels: sleap.Labels, old_path: str, new_path: str
) -> sleap.Labels:
    """
    Updates video paths in a SLEAP labels object (e.g., to move training from local to remote machine).

    Args:
        labels (sleap.Labels): A SLEAP Labels object.
        old_path (str): Old path to video files.
        new_path (str): New path to video files.

    Returns:
        sleap.Labels: A SLEAP Labels object with updated video paths.

    """

    videos = [
        sleap.Video.from_filename(
            fix_path_separator(vid.filename).replace(old_path, new_path), grayscale=True
        )
        for vid in labels.videos
    ]

    lfs = []
    for lf in labels.labeled_frames:
        lf = sleap.instance.LabeledFrame(
            video=videos[labels.videos.index(lf.video)],
            frame_idx=lf.frame_idx,
            instances=lf.instances,
        )
        lfs.append(lf)

    return sleap.Labels(
        labeled_frames=lfs,
        videos=videos,
        skeletons=labels.skeletons,
        tracks=labels.tracks,
    )

In [None]:
cameras = pd.read_csv(
    "/ceph/aeon/aeon/code/scratchpad/sleap/social0.2/"
    f"AEON3/{aeon3_social_02['session']}_rfid_patch_frames_tail2.csv"
)
cameras

In [5]:
videos_dict = {
    video: sleap.Video.from_filename(video, grayscale=True)
    for video in cameras._path.unique()
}

In [6]:
# create skeleton with just "spine_bottom", "tail_base"
skeleton = sleap.Skeleton()
skeleton.add_nodes(["spine_bottom", "tail_base"])
# skeleton.add_edge("spine_top", "centroid")
# skeleton.add_edge("centroid", "spine_bottom")
skeleton.add_edge("spine_bottom", "tail_base")

In [9]:
labels = sleap.load_file(
    "/ceph/aeon/aeon/code/scratchpad/sleap/social0.2/"
    f"AEON3/{aeon3_social_02['session']}_rfid_patch_frames_tail_2_points.slp"
)

In [17]:
# create a dict of Tracks
subjs = (
    cameras["id"].unique()
    if "id" in cameras.columns
    else [subj for subj in aeon3_social_02["subjects"].keys() if "multi_" not in subj]
)
tracks_dict = {subj: sleap.Track(spawned_on=0, name=subj) for subj in subjs}

### Create SuggestionFrames to be inferred using existing patch cam models

In [12]:
sfs = []
for _, row in cameras.drop_duplicates(subset=["_path", "_frame"]).iterrows():
    # create a new SuggestionFrame for each row
    sfs.append(
        SuggestionFrame(
            video=videos_dict[row._path],
            frame_idx=row._frame,
        )
    )
labels_suggest = sleap.Labels(
    videos=list(videos_dict.values()),
    suggestions=sfs,
    skeletons=[skeleton],
    tracks=labels.tracks,  # list(tracks_dict.values()),
)
len(labels_suggest.suggestions)

600

In [13]:
slp_path = (
    "/ceph/aeon/aeon/code/scratchpad/sleap/social0.2/"
    f"AEON3/{aeon3_social_02['session']}_rfid_patch_frames_tail_2_points_1.slp"
)
sleap.Labels.save_file(labels_suggest, slp_path)

### Convert predictions to user labels

In [4]:
slp_pr_path = (
    "/ceph/aeon/aeon/code/scratchpad/sleap/social0.2/AEON3/"
    f"{aeon3_social_02['session']}_rfid_patch_frames_tail_2_points_2.slp"
)
labels_pr = sleap.load_file(slp_pr_path)

In [49]:
lfs = []
for lf in labels_pr.labeled_frames:
    instances = []
    for inst in lf.instances_to_show:  # instances:
        track_name = inst.track.name
        instances.append(
            sleap.Instance(
                skeleton=skeleton,
                track=tracks_dict.get(track_name),  # inst.track,
                points={
                    node.name: sleap.instance.Point(
                        x=point.x, y=point.y, visible=point.visible
                    )
                    for node, point in inst.nodes_points
                    if node.name in skeleton.node_names
                },
            )
        )
    # create a new labeled frame
    user_lf = sleap.instance.LabeledFrame(
        video=lf.video,
        frame_idx=lf.frame_idx,
        instances=instances,
    )
    lfs.append(user_lf)
labels_user = sleap.Labels(labeled_frames=lfs)
labels_user = update_slp_video_paths(labels_user, "Z:", "/ceph/aeon")
labels_user

Labels(labeled_frames=1299, videos=32, skeletons=1, tracks=2)

In [52]:
sleap.Labels.save_file(
    labels_user,
    "/ceph/aeon/aeon/code/scratchpad/sleap/social0.2/"
    f"AEON3/{aeon3_social_02['session']
             }_rfid_patch_frames_tail_2_points_2.slp",
)