# Track Identity Assignment Problem

Chang Huan

In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np
import pandas as pd
import cv2
import plotly
import plotly.graph_objs as go
from plotly.subplots import make_subplots
from collections import deque
import datajoint as dj
import aeon
from aeon.io import video
from aeon.io import api
from aeon.schema.schemas import social02
from shapely.geometry import Point, Polygon
from aeon.dj_pipeline.analysis.block_analysis import * 

In [None]:
## Setup
"""Define functions"""


def plot_xy(df):
    """Function to plot the x and y positions of the subjects."""
    fig = make_subplots(rows=2, cols=1, shared_xaxes=True)
    classes = df["class"].unique()
    for class_ in classes:
        data = df[df["class"] == class_]
        fig.add_trace(
            go.Scatter(
                x=data.index,
                y=data["x"],
                mode="markers",
                name=class_,  # Use the class as the name of the trace
                marker=dict(color=subject_colors_dict[class_], symbol="circle"),
                hovertemplate="Speed: %{text}",
                text=data["speed"].tolist(),
            ),
            row=1,
            col=1,
        )
        fig.add_trace(
            go.Scatter(
                x=data.index,
                y=data["y"],
                mode="markers",
                name=class_,  # Use the class as the name of the trace
                marker=dict(color=subject_colors_dict[class_], symbol="square"),
                hovertemplate="Speed: %{text}",
                text=data["speed"].tolist(),
            ),
            row=2,
            col=1,
        )
    fig.update_yaxes(title_text="x position", row=1, col=1)
    fig.update_yaxes(title_text="y position", row=2, col=1)
    return fig


def plot_speed(df):
    """Function to plot the speeds of the subjects."""
    fig = go.Figure()
    classes = df["class"].unique()
    for class_ in classes:
        data = df[df["class"] == class_]
        fig.add_trace(
            go.Scatter(
                x=data.index,
                y=data["speed"],
                mode="markers",
                name=class_,  # Use the class as the name of the trace
                marker=dict(color=subject_colors_dict[class_], symbol="circle"),
            )
        )
    fig.update_yaxes(title_text="speed")
    return fig


def compute_class_speed(df):
    """Function to compute the instantaneous speed of each class."""
    return (
        df.groupby("class")[["x", "y"]].diff().apply(np.linalg.norm, axis=1)
        / df.reset_index().groupby("class")["time"].diff().dt.total_seconds().values
    )

def compute_speed_mask(df, threshold):
    """Function to compute the mask of df rows with speed > threshold."""	
    speed_mask = (np.isfinite(df["speed"].values)) & (
        df["speed"] > threshold
    )
    # select only rows when more than 1 subject has speed > threshold
    speed_mask &= (speed_mask.groupby(level=0).transform("sum") > 1)
    return speed_mask

"""Standardize subject colors for plotting"""

subject_colors = plotly.colors.qualitative.Plotly
subject_colors_dict = {
    "BAA-1104045": subject_colors[0],
    "BAA-1104047": subject_colors[1],
    "BAA-1104048": subject_colors[2],
    "BAA-1104049": subject_colors[3],
}

Get metadata:

In [None]:
#just get metadata base on a random solo block

exp_start = pd.Timestamp("2024-01-31")

metadata = (
    api.load('/ceph/aeon/aeon/data/raw/AEON3/social0.2', social02.Metadata, exp_start, pd.Timestamp('2024-02-11 15:57:42')).iloc[0].metadata
)
metadata

Use processed instead cuz it has likelihood:

In [None]:
root = "/ceph/aeon/aeon/data/processed/AEON3/social0.2/"
# Pick start and end time as you like
vid_start = pd.Timestamp("2024-02-18 17:00:00")
vid_end = vid_start + pd.Timedelta("1h") 
# All files within the root directory with the following pattern will be loaded
print("Pattern:", aeon.io.reader.Pose(pattern="CameraTop_202*").pattern)
df = api.load(root, aeon.io.reader.Pose(pattern="CameraTop_202*"), start=vid_start, end=vid_end) 
df= df[df['part'] == 'spine2']
df = df.drop(columns=['part_likelihood', 'part'])
pose_df = df
pose_df

In [None]:
# format df
# rename variables
pose_df.rename(columns={"identity": "class",
                        "position_timestamps": "time",
                        "identity_likelihood": "class_likelihood"}, inplace=True)
pose_df


In [None]:
cm2px = 5.4 
fps = 50
unique_classes = pose_df["class"].unique()

In [None]:
#downsample ot 10hz
pose_df = pose_df.reset_index()

# Select every 5th timestamp
timestamps = pose_df['time'].iloc[::5]

# Filter the DataFrame to keep rows corresponding to the selected timestamps
pose_df = pose_df[pose_df['time'].isin(timestamps)]

