# MABe – Dataset processing (windowing + labels + metadata)

Goal:
Build a reusable processed dataset for training:

Outputs saved to `data/data_processed/`:
- `X_windows.npy`                     (N, T, D)
- `y_windows.npy`                     (N,)
- `video_id_windows.npy`              (N,)
- `category_windows.npy`              (N,)
- `mouse_id_windows.npy`              (N,)
- `window_start_frame_windows.npy`    (N,)
- `class_mappings.json`               (label ↔ id)
- `file_info.pkl`                     (per-file processing summary)

We will proceed in steps:
1) Setup and paths
2) Define label mapping and windowing rule
3) Identify common bodyparts (stable feature set)
4) Process each tracking file (X) and annotation file (segments)
5) Build window-level labels (multi-class)
6) Save arrays and metadata
7) Sanity checks

In [1]:
%pip install pyarrow fastparquet numpy pandas scikit-learn reservoirpy

Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 23.0.1 -> 25.3
[notice] To update, run: python.exe -m pip install --upgrade pip


## Imports

We import libraries for:
- file system navigation (Path)
- parquet loading (pandas + pyarrow)
- numeric arrays (numpy)
- saving metadata (json, pickle)

In [2]:
import json
import pickle
from pathlib import Path

import numpy as np
import pandas as pd

## Data paths

We use the structure:
- raw tracking: `data/data_raw/train_tracking/<category>/<video_id>.parquet`
- raw annotation: `data/data_raw/train_annotation/<category>/<video_id>.parquet`
- processed output: `data/data_processed/`

Cell objective: create `data/data_processed/` if needed and print paths for verification.

In [3]:
TRACK_ROOT = Path("data/data_raw/train_tracking")
ANNOT_ROOT = Path("data/data_raw/train_annotation")
PROCESSED_ROOT = Path("data/data_processed")
PROCESSED_ROOT.mkdir(parents=True, exist_ok=True)

print("TRACK_ROOT exists:", TRACK_ROOT.exists(), TRACK_ROOT.resolve())
print("ANNOT_ROOT exists:", ANNOT_ROOT.exists(), ANNOT_ROOT.resolve())
print("PROCESSED_ROOT:", PROCESSED_ROOT.resolve())

TRACK_ROOT exists: True C:\perso\Ensc\3A\IA\reservoir\reservoir-computing-mice-behavior\data\data_raw\train_tracking
ANNOT_ROOT exists: True C:\perso\Ensc\3A\IA\reservoir\reservoir-computing-mice-behavior\data\data_raw\train_annotation
PROCESSED_ROOT: C:\perso\Ensc\3A\IA\reservoir\reservoir-computing-mice-behavior\data\data_processed


## Global configuration

We decide:
- which mouse_id we process (start simple with one mouse)
- window size and step
- which categories/files to process (debug subset first, then all)

In [4]:
# Minimal, reproducible defaults
MOUSE_ID = 1

WINDOW_SIZE = 200
STEP = 200

# Debug mode: process only a few files per category first
DEBUG = True
MAX_FILES_PER_CATEGORY = 5  # set None for all files

# Categories to process
categories = sorted([d.name for d in TRACK_ROOT.iterdir() if d.is_dir()])
print("Found categories:", categories)

Found categories: ['AdaptableSnail', 'BoisterousParrot', 'CRIM13', 'CalMS21_supplemental', 'CalMS21_task1', 'CalMS21_task2', 'CautiousGiraffe', 'DeliriousFly', 'ElegantMink', 'GroovyShrew', 'InvincibleJellyfish', 'JovialSwallow', 'LyricalHare', 'MABe22_keypoints', 'MABe22_movies', 'NiftyGoldfinch', 'PleasantMeerkat', 'ReflectiveManatee', 'SparklingTapir', 'TranquilPanther', 'UppityFerret']


## Label mapping and multi-class rule

Annotations are segments:
(agent_id, target_id, action, start_frame, stop_frame)

We will build a window label by looking at which actions overlap the window.
If multiple actions overlap, we resolve conflicts using a priority order.

Next cell objective: define the class mapping (including "none") and a priority order.


In [5]:
# Define label vocabulary
# "none" means no annotated action overlaps the window
classes = ["none", "chase", "avoid", "attack", "chaseattack"]

class_to_id = {c: i for i, c in enumerate(classes)}
id_to_class = {i: c for c, i in class_to_id.items()}

# Priority order when multiple actions overlap a window
# Put the most specific / important actions first if you want
priority_order = ["attack", "chaseattack", "chase", "avoid", "none"]

print("class_to_id:", class_to_id)
print("priority_order:", priority_order)

