# Behavior Classification using Explainable Active Learning Model

## 1. Load Data

In [15]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from itertools import combinations
import seaborn as sns
sns.set_theme(style="whitegrid")

import os
from dotenv import load_dotenv
load_dotenv()
DLC_DATASET_PATH = os.getenv('DLC_DATASET_PATH')

In [25]:
FRAME_RATE = 30
SAMPLE_RATE = 10
MULTI_ANIMAL = False
LIKELIHOOD_THRESHOLD = 0.6
BODY_PARTS = ["tailbase", "earR", "earL", "msBase", "msTop", "centroid", "cleft", "cright"]
RULE_BASED_LABELS = ["shuttles_label_naive1", "shuttles_label_naive2", "shuttles_label_naive3", "shuttles_label_hardcoded", "freezing_label"]
labels = RULE_BASED_LABELS+["other"]
TIME_WINDOW = 10 # frames

In [3]:
# Multi-level header
df_pose_labels = pd.read_csv(DLC_DATASET_PATH + "/labelled_DLC.csv", header=[0, 1], index_col=0)

# Flatten the multi-level header into single strings
df_pose_labels.columns = ['_'.join(col).strip() for col in df_pose_labels.columns.values]

# Remove second level header if it is empty
df_pose_labels.columns = [col.split("_Unnamed")[0] if "_Unnamed" in col else col for col in df_pose_labels.columns.values]

# Remove NaN columns and rows
df_pose_labels = df_pose_labels.dropna(axis=1, how='all')
df_pose_labels = df_pose_labels.dropna(axis=0, how='any')

print("Shape of df_pose_labels:", df_pose_labels.shape)
df_pose_labels.head()

Shape of df_pose_labels: (20000, 32)


Unnamed: 0,tailbase_x,tailbase_y,tailbase_likelihood,earR_x,earR_y,earR_likelihood,earL_x,earL_y,earL_likelihood,msBase_x,...,cright_y,cright_likelihood,calc_centroid_x,calc_centroid_y,speed,shuttles_label_naive1,shuttles_label_naive2,shuttles_label_naive3,shuttles_label_hardcoded,freezing_label
0,215.430298,318.572968,0.999996,335.296844,346.048737,0.999918,351.505981,309.215546,0.999891,351.370636,...,345.997833,0.999949,313.40094,326.633553,0.0,0,0,0,0,0
1,216.124664,319.105774,0.999997,334.252045,349.912659,0.99995,351.527374,307.669678,0.99998,351.550171,...,347.752991,0.99995,313.363564,327.412323,0.779667,0,0,0,0,0
2,215.643936,319.362427,0.999998,333.88031,349.696228,0.999942,351.895233,308.665619,0.999976,352.242249,...,348.114502,0.999924,313.415432,327.838448,0.42927,0,0,0,0,0
3,216.000183,319.383545,0.999997,334.263062,349.323456,0.99992,351.224792,307.465485,0.999981,351.896088,...,347.783112,0.999937,313.346031,327.274757,0.567946,0,0,0,0,0
4,215.784775,319.112488,0.999998,334.187714,349.681763,0.999959,351.621399,307.867065,0.999985,351.576569,...,347.832947,0.99997,313.292614,327.455528,0.188498,0,0,0,0,0


In [4]:
df_labels = df_pose_labels[RULE_BASED_LABELS]

print("Shape of df_labels:", df_labels.shape)
df_labels.head()

Shape of df_labels: (20000, 5)


Unnamed: 0,shuttles_label_naive1,shuttles_label_naive2,shuttles_label_naive3,shuttles_label_hardcoded,freezing_label
0,0,0,0,0,0
1,0,0,0,0,0
2,0,0,0,0,0
3,0,0,0,0,0
4,0,0,0,0,0


In [5]:
# all unlabeled are placed into "other"
df_labels_other = df_labels.copy()
df_labels_other["other"] = 0

# find all unlabeled
unlabeled_data = df_labels_other.sum(axis=1) == 0

