# MABe – Dataset Processing and Scaling

This notebook processes the entire MABe dataset for Reservoir Computing training.
It loads all tracking files, extracts common features, and creates windowed datasets.

**Prerequisites:** Run `01_exploration.ipynb` first to understand the data structure.

In [None]:
%pip install pyarrow fastparquet
%pip install reservoirpy
%pip install matplotlib
%pip install scikit-learn

Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 23.0.1 -> 25.3
[notice] To update, run: python.exe -m pip install --upgrade pip


In [2]:
import pandas as pd
import numpy as np
from pathlib import Path
import pickle
import json

## Data Paths

In [3]:
TRACK_ROOT = Path("data/data_raw/train_tracking")
ANNOT_ROOT = Path("data/data_raw/train_annotation")
PROCESSED_ROOT = Path("data/data_processed")
PROCESSED_ROOT.mkdir(exist_ok=True)

print("Data directories:")
print(f"  Tracking: {TRACK_ROOT}")
print(f"  Annotations: {ANNOT_ROOT}")
print(f"  Processed: {PROCESSED_ROOT}")

Data directories:
  Tracking: data\data_raw\train_tracking
  Annotations: data\data_raw\train_annotation
  Processed: data\data_processed


## Helper Functions

In [4]:
def load_annotation(path: Path) -> pd.DataFrame:
    """Load annotation file based on extension."""
    suffix = path.suffix.lower()
    if suffix == ".parquet":
        return pd.read_parquet(path)
    if suffix == ".csv":
        return pd.read_csv(path)
    if suffix in [".tsv", ".txt"]:
        return pd.read_csv(path, sep="\t")
    raise ValueError(f"Unsupported annotation file type: {path}")

def find_common_bodyparts(categories, max_files_per_category=10):
    """Find bodyparts common to all files in the given categories."""
    all_bodyparts = None
    for category in categories:
        category_path = TRACK_ROOT / category
        tracking_files = list(category_path.glob("*.parquet"))[:max_files_per_category]
        
        for tracking_path in tracking_files:
            df = pd.read_parquet(tracking_path)
            df_mouse = df[df["mouse_id"] == 1]
            bodyparts = set(df_mouse["bodypart"].unique())
            if all_bodyparts is None:
                all_bodyparts = bodyparts
            else:
                all_bodyparts = all_bodyparts.intersection(bodyparts)
    
    # Remove bad bodyparts
    bad_bodyparts = {"neck"}
    all_bodyparts = all_bodyparts - bad_bodyparts
    return all_bodyparts

def process_tracking_file(tracking_path: Path, mouse_id: int = 1, bad_bodyparts: list = ["neck"], common_bodyparts: set = None):
    """Process a single tracking file into X_final and return metadata."""
    # Load tracking data
    df = pd.read_parquet(tracking_path)
    
    # Filter to specific mouse
    df_mouse = df[df["mouse_id"] == mouse_id]
    
    # Convert to wide format
    df_wide = (
        df_mouse
        .pivot(index="video_frame", columns="bodypart", values=["x", "y"])
        .sort_index()
    )
    df_wide.columns = [f"{coord}_{part}" for coord, part in df_wide.columns]
    
    # Filter to common bodyparts if specified
    if common_bodyparts is not None:
        cols_to_keep = []
        for col in df_wide.columns:
            # col is like "x_bodypart" or "y_bodypart"
            parts = col.split('_', 1)
            if len(parts) == 2:
                coord, bodypart = parts
                if bodypart in common_bodyparts:
                    cols_to_keep.append(col)
        df_wide = df_wide[cols_to_keep]
    
    # Remove bad bodyparts
    cols_to_drop = [
        col for col in df_wide.columns
        if any(bp in col for bp in bad_bodyparts)
    ]
    df_clean = df_wide.drop(columns=cols_to_drop)
    
    # Interpolate NaNs
    df_interp = df_clean.interpolate(method="linear").ffill().bfill()
    
    # Convert to numpy
    X_final = df_interp.to_numpy(dtype=np.float32)
    
    return X_final, df_interp.index.to_numpy()  # frames

