# AEON

#### create training dataset from aeon raw data

In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
from pathlib import Path
import re
import warnings
import aeon
import aeon.io.api as aeon_api

from aeon.schema.schemas import exp02, social02
from aeon.analysis.utils import *
from dotmap import DotMap
from scipy import stats

import datajoint as dj
from aeon.dj_pipeline.analysis.block_analysis import *


def get_experiment_times(
    root: str | os.PathLike, start_time: pd.Timestamp, end_time: pd.Timestamp
) -> DotMap:
    """
    Retrieve experiment start and stop times from environment states
    (i.e. times outside of maintenance mode) occurring within the
    given start and end times.

    Args:
        root (str or os.PathLike): The root path where epoch data is stored.
        start_time (pandas.Timestamp): Start time.
        end_time (pandas.Timestamp): End time.

    Returns:
        DotMap: A DotMap object containing two keys: 'start' and 'stop',
        corresponding to pairs of experiment start and stop times.

    Notes:
    This function uses the last 'Maintenance' event (if available, otherwise
    `end_time`) as the last 'Experiment' stop time. If the first retrieved state
    is 'Maintenance' (e.g. 'Experiment' mode entered before `start_time`),
    `start_time` is used as the first 'Experiment' start time.
    """

    experiment_times = DotMap()
    env_states = aeon.load(
        root,
        social02.Environment.EnvironmentState,
        # aeon.io.reader.Csv("Environment_EnvironmentState_*", ["state"]),
        start_time,
        end_time,
    )
    if env_states.empty:
        warnings.warn(
            "The environment state df is empty. "
            "Using input `start_time` and `end_time` as experiment times."
        )
        experiment_times.start = [start_time]
        experiment_times.stop = [end_time]
        return experiment_times
    if env_states["state"].iloc[-1] != "Maintenance":
        warnings.warn(
            "No 'Maintenance' event at the end of the search range. "
            "Using input `end_time` as last experiment stop time."
        )
        # Pad with a "Maintenance" event at the end
        env_states = pd.concat(
            [
                env_states,
                pd.DataFrame(
                    "Maintenance",
                    index=[end_time],
                    columns=env_states.columns,
                ),
            ]
        )
    # Use the last "Maintenance" event as end time
    end_time = (env_states[env_states.state == "Maintenance"]).index[-1]
    env_states = env_states[~env_states.index.duplicated(keep="first")]
    # Retain only events between visit start and stop times
    env_states = env_states.iloc[
        env_states.index.get_indexer([start_time], method="bfill")[
            0
        ] : env_states.index.get_indexer([end_time], method="ffill")[0] + 1
    ]
    # Retain only events where state changes (experiment-maintenance pairs)
    env_states = env_states[env_states["state"].ne(env_states["state"].shift())]
    if env_states["state"].iloc[0] == "Maintenance":
        warnings.warn(
            "No 'Experiment' event at the start of the search range. "
            "Using input `end_time` as last experiment stop time."
        )
        # Pad with an "Experiment" event at the start
        env_states = pd.concat(
            [
                pd.DataFrame(
                    "Experiment",
                    index=[start_time],
                    columns=env_states.columns,
                ),
                env_states,
            ]
        )
    experiment_times.start = env_states[
        env_states["state"] == "Experiment"
    ].index.values
    experiment_times.stop = env_states[
        env_states["state"] == "Maintenance"
    ].index.values

    return experiment_times


def exclude_maintenance_data(
    data: pd.DataFrame, experiment_times: DotMap
) -> pd.DataFrame:
    """
    Excludes rows not in experiment times (i.e., corresponding to maintenance times)
    from the given dataframe.

    Args:
        data (pandas.DataFrame): The data to filter. Expected to have a pandas.DateTimeIndex.
        experiment_times (DotMap): A DotMap object containing experiment start and stop times.

    Returns:
        pandas.DataFrame: A pandas DataFrame containing filtered data.
    """
    filtered_data = pd.concat(
        [
            data.loc[start:stop]
            for start, stop in zip(experiment_times.start, experiment_times.stop)
        ]
    )
    return filtered_data