# change all unlabeled to 1 in other
df_labels_other.loc[unlabeled_data, "other"] = 1

print("Shape of df_labels_other:", df_labels_other.shape)
df_labels_other.head()

Shape of df_labels_other: (20000, 6)


Unnamed: 0,shuttles_label_naive1,shuttles_label_naive2,shuttles_label_naive3,shuttles_label_hardcoded,freezing_label,other
0,0,0,0,0,0,1
1,0,0,0,0,0,1
2,0,0,0,0,0,1
3,0,0,0,0,0,1
4,0,0,0,0,0,1


In [6]:
df_pose = df_pose_labels.drop(columns=RULE_BASED_LABELS)

print("Shape of df_pose:", df_pose.shape)
df_pose.head()

Shape of df_pose: (20000, 27)


Unnamed: 0,tailbase_x,tailbase_y,tailbase_likelihood,earR_x,earR_y,earR_likelihood,earL_x,earL_y,earL_likelihood,msBase_x,...,centroid_likelihood,cleft_x,cleft_y,cleft_likelihood,cright_x,cright_y,cright_likelihood,calc_centroid_x,calc_centroid_y,speed
0,215.430298,318.572968,0.999996,335.296844,346.048737,0.999918,351.505981,309.215546,0.999891,351.370636,...,0.999951,281.177429,278.16452,0.999803,275.792725,345.997833,0.999949,313.40094,326.633553,0.0
1,216.124664,319.105774,0.999997,334.252045,349.912659,0.99995,351.527374,307.669678,0.99998,351.550171,...,0.999983,281.244598,276.559937,0.999884,277.471069,347.752991,0.99995,313.363564,327.412323,0.779667
2,215.643936,319.362427,0.999998,333.88031,349.696228,0.999942,351.895233,308.665619,0.999976,352.242249,...,0.999982,282.179657,276.486542,0.999849,276.880859,348.114502,0.999924,313.415432,327.838448,0.42927
3,216.000183,319.383545,0.999997,334.263062,349.323456,0.99992,351.224792,307.465485,0.999981,351.896088,...,0.999978,283.015259,276.601379,0.999904,277.634735,347.783112,0.999937,313.346031,327.274757,0.567946
4,215.784775,319.112488,0.999998,334.187714,349.681763,0.999959,351.621399,307.867065,0.999985,351.576569,...,0.999984,282.534119,277.098724,0.999917,276.700134,347.832947,0.99997,313.292614,327.455528,0.188498


In [7]:
df_pose_labels_other = pd.concat([df_pose, df_labels_other], axis=1)

print("Shape of df_pose_labels_other:", df_pose_labels_other.shape)
df_pose_labels_other.head()

Shape of df_pose_labels_other: (20000, 33)


Unnamed: 0,tailbase_x,tailbase_y,tailbase_likelihood,earR_x,earR_y,earR_likelihood,earL_x,earL_y,earL_likelihood,msBase_x,...,cright_likelihood,calc_centroid_x,calc_centroid_y,speed,shuttles_label_naive1,shuttles_label_naive2,shuttles_label_naive3,shuttles_label_hardcoded,freezing_label,other
0,215.430298,318.572968,0.999996,335.296844,346.048737,0.999918,351.505981,309.215546,0.999891,351.370636,...,0.999949,313.40094,326.633553,0.0,0,0,0,0,0,1
1,216.124664,319.105774,0.999997,334.252045,349.912659,0.99995,351.527374,307.669678,0.99998,351.550171,...,0.99995,313.363564,327.412323,0.779667,0,0,0,0,0,1
2,215.643936,319.362427,0.999998,333.88031,349.696228,0.999942,351.895233,308.665619,0.999976,352.242249,...,0.999924,313.415432,327.838448,0.42927,0,0,0,0,0,1
3,216.000183,319.383545,0.999997,334.263062,349.323456,0.99992,351.224792,307.465485,0.999981,351.896088,...,0.999937,313.346031,327.274757,0.567946,0,0,0,0,0,1
4,215.784775,319.112488,0.999998,334.187714,349.681763,0.999959,351.621399,307.867065,0.999985,351.576569,...,0.99997,313.292614,327.455528,0.188498,0,0,0,0,0,1