class_to_id: {'none': 0, 'chase': 1, 'avoid': 2, 'attack': 3, 'chaseattack': 4}
priority_order: ['attack', 'chaseattack', 'chase', 'avoid', 'none']


## Helper functions (I/O)

We define:
- `load_parquet_df(path)` for robustness
- `load_annotation(video_id, category)` to load the matching annotation file

Next cell objective: implement these functions.

In [6]:
def load_parquet_df(path: Path) -> pd.DataFrame:
    if not path.exists():
        raise FileNotFoundError(str(path))
    return pd.read_parquet(path)

def load_annotation(category: str, video_id: str) -> pd.DataFrame:
    annot_path = ANNOT_ROOT / category / f"{video_id}.parquet"
    if not annot_path.exists():
        # Some datasets might have missing annotations, handle gracefully upstream
        return pd.DataFrame(columns=["agent_id", "target_id", "action", "start_frame", "stop_frame"])
    return pd.read_parquet(annot_path)

## Stable feature set (common bodyparts)

Tracking files can differ in which bodyparts are present.
To build a consistent feature matrix, we compute the intersection of bodyparts across selected files.

Next cell objective: implement `find_common_bodyparts` and run it on a subset of files.

In [7]:
def find_common_bodyparts(categories, max_files_per_category=5):
    common = None
    
    for category in categories:
        files = sorted((TRACK_ROOT / category).glob("*.parquet"))
        if max_files_per_category is not None:
            files = files[:max_files_per_category]
        
        for f in files:
            df = pd.read_parquet(f, columns=["bodypart"])
            bps = set(df["bodypart"].unique())
            common = bps if common is None else (common & bps)
    
    return sorted(list(common)) if common is not None else []

common_bodyparts = find_common_bodyparts(
    categories=categories,
    max_files_per_category=(MAX_FILES_PER_CATEGORY if DEBUG else 10)
)

print("Common bodyparts count:", len(common_bodyparts))
print("Common bodyparts:", common_bodyparts)

Common bodyparts count: 3
Common bodyparts: ['ear_left', 'ear_right', 'tail_base']


## Build X(t) from tracking (single file)

We convert the long tracking format into a wide matrix:
- index = video_frame
- columns = x_<bodypart>, y_<bodypart>

Then we:
- keep only `common_bodyparts` (stable features)
- interpolate and fill remaining missing values
- return:
  - X_final: NumPy array (T, D)
  - frames: NumPy array (T,) with frame indices

Next cell objective: implement `build_X_for_file(...)`.

In [8]:
def build_X_for_file(tracking_df: pd.DataFrame, mouse_id: int, common_bodyparts: list[str]) -> tuple[np.ndarray, np.ndarray, list[str]]:
    # Filter one mouse
    df_mouse = tracking_df[tracking_df["mouse_id"] == mouse_id].copy()
    if df_mouse.empty:
        return np.zeros((0, 0), dtype=np.float32), np.zeros((0,), dtype=np.int64), []

    # Pivot to wide
    df_wide = (
        df_mouse
        .pivot(index="video_frame", columns="bodypart", values=["x", "y"])
        .sort_index()
    )

    # Flatten columns (MultiIndex -> strings)
    df_wide.columns = [f"{coord}_{part}" for coord, part in df_wide.columns]

    # Keep only common bodyparts
    cols_keep = []
    for bp in common_bodyparts:
        cx = f"x_{bp}"
        cy = f"y_{bp}"
        if cx in df_wide.columns and cy in df_wide.columns:
            cols_keep.extend([cx, cy])

    df_feat = df_wide[cols_keep].copy()

    # Interpolate time-wise and fill edges
    df_feat = df_feat.interpolate(method="linear").ffill().bfill()

    X_final = df_feat.to_numpy(dtype=np.float32)
    frames = df_feat.index.to_numpy(dtype=np.int64)

    return X_final, frames, cols_keep

## Build window labels from segment annotations (multi-class)

We label each window [start_frame, end_frame] based on segment overlap:
- filter segments where agent_id == MOUSE_ID
- check overlap with the window
- if no overlap => "none"
- if overlap with multiple actions => pick using `priority_order`

Next cell objective: implement a function that converts segments into window labels and returns y_windows + window start frames.

In [9]:
def label_window_from_segments(ann: pd.DataFrame, mouse_id: int, w_start: int, w_end: int,
                               class_to_id: dict, priority_order: list[str]) -> int:
    """
    Return class id for one window.
    Overlap definition: segment overlaps if stop_frame >= w_start and start_frame <= w_end
    """
    if ann.empty:
        return class_to_id["none"]

    ann_sel = ann[ann["agent_id"] == mouse_id]
    if ann_sel.empty:
        return class_to_id["none"]

    overlaps = ann_sel[(ann_sel["stop_frame"] >= w_start) & (ann_sel["start_frame"] <= w_end)]
    if overlaps.empty:
        return class_to_id["none"]

    actions = set(overlaps["action"].astype(str).unique())

    for a in priority_order:
        if a in actions:
            return class_to_id[a]

    return class_to_id["none"]