def get_raw_tracking_data(
    root: str | os.PathLike,
    subj_id: str,
    start: pd.Timestamp,
    end: pd.Timestamp,
    source_reader: aeon.io.reader.Video = exp02.CameraTop.Video,
) -> pd.DataFrame:
    """
    Retrieve pos tracking and video data and assigns subject ID.

    Args:
        root (str or os.PathLike): The root path, or prioritised sequence of paths, where epoch data is stored.
        subj_id (str): The subject ID string to be assigned.
        start (pandas.Timestamp): The left bound of the time range to extract.
        end (pandas.Timestamp): The right bound of the time range to extract.
        source_reader (aeon.io.reader.Video, optional): The frame source reader. Default is exp02.CameraTop.Video.
    Returns:
        pandas.DataFrame: A pandas DataFrame containing pos tracking and video data, and subject ID.
    """

    subj_video = aeon_api.load(root, source_reader, start=start, end=end)
    path = Path(root)
    acquisition_computer = path.parts[-2].lower()
    experiment = path.parts[-1].rstrip('/')
    experiment = re.sub(r"(social0)(\d+)", r"\1.\2", experiment)
    experiment_name = f"{experiment}-{acquisition_computer}"
    key = {"experiment_name": experiment_name}
    chunk_restriction = {"chunk_start": start}
    pose_query = (
        streams.SpinnakerVideoSource
        * tracking.SLEAPTracking.PoseIdentity.proj("identity_name", "identity_likelihood", anchor_part="part_name")
        * tracking.SLEAPTracking.Part
        & {"spinnaker_video_source_name": "CameraTop"}
        & key
        & chunk_restriction
        & {"part_name": "spine2"}
    )
    subj_pos = fetch_stream(pose_query)
    # replace "raw" in root with "processed"
    processed_root = root.replace("raw", "processed")
    if subj_video.empty:
        subj_video = aeon_api.load(processed_root, source_reader, start=start, end=end)
        warnings.warn("subj_video data from raw is empty, retrieving data from processed")
    if subj_pos.empty:
        warnings.warn("subj_pos data from DJ is empty, retrieving data from processed")
        subj_pos = aeon_api.load(processed_root, aeon.io.reader.Pose(pattern="CameraTop_202_gpu-partition*"), start=start, end=end)
        if subj_pos.empty:
            warnings.warn("subj_pos data from processed is empty, retrieving data from raw")
            subj_pos = aeon_api.load(root, aeon.io.reader.Pose(pattern="CameraTop_202*"), start=start, end=end)
            display(subj_pos.sort_index())
            if subj_pos.empty:
                raise ValueError("No tracking data found.")
        subj_pos = subj_pos[subj_pos["part"] == ("spine2" if "spine2" in subj_pos["part"].values else "centroid")]
        subj_pos.drop(columns=["part"], inplace=True)
        subj_pos.rename(columns={"identity": "identity_name", "part_likelihood": "likelihood"}, inplace=True)
    subj_data = pd.DataFrame()
    for subj in subj_pos["identity_name"].unique():
        sample = subj_pos[subj_pos["identity_name"] == subj]
        subj_data_group = pd.merge_asof(
            subj_video,
            sample,
            left_index=True,
            right_index=True,
            direction="nearest",
            tolerance=pd.Timedelta("1ms"),
        )[["x", "y", "identity_name", "_frame", "_path"]]
        subj_data = pd.concat([subj_data, subj_data_group])
    subj_data.dropna(inplace=True)
    subj_data["distance"] = np.nan
    if "multi_" in subj_id:
        group_sizes = subj_data.groupby(level=0).size()
        if (group_sizes > 2).any():
            times_with_more_than_two = group_sizes[group_sizes > 2].index.tolist()
            raise ValueError(f"More than two identities found for time(s): {times_with_more_than_two}")
        # Add a row number within each group to pivot the data
        subj_data['row_num'] = subj_data.groupby(level=0).cumcount()
        # Pivot the DataFrame to get 'x' and 'y' coordinates for each identity as separate columns
        subj_data_pivot = subj_data.pivot_table(
            index=subj_data.index.get_level_values(0),
            columns='row_num',
            values=['x', 'y']
        )
        # Extract coordinates
        x0 = subj_data_pivot['x'][0]
        y0 = subj_data_pivot['y'][0]
        x1 = subj_data_pivot['x'][1]
        y1 = subj_data_pivot['y'][1]
        # Calculate distances using vectorized operations
        distance = np.sqrt((x0 - x1)**2 + (y0 - y1)**2)
        # Map the calculated distances back to the original DataFrame
        subj_data['distance'] = subj_data.index.get_level_values(0).map(distance)
        subj_data.drop(columns='row_num', inplace=True)
    subj_data["id"] = subj_id
    # Make x and y columns numeric
    subj_data["x"] = pd.to_numeric(subj_data["x"])
    subj_data["y"] = pd.to_numeric(subj_data["y"])

    return subj_data


def sample_n_from_bins(
    subj_data: pd.DataFrame,
    n_samples: int = 1,
    n_bins: int = 50,
    range: npt.ArrayLike = [[0, 1440], [0, 1080]],
) -> pd.DataFrame:
    """
    Uniformly samples n number of data from x number of bins.

    Args:
        subj_data (pandas.DataFrame): A pandas DataFrame containing pos tracking and video data, and subject ID.
        n_samples (int, optional): The number of samples to take from each bin. Default is 1.
        n_bins (int, optional): The number of bins to use for sampling. Default is 50.
        range (list of lists, optional): The leftmost and rightmost edges of the bins along each dimension
            (if not specified explicitly in the bins parameters): [[xmin, xmax], [ymin, ymax]]. All values
            outside of this range will be considered outliers and not tallied in the histogram. Default is
            [[0, 1440], [0, 1080]].

    Returns:
        pandas.DataFrame: A pandas DataFrame containing uniformly-sampled pos tracking and video data, and subject ID.
    """

    hist_data = stats.binned_statistic_2d(
        subj_data.x,
        subj_data.y,
        values=subj_data,
        statistic="count",
        bins=n_bins,
        range=range,
    )
    subj_data = subj_data.copy()
    subj_data["bin"] = hist_data.binnumber
    sampled_data = (
        subj_data.groupby(["bin"]).sample(n=n_samples, replace=True).drop_duplicates()
    )
    return sampled_data


def create_session_dataset(
    session: dict,
    subj_ids: list = None,
    plot_dist: bool = True,
) -> pd.DataFrame:
    """
    Creates a dataset for a given session dict.
    Args:
        session (dict): A dictionary containing the root path, subject IDs, and their start and end times.
        subj_ids (list, optional): A list of subject ids. If None, all subjects are selected.
        plot_dist (bool, optional): Whether to plot the 1d and 2d histograms of x, y pos tracking for each subject.
    Returns:
        pandas.DataFrame: A pandas dataframe containing uniformly-sampled pos tracking and video data, and subject ID.
    """
    all_subj_data = pd.DataFrame()
    if not subj_ids:
        subj_ids = session["subjects"].keys()
    for subj in subj_ids:
        subj_dict = {
            "id": subj,
            "root": session.get("root", session["subjects"][subj].get("root")),
            "start": session["subjects"][subj]["start"],
            "end": session["subjects"][subj]["end"],
        }
        subj_data = (
            create_subject_dataset(
                subj_dict,
                max_dist = 300,
                n_samples=2,
                n_bins=50,
            )  # sample fewer points for manual annotation
            if "multi_" in subj
            else create_subject_dataset(
                subj_dict,
            )
        )
        all_subj_data = pd.concat([all_subj_data, subj_data])
    if plot_dist:
        fig = plot_position_histograms(all_subj_data)
        fig.show()
    return all_subj_data