## 2. Preprocessing using A-SOiD Pipeline

### 2.1 Filter by Likelihood

Smooths out unreliable keypoint coordinates based on confidence values (likelihoods). This is a "hold last good value" filter to avoid jittery or missing points.

The `adp_filt` function is rewrite from the `adp_filt` function in A-SOiD (likelihood adaptive filtering).

In [8]:
# likelihood adaptive filtering
def adp_filt(df_pose_labels_other, idx_coord, idx_llh, llh_value, labels):
    """
    For body parts with likelihood values below the threshold, copy the previous valid row's x, y coordinates.
    Labels are not modified during filtering and are set back after filtering.

    Parameters:
        df_pose_labels_other (pd.DataFrame): The input DataFrame containing x, y, likelihood, and pose labels.
        idx_coord (list): Indices of selected body parts (x and y columns).
        idx_llh (list): Indices of likelihood columns.
        llh_value (float): Threshold for likelihood filtering.
        labels (list): List of pose labels to retain in the output.

    Returns:
        pd.DataFrame: DataFrame with invalid x, y coordinates corrected.
        dict: Statistics on filtered body parts and likelihood values below the threshold.
    """
    # Convert DataFrame to numpy arrays for x, y, likelihood, and pose labels
    data_x_coord = np.array(df_pose_labels_other.iloc[:, idx_coord[::2]])
    data_y_coord = np.array(df_pose_labels_other.iloc[:, idx_coord[1::2]])
    data_llh = np.array(df_pose_labels_other.iloc[:, idx_llh])
    original_labels = df_pose_labels_other[labels].copy()  # Preserve original labels

    # Initialize statistics
    llh_below_threshold = 0
    total_llh = data_llh.size  # Total number of likelihood values
    body_part_stats = {df_pose_labels_other.columns[idx_coord[::2][i]]: 0 for i in range(len(idx_coord[::2]))}

    # Iterate through body parts and correct invalid x, y coordinates
    for x in range(data_llh.shape[1]):  # Iterate over each body part
        for i in range(1, data_llh.shape[0]):  # Start from the second row
            if data_llh[i, x] < llh_value:  # If likelihood is below the threshold
                llh_below_threshold += 1
                body_part_stats[df_pose_labels_other.columns[idx_coord[::2][x]]] += 1
                # Copy the previous row's x, y coordinates for this body part
                data_x_coord[i, x] = data_x_coord[i - 1, x]
                data_y_coord[i, x] = data_y_coord[i - 1, x]

    # Replace the x, y columns in the DataFrame with corrected values
    for idx, col_idx in enumerate(idx_coord[::2]):  # Update x columns
        df_pose_labels_other.iloc[:, col_idx] = data_x_coord[:, idx]
    for idx, col_idx in enumerate(idx_coord[1::2]):  # Update y columns
        df_pose_labels_other.iloc[:, col_idx] = data_y_coord[:, idx]

    # Restore the original labels
    df_pose_labels_other[labels] = original_labels

    # Calculate likelihood ratio
    llh_ratio = llh_below_threshold / total_llh

    # Prepare statistics
    stats = {
        "llh_below_threshold": llh_below_threshold,
        "total_llh": total_llh,
        "llh_below_threshold_ratio": llh_ratio,
        "body_part_stats": body_part_stats
    }

    return df_pose_labels_other, stats

In [9]:
# Get numeric indices of columns ending with "_likelihood"
idx_llh = [i for i, col in enumerate(df_pose_labels_other.columns) if col.endswith("_likelihood")]
labels = RULE_BASED_LABELS+["other"]
# Get numeric indices of columns of x and y coordinates
idx_coord = [i for i, col in enumerate(df_pose_labels_other.columns) if not col.endswith("_likelihood") and col not in labels]

