## ðŸ“˜ MABe: 2.5D Pose-based Social Action Detection â€” Inference Notebook

**Goal.** Generate `submission.csv` for the **MABe Challenge â€” Social Action Recognition in Mice** using a trained 2.5D EfficientNet model that consumes top-down 2D keypoints (pose) and predicts frame-wise probabilities for action classes, then converts them into action bouts.

**What this notebook does**

1. **Preprocess pose**

   * Drop head-mounted parts (e.g., `headpiece_*`).
   * Map body parts to grayscale intensities (mouse-ID Ã— body-part position) and rasterize per frame.
   * Build **2.5D clips** by stacking frames at fixed temporal **offsets** around a center frame.
2. **Model inference**

   * Uses a trained **EfficientNet (timm)** with custom `in_chans = len(OFFSETS)`; loads weights and a `class_to_idx` map.
   * Iterates test frames (optionally **subsampled** by `INFER_FRAME_STEP`), batches clips, and produces per-class probabilities.
3. **Post-process**

   * Threshold probabilities â†’ binary sequences â†’ merge consecutive frames into bouts.
   * Enforce **whitelists** from `behaviors_labeled` (only allowed agent/target/action triplets per video).
   * Drop invalid bouts and resolve overlaps consistently.

> ðŸ§­ Why this approach?  
> Many baselines convert pose into **tabular/hand-crafted features** for sequence models. Here we take the opposite route: **turn pose into images**, then apply a **2D CNN** over a **2.5D clip** (stack of offset frames). This keeps preprocessing simple, leverages CNN priors.

## Related tools & notebooks

- **Visualizing Mouse Pose Track** â€” quick look at keypoint tracks and how they rasterize  
  https://www.kaggle.com/code/kmatsu01/visualizing-mouse-pose-track

In [None]:

# ===================================================================
# 0. Libraries
# ===================================================================
import os, gc, re, json, warnings
from pathlib import Path
import pandas as pd
import numpy as np
from PIL import Image
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torchvision.transforms as T
import timm

warnings.filterwarnings('ignore')

# ===================================================================
# 1. Configuration
# ===================================================================
class CFG:
    # Base dataset path (Kaggle competition input)
    KAGGLE_INPUT_DIR = Path("/kaggle/input/MABe-mouse-behavior-detection")

    # <<< Replace with your uploaded weights and class map >>>
    MODEL_PATH = "/kaggle/input/2-5d-cnn-for-mabe/model_fold0_best.pth"
    CLASS_MAP_PATH = "/kaggle/input/2-5d-cnn-for-mabe/class_to_idx.json"

    # Backbone and input settings (must match training)
    MODEL_NAME = "efficientnet_b0"
    IMG_SIZE = 128
    BATCH_SIZE = 256
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

    # Post-processing
    # PREDICTION_THRESHOLD = 0.3
    PREDICTION_THRESHOLD = 0.05
    MIN_BOUT_LENGTH = 3

    # Frame subsampling for speed (e.g., use every 5th frame)
    # INFER_FRAME_STEP = 5
    INFER_FRAME_STEP = 2

    # Temporal offsets (2.5D channels); 
    OFFSETS = [-25, -15, -10, -5, -2, 0, 2, 5, 10, 15, 25]