def plot_position_histograms(data: pd.DataFrame, n_bins: int = 50) -> plt.Figure:
    """
    Plots the 1d and 2d histograms of x, y pos tracking for each subject in a given DataFrame.
    Args:
        data (pandas.DataFrame): A pandas DataFrame containing x, y pos tracking and subject ID(s).
        n_bins (int, optional): The number of bins to use for plotting histograms. Default is 50.
    Returns:
        matplotlib.pyplot: A plot containing 1d and 2d histograms of x, y pos tracking for each subject.
    """
    subj_ids = data["id"].unique()
    fig, ax = plt.subplots(2, len(subj_ids))
    n_bins = 50
    if len(subj_ids) == 1:
        data[["x", "y"]].plot.hist(bins=n_bins, alpha=0.5, ax=ax[0], title=subj_ids[0])
        ax[1].hist2d(
            data.x,
            data.y,
            bins=(n_bins, n_bins),
            cmap=plt.cm.jet,
        )
    else:
        for i, subj_id in enumerate(subj_ids):
            subj_data = data[data["id"] == subj_id]
            subj_data[["x", "y"]].plot.hist(
                bins=n_bins, alpha=0.5, ax=ax[0, i], title=subj_id
            )
            ax[1, i].hist2d(
                subj_data.x, subj_data.y, bins=(n_bins, n_bins), cmap=plt.cm.jet
            )

    plt.tight_layout()
    return fig


def create_subject_dataset(
    subject: dict,
    max_dist: float = None,
    n_samples: int = 1,
    n_bins: int = 50,
) -> pd.DataFrame:
    """
    Creates a dataset for a given subject dict.

    Args:
        subject (dict): A dictionary containing the root path, subject ID, and their start and end times.
        min_area (float, optional): The minimum area of the subject to be included in the dataset. Default is None.
        n_samples (int, optional): The number of samples to take from each bin. Default is 1.
        n_bins (int, optional): The number of bins to use for sampling. Default is 50.

    Returns:
        pandas.DataFrame: A pandas DataFrame containing uniformly-sampled pos tracking and video data, and subject ID.
    """
    subj_data = get_raw_tracking_data(
        subject["root"],
        subject["id"],
        subject["start"],
        subject["end"],
    )
    if max_dist:
        subj_data = subj_data.loc[subj_data.distance <= max_dist]
    subj_data_sampled = sample_n_from_bins(subj_data, n_samples=n_samples, n_bins=n_bins)
    subj_data = subj_data.loc[subj_data.index.isin(subj_data_sampled.index)] # Ensures that for multi-animal data, both subjects are retrieved for any sampled time
    return subj_data


def create_fully_labelled_dataset(session: dict, subj_ids: list = None) -> pd.DataFrame:
    """
    Creates a fully labelled dataset for all or selected subjects of
    a given session dict. Useful to "bookmark" frames for use with SLEAP's
    predict "only-labeled-frames" option.

    Args:
        session (dict): A session dictionary.
        subj_ids (list, optional): A list of subject ids. If None, all
            subjects are selected.
    Returns:
        pandas.DataFrame: A pandas DataFrame containing pos tracking,
            video data, and subject ID.
    """
    all_subj_data = pd.DataFrame()
    if not subj_ids:
        subj_ids = session["subjects"].keys()
    for subj in subj_ids:
        root = session.get("root", session["subjects"][subj].get("root"))
        subj_data = get_raw_tracking_data(
            root,
            subj,
            session["subjects"][subj]["start"],
            session["subjects"][subj]["end"],
            source_reader=exp02.CameraTop.Video,
        )
        all_subj_data = pd.concat([all_subj_data, subj_data])

    return all_subj_data

def transform_coordinates(homography_matrix: np.ndarray, coordinate: np.ndarray) -> np.ndarray:
    """
    Transforms coordinates using a homography matrix.

    Args:
        homography_matrix (numpy.ndarray): The homography matrix.
        coordinate (numpy.ndarray): The coordinates to transform.

    Returns:
        numpy.ndarray: The transformed coordinates.
    """
    homography_matrix = np.linalg.inv(homography_matrix)
    # Convert to homogeneous coordinate
    coordinate = np.append(coordinate, 1)
    # Transform the coordinate
    transformed_coordinate = homography_matrix @ coordinate
    # Convert back to cartesian coordinate
    transformed_coordinate = transformed_coordinate[:2] / transformed_coordinate[2]
    return np.array(transformed_coordinate)


In [4]:
# dictionaries for each session
aeon3_social02 = {
    "root": "/ceph/aeon/aeon/data/raw/AEON3/social0.2/",
    "working_dir": "/ceph/aeon/aeon/code/scratchpad/sleap/multi_point_tracking/multi_animal_CameraNSEW/",
    "subjects": {
        "BAA-1104045": { # tattooed
            "start": pd.Timestamp("2024-01-31 12:00:00"),
        },
        "BAA-1104047": {
            "start": pd.Timestamp("2024-02-05 16:00:00"),
        },
        "multi_animal": {
            "start": pd.Timestamp("2024-02-10 11:00:00"),
        },
    },
    "session": "aeon3_social02",
}