filt_df_pose_labels_other, stats = adp_filt(
    df_pose_labels_other=df_pose_labels_other,
    idx_coord=idx_coord,
    idx_llh=idx_llh,
    llh_value=LIKELIHOOD_THRESHOLD,
    labels=labels
)

print("Filtered DataFrame's shape:")
print(filt_df_pose_labels_other.shape)
print("\nStatistics:")
print(stats)

Filtered DataFrame's shape:
(20000, 33)

Statistics:
{'llh_below_threshold': 1658, 'total_llh': 160000, 'llh_below_threshold_ratio': 0.0103625, 'body_part_stats': {'tailbase_x': 287, 'earR_x': 261, 'earL_x': 258, 'msBase_x': 279, 'msTop_x': 284, 'centroid_x': 14, 'cleft_x': 158, 'cright_x': 117, 'calc_centroid_x': 0, 'speed': 0}}


In [10]:
filt_df_pose = filt_df_pose_labels_other.drop(columns=RULE_BASED_LABELS + ["other"])

print("Filtered DataFrame's shape without labels:")
print(filt_df_pose.shape)
filt_df_pose.head()

Filtered DataFrame's shape without labels:
(20000, 27)


Unnamed: 0,tailbase_x,tailbase_y,tailbase_likelihood,earR_x,earR_y,earR_likelihood,earL_x,earL_y,earL_likelihood,msBase_x,...,centroid_likelihood,cleft_x,cleft_y,cleft_likelihood,cright_x,cright_y,cright_likelihood,calc_centroid_x,calc_centroid_y,speed
0,215.430298,318.572968,0.999996,335.296844,346.048737,0.999918,351.505981,309.215546,0.999891,351.370636,...,0.999951,281.177429,278.16452,0.999803,275.792725,345.997833,0.999949,313.40094,326.633553,0.0
1,216.124664,319.105774,0.999997,334.252045,349.912659,0.99995,351.527374,307.669678,0.99998,351.550171,...,0.999983,281.244598,276.559937,0.999884,277.471069,347.752991,0.99995,313.363564,327.412323,0.779667
2,215.643936,319.362427,0.999998,333.88031,349.696228,0.999942,351.895233,308.665619,0.999976,352.242249,...,0.999982,282.179657,276.486542,0.999849,276.880859,348.114502,0.999924,313.415432,327.838448,0.42927
3,216.000183,319.383545,0.999997,334.263062,349.323456,0.99992,351.224792,307.465485,0.999981,351.896088,...,0.999978,283.015259,276.601379,0.999904,277.634735,347.783112,0.999937,313.346031,327.274757,0.567946
4,215.784775,319.112488,0.999998,334.187714,349.681763,0.999959,351.621399,307.867065,0.999985,351.576569,...,0.999984,282.534119,277.098724,0.999917,276.700134,347.832947,0.99997,313.292614,327.455528,0.188498


In [11]:
filt_df_labels_other = filt_df_pose_labels_other[RULE_BASED_LABELS + ["other"]]

print("Filtered DataFrame's shape with labels:")
print(filt_df_labels_other.shape)
filt_df_labels_other.head()

Filtered DataFrame's shape with labels:
(20000, 6)


Unnamed: 0,shuttles_label_naive1,shuttles_label_naive2,shuttles_label_naive3,shuttles_label_hardcoded,freezing_label,other
0,0,0,0,0,0,1
1,0,0,0,0,0,1
2,0,0,0,0,0,1
3,0,0,0,0,0,1
4,0,0,0,0,0,1


In [12]:
filt_df_pose_labels_other.info()

<class 'pandas.core.frame.DataFrame'>
Index: 20000 entries, 0 to 19999
Data columns (total 33 columns):
 #   Column                    Non-Null Count  Dtype  