def process_annotation_file(tracking_path: Path, frames: np.ndarray, mouse_id: int = 1, action_name: str = "chase"):
    """Create y array from annotation file matching the tracking file."""
    file_id = tracking_path.stem
    category = tracking_path.parent.name
    annot_path = ANNOT_ROOT / category / f"{file_id}.parquet"
    
    if not annot_path.exists():
        print(f"Warning: No annotation for {tracking_path}")
        return np.zeros(len(frames), dtype=np.int8)
    
    ann = load_annotation(annot_path)
    ann_sel = ann[(ann["agent_id"] == mouse_id) & (ann["action"] == action_name)]
    
    y = np.zeros(len(frames), dtype=np.int8)
    for start, stop in ann_sel[["start_frame", "stop_frame"]].itertuples(index=False, name=None):
        mask = (frames >= start) & (frames <= stop)
        y[mask] = 1
    
    return y

def create_windows(X, y, window_size=200, step=200):
    """Create windowed dataset from time series X and y."""
    Xw, yw = [], []
    T = X.shape[0]
    for start in range(0, T - window_size + 1, step):
        end = start + window_size
        Xw.append(X[start:end])
        yw.append(1 if y[start:end].any() else 0)
    return np.stack(Xw), np.array(yw, dtype=np.int8)

# Multi-class classification setup
actions = ["chase", "avoid", "attack", "chaseattack"]
classes = actions + ["none"]  # Add "none" class for no action
class_to_id = {cls: i for i, cls in enumerate(classes)}
id_to_class = {i: cls for i, cls in enumerate(classes)}
priority_order = ["chaseattack", "attack", "chase", "avoid"]  # Highest priority first

def process_annotation_multiclass(tracking_path: Path, mouse_id: int = 1):
    """Load annotation file and return filtered segments for the mouse."""
    file_id = tracking_path.stem
    category = tracking_path.parent.name
    annot_path = ANNOT_ROOT / category / f"{file_id}.parquet"
    
    if not annot_path.exists():
        print(f"Warning: No annotation for {tracking_path}")
        return pd.DataFrame()  # Return empty DataFrame
    
    ann = load_annotation(annot_path)
    ann_sel = ann[ann["agent_id"] == mouse_id].copy()
    return ann_sel

def create_windows_multiclass(X, ann_segments, frames, window_size=200, step=200):
    """Create windowed dataset with multi-class labels from annotation segments."""
    Xw, yw, window_starts, window_ends = [], [], [], []
    T = X.shape[0]
    
    for start in range(0, T - window_size + 1, step):
        end = start + window_size
        
        # Get frame range for this window
        window_start_frame = frames[start]
        window_end_frame = frames[end-1]
        
        # Find actions that overlap with this window
        overlapping_actions = set()
        for _, segment in ann_segments.iterrows():
            seg_start = segment["start_frame"]
            seg_end = segment["stop_frame"]
            action = segment["action"]
            
            # Check for overlap: segments overlap if seg_start < window_end and seg_end > window_start
            if seg_start < window_end_frame and seg_end > window_start_frame:
                overlapping_actions.add(action)
        
        # Determine label based on priority
        if not overlapping_actions:
            label = "none"
        else:
            # Find highest priority action
            for action in priority_order:
                if action in overlapping_actions:
                    label = action
                    break
        
        # Convert to class ID
        y_label = class_to_id[label]
        
        # Store window
        Xw.append(X[start:end])
        yw.append(y_label)
        window_starts.append(window_start_frame)
        window_ends.append(window_end_frame)
    
    return (np.stack(Xw), 
            np.array(yw, dtype=np.int64), 
            np.array(window_starts, dtype=np.int64), 
            np.array(window_ends, dtype=np.int64))