aeon3_social02_EVAL = {
    "root": "/ceph/aeon/aeon/data/raw/AEON3/social0.2/",
    "working_dir": "/ceph/aeon/aeon/code/scratchpad/sleap/multi_point_tracking/multi_animal_CameraNSEW/",
    "subjects": {
        "BAA-1104045": { # tattooed
            "start": pd.Timestamp("2024-02-25 18:00:00"),
        },
        "BAA-1104047": {
            "start": pd.Timestamp("2024-02-28 15:00:00"),
        },
        # "multi_animal": {
        #     "start": pd.Timestamp("2024-02-17 17:00:00"),
        # },
    },
    "session": "aeon3_social02_EVAL",
}

aeon4_social02 = {
    "root": "/ceph/aeon/aeon/data/raw/AEON4/social0.2/",
    "working_dir": "/ceph/aeon/aeon/code/scratchpad/sleap/multi_point_tracking/multi_animal_CameraNSEW/",
    "subjects": {
        "BAA-1104048": { # tattooed
            "start": pd.Timestamp("2024-02-28 11:00:00"), 
        },
        "BAA-1104049": {
            "start": pd.Timestamp("2024-02-05 17:00:00"), # 
        },
        "multi_animal": {
            "start": pd.Timestamp("2024-02-10 11:00:00"),
        },
    },
    "session": "aeon4_social02",
}

aeon4_social02_EVAL = {
    "root": "/ceph/aeon/aeon/data/raw/AEON4/social0.2/",
    "working_dir": "/ceph/aeon/aeon/code/scratchpad/sleap/multi_point_tracking/multi_animal_CameraNSEW/",
    "subjects": {
        "BAA-1104048": { # tattooed
            "start": pd.Timestamp("2024-02-25 18:00:00"),
        },
        "BAA-1104049": {
            "start": pd.Timestamp("2024-02-28 15:00:00"),
        },
        # "multi_animal": {
        #     "start": pd.Timestamp("2024-02-17 12:00:00"),
        # },
    },
    "session": "aeon4_social02_EVAL",
}

aeon3_social03 = {
    "root": "/ceph/aeon/aeon/data/raw/AEON3/social0.3/",
    "working_dir": "/ceph/aeon/aeon/code/scratchpad/sleap/multi_point_tracking/multi_animal_CameraNSEW/",
    "subjects": {
        "BAA-1104516": { # tattooed
            "start": pd.Timestamp("2024-07-07 16:00:00"),
        },
        "BAA-1104519": {
            "start": pd.Timestamp("2024-06-10 15:00:00"),
        },
        "multi_animal": {
            "start": pd.Timestamp("2024-06-25 18:00:00"),
        },
    },
    "session": "aeon3_social03",
}

aeon3_social03_EVAL = {
    "root": "/ceph/aeon/aeon/data/raw/AEON3/social0.3/",
    "working_dir": "/ceph/aeon/aeon/code/scratchpad/sleap/multi_point_tracking/multi_animal_CameraNSEW/",
    "subjects": {
        "BAA-1104516": { # tattooed
            "start": pd.Timestamp("2024-07-10 11:00:00"),
        },
        "BAA-1104519": {
            "start": pd.Timestamp("2024-07-14 11:00:00"),
        },
        # "multi_animal": {
        #     "start": pd.Timestamp("2024-07-05 12:00:00"),
        # },
    },
    "session": "aeon3_social03_EVAL",
}

aeon4_social03 = {
    "root": "/ceph/aeon/aeon/data/raw/AEON4/social0.3/",
    "working_dir": "/ceph/aeon/aeon/code/scratchpad/sleap/multi_point_tracking/multi_animal_CameraNSEW/",
    "subjects": {
        "BAA-1104568": { # tattooed
            "start": pd.Timestamp("2024-07-04 14:00:00"),
        },
        "BAA-1104569": {
            "start": pd.Timestamp("2024-06-09 13:00:00"),
        },
        "multi_animal": {
            "start": pd.Timestamp("2024-06-30 11:00:00"),
        },
    },
    "session": "aeon4_social03",
}

aeon4_social03_EVAL = {
    "root": "/ceph/aeon/aeon/data/raw/AEON4/social0.3/",
    "working_dir": "/ceph/aeon/aeon/code/scratchpad/sleap/multi_point_tracking/multi_animal_CameraNSEW/",
    "subjects": {
        "BAA-1104568": { # tattooed
            "start": pd.Timestamp("2024-07-08 13:00:00"),
        },
        "BAA-1104569": {
            "start": pd.Timestamp("2024-07-13 10:00:00"),
        },
        # "multi_animal": {
        #     "start": pd.Timestamp("2024-07-02 13:00:00"),
        # },
    },
    "session": "aeon4_social03_EVAL",
}

aeon3_social04 = {
    "root": "/ceph/aeon/aeon/data/raw/AEON3/social0.4/",
    "working_dir": "/ceph/aeon/aeon/code/scratchpad/sleap/multi_point_tracking/multi_animal_CameraNSEW/",
    "subjects": {
        "BAA-1104792": { # tattooed
            "start": pd.Timestamp("2024-09-13 09:00:00"),
        },
        "BAA-1104794": {
            "start": pd.Timestamp("2024-09-22 09:00:00"),
        },
        "multi_animal": {
            "start": pd.Timestamp("2024-08-28 14:00:00"),
        },
    },
    "session": "aeon3_social04",
}

aeon3_social04_EVAL = {
    "root": "/ceph/aeon/aeon/data/raw/AEON3/social0.4/",
    "working_dir": "/ceph/aeon/aeon/code/scratchpad/sleap/multi_point_tracking/multi_animal_CameraNSEW/",
    "subjects": {
        "BAA-1104792": { # tattooed
            "start": pd.Timestamp("2024-09-09 18:00:00"),
        },
        "BAA-1104794": {
            "start": pd.Timestamp("2024-09-17 14:00:00"),
        },
        # "multi_animal": {
        #     "start": pd.Timestamp("2024-09-08 11:00:00"),
        # },
    },
    "session": "aeon3_social04_EVAL",
}