---  ------                    --------------  -----  
 0   tailbase_x                20000 non-null  float64
 1   tailbase_y                20000 non-null  float64
 2   tailbase_likelihood       20000 non-null  float64
 3   earR_x                    20000 non-null  float64
 4   earR_y                    20000 non-null  float64
 5   earR_likelihood           20000 non-null  float64
 6   earL_x                    20000 non-null  float64
 7   earL_y                    20000 non-null  float64
 8   earL_likelihood           20000 non-null  float64
 9   msBase_x                  20000 non-null  float64
 10  msBase_y                  20000 non-null  float64
 11  msBase_likelihood         20000 non-null  float64
 12  msTop_x                   20000 non-null  float64
 13  msTop_y                   20000 non-null  float64
 14  msTop_likel

## 3. Feature Engineering

In [19]:
# Extract x and y coordinates
pose_xy = {}
for part in BODY_PARTS:
    x_col = f"{part}_x"
    y_col = f"{part}_y"
    pose_xy[part] = filt_df_pose[[x_col, y_col]].values  # shape: (n_frames, 2)

### 3.1 Euclidean Distance Between Body Parts

In [20]:
def compute_pairwise_distances(pose_xy, pairs):
    dist_feats = {}
    for (a, b) in pairs:
        pa = pose_xy[a]
        pb = pose_xy[b]
        dist = np.linalg.norm(pa - pb, axis=1)
        dist_feats[f"dist_{a}_{b}"] = dist
    return pd.DataFrame(dist_feats)

In [None]:
# Example useful anatomical pairs
pairs = [("tailbase", "centroid"), ("earL", "earR"), ("msBase", "msTop")]

In [21]:
# All unique pairs
pairs = list(combinations(BODY_PARTS, 2))

print("Number of pairs:", len(pairs))
print(pairs)

Number of pairs: 28
[('tailbase', 'earR'), ('tailbase', 'earL'), ('tailbase', 'msBase'), ('tailbase', 'msTop'), ('tailbase', 'centroid'), ('tailbase', 'cleft'), ('tailbase', 'cright'), ('earR', 'earL'), ('earR', 'msBase'), ('earR', 'msTop'), ('earR', 'centroid'), ('earR', 'cleft'), ('earR', 'cright'), ('earL', 'msBase'), ('earL', 'msTop'), ('earL', 'centroid'), ('earL', 'cleft'), ('earL', 'cright'), ('msBase', 'msTop'), ('msBase', 'centroid'), ('msBase', 'cleft'), ('msBase', 'cright'), ('msTop', 'centroid'), ('msTop', 'cleft'), ('msTop', 'cright'), ('centroid', 'cleft'), ('centroid', 'cright'), ('cleft', 'cright')]


In [24]:
df_distances = compute_pairwise_distances(pose_xy, pairs)
df_distances

Unnamed: 0,dist_tailbase_earR,dist_tailbase_earL,dist_tailbase_msBase,dist_tailbase_msTop,dist_tailbase_centroid,dist_tailbase_cleft,dist_tailbase_cright,dist_earR_earL,dist_earR_msBase,dist_earR_msTop,...,dist_msBase_msTop,dist_msBase_centroid,dist_msBase_cleft,dist_msBase_cright,dist_msTop_centroid,dist_msTop_cleft,dist_msTop_cright,dist_centroid_cleft,dist_centroid_cright,dist_cleft_cright
0,122.975229,136.397042,136.672099,166.476268,63.320757,77.172067,66.300421,40.242019,20.895854,43.994094,...,31.926015,75.445604,88.886857,76.739390,106.687303,120.547023,103.466712,35.038341,33.008714,68.046700
1,122.078426,135.884797,136.132436,169.519095,62.794644,77.786593,67.705572,45.638870,24.219367,48.560720,...,35.583866,75.734094,90.133090,75.541452,110.697430,125.312568,105.364732,36.198076,35.110380,71.292990
2,122.065473,136.670544,137.341359,170.888230,63.621163,79.153924,67.650888,44.811252,24.398770,49.836089,...,35.601540,76.226383,90.410654,76.740822,111.296469,125.513767,106.854020,36.436717,35.396018,71.823686
3,121.993880,135.748794,136.569066,169.831756,63.867636,79.506818,67.862754,45.164035,24.078669,48.821714,...,35.468056,75.132623,88.978047,75.732860,109.993800,123.993730,105.467191,36.319353,35.081536,71.384796
4,122.285471,136.301313,136.516542,170.248654,63.037435,78.870978,67.346461,45.303447,23.985716,48.965926,...,35.932480,75.916639,88.937146,76.300409,111.238702,124.435479,106.478351,35.762921,35.246282,70.974401
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
19995,114.211038,116.513559,124.011370,151.460259,53.391755,57.646419,72.330235,46.852324,26.019420,40.756516,...,29.408292,70.880384,85.787525,72.952174,98.068735,115.193691,93.127008,39.794623,34.884857,74.653139
19996,111.687333,115.461766,121.515750,149.879128,52.597994,58.936623,70.859876,45.830029,26.966468,41.709922,...,30.659844,69.088407,83.256589,71.249996,97.307651,113.914055,91.958028,39.816424,34.594374,74.346044
19997,111.311136,111.425706,120.730248,150.484771,53.882642,57.895413,71.337304,41.908644,26.339697,42.030294,...,32.382504,67.153725,82.191668,70.627390,96.612096,114.570203,91.972156,39.777909,33.814905,73.588233
19998,109.734464,106.992409,119.855866,148.677283,52.354342,55.990388,69.843360,38.449068,26.371002,41.847776,...,31.365284,68.442968,81.178398,73.915220,96.424730,112.524650,94.183856,40.559783,32.703445,73.260289


