In [36]:
from collections import deque

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objs as go
from aeon.dj_pipeline import acquisition, fetch_stream, streams, tracking
from aeon.dj_pipeline.analysis import block_analysis
from plotly.subplots import make_subplots


In [38]:
# choose one block key
key = {"experiment_name": "social0.2-aeon3",
       "block_start": "2024-02-11 13:36:10"}
block_start, block_end = (block_analysis.Block & key).fetch1(
    "block_start", "block_end")

chunk_restriction = acquisition.create_chunk_restriction(
    key["experiment_name"], block_start, block_end
)

# retrieve `pos_df` for ALL subjects in this block
pos_query = (
    streams.SpinnakerVideoSource
    * tracking.SLEAPTracking.PoseIdentity.proj("identity_name", anchor_part="part_name")
    * tracking.SLEAPTracking.Part
    & key
    & {
        "spinnaker_video_source_name": "CameraTop",
    }
    & chunk_restriction
)
all_pos_df = fetch_stream(pos_query).sort_index()[block_start:block_end]



In [65]:
"""Standardize subject colors for plotting"""

subject_colors = px.colors.qualitative.Plotly
subject_colors_dict = {
    name: subject_colors[i]
    for i, name in enumerate(all_pos_df["identity_name"].unique())
}


def plot_xy(df: pd.DataFrame) -> go.Figure:
    """Plot the x and y positions of each subject.

    Args:
        df (pandas.DataFrame): DataFrame with columns ``x``,
            ``y``, and ``identity_name``.

    Returns:
        plotly.graph_objs.Figure: Plotly figure object with
            x and y positions of each subject.
    """
    fig = make_subplots(rows=2, cols=1, shared_xaxes=True)
    names = df["identity_name"].unique()
    for identity_name in names:
        data = df[df["identity_name"] == identity_name]
        fig.add_trace(
            go.Scatter(
                x=data.index,
                y=data["x"],
                mode="markers",
                name=identity_name,  # Use the class as the name of the trace
                marker=dict(color=subject_colors_dict[identity_name], symbol="circle"),
            ),
            row=1,
            col=1,
        )
        fig.add_trace(
            go.Scatter(
                x=data.index,
                y=data["y"],
                mode="markers",
                name=identity_name,  # Use the class as the name of the trace
                marker=dict(color=subject_colors_dict[identity_name], symbol="square"),
            ),
            row=2,
            col=1,
        )
    fig.update_yaxes(title_text="x position", row=1, col=1)
    fig.update_yaxes(title_text="y position", row=2, col=1)
    return fig


def resolve_duplicate_identities(df: pd.DataFrame) -> pd.DataFrame:
    """
    Reassign ID of the row with duplicated ID and lower likelihood.

    This function checks for duplicated ``identity_name`` for each
    unique DatetimeIndex, and randomly assigns another
    ``identity_name`` from the available identity names in the
    DataFrame to the row having duplicated identity and lower
    likelihood. This function is useful in a 2-subject case, but
    is not guaranteed to work in a >2-subject case.

    Args:
        df (pandas.DataFrame): DataFrame with columns
            ``identity_name`` and ``likelihood``.

    Returns:
        pandas.DataFrame: DataFrame without duplicate identities
            per unique DatetimeIndex.
    """
    df_cp = df.reset_index().copy()
    names = df_cp["identity_name"].unique()
    # Mask for rows with multiple assignments of the same ID at the same time
    many_to_one_mask = df_cp.groupby(["time", "identity_name"]).transform("size") > 1
    duplicated_data = df_cp.loc[many_to_one_mask]
    # Indices for rows with lower likelihood
    low_likelihood_idx = duplicated_data.loc[
        ~duplicated_data.index.isin(
            duplicated_data.groupby(["time", "identity_name"])["likelihood"].idxmax()
        )
    ].index
    # This assigns another class randomly (in 2-animal case, it's the other animal,
    # but in >2-animal case, it may assign duplicate IDs again)
    df_cp.loc[low_likelihood_idx, "identity_name"] = df_cp.loc[
        low_likelihood_idx
    ].apply(lambda x: np.random.choice(names[names != x["identity_name"]]), axis=1)
    return df_cp.set_index("time")