aeon4_social04 = {
    "root": "/ceph/aeon/aeon/data/raw/AEON4/social0.4/",
    "working_dir": "/ceph/aeon/aeon/code/scratchpad/sleap/multi_point_tracking/multi_animal_CameraNSEW/",
    "subjects": {
        "BAA-1104795": { # tattooed
            "start": pd.Timestamp("2024-08-16 17:00:00"),
        },
        "BAA-1104797": {
            "start": pd.Timestamp("2024-08-20 11:00:00"),
        },
        "multi_animal": {
            "start": pd.Timestamp("2024-08-28 14:00:00"),
        },
    },
    "session": "aeon4_social04",
}

aeon4_social04_EVAL = {
    "root": "/ceph/aeon/aeon/data/raw/AEON4/social0.4/",
    "working_dir": "/ceph/aeon/aeon/code/scratchpad/sleap/multi_point_tracking/multi_animal_CameraNSEW/",
    "subjects": {
        "BAA-1104795": { # tattooed
            "start": pd.Timestamp("2024-09-09 16:00:00"),
        },
        "BAA-1104797": {
            "start": pd.Timestamp("2024-09-17 13:00:00"),
        },
        # "multi_animal": {
        #     "start": pd.Timestamp("2024-09-08 11:00:00"),
        # },
    },
    "session": "aeon4_social04_EVAL",
}

session = aeon3_social03
session_EVAL = aeon3_social03_EVAL
for sess in [session, session_EVAL]:
    for subj in sess["subjects"]:
        sess["subjects"][subj]["end"] = pd.Timestamp(sess["subjects"][subj]["start"]) + pd.Timedelta("1h")
session

In [None]:
# extract single + multi-animal frames
all_subj_data = create_session_dataset(session)
overview = all_subj_data.groupby(["id", "identity_name"]).count()
display(overview)

In [None]:
all_subj_data_copy = all_subj_data.copy()

In [None]:
all_subj_data = all_subj_data_copy.copy()
fig = plot_position_histograms(all_subj_data)
overview = all_subj_data.groupby(["id", "identity_name"]).count()
display(overview)

In [None]:
frames_added = False
for subj in session["subjects"]:
    if "multi_" in subj and overview.loc[subj].iloc[0][0] < 1000:
        frames_added = True
        print("Not enough multi-animal frames, adding more...")
        subj_dict = {
            "id": "multi_animal",
            "root": session["root"],
            "start": session["subjects"]["multi_animal"]["start"],
            "end": session["subjects"]["multi_animal"]["end"],
        }
        # adjust n_samples and n_bins to get the desired number of frames
        subj_data = create_subject_dataset(
            subj_dict,
            max_dist = 600,
            n_samples=1,
            n_bins=30,
        )
        all_subj_data = pd.concat([all_subj_data, subj_data])
        # drop duplicates
        all_subj_data.drop_duplicates(inplace=True)
    elif overview.loc[subj].iloc[0][0] < 1000:
        frames_added = True
        print(f"Not enough frames for {subj}, adding more...")
        subj_dict = {
            "id": subj,
            "root": session["root"],
            "start": session["subjects"][subj]["start"],
            "end": session["subjects"][subj]["end"],
        }
        # adjust n_samples and n_bins to get the desired number of frames
        if overview.loc[subj].iloc[0][0] > 700:
            subj_data = create_subject_dataset(
                subj_dict,
                n_samples=2,
                n_bins=35
            )
        else: # sample more
            subj_data = create_subject_dataset(
                subj_dict,
                n_samples=2,
                n_bins=55
            )
        all_subj_data = all_subj_data[all_subj_data["id"] != subj]
        all_subj_data = pd.concat([all_subj_data, subj_data])
if frames_added:
    fig = plot_position_histograms(all_subj_data)
    overview = all_subj_data.groupby(["id", "identity_name"]).count()
    display(overview)
    

In [None]:
# # Optional sanity check
# # Choose f so that it is one of the frames from all_subj_data
# from aeon.io.video import frames
# from aeon.analysis.movies import gridframes
# import plotly.express as px
# f = 162681
# fig = px.imshow(gridframes(list(frames(all_subj_data[all_subj_data['_frame'] == f])), width=1440, height=1080, shape=1))
# fig.add_trace(px.scatter(all_subj_data[all_subj_data['_frame'] == f], x='x', y='y').data[0])
# fig.show()

In [None]:
acquisition_computer = session["root"].split("/")[-3]
experiment = session["root"].split("/")[-2]
experiment_no_period = experiment.replace(".", "")
quadrant_cameras = ['CameraSouth', 'CameraNorth', 'CameraEast', 'CameraWest']
# Paths for homographies
homography_paths = [f'/ceph/aeon/aeon/code/scratchpad/Orsi/pixel_mapping/pixel_mapping_results/{experiment}/{acquisition_computer}/H_{camera}.npy' 
                    for camera in quadrant_cameras]
# Load homographies
homographies = [np.load(path) for path in homography_paths]