### 3.2 Angular Change (turning angle)

Use 3 points to define angle at B (A-B-C triangle).

In [26]:
def compute_angle(a, b, c):
    v1 = a - b
    v2 = c - b
    cosine = np.einsum('ij,ij->i', v1, v2) / (
        np.linalg.norm(v1, axis=1) * np.linalg.norm(v2, axis=1) + 1e-6
    )
    return np.arccos(np.clip(cosine, -1, 1))  # In radians

In [27]:
# Turning angle at centroid (tailbase - centroid - msTop)
angle_vals = compute_angle(pose_xy["tailbase"], pose_xy["centroid"], pose_xy["msTop"])
df_angles = pd.DataFrame({
    "angle_tailbase_centroid_msTop": np.degrees(angle_vals)  # Convert to degrees
})
df_angles

Unnamed: 0,angle_tailbase_centroid_msTop
0,155.789200
1,154.418971
2,154.370237
3,154.346141
4,154.297815
...,...
19995,179.790712
19996,177.741804
19997,178.624531
19998,175.561642


### 3.3 Total Displacement over Time Window

In [28]:
def compute_displacement(xy_array, window):
    # Euclidean distance between current and past frame
    disp = np.linalg.norm(xy_array - np.roll(xy_array, window, axis=0), axis=1)
    disp[:window] = 0  # zero displacement for initial frames
    return disp

In [29]:
disp_centroid = compute_displacement(pose_xy["centroid"], window=TIME_WINDOW)
df_displacement = pd.DataFrame({"displacement_centroid_w5": disp_centroid})
df_displacement

Unnamed: 0,displacement_centroid_w5
0,0.000000
1,0.000000
2,0.000000
3,0.000000
4,0.000000
...,...
19995,15.239956
19996,12.837550
19997,8.456646
19998,6.843706


### 3.4 Velocity Vector (Δx, Δy per frame)

### 3.5 Combine All Features

In [None]:
df_features = pd.concat([df_distances, df_angles, df_displacement, df_speed], axis=1)

df_pose_features_labels_other = pd.concat([filt_df_pose, df_features, filt_df_labels_other], axis=1)

## 4. Rule-based Labeling (skipped for now)

## 5. Semi-supervised metric learning → low-dimensional vector embedding space

## 6. Clustering (active learning loop skipped for now)

## 7. Train a decision tree classification model on engineered features to predict cluster IDs. This way we can get explainable feature importance (Explainability Model, interpretable)

## 8. Train a classification model (maybe also decision tree) on embedded (maybe: and clustered) data to classify behaviors (Performance Model, accurate but abstract)