print(f"Using device: {CFG.DEVICE}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

# ===================================================================
# 2. Utilities: pose cleaning and rasterization
# ===================================================================
def drop_body_parts(df: pd.DataFrame) -> pd.DataFrame:
    """Remove head-mounted parts and unify left/right suffix for bodypart."""
    df = df[~df['bodypart'].str.contains('headpiece')]
    df['bodypart'] = (
        df['bodypart']
        .str.replace('_left',  '', regex=False)
        .str.replace('_right', '', regex=False)
    )
    return df

# Ordered coarse body parts (for a simple grayscale mapping)
BODYPART_ORDER = ['nose','ear','neck','lateral','body_center','hip','tail_base','tail_midpoint','tail_tip']
_bodypart_pos_map = {p: i for i, p in enumerate(BODYPART_ORDER)}

def _gray(mouse_id: int, bodypart: str) -> tuple:
    """Legacy grayscale mapper used by the fallback PIL path."""
    base = {1: 0.9, 2: 0.7, 3: 0.5, 4: 0.3}.get(mouse_id, 0.1)
    rng = 0.1
    pos = next((_bodypart_pos_map[p] for p in BODYPART_ORDER if bodypart.startswith(p)), -1)
    ratio = 0.5 if pos == -1 else pos / (len(BODYPART_ORDER) - 1)
    v = np.clip(base - ratio * rng, 0.0, 1.0)
    return (v, v, v)

def keypoints_to_image_numpy(frame_df: pd.DataFrame, width: int, height: int, img_size: int) -> Image.Image:
    """Fallback (slower) single-frame rasterization via PIL; returns 1-channel image."""
    img = np.zeros((img_size, img_size), dtype=np.uint8)
    if frame_df is not None and len(frame_df) > 0:
        width  = width  if (width  is not None and width  > 0) else 1
        height = height if (height is not None and height > 0) else 1
        x = (frame_df['x'].values * (img_size / width)).astype(int)
        y = (frame_df['y'].values * (img_size / height)).astype(int)
        x = np.clip(x, 0, img_size - 1)
        y = np.clip(y, 0, img_size - 1)
        vals = [int(_gray(int(r['mouse_id']), r['bodypart'])[0] * 255) for _, r in frame_df.iterrows()]
        img[y, x] = vals
    return Image.fromarray(img, mode='L')

# ===== Fast path used in inference (matches training approach) =====
def _precompute_tracking_buffers(df: pd.DataFrame):
    """
    Precompute arrays for fast per-frame rasterization from a cleaned tracking DataFrame:
      - grayscale per row (uint8) from mouse-id and body-part position,
      - a mapping: video_frame -> row indices for that frame.
    """
    key = df['bodypart'].str.extract(
        r'^(nose|ear|neck|lateral|body_center|hip|tail_base|tail_midpoint|tail_tip)'
    )[0]
    pos = key.map(_bodypart_pos_map).fillna(-1).astype(np.int16).to_numpy()

    base = df['mouse_id'].map({1: 0.9, 2: 0.7, 3: 0.5, 4: 0.3}).fillna(0.1).astype(np.float32).to_numpy()
    ratio = np.where(pos >= 0, pos / (len(BODYPART_ORDER) - 1), 0.5).astype(np.float32)
    gray = np.clip(base - 0.1 * ratio, 0.0, 1.0)
    gray_u8 = (gray * 255.0).astype(np.uint8)

    groups = df.groupby('video_frame', sort=True).indices
    frame_to_rows = {int(k): v for k, v in groups.items()}

    return {
        'x': df['x'].to_numpy(np.float32),
        'y': df['y'].to_numpy(np.float32),
        'gray': gray_u8,
        'frame_to_rows': frame_to_rows
    }

def _rasterize_frame_tensor(buf: dict, frame: int, width: int, height: int, img_size: int) -> torch.Tensor:
    """
    Rasterize a single frame directly into a torch tensor (1, H, W) in [0, 1].
    Avoids PIL for speed; uses precomputed buffers.
    """
    img = np.zeros((img_size, img_size), dtype=np.uint8)
    idx = buf['frame_to_rows'].get(int(frame), None)
    if idx is not None and len(idx) > 0:
        sx = img_size / max(int(width), 1)
        sy = img_size / max(int(height), 1)
        x = (buf['x'][idx] * sx).astype(np.int32)
        y = (buf['y'][idx] * sy).astype(np.int32)
        x = np.clip(x, 0, img_size - 1)
        y = np.clip(y, 0, img_size - 1)
        img[y, x] = buf['gray'][idx]
    return torch.from_numpy(img).unsqueeze(0).to(dtype=torch.float32).div_(255.0)

def build_2p5d_clip(
    tracking_df_indexed,
    center_frame_idx: int,
    width: int,
    height: int,
    img_size: int,
    offsets = CFG.OFFSETS
) -> torch.Tensor:
    """
    Build a 2.5D clip by stacking frames at specified temporal offsets around a center frame.
    Accepts either:
      - dict buffer from _precompute_tracking_buffers (fast path), or
      - a DataFrame indexed by 'video_frame' (fallback path).
    Returns: torch.FloatTensor [C=len(offsets), H, W] with values in [0, 1].
    """
    imgs = []

    # Fast path (dict buffer)
    if isinstance(tracking_df_indexed, dict) and 'frame_to_rows' in tracking_df_indexed:
        buf = tracking_df_indexed
        for off in offsets:
            f = center_frame_idx + off
            t = _rasterize_frame_tensor(buf, f, width, height, img_size)  # (1, H, W)
            imgs.append(t)
        return torch.cat(imgs, dim=0)

    # Fallback path (DataFrame + PIL)
    for off in offsets:
        f = center_frame_idx + off
        if f in tracking_df_indexed.index:
            frame_df = tracking_df_indexed.loc[f]
            if isinstance(frame_df, pd.Series):
                frame_df = frame_df.to_frame().T
        else:
            frame_df = None  # empty frame
        pil = keypoints_to_image_numpy(frame_df, width, height, img_size)
        t = T.ToTensor()(pil)  # (1, H, W) in [0, 1]
        imgs.append(t)
    return torch.cat(imgs, dim=0)

# ===================================================================
# 3. Model: 2.5D EfficientNet
# ===================================================================
class MABeEfficientNet2p5D(nn.Module):
    """
    EfficientNet backbone with a custom input channel count (= len(OFFSETS)).
    The classifier head is replaced to match the number of classes.
    """
    def __init__(self, model_name, n_classes, pretrained=False, in_chans: int = None):
        super().__init__()
        if in_chans is None:
            in_chans = len(CFG.OFFSETS)
        m = timm.create_model(model_name, pretrained=pretrained, in_chans=in_chans)
        in_features = m.classifier.in_features
        m.classifier = nn.Linear(in_features, n_classes)
        self.model = m

    def forward(self, x):
        return self.model(x)

# ===================================================================
# 4. Label decoding and whitelist helpers
# ===================================================================
_PATTERNS = [
    re.compile(r'^(\d+)[_\-\.](\d+)[_\-\.](.+)$'),
    re.compile(r'^(\d+)[_\-\.]self[_\-\.](.+)$'),
    re.compile(r'^mouse(\d+)[_\-\.]mouse(\d+)[_\-\.](.+)$'),
    re.compile(r'^mouse(\d+)[_\-\.]self[_\-\.](.+)$'),
]

def decode_class_label(label: str):
    """
    Convert a class label into (agent_id, target_id, action) triples in 'mouse{n}' form.
    Accepts a few common label patterns used in training.
    """
    s = label.strip().lower()
    for pat in _PATTERNS:
        m = pat.match(s)
        if m:
            if len(m.groups()) == 3:
                a, t, act = m.groups()
                return f"mouse{int(a)}", f"mouse{int(t)}", act
            elif len(m.groups()) == 2:
                a, act = m.groups()
                return f"mouse{int(a)}", "self", act
    return None

def parse_behaviors_whitelist(behaviors_labeled_str: str):
    """
    Parse test.csv 'behaviors_labeled' into a set of lowercase (agent, target, action) triplets
    that are allowed for the video.
    """
    allow = set()
    try:
        items = json.loads(behaviors_labeled_str)
    except Exception:
        s = behaviors_labeled_str.strip()
        if s.startswith('[') and s.endswith(']'):
            items = json.loads(s)
        else:
            items = [x.strip() for x in s.split(',') if x.strip()]
    for it in items:
        it = it.replace("'", "").strip().lower()
        parts = [p.strip() for p in it.split(',')]
        if len(parts) == 3:
            agent, target, action = parts
            allow.add((agent, target, action))
    return allow

# ===================================================================
# 5. Post-processing: thresholding, boutization, and de-duplication
# ===================================================================
def post_process_and_submit(
    predictions: np.ndarray,
    video_id: int,
    idx_to_class: dict,
    threshold: float,
    min_len: int,
    whitelist: set,
    present_mice_nums: list,
    video_max_frame: int,
    frame_ids: np.ndarray
) -> pd.DataFrame:
    """
    Convert per-frame probabilities into valid action bouts:
      1) Decode class â†’ (agent, target, action)
      2) Mask by existing mice and whitelist
      3) Threshold and merge consecutive 1's
      4) Resolve overlaps by keeping the earliest interval
    Returns a DataFrame with (video_id, agent_id, target_id, action, start_frame, stop_frame).
    """
    def mouse_exists(agent_id: str) -> bool:
        if not agent_id.startswith('mouse'):
            return False
        try:
            m = int(agent_id[5:])
        except Exception:
            return False
        return m in present_mice_nums

    n_frames, n_classes = predictions.shape
    bouts = []

    for ci in range(n_classes):
        label = idx_to_class[ci]
        decoded = decode_class_label(label)
        if decoded is None:
            continue
        agent_id, target_id, action = decoded

        # Mouse presence checks
        if not mouse_exists(agent_id):
            continue
        if target_id != 'self' and not mouse_exists(target_id):
            continue

        # Enforce per-video whitelist
        if (agent_id.lower(), target_id.lower(), action.lower()) not in whitelist:
            continue

        prob = predictions[:, ci]
        bin_ = (prob > threshold).astype(np.int8)
        if bin_.sum() < min_len:
            continue

        diff = np.diff(np.concatenate(([0], bin_, [0])))
        starts = np.where(diff == 1)[0]
        stops  = np.where(diff == -1)[0] - 1

        for s_idx, e_idx in zip(starts, stops):
            if e_idx - s_idx + 1 >= min_len:
                s_frame = int(frame_ids[s_idx])
                e_frame = int(frame_ids[e_idx])
                s_frame = max(0, s_frame)
                e_frame = min(int(video_max_frame), e_frame)
                if e_frame > s_frame:
                    bouts.append({
                        'video_id': int(video_id),
                        'agent_id': agent_id,
                        'target_id': target_id,
                        'action': action,
                        'start_frame': s_frame,
                        'stop_frame': e_frame
                    })

    if not bouts:
        return pd.DataFrame(columns=['video_id','agent_id','target_id','action','start_frame','stop_frame'])

    # Within (agent, target): keep earliest interval if overlaps occur
    df = pd.DataFrame(bouts).sort_values(
        ['video_id','agent_id','target_id','start_frame','stop_frame','action']
    )
    cleaned_groups = []
    for (vid, ag, tg), g in df.groupby(['video_id','agent_id','target_id'], sort=False):
        g = g.sort_values(['start_frame','stop_frame']).reset_index(drop=True)
        keep_rows = []
        last_stop = -1
        for r in g.itertuples(index=False):
            if r.start_frame >= last_stop:
                keep_rows.append(r)
                last_stop = r.stop_frame
            else:
                # If overlapped, drop the later one (mirrors sample behavior)
                pass
        if keep_rows:
            cleaned_groups.append(pd.DataFrame(keep_rows))
    if cleaned_groups:
        out = pd.concat(cleaned_groups, ignore_index=True)
        out.columns = ['video_id','agent_id','target_id','action','start_frame','stop_frame']
    else:
        out = pd.DataFrame(columns=['video_id','agent_id','target_id','action','start_frame','stop_frame'])

    out = out.dropna()
    out = out[(out['stop_frame'] > out['start_frame'])].copy()
    return out

# ===================================================================
# 6. Inference loop
# ===================================================================
print(f"Starting inference on {CFG.DEVICE}...")

test_meta_df = pd.read_csv(CFG.KAGGLE_INPUT_DIR / "test.csv")

def present_mice_nums_from_row(row: pd.Series):
    """Infer which mouse indices (1..4) exist in the video from metadata."""
    nums = []
    for i in [1, 2, 3, 4]:
        if not pd.isna(row.get(f'mouse{i}_strain', np.nan)):
            nums.append(i)
    return nums

# Load class map (class -> index), then invert
with open(CFG.CLASS_MAP_PATH, 'r') as f:
    class_to_idx = json.load(f)
idx_to_class = {v: k for k, v in class_to_idx.items()}
n_classes = len(idx_to_class)

# Build model and load weights
model = MABeEfficientNet2p5D(
    CFG.MODEL_NAME, n_classes=n_classes, pretrained=False, in_chans=len(CFG.OFFSETS)
).to(CFG.DEVICE)
model.load_state_dict(torch.load(CFG.MODEL_PATH, map_location=CFG.DEVICE))
model.eval()

# Normalization (must match training)
norm_transform = T.Normalize(mean=[0.5] * len(CFG.OFFSETS), std=[0.5] * len(CFG.OFFSETS))

submissions = []

for _, row in tqdm(test_meta_df.iterrows(), total=len(test_meta_df), desc="Inferencing test videos"):
    video_id = int(row['video_id'])
    lab_id   = str(row['lab_id'])

    # Tracking file path (some datasets may omit the lab subfolder)
    fpath = CFG.KAGGLE_INPUT_DIR / "test_tracking" / lab_id / f"{video_id}.parquet"
    if not fpath.exists():
        fpath = CFG.KAGGLE_INPUT_DIR / "test_tracking" / f"{video_id}.parquet"
    if not fpath.exists():
        print(f"[WARN] file not found: {video_id}")
        continue

    try:
        tracking_df = pd.read_parquet(fpath)
        tracking_df = drop_body_parts(tracking_df)

        # Precompute fast buffers for rasterization
        tracking_buf = _precompute_tracking_buffers(tracking_df)

        # Image size in pixels (fallback to 1 to avoid division by zero)
        width_val  = row.get('video_width_pix',  np.nan)
        height_val = row.get('video_height_pix', np.nan)
        width  = int(width_val)  if pd.notna(width_val)  and width_val  else 1
        height = int(height_val) if pd.notna(height_val) and height_val else 1

        # Frame IDs present in the file, optionally subsampled for speed
        all_frames_full = np.array(sorted(tracking_buf['frame_to_rows'].keys()))
        if len(all_frames_full) == 0:
            continue
        frame_ids = all_frames_full[::CFG.INFER_FRAME_STEP] if CFG.INFER_FRAME_STEP > 1 else all_frames_full
        if len(frame_ids) == 0:
            frame_ids = all_frames_full[[-1]]

        # For clipping bounds and final validation
        video_max_frame = int(all_frames_full.max())

        whitelist = parse_behaviors_whitelist(row['behaviors_labeled'])
        present_mice_nums = present_mice_nums_from_row(row)

        # Run model
        preds_list = []
        with torch.no_grad():
            for i in range(0, len(frame_ids), CFG.BATCH_SIZE):
                frames_batch = frame_ids[i:i + CFG.BATCH_SIZE]
                clips = []
                for fr in frames_batch:
                    clip = build_2p5d_clip(
                        tracking_df_indexed=tracking_buf,
                        center_frame_idx=int(fr),
                        width=width, height=height,
                        img_size=CFG.IMG_SIZE,
                        offsets=CFG.OFFSETS
                    )  # (C, H, W) in [0, 1]
                    clips.append(clip)
                if not clips:
                    continue
                x = torch.stack(clips).to(CFG.DEVICE, non_blocking=True)  # (B, C, H, W)
                x = norm_transform(x)
                logits = model(x)
                probs = torch.sigmoid(logits).cpu().numpy()
                preds_list.append(probs)

        # Build per-video submission rows
        if len(preds_list) == 0:
            video_df = pd.DataFrame(columns=['video_id','agent_id','target_id','action','start_frame','stop_frame'])
        else:
            preds = np.concatenate(preds_list, axis=0)  # (len(frame_ids), n_classes)
            video_df = post_process_and_submit(
                preds, video_id, idx_to_class,
                CFG.PREDICTION_THRESHOLD, CFG.MIN_BOUT_LENGTH,
                whitelist, present_mice_nums, video_max_frame,
                frame_ids=frame_ids
            )

        if not video_df.empty:
            submissions.append(video_df)

        # Cleanup
        del tracking_df, tracking_buf, preds_list
        gc.collect()

    except Exception as e:
        print(f"[ERROR] {video_id} :: {e}")

# ===================================================================
# 7. Save submission.csv (schema strictly matches sample)
# ===================================================================
if len(submissions) > 0:
    submission_df = pd.concat(submissions, ignore_index=True)
else:
    submission_df = pd.DataFrame(columns=['video_id','agent_id','target_id','action','start_frame','stop_frame'])

# Column order and dtypes
submission_df = submission_df[['video_id','agent_id','target_id','action','start_frame','stop_frame']].copy()
submission_df['video_id'] = submission_df['video_id'].astype(int)
submission_df['agent_id'] = submission_df['agent_id'].astype(str)
submission_df['target_id'] = submission_df['target_id'].astype(str)
submission_df['action'] = submission_df['action'].astype(str)
submission_df['start_frame'] = submission_df['start_frame'].astype(int)
submission_df['stop_frame']  = submission_df['stop_frame'].astype(int)

# Keep only valid intervals (start < stop)
submission_df = submission_df[submission_df['stop_frame'] > submission_df['start_frame']].reset_index(drop=True)

# Write with index named 'row_id' (as in the sample)
submission_df.index.name = 'row_id'
submission_df.to_csv('submission.csv')

print("\nsubmission.csv created successfully")
print("Shape (rows, cols):", submission_df.shape)
print(submission_df.head())