# Set 'time' back as the index
pose_df = pose_df.set_index('time')

# Display the downsampled DataFrame
pose_df = pose_df.sort_index()
pose_df

In [None]:
# Compute the speed 
pose_df["speed"] = compute_class_speed(pose_df)
pose_df

In [None]:
## Visualise data
fig = plot_xy(pose_df[:10000])
fig.show()

Issues
- Track IDs swap between frames (temporal discontinuities)
- Same Track ID is assigned to multiple animals in the same frame/timestamp

In [None]:

"""Assign the row with duplicated ID with lower likelihood to another ID"""

pose_df_cp = pose_df.reset_index().copy()
classes = np.array(pose_df_cp["class"].unique())
# Mask for rows with multiple assignments of the same ID at the same time
many_to_one_mask = pose_df_cp.groupby(["time", "class"]).transform("size") > 1
duplicated_data = pose_df_cp.loc[many_to_one_mask]
print(duplicated_data.shape)
# Indices for rows with lower likelihood
low_likelihood_idx = duplicated_data.loc[
    ~duplicated_data.index.isin(
        duplicated_data.groupby(["time", "class"])["class_likelihood"].idxmax()
    )
].index
# This assigns another class randomly (in 2-animal case, it's the other animal, but in >2-animal case, it may assign duplicate IDs again)
pose_df_cp.loc[low_likelihood_idx, "class"] = pose_df_cp.loc[low_likelihood_idx].apply(
    lambda x: np.random.choice(classes[classes != x["class"]]), axis=1
)


In [None]:
# Remove rows with timestamps where only one class has data
initial_row_count = pose_df_cp.shape[0]
valid_times = pose_df_cp.groupby("time")["class"].nunique() > 1
pose_df_cp = pose_df_cp[pose_df_cp["time"].isin(valid_times[valid_times].index)]
final_row_count = pose_df_cp.shape[0]

# Print the number of rows removed
rows_removed = initial_row_count - final_row_count
print(f"Number of rows removed: {rows_removed}")

In [None]:
pose_df_cp.shape

In [None]:
# get speed
pose_df_cp.set_index("time", inplace=True)
pose_df_cp["speed"] = compute_class_speed(pose_df_cp)


Temporal discontinuities
- Typically we use distance between consecutive frames to determine potential swaps
- However, this is not always reliable as there can be missing data (e.g. occlusions) &rarr; use _speed_ 

Pseudocode
```
Define a speed threshold for speed violation
Start with small time window t (e.g. 3s)
While speed violation exists
    Use consecutive pairs of violation timestamps as "start" and "end" of a potential swap duration
        If swap duration exceeds t
            Discard the current "start" and move on to the next iteration, using "end" as the next "start"
        Else 
            Flip IDs between "start" and "end"
    Recompute speed and speed violation mask
    Increase t   
```

In [None]:
# plot boxplot of speed for each class
fig = go.Figure()
for class_ in pose_df_cp["class"].unique():
    fig.add_trace(
        go.Box(
            y=pose_df_cp[pose_df_cp["class"] == class_]["speed"],
            name=class_,
            marker=dict(color=subject_colors_dict[class_]),
        )
    )
fig.show()

In [None]:
speed_threshold = 700
speed_mask = compute_speed_mask(pose_df_cp, threshold=speed_threshold)
classes = pose_df_cp["class"].unique()
timedelta = 3
iter = 0
max_iter = 4 # limit swap window duration to 3 * 2**(3) = 24 seconds
while speed_mask.sum() > 2 and iter <= max_iter: 
    print(f"Iteration {iter}: {speed_mask.sum()} rows with speed > {speed_threshold}")
    q = deque(pose_df_cp[speed_mask].index.unique())
    while q:
        start = q.popleft()
        try:
            end = q[0]
        except IndexError:
            break
        # compute timedelta between start and end
        # ignore if timedelta is more than t seconds
        if (end - start) > pd.Timedelta(timedelta, unit="s"):
            continue
        end = q.popleft()
        # ``end`` needs to be exclusive
        end = pose_df_cp.index[pose_df_cp.index < end].max()
        pose_df_cp.loc[start:end, "class"] = pose_df_cp.loc[start:end].apply(
            lambda x: np.random.choice(classes[classes != x["class"]]), axis=1
        )
    # recompute speed and speed_mask
    pose_df_cp["speed"] = compute_class_speed(pose_df_cp)
    speed_mask = compute_speed_mask(pose_df_cp, threshold=speed_threshold)
    # update timedelta
    timedelta *= 2
    # update iter count
    iter += 1


In [None]:
# high speed frames that remain:
speed_mask.sum()