In [None]:
all_subj_data_composite_vid = {'time': [], 'x': [], 'y': [], 'identity_name': [], 'distance': [], '_frame': [], '_path': [], 'id': []}
for subject in session["subjects"]:
    subj_data = all_subj_data[all_subj_data["id"] == subject]
    start = session["subjects"][subject]["start"].strftime("%Y-%m-%dT%H-%M-%S")
    composite_vid_frames_info = pd.read_csv(f'{session["working_dir"]}{acquisition_computer}_{experiment_no_period}_{start}_composite_vid_frames_info.csv', index_col=0)
    for _, row in subj_data.iterrows():
        # Find the corresponding frames in the composite video (there can be >1 if the mice are far apart)
        frame = row["_frame"]
        composite_vid_frame_info = composite_vid_frames_info[composite_vid_frames_info["_frame"] == frame]
        num_quadrant_frames = composite_vid_frame_info.shape[0]
        i = 0
        for _, composite_vid_frame_info_row in composite_vid_frame_info.iterrows():
            # Find the quadrant camera for the frame
            camera = composite_vid_frame_info_row["_path"].split("/")[-2]
            # Find the homography for the camera
            homography = homographies[quadrant_cameras.index(camera)]
            # Convert the top camera x and y to the quadrant camera x and y 
            transformed_coordinate = transform_coordinates(homography, np.array([row["x"], row["y"]]))
            x = transformed_coordinate[0]
            y = transformed_coordinate[1]
            # If outside the quadrant camera dimensions, skip
            if x < 0 or x > 1440 or y < 0 or y > 1080:
                i+=1
                warnings.warn(f'Coordinate {row["x"], row["y"]} on frame {row["_frame"]} for subject {subject} is outside the dimensions of the {camera} camera')
                # If outside the dimensions of all quadrant cameras, raise a warning
                # if i == num_quadrant_frames:
                #     warnings.warn(f'Coordinate {row["x"], row["y"]} on frame {row["_frame"]} is outside the dimensions of the quadrant cameras')
                continue
            # Else, append to the all_subj_data_composite_vid df
            all_subj_data_composite_vid["time"].append(row.name)
            all_subj_data_composite_vid["x"].append(x)
            all_subj_data_composite_vid["y"].append(y)
            all_subj_data_composite_vid["identity_name"].append(row["identity_name"])
            all_subj_data_composite_vid["distance"].append(row["distance"])
            all_subj_data_composite_vid["_frame"].append(composite_vid_frame_info_row.name) # append(composite_vid_frame_info_row["_frame"])
            all_subj_data_composite_vid["_path"].append(f'{session["working_dir"]}{acquisition_computer}_{experiment_no_period}_{start}_composite_video.avi') # append(composite_vid_frame_info_row["_path"])
            all_subj_data_composite_vid["id"].append(row["id"])
            
all_subj_data_composite_vid = pd.DataFrame(all_subj_data_composite_vid).set_index('time')
display(all_subj_data_composite_vid.sort_values(by='_frame'))

In [None]:
# # Optional sanity check
# # Since extracting frames from the composite vids fails for some reason, you need to change a couple of lines in the code above for this to work
# # Append composite_vid_frame_info_row["_frame"] to all_subj_data_composite_vid["_frame"]
# # And append composite_vid_frame_info_row["_path"] to all_subj_data_composite_vid["_path"]
# fig = px.imshow(gridframes(list(frames(all_subj_data_composite_vid[all_subj_data_composite_vid['_frame'] == f].iloc[0:1])), width=1440, height=1080, shape=1))
# fig.add_trace(px.scatter(all_subj_data_composite_vid[all_subj_data_composite_vid['_frame'] == f].iloc[0:1], x='x', y='y').data[0])
# fig.show()

In [None]:
# Save to csv
all_subj_data_composite_vid.to_csv(f'{session["working_dir"]}{session["session"]}.csv')

In [None]:
# create fully labelled datasets for SLEAP model validation
data = create_fully_labelled_dataset(
    session_EVAL,
    subj_ids=[
        subj for subj in session_EVAL["subjects"].keys()
    ],
).sample(frac=0.1)  # sample only 10% of the data
display(data)
data.to_csv(f'{session_EVAL["working_dir"]}{session_EVAL["session"]}_frames_top_cam.csv')

In [None]:
composite_vid_data = {'time': [], 'x': [], 'y': [], 'identity_name': [], 'distance': [], '_frame': [], '_path': [], 'id': []}
for subject in data["id"].unique():
    subj_data = data[data["id"] == subject]
    start = session_EVAL["subjects"][subject]["start"].strftime("%Y-%m-%dT%H-%M-%S")
    composite_vid_frames_info = pd.read_csv(f'{session_EVAL["working_dir"]}{acquisition_computer}_{experiment_no_period}_{start}_composite_vid_frames_info.csv', index_col=0)
    for _, row in subj_data.iterrows():
        # Find the corresponding frames in the composite video (there can be >1 if the mice are far apart)
        frame = row["_frame"]
        composite_vid_frame_info = composite_vid_frames_info[composite_vid_frames_info["_frame"] == frame]
        num_quadrant_frames = composite_vid_frame_info.shape[0]
        i = 0
        for _, composite_vid_frame_info_row in composite_vid_frame_info.iterrows():
            # Find the quadrant camera for the frame
            camera = composite_vid_frame_info_row["_path"].split("/")[-2]
            # Find the homography for the camera
            homography = homographies[quadrant_cameras.index(camera)]
            # Convert the top camera x and y to the quadrant camera x and y 
            transformed_coordinate = transform_coordinates(homography, np.array([row["x"], row["y"]]))
            x = transformed_coordinate[0]
            y = transformed_coordinate[1]
            # If outside the quadrant camera dimensions, skip
            if x < 0 or x > 1440 or y < 0 or y > 1080:
                i+=1
                # If outside the dimensions of all quadrant cameras, raise a warning
                if i == num_quadrant_frames:
                    warnings.warn(f'Coordinate {row["x"], row["y"]} on frame {row["_frame"]} is outside the dimensions of the quadrant cameras')
                continue
            # Else, append to the all_subj_data_composite_vid df
            composite_vid_data["x"].append(x)
            composite_vid_data["y"].append(y)
            composite_vid_data["identity_name"].append(row["identity_name"])
            composite_vid_data["time"].append(row.name)
            composite_vid_data["distance"].append(row["distance"])
            composite_vid_data["_frame"].append(composite_vid_frame_info_row.name) 
            composite_vid_data["_path"].append(f'{session_EVAL["working_dir"]}{acquisition_computer}_{experiment_no_period}_{start}_composite_video.avi') 
            composite_vid_data["id"].append(row["id"])
    