def make_windows_multiclass(X: np.ndarray, frames: np.ndarray, ann: pd.DataFrame, mouse_id: int,
                            window_size: int, step: int,
                            class_to_id: dict, priority_order: list[str]):
    """
    Build:
    - X_windows: (N, window_size, D)
    - y_windows: (N,)
    - window_start_frames: (N,)
    """
    T = X.shape[0]
    if T < window_size:
        return (
            np.zeros((0, window_size, X.shape[1]), dtype=np.float32),
            np.zeros((0,), dtype=np.int64),
            np.zeros((0,), dtype=np.int64),
        )

    Xw = []
    yw = []
    w_starts = []

    for i0 in range(0, T - window_size + 1, step):
        i1 = i0 + window_size - 1
        w_start = int(frames[i0])
        w_end = int(frames[i1])

        y_id = label_window_from_segments(ann, mouse_id, w_start, w_end, class_to_id, priority_order)

        Xw.append(X[i0:i0 + window_size])
        yw.append(y_id)
        w_starts.append(w_start)

    return (
        np.stack(Xw).astype(np.float32),
        np.array(yw, dtype=np.int64),
        np.array(w_starts, dtype=np.int64),
    )

## Process all files

For each category and each tracking file:
1) Load tracking
2) Load the matching annotation file (same video_id)
3) Build X(t) for one mouse
4) Build window dataset + window labels
5) Append to global arrays
6) Record per-file summary in `file_info`

Next cell objective: implement the processing loop (with debug limits) and store metadata arrays aligned with windows.

In [10]:
all_X = []
all_y = []
all_video_id = []
all_category = []
all_mouse_id = []
all_w_start = []

file_info = []

for category in categories:
    category_path = TRACK_ROOT / category
    tracking_files = sorted(category_path.glob("*.parquet"))
    if DEBUG and MAX_FILES_PER_CATEGORY is not None:
        tracking_files = tracking_files[:MAX_FILES_PER_CATEGORY]

    print(f"\nProcessing category: {category} ({len(tracking_files)} files)")

    for tracking_path in tracking_files:
        video_id = tracking_path.stem

        # Load raw data
        tracking_df = load_parquet_df(tracking_path)
        ann = load_annotation(category, video_id)

        # Build X(t)
        X_final, frames, cols_keep = build_X_for_file(tracking_df, MOUSE_ID, common_bodyparts)

        if X_final.shape[0] == 0:
            file_info.append({
                "category": category,
                "video_id": video_id,
                "status": "skip_empty_mouse",
                "n_windows": 0
            })
            continue

        # Windowing + labels
        Xw, yw, w_starts = make_windows_multiclass(
            X=X_final,
            frames=frames,
            ann=ann,
            mouse_id=MOUSE_ID,
            window_size=WINDOW_SIZE,
            step=STEP,
            class_to_id=class_to_id,
            priority_order=priority_order
        )

        if Xw.shape[0] == 0:
            file_info.append({
                "category": category,
                "video_id": video_id,
                "status": "skip_too_short",
                "n_windows": 0
            })
            continue

        # Append global
        all_X.append(Xw)
        all_y.append(yw)
        all_video_id.append(np.array([video_id] * len(yw), dtype=object))
        all_category.append(np.array([category] * len(yw), dtype=object))
        all_mouse_id.append(np.array([MOUSE_ID] * len(yw), dtype=np.int8))
        all_w_start.append(w_starts)

        # Per-file summary
        class_counts = np.bincount(yw, minlength=len(classes))
        file_info.append({
            "category": category,
            "video_id": video_id,
            "status": "ok",
            "n_windows": int(len(yw)),
            "features": int(Xw.shape[2]),
            "class_counts": class_counts.tolist()
        })

        print(f"  {video_id}: windows={len(yw)}  shape={Xw.shape}")


Processing category: AdaptableSnail (5 files)
  1212811043: windows=367  shape=(367, 200, 6)
  1260392287: windows=269  shape=(269, 200, 6)
  1351098077: windows=400  shape=(400, 200, 6)
  1408652858: windows=92  shape=(92, 200, 6)
  143861384: windows=417  shape=(417, 200, 6)