In [None]:
# can still have high speed outliers if tracking errors. Delete these unrealistic points and interpolate speed instead
max_speed_threshold = 400 * cm2px # 400 cm/s
pose_df_cp = pose_df_cp.reset_index()
pose_df_cp = pose_df_cp.groupby('class').apply(
    # Identify and set unrealistic speed points to NaN and interpolate
    lambda group: group.assign(
        speed=group['speed'].mask(group['speed'] > max_speed_threshold).interpolate().fillna(method='bfill').fillna(method='ffill')
    )
)

# Reset the index without inserting the 'class' column again
pose_df_cp = pose_df_cp.reset_index(drop=True)
pose_df_cp= pose_df_cp.set_index("time")
pose_df_cp = pose_df_cp.sort_index()


pose_df_cp

In [None]:
#smooth data
# Apply smoothing using a rolling window separately for each subject
window_size = 5  # 5 frames window size = 0.5s at 10hz
smoothed_df_list = []

for class_name, class_df in pose_df_cp.groupby('class'):
    class_df['x'] = class_df['x'].rolling(window=window_size, min_periods=1, center = True).mean()
    class_df['y'] = class_df['y'].rolling(window=window_size, min_periods=1, center = True).mean()
    smoothed_df_list.append(class_df)

# Concatenate the smoothed results back into a single DataFrame
smoothed_pose_df_cp = pd.concat(smoothed_df_list)
smoothed_pose_df_cp = smoothed_pose_df_cp.sort_index()
smoothed_pose_df_cp

In [None]:
fig = plot_xy(smoothed_pose_df_cp[:10000])
fig.show()

## Detect sleeping

In [None]:
# Condition 1:  In nest
# Define the nest region as a polygon
nest_corners = metadata.ActiveRegion.NestRegion.ArrayOfPoint
nest_polygon = Polygon([
    (int(nest_corners[0]["X"]), int(nest_corners[0]["Y"])),
    (int(nest_corners[1]["X"]), int(nest_corners[1]["Y"])),
    (int(nest_corners[2]["X"]), int(nest_corners[2]["Y"])),
    (int(nest_corners[3]["X"]), int(nest_corners[3]["Y"]))
])

# Function to check if a point is within the nest polygon
def is_in_nest(x, y):
    point = Point(x, y)
    return nest_polygon.contains(point)

# Apply the function to create a boolean mask
in_nest_mask = smoothed_pose_df_cp.apply(lambda row: is_in_nest(row["x"], row["y"]), axis=1)

# Calculate the "in nest" time ratio
in_nest_time_ratio = in_nest_mask.sum() / len(smoothed_pose_df_cp)
print(f'In nest time ratio is {in_nest_time_ratio}')

In [None]:
'''# Sidenote: it can be that the ids aer swappin over two mice bein very close to eac othr while sleeping?? - NO, doesn't help!
# Ensure there are exactly two unique classes
unique_classes = smoothed_pose_df_cp["class"].unique()
if len(unique_classes) != 2:
    raise ValueError("There should be exactly two unique classes to compute the distance.")

class1, class2 = unique_classes

# Pivot the DataFrame to separate x and y coordinates for each class
merged_df = smoothed_pose_df_cp.pivot(columns='class', values=['x', 'y'])

# Compute the Euclidean distance for each time point
merged_df['distance'] = np.sqrt(
    (merged_df['x'][class1] - merged_df['x'][class2])**2 +
    (merged_df['y'][class1] - merged_df['y'][class2])**2
)
# Set distance to NaN if any x or y values are None
merged_df['distance'] = merged_df['distance'].where(
    merged_df['x'][class1].notna() & merged_df['x'][class2].notna() &
    merged_df['y'][class1].notna() & merged_df['y'][class2].notna(), np.nan
)
# Compute the time difference between consecutive frames
merged_df['time_diff'] = merged_df.index.to_series().diff().dt.total_seconds()

# Compute the speed (distance/time)
merged_df['swapping_speed'] = merged_df['distance'] / merged_df['time_diff']
# Set swapping_speed to NaN if time_diff is NaN
merged_df['swapping_speed'] = merged_df['swapping_speed'].where(merged_df['time_diff'].notna(), np.nan)

# Add the computed columns to the original DataFrame
smoothed_pose_df_cp['distance'] = merged_df['distance']
smoothed_pose_df_cp['swapping_speed'] = merged_df['swapping_speed']
# Set swapping_speed to NaN where speed is NaN
smoothed_pose_df_cp['swapping_speed'] = smoothed_pose_df_cp['swapping_speed'].where(smoothed_pose_df_cp['speed'].notna(), np.nan)

# Apply smoothing to swapping_speed
window_size = 100  # You can adjust the window size as needed
smoothed_pose_df_cp['swapping_speed'] = smoothed_pose_df_cp['swapping_speed'].rolling(window=window_size, min_periods=1, center = True).mean()

fig = plot_speed(smoothed_pose_df_cp)


# Add the smoothed swapping speed as a new trace
fig.add_trace(
    go.Scatter(
        x=smoothed_pose_df_cp.index,
        y=smoothed_pose_df_cp["swapping_speed"],
        mode="markers",
        name="Smoothed Swapping Speed"
    )
)

fig.update_layout(
    title="Speed and Smoothed Swapping Speed over time",
    xaxis_title="Time",
    yaxis_title="Speed"
)

fig.show()'''