# Multi-class classification setup
actions = ["chase", "avoid", "attack", "chaseattack"]
classes = actions + ["none"]  # Add "none" class for no action
class_to_id = {cls: i for i, cls in enumerate(classes)}
id_to_class = {i: cls for i, cls in enumerate(classes)}
priority_order = ["chaseattack", "attack", "chase", "avoid"]  # Highest priority first

def process_annotation_multiclass(tracking_path: Path, mouse_id: int = 1):
    """Load annotation file and return filtered segments for the mouse."""
    file_id = tracking_path.stem
    category = tracking_path.parent.name
    annot_path = ANNOT_ROOT / category / f"{file_id}.parquet"
    
    if not annot_path.exists():
        print(f"Warning: No annotation for {tracking_path}")
        return pd.DataFrame()  # Return empty DataFrame
    
    ann = load_annotation(annot_path)
    ann_sel = ann[ann["agent_id"] == mouse_id].copy()
    return ann_sel

def create_windows_multiclass(X, ann_segments, frames, window_size=200, step=200):
    """Create windowed dataset with multi-class labels from annotation segments."""
    Xw, yw, window_starts, window_ends = [], [], [], []
    T = X.shape[0]
    
    for start in range(0, T - window_size + 1, step):
        end = start + window_size
        
        # Get frame range for this window
        window_start_frame = frames[start]
        window_end_frame = frames[end-1]
        
        # Find actions that overlap with this window
        overlapping_actions = set()
        for _, segment in ann_segments.iterrows():
            seg_start = segment["start_frame"]
            seg_end = segment["stop_frame"]
            action = segment["action"]
            
            # Check for overlap: segments overlap if seg_start < window_end and seg_end > window_start
            if seg_start < window_end_frame and seg_end > window_start_frame:
                overlapping_actions.add(action)
        
        # Determine label based on priority
        if not overlapping_actions:
            label = "none"
        else:
            # Find highest priority action
            for action in priority_order:
                if action in overlapping_actions:
                    label = action
                    break
        
        # Convert to class ID
        y_label = class_to_id[label]
        
        # Store window
        Xw.append(X[start:end])
        yw.append(y_label)
        window_starts.append(window_start_frame)
        window_ends.append(window_end_frame)
    
    return (np.stack(Xw), 
            np.array(yw, dtype=np.int64), 
            np.array(window_starts, dtype=np.int64), 
            np.array(window_ends, dtype=np.int64))


## Batch Processing Configuration

In [5]:
# Processing parameters
mouse_id = 1
action_name = "chase"
window_size = 200
step = 200
bad_bodyparts = ["neck"]

# Categories to process (uncomment for all)
# categories = [d.name for d in TRACK_ROOT.iterdir() if d.is_dir()]
categories = ["AdaptableSnail", "BoisterousParrot"]  # Test with 2 categories

print(f"Will process {len(categories)} categories: {categories}")
print(f"Parameters: mouse_id={mouse_id}, action='{action_name}', window_size={window_size}")

Will process 2 categories: ['AdaptableSnail', 'BoisterousParrot']
Parameters: mouse_id=1, action='chase', window_size=200


## Find Common Bodyparts

In [6]:
print("Finding common bodyparts across files...")
common_bodyparts = find_common_bodyparts(categories, max_files_per_category=10)
print(f"Common bodyparts ({len(common_bodyparts)}): {sorted(common_bodyparts)}")
print(f"Total features per mouse: {len(common_bodyparts) * 2} (x,y coordinates)")

Finding common bodyparts across files...
Common bodyparts (5): ['body_center', 'ear_left', 'ear_right', 'nose', 'tail_base']
Total features per mouse: 10 (x,y coordinates)


## Process All Files

In [7]:
all_X_windows = []
all_y_windows = []
file_info = []

# Initialize metadata lists
video_id_windows = []
category_windows = []
mouse_id_windows = []
window_start_frame_windows = []