Processing category: BoisterousParrot (5 files)
  1059582964: windows=2970  shape=(2970, 200, 6)
  1184291605: windows=2970  shape=(2970, 200, 6)
  1201849558: windows=2970  shape=(2970, 200, 6)
  1459695188: windows=2970  shape=(2970, 200, 6)
  1985626297: windows=2970  shape=(2970, 200, 6)

Processing category: CRIM13 (5 files)
  1009459450: windows=55  shape=(55, 200, 6)
  1057221056: windows=41  shape=(41, 200, 6)
  1149348188: windows=40  shape=(40, 200, 6)
  1213233769: windows=73  shape=(73, 200, 6)
  1313797424: windows=44  shape=(44, 200, 6)

Processing category: CalMS21_supplemental (5 files)
  1006083669: windows=91  shape=(91, 200, 6)
  1012566686: windows=133  shape=(133, 200, 6)
  1012566850: wind

## Concatenate and save processed arrays

We concatenate all per-file arrays into one dataset and save everything into `data/data_processed/`.

Next cell objective: concatenate arrays, write `.npy` outputs and save mappings and file_info.

In [11]:
if len(all_X) == 0:
    raise RuntimeError("No windows were produced. Check paths, mouse_id, and processing settings.")

X_all = np.concatenate(all_X, axis=0)
y_all = np.concatenate(all_y, axis=0)
video_id_all = np.concatenate(all_video_id, axis=0)
category_all = np.concatenate(all_category, axis=0)
mouse_id_all = np.concatenate(all_mouse_id, axis=0)
w_start_all = np.concatenate(all_w_start, axis=0)

print("X_all:", X_all.shape)
print("y_all:", y_all.shape)
print("video_id_all:", video_id_all.shape)
print("w_start_all:", w_start_all.shape)

np.save(PROCESSED_ROOT / "X_windows.npy", X_all)
np.save(PROCESSED_ROOT / "y_windows.npy", y_all)
np.save(PROCESSED_ROOT / "video_id_windows.npy", video_id_all, allow_pickle=True)
np.save(PROCESSED_ROOT / "category_windows.npy", category_all, allow_pickle=True)
np.save(PROCESSED_ROOT / "mouse_id_windows.npy", mouse_id_all)
np.save(PROCESSED_ROOT / "window_start_frame_windows.npy", w_start_all)

with open(PROCESSED_ROOT / "class_mappings.json", "w", encoding="utf-8") as f:
    json.dump(
        {
            "classes": classes,
            "class_to_id": class_to_id,
            "id_to_class": id_to_class,
            "priority_order": priority_order
        },
        f,
        indent=2
    )

with open(PROCESSED_ROOT / "file_info.pkl", "wb") as f:
    pickle.dump(file_info, f)

print("\nSaved to:", PROCESSED_ROOT.resolve())

X_all: (26498, 200, 6)
y_all: (26498,)
video_id_all: (26498,)
w_start_all: (26498,)

Saved to: C:\perso\Ensc\3A\IA\reservoir\reservoir-computing-mice-behavior\data\data_processed


## Sanity checks

We verify:
- shapes are consistent
- no NaNs in X
- class distribution is extremely imbalanced (expected)
- metadata arrays align with X/y

Next cell objective: run sanity checks and print a compact summary.

In [12]:
assert X_all.shape[0] == y_all.shape[0] == len(video_id_all) == len(category_all) == len(mouse_id_all) == len(w_start_all)
assert X_all.shape[1] == WINDOW_SIZE
assert not np.isnan(X_all).any()

counts = np.bincount(y_all, minlength=len(classes))
total = counts.sum()

print("\n=== DATASET SUMMARY ===")
print("Total windows:", int(total))
print("Window shape:", X_all.shape[1:], "(timesteps × features)")
print("Class distribution:")
for i, c in enumerate(classes):
    print(f"  {c:>10s}: {counts[i]:>8d}  ({counts[i]/total:.5f})")

print("\nFirst 5 samples metadata:")
for k in range(5):
    print(video_id_all[k], category_all[k], int(mouse_id_all[k]), int(w_start_all[k]), id_to_class[int(y_all[k])])


=== DATASET SUMMARY ===
Total windows: 26498
Window shape: (200, 6) (timesteps × features)
Class distribution:
        none:    25902  (0.97751)
       chase:       42  (0.00159)
       avoid:       51  (0.00192)
      attack:      502  (0.01894)
  chaseattack:        1  (0.00004)

First 5 samples metadata:
1212811043 AdaptableSnail 1 0 chase
1212811043 AdaptableSnail 1 200 chase
1212811043 AdaptableSnail 1 400 none
1212811043 AdaptableSnail 1 600 none
1212811043 AdaptableSnail 1 800 chase