composite_vid_data = pd.DataFrame(composite_vid_data).set_index('time')
composite_vid_data.to_csv(f'{session_EVAL["working_dir"]}{session_EVAL["session"]}_frames.csv')

# SLEAP

In [2]:
import pandas as pd
import random
import sleap
import numpy as np

from pathlib import Path
from typing import Optional, Dict
from sleap.io.pathutils import fix_path_separator
from sleap.gui.suggestions import VideoFrameSuggestions
from sleap.nn.config import *
from sleap.nn.inference import TopDownMultiClassPredictor

In [3]:
def remove_nan_tracks(labels: sleap.Labels) -> sleap.Labels:
    """
    Removes instances from SLEAP Labels object where track = None.

    Args:
        labels (sleap.Labels): A SLEAP Labels object.
    Returns:
        sleap.Labels: A SLEAP Labels object with only tracked instances in each frame.
    """
    lfs = [lf.remove_untracked() for lf in labels.labeled_frames]
    return sleap.Labels(
        labeled_frames=lfs,
        videos=labels.videos,
        skeletons=labels.skeletons,
        tracks=labels.tracks,
    )

def generate_slp_dataset(
    subj_data: pd.DataFrame,
    skeleton: sleap.Skeleton,
    tracks_dict: Optional[Dict[str, sleap.Track]] = None,
) -> sleap.Labels:
    """
    Generates .slp dataset for a given session dict.

    Args:
        subj_data (pandas.DataFrame): A pandas DataFrame containing the labeled data for a given session.
        skeleton (sleap.Skeleton): A sleap Skeleton object.
        tracks_dict (dict, optional): A dictionary containing track names and their corresponding sleap Track objects.
            If None, a new dictionary is created from the subject IDs in the input data. Default is None.
    Returns:
        sleap.Labels: A SLEAP Labels object containing labeled frames.
    """

    # create tracks dictionary from subj_ids that are not multi_animal
    if not tracks_dict:
        tracks_dict = {
            subj: sleap.Track(spawned_on=0, name=subj)
            for subj in subj_data["id"].unique()
            if "multi_" not in subj
        }

    lfs = []

    # create video dictionary from new labels
    videos_dict = {
        video: sleap.Video.from_filename(video, grayscale=True)
        for video in subj_data._path.unique()
    }

    for name, group in subj_data.groupby(["_path", "_frame"]):
        instances = []
        for _, row in group.iterrows():
            instances.append(
                sleap.Instance(
                    skeleton=skeleton,
                    track=tracks_dict[row.identity_name],
                    points={"centroid": sleap.instance.Point(row.x, row.y)},
                )
            )
        lf = sleap.instance.LabeledFrame(
            video=videos_dict[name[0]],
            frame_idx=name[1],
            instances=instances,
        )
        lfs.append(lf)

    return sleap.Labels(labeled_frames=lfs)


def update_slp_video_paths(
    labels: "sleap.Labels", old_path: str, new_path: str
) -> sleap.Labels:
    """
    Updates video paths in a SLEAP labels object (e.g., to move training from local to remote machine).

    Args:
        labels (sleap.Labels): A SLEAP Labels object.
        old_path (str): Old path to video files.
        new_path (str): New path to video files.

    Returns:
        sleap.Labels: A SLEAP Labels object with updated video paths.

    """

    videos = [
        sleap.Video.from_filename(
            fix_path_separator(vid.filename).replace(old_path, new_path), grayscale=True
        )
        for vid in labels.videos
    ]
    print(labels.videos)
    print(fix_path_separator(labels.videos[0].filename).replace(old_path, new_path))
    print(sleap.Video.from_filename(
            fix_path_separator(labels.videos[0].filename).replace(old_path, new_path), grayscale=True
        ))

    lfs = []
    for lf in labels.labeled_frames:
        lf = sleap.instance.LabeledFrame(
            video=videos[labels.videos.index(lf.video)],
            frame_idx=lf.frame_idx,
            instances=lf.instances,
        )
        lfs.append(lf)

    return sleap.Labels(
        labeled_frames=lfs,
        videos=videos,
        skeletons=labels.skeletons,
        tracks=labels.tracks,
    )


In [None]:
# create new skeleton
skeleton = sleap.Skeleton()
skeleton.add_node("centroid")

In [None]:
# generate slp training dataset for all subjects
session["working_dir"] = session["working_dir"].replace("/ceph/aeon/", "Z:/")
subj_data = pd.read_csv(f'{session["working_dir"]}{session["session"]}.csv')
tracks_dict = {
    'BAA-1104045': sleap.Track(spawned_on=0, name='BAA-1104045'),
    'BAA-1104047': sleap.Track(spawned_on=0, name='BAA-1104047')
}
labels = generate_slp_dataset(subj_data, skeleton, tracks_dict)
sleap.Labels.save_file(labels, f'{session["working_dir"]}{session["session"]}.slp')