for category in categories:
    print(f"\nProcessing category: {category}")
    category_path = TRACK_ROOT / category
    tracking_files = list(category_path.glob("*.parquet"))
    
    # Limit for testing - remove for full processing
    tracking_files = tracking_files[:5]  # First 5 files per category
    
    for tracking_path in tracking_files:
        print(f"  Processing: {tracking_path.name}")
        
        try:
            # Process tracking
            X_final, frames = process_tracking_file(tracking_path, mouse_id, bad_bodyparts, common_bodyparts)
            
            # Process annotation (multi-class)
            ann_segments = process_annotation_multiclass(tracking_path, mouse_id)
            
            # Create windows (multi-class)
            X_windows, y_windows, window_starts, window_ends = create_windows_multiclass(
                X_final, ann_segments, frames, window_size, step)
            
            if len(X_windows) > 0:
                all_X_windows.append(X_windows)
                all_y_windows.append(y_windows)
                file_info.append({
                    'category': category,
                    'file': tracking_path.name,
                    'n_windows': len(X_windows),
                    'class_counts': np.bincount(y_windows, minlength=len(classes)),
                    'n_features': X_final.shape[1]
                })
                
                # Collect metadata for each window
                for i in range(len(X_windows)):
                    video_id_windows.append(tracking_path.stem)
                    category_windows.append(category)
                    mouse_id_windows.append(mouse_id)
                    window_start_frame_windows.append(window_starts[i])
            else:
                print(f"    Warning: No windows created for {tracking_path.name}")
                
        except Exception as e:
            print(f"    Error processing {tracking_path.name}: {e}")
            continue

# Combine all data
if all_X_windows:
    X_all = np.concatenate(all_X_windows, axis=0)
    y_all = np.concatenate(all_y_windows, axis=0)
    
    print("\n=== DATASET SUMMARY ===")
    print(f"Total windows: {len(X_all):,}")
    print(f"Window shape: {X_all.shape[1:]} (timesteps × features)")
    print(f"Classes: {classes}")
    class_counts = np.bincount(y_all, minlength=len(classes))
    for i, (cls, count) in enumerate(zip(classes, class_counts)):
        ratio = count / len(y_all)
        print(f"  {cls}: {count:,} ({ratio:.3f})")
    print(f"Files processed: {len(file_info)}")
    
    # Save processed data
    np.save(PROCESSED_ROOT / "X_windows.npy", X_all)
    np.save(PROCESSED_ROOT / "y_windows.npy", y_all)
    
    # Save metadata arrays
    np.save(PROCESSED_ROOT / "video_id_windows.npy", np.array(video_id_windows, dtype=object))
    np.save(PROCESSED_ROOT / "category_windows.npy", np.array(category_windows, dtype=object))
    np.save(PROCESSED_ROOT / "mouse_id_windows.npy", np.array(mouse_id_windows, dtype=np.int8))
    np.save(PROCESSED_ROOT / "window_start_frame_windows.npy", np.array(window_start_frame_windows, dtype=np.int32))
    
    # Save class mappings
    class_mappings = {
        'classes': classes,
        'class_to_id': class_to_id,
        'id_to_class': id_to_class,
        'priority_order': priority_order
    }
    with open(PROCESSED_ROOT / "class_mappings.json", 'w') as f:
        json.dump(class_mappings, f, indent=2)
    
    with open(PROCESSED_ROOT / "file_info.pkl", 'wb') as f:
        pickle.dump(file_info, f)
    
    print(f"\n✅ Data saved to {PROCESSED_ROOT}")
    print("✅ Class mappings saved to class_mappings.json")
    print("Ready for training with 03_training.ipynb")
else:
    print("❌ No data processed!")


Processing category: AdaptableSnail
  Processing: 1212811043.parquet
  Processing: 1260392287.parquet
  Processing: 1351098077.parquet
  Processing: 1408652858.parquet
  Processing: 143861384.parquet

Processing category: BoisterousParrot
  Processing: 1059582964.parquet
  Processing: 1184291605.parquet
  Processing: 1201849558.parquet
  Processing: 1459695188.parquet
  Processing: 1985626297.parquet