def compute_class_speed(df: pd.DataFrame) -> pd.Series:
    """Compute the instantaneous speed of each class.

    Args:
        df (pandas.DataFrame): DataFrame with columns ``x``,
            ``y``, ``identity_name``, and DatetimeIndex ``time``.

    Returns:
        pandas.Series: Series with the instantaneous speed of each class.
    """
    return (
        df.groupby("identity_name")[["x", "y"]].diff().apply(np.linalg.norm, axis=1)
        / df.reset_index()
        .groupby("identity_name")["time"]
        .diff()
        .dt.total_seconds()
        .values
    )


def compute_speed_mask(df: pd.DataFrame, threshold: float) -> pd.Series:
    """Compute the boolean mask of rows with ``speed`` exceeding threshold.

    Args:
        df (pandas.DataFrame): DataFrame with columns ``speed``.
        threshold (float): Speed threhold.

    Returns:
        pandas.Series: Boolean mask of rows with ``speed`` greater
            than threshold.
    """
    speed_mask = (np.isfinite(df["speed"].values)) & (df["speed"] > threshold)
    # select only rows when more than 1 subject has speed > threshold
    speed_mask &= speed_mask.groupby(level=0).transform("sum") > 1
    return speed_mask


def resolve_swapped_identities(
    df: pd.DataFrame, threshold: float = 700.0, max_window_length: int = 6
) -> pd.DataFrame:
    """
    Reassign ID of the row with identity swaps.

    This function attempts to identify windows of identity swaps
    based on pairs of speed "violations". The windows are incremented
    by a factor of 2 each iteration, starting at a minimum
    of 3s, up to the maximum duration specified by ``max_window_length``
    seconds. Within each window, the identity of the subjects are
    randomly assigned to the other subject's identity. This method
    will not resolve all identity swaps, especially if the swaps occur
    for extended durations. It also does not account for more than
    2 subjects.

    Args:
        df (pandas.DataFrame): DataFrame with columns ``x``,
            ``y``, ``identity_name``, and DatetimeIndex ``time``.
        threshold (float): Speed threshold. Default is 700.0.
        max_window_length (int): Maximum duration in seconds for swapping
            identities. Potential swaps spanning longer durations will be
            ignored. Default is 6 seconds.

    Returns:
        pandas.DataFrame: DataFrame with resolved identity swaps within
            the specified ``max_window_length``.
    """
    df["speed"] = compute_class_speed(df)
    speed_mask = compute_speed_mask(df, threshold=threshold)
    names = df["identity_name"].unique()
    timedelta = 3
    iter = 0
    # limit swap window duration to 3 * 2**(max_iter) = max_window_length seconds
    max_iter = np.sqrt((max_window_length // 3) / 2)
    while speed_mask.sum() > 2 and iter <= max_iter:
        print(f"Iteration {iter}: {speed_mask.sum()
                                   } rows with speed > {threshold}")
        q = deque(df[speed_mask].index.unique())
        while q:
            start = q.popleft()
            try:
                end = q[0]
            except IndexError:
                break
            # compute timedelta between start and end
            # ignore if timedelta is more than t seconds
            if (end - start) > pd.Timedelta(timedelta, unit="s"):
                continue
            end = q.popleft()
            # ``end`` needs to be exclusive
            end = df.index[df.index < end].max()
            df.loc[start:end, "identity_name"] = df.loc[start:end].apply(
                lambda x: np.random.choice(names[names != x["identity_name"]]), axis=1
            )
        # recompute speed and speed_mask
        df["speed"] = compute_class_speed(df)
        speed_mask = compute_speed_mask(df, threshold=threshold)
        # update timedelta
        timedelta *= 2
        # update iter count
        iter += 1
    return df.drop(columns=["speed"])


In [74]:
# Step 1: Resolve duplicate identities
all_pos_df1 = resolve_duplicate_identities(all_pos_df)

In [75]:
# Step 2: Resolve swapped identities
all_pos_df2 = resolve_swapped_identities(all_pos_df1)

Iteration 0: 35282 rows with speed > 700.0