In [None]:
# Condition 2: Below threshold movement for a certain time
# apply different threshold to in nest and to outside nest, because in nest awake state has less movement than outside nest normally??
# Define thresholds
speed_threshold = 20 * cm2px # in cm/s
speed_crossing_frame_ratio = 0.05 # for max 5% speed can go over threshold
min_sleeping_time = 4 * 60 * fps # met 4min of consiitons met to be sleeping

In [None]:
#plot speed over time
fig = plot_speed(smoothed_pose_df_cp)
# Add a horizontal line at the speed threshold
fig.add_shape(
    type="line",
    x0=smoothed_pose_df_cp.index.min(),
    y0=speed_threshold,
    x1=smoothed_pose_df_cp.index.max(),
    y1=speed_threshold,
    line=dict(color="black", width=2, dash="dash"),
)

fig.show()

In [None]:
def detect_sleeping_periods(group):
    group = group.reset_index()
    
    # Identify low-speed frames
    group['below_threshold'] = (group['speed'] < speed_threshold)
    
    # Find segments of consecutive below-threshold frames
    group['segment'] = (group['below_threshold'] != group['below_threshold'].shift(1)).cumsum() # Create a new segment when the below_threshold value changes
    
    # Initialize the sleeping column
    group['inactive'] = False
    
    for segment, segment_data in group.groupby('segment'):
        low_speed_frames = segment_data['below_threshold'].sum()
        total_frames = len(segment_data)
        
        if low_speed_frames == total_frames:  # All frames in the segment are below the threshold
            if total_frames >= min_sleeping_time:
                # Mark the entire segment as sleeping
                group.loc[segment_data.index, 'inactive'] = True
        else:
            # Calculate the allowed frames above threshold
            allowed_above_threshold = int(speed_crossing_frame_ratio * total_frames)
            
            if low_speed_frames >= (min_sleeping_time - allowed_above_threshold):
                # Mark the entire segment as sleeping
                group.loc[segment_data.index, 'inactive'] = True

    return group

# Apply to each class (mouse) group
grouped = smoothed_pose_df_cp.groupby('class')
smoothed_pose_df_cp = grouped.apply(detect_sleeping_periods).reset_index(drop=True)

# Set the time column as the index
smoothed_pose_df_cp.set_index('time', inplace=True)
smoothed_pose_df_cp.sort_index(inplace=True)

# Calculate the ratio of sleeping time
sleeping_ratio = smoothed_pose_df_cp['inactive'].sum() / len(smoothed_pose_df_cp)
print(f'Ratio of inactive time is {sleeping_ratio}')


In [None]:
# get intersection of nest and inactive time
smoothed_pose_df_cp['sleeping'] = in_nest_mask & smoothed_pose_df_cp['inactive']
print(f'Ratio of sleeping time in nest is {smoothed_pose_df_cp["sleeping"].sum() / len(smoothed_pose_df_cp)}')

# Assuming plot_speed is a function that plots the speed over time
fig = plot_speed(smoothed_pose_df_cp)

fig.add_trace(
    go.Scatter(
        x=smoothed_pose_df_cp[smoothed_pose_df_cp['class'] == unique_classes[0]].index,
        y=smoothed_pose_df_cp[smoothed_pose_df_cp['class'] == unique_classes[0]]['sleeping'].apply(lambda x: 3200 if x else None),
        mode="markers",
        name=f'{unique_classes[0]}_sleeping',  
        marker=dict(color=subject_colors_dict[unique_classes[0]], symbol="circle"),
    )
)
fig.add_trace(
    go.Scatter(
        x=smoothed_pose_df_cp[smoothed_pose_df_cp['class'] == unique_classes[1]].index,
        y=smoothed_pose_df_cp[smoothed_pose_df_cp['class'] == unique_classes[1]]['sleeping'].apply(lambda x: 3000 if x else None),
        mode="markers",
        name=f'{unique_classes[1]}_sleeping', 
        marker=dict(color=subject_colors_dict[unique_classes[1]], symbol="circle"),
    )
)
fig.show()



## Questions



- Sleeping detection: better thresholds?
- removing a lot fo data, e.g. when onlyone subject
- sleeping  ID not awlways correct
   - the swap occurs where SLEAP detects only a single instance of the animal
   - the swap never ends (at the end of the dataframe)
   - animals are close together (in this case do we care?)
- grooming detected as sleeping!!
- assumes only sleeping in nest?