=== DATASET SUMMARY ===
Total windows: 16,395
Window shape: (200, 10) (timesteps × features)
Classes: ['chase', 'avoid', 'attack', 'chaseattack', 'none']
  chase: 22 (0.001)
  avoid: 52 (0.003)
  attack: 111 (0.007)
  chaseattack: 27 (0.002)
  none: 16,183 (0.987)
Files processed: 10

✅ Data saved to data\data_processed
✅ Class mappings saved to class_mappings.json
Ready for training with 03_training.ipynb


## Processing Summary

In [8]:
if all_X_windows:
    print("File breakdown:")
    for info in file_info:
        class_counts_str = ", ".join([f"{classes[i]}:{count}" for i, count in enumerate(info['class_counts']) if count > 0])
        print(f"  {info['category']}/{info['file']}: {info['n_windows']} windows [{class_counts_str}]")
    
    print(f"\nTotal: {sum(info['n_windows'] for info in file_info)} windows from {len(file_info)} files")

File breakdown:
  AdaptableSnail/1212811043.parquet: 367 windows [chase:14, avoid:26, attack:68, chaseattack:27, none:232]
  AdaptableSnail/1260392287.parquet: 269 windows [avoid:11, attack:1, none:257]
  AdaptableSnail/1351098077.parquet: 400 windows [avoid:3, attack:2, none:395]
  AdaptableSnail/1408652858.parquet: 92 windows [avoid:6, none:86]
  AdaptableSnail/143861384.parquet: 417 windows [chase:8, avoid:6, attack:40, none:363]
  BoisterousParrot/1059582964.parquet: 2970 windows [none:2970]
  BoisterousParrot/1184291605.parquet: 2970 windows [none:2970]
  BoisterousParrot/1201849558.parquet: 2970 windows [none:2970]
  BoisterousParrot/1459695188.parquet: 2970 windows [none:2970]
  BoisterousParrot/1985626297.parquet: 2970 windows [none:2970]

Total: 16395 windows from 10 files


## Sanity Check

In [9]:
# Sanity check for multi-class dataset
if 'y_all' in locals() and len(y_all) > 0:
    print("=== MULTI-CLASS DATASET SANITY CHECK ===")
    
    # Class distribution
    class_counts = np.bincount(y_all, minlength=len(classes))
    print(f"Class counts: {dict(zip(classes, class_counts))}")
    print(f"Class ratios: {dict(zip(classes, class_counts / len(y_all)))}")
    
    # Show some examples of non-"none" windows
    none_id = class_to_id["none"]
    non_none_indices = np.where(y_all != none_id)[0]
    
    if len(non_none_indices) > 0:
        print(f"\nExamples of windows with actions (showing first 5):")
        for i in non_none_indices[:5]:
            class_name = id_to_class[y_all[i]]
            print(f"  Window {i}: class '{class_name}'")
    else:
        print("\nNo windows with actions found in this dataset!")
    
    print(f"\nTotal windows: {len(y_all)}")
    print(f"Windows with actions: {len(non_none_indices)} ({len(non_none_indices)/len(y_all):.3f})")
else:
    print("No processed data available for sanity check")

=== MULTI-CLASS DATASET SANITY CHECK ===
Class counts: {'chase': np.int64(22), 'avoid': np.int64(52), 'attack': np.int64(111), 'chaseattack': np.int64(27), 'none': np.int64(16183)}
Class ratios: {'chase': np.float64(0.0013418725221103994), 'avoid': np.float64(0.0031716986886245807), 'attack': np.float64(0.00677035681610247), 'chaseattack': np.float64(0.001646843549862763), 'none': np.float64(0.9870692284232998)}

Examples of windows with actions (showing first 5):
  Window 0: class 'chase'
  Window 1: class 'chase'
  Window 4: class 'chase'
  Window 5: class 'chase'
  Window 7: class 'chase'

Total windows: 16395
Windows with actions: 212 (0.013)