### manual annotation
- open the .slp file in the sleap GUI
- go through all the videos and adjust any labels that are not on the animal (this happens occasionally since the homography isn't perfect)

In [None]:
# after manual annotation on local machine, update video paths in labels to point to ceph
labels = sleap.Labels.load_file(f'{session["session"]}_labelled.slp')
labels = update_slp_video_paths(labels=labels, old_path="Z:", new_path="/ceph/aeon")
# labels = update_slp_video_paths(labels=labels, old_path="Z:/aeon/code/scratchpad/sleap/multi_point_tracking/multi_animal_CameraNSEW", new_path='Z:/aeon/code/scratchpad/sleap/multi_point_tracking/multi_animal_CameraNSEW/videos')
sleap.Labels.save_file(labels, f'{session["session"]}_ceph.slp')

In [None]:
# create fully labelled datasets for SLEAP model validation
session_EVAL["working_dir"] = session_EVAL["working_dir"].replace("/ceph/aeon/", "Z:/")
csv_files = [
    f'{session_EVAL["session"]}_frames.csv',
    f'{session_EVAL["session"]}_frames_top_cam.csv',
]
tracks_dict = {
    subj: sleap.Track(spawned_on=0, name=subj)
    for subj in session_EVAL["subjects"].keys()
    if "multi_" not in subj
}
for csv_file in csv_files:
    subj_data = pd.read_csv(csv_file)
    labels = generate_slp_dataset(subj_data, skeleton, tracks_dict=tracks_dict)
    labels = update_slp_video_paths(labels=labels, old_path="Z:", new_path="/ceph/aeon")
    sleap.Labels.save_file(labels, f'{session_EVAL["working_dir"]}{Path(csv_file).stem}_ceph.slp')

### train model on HPC

### evaluate
- modify `sleap_predict.sh` to predict on EVAL frames

In [25]:
# evaluate single-animal predictions with ground truth data
# matches the metrics used to optimise the sleap models with optuna
session_EVAL["working_dir"] = session_EVAL["working_dir"].replace("/ceph/aeon/", "Z:/")
gt_file = f'{session_EVAL["working_dir"]}{session_EVAL["session"]}_frames_ceph.slp'
pr_file = f'{session_EVAL["working_dir"]}predictions/{session_EVAL["session"]}_frames_ceph_pr.slp'
labels_gt = sleap.load_file(gt_file)
labels_pr = sleap.load_file(pr_file)
crop_size = 112 # make this match the crop size used in the centered-instance model

track_names = [track.name for track in labels_gt.tracks]
max_instances = len(track_names)

framepairs = sleap.nn.evals.find_frame_pairs(labels_gt, labels_pr)
matches = sleap.nn.evals.match_frame_pairs(framepairs, scale=crop_size)
positive_pairs = matches[0]

# initialize confusion matrix components
total_tp = total_fp = total_fn = total_tn = 0

for gt_frame, pr_frame in framepairs:
    gt_count = len(gt_frame.instances)
    pr_count = len(pr_frame.instances)

    if gt_count > max_instances:
        raise ValueError(
            f"Ground truth frame {gt_frame.frame_idx} has {gt_count} instances, which is more than the maximum of {max_instances}."
        )
    if pr_count > max_instances:
        raise ValueError(
            f"Predicted frame {pr_frame.frame_idx} has {pr_count} instances, which is more than the maximum of {max_instances}."
        )
    
    # compute TP, FP, FN, TN for this frame
    tp = min(gt_count, pr_count)  # correct detections
    fp = max(0, pr_count - gt_count)  # extra detections
    fn = max(0, gt_count - pr_count)  # missed detections
    tn = max_instances - max(gt_count, pr_count)  # unused "slots"

    total_tp += tp
    total_fp += fp
    total_fn += fn
    total_tn += tn

# detection metrics
detection_precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0
detection_recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0
detection_f1_score = (
    2 * detection_precision * detection_recall / (detection_precision + detection_recall) if (detection_precision + detection_recall) > 0 else 0
)


# identity accuracy
correct_id = {track_name: 0 for track_name in track_names}
total_id_checks = {track_name: 0 for track_name in track_names}

for positive_pair in positive_pairs:
    gt = (
        positive_pair[0]
        if isinstance(positive_pair[1], sleap.PredictedInstance)
        else positive_pair[1]
    )
    pr = (
        positive_pair[1]
        if isinstance(positive_pair[1], sleap.PredictedInstance)
        else positive_pair[0]
    )
    total_id_checks[gt.track.name] += 1
    if gt.track.name == pr.track.name:
        correct_id[gt.track.name] += 1

id_accuracy = (
    sum(correct_id.values()) / sum(total_id_checks.values())
    if sum(total_id_checks.values()) > 0
    else 0.0
)

# Harmonic mean for composite metric
composite_metric = (
    (2 * detection_f1_score * id_accuracy) / (detection_f1_score + id_accuracy)
    if (detection_f1_score + id_accuracy) > 0
    else 0
)

# print for debugging
print("Total TP: ", total_tp)
print("Total FP: ", total_fp)
print("Total FN: ", total_fn)
print("Total TN: ", total_tn)
print("Detection precision: ", round(detection_precision, 3))
print("Detection recall: ", round(detection_recall, 3))
print("-")
print("Correct ID: ", correct_id)
print("Total ID checks: ", total_id_checks)
print("-")
print("Detection F1 score: ", round(detection_f1_score, 3))
print("ID accuracy: ", round(id_accuracy, 3))
print(f"Composite metric (harmonic mean of detection F1 and ID accuracy): {round(composite_metric, 3)}")

### export

In [None]:
# export model
predictor = TopDownMultiClassPredictor.from_trained_models(
    centroid_model_path=f"{session['working_dir']}/models/{session['session']}_labelled_topdown_top.centroid",
    confmap_model_path=f"{session['working_dir']}/models/{session['session']}_labelled_topdown_top.centered_instance",
    resize_input_layer=False,  # SLEAP 1.3.0+
)

predictor.export_model(
    "/ceph/aeon/aeon/code/bonsai-sleap/example_workflows/match_quad_id_to_top_cam_pose/quad_cam_exported_models"
    # max_instances=2,
    # unrag_outputs=False,
)