# MABe Mouse Behavior - Optimized Complete Solution

## Full Pipeline:
```
Phase 1: Lightweight Transformer
  ├─ Load top 120 labeled videos
  ├─ Extract sequences (aggressive subsampling)
  ├─ Train tiny transformer (5 epochs)
  └─ Save checkpoint

Phase 2: Feature Extraction
  ├─ Extract transformer embeddings
  ├─ Compute essential handcrafted features
  └─ Cache combined features

Phase 3: XGBoost Training
  ├─ Load features
  ├─ Train per action
  └─ Fast threshold optimization

Phase 4: Test Inference
  ├─ Extract test features
  ├─ Predict
  └─ Create submission
```

In [None]:
import numpy as np
import pandas as pd
import json, gc, warnings, itertools, pickle, os, time
from pathlib import Path
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import StratifiedGroupKFold
from sklearn.metrics import f1_score
from xgboost import XGBClassifier
import optuna

warnings.filterwarnings('ignore')
optuna.logging.set_verbosity(optuna.logging.WARNING)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
if os.path.exists('/kaggle/input'):
    input_dirs = [str(d) for d in Path('/kaggle/input').iterdir() if d.is_dir()]
    DATA_DIR = next((d for d in input_dirs if 'mabe' in d.lower()), "/kaggle/input/mabe-mouse-behavior-detection")
else:
    DATA_DIR = "data/mabe-mouse-behavior-detection"

print(f"Data: {DATA_DIR}")

class Config:
    TRAIN_PATH = f"{DATA_DIR}/train.csv"
    TEST_PATH = f"{DATA_DIR}/test.csv"
    TRAIN_ANNOTATION_DIR = f"{DATA_DIR}/train_annotation"
    TRAIN_TRACKING_DIR = f"{DATA_DIR}/train_tracking"
    TEST_TRACKING_DIR = f"{DATA_DIR}/test_tracking"
    
    # Checkpoints
    TRANSFORMER_PATH = "transformer.pth"
    FEATURES_PATH = "features.pkl"
    
    # Phase 1: Transformer
    SEQ_LEN = 30
    EMBED_DIM = 64
    N_HEADS = 4
    N_LAYERS = 2
    DROPOUT = 0.1
    TRANSFORMER_EPOCHS = 8  # Increased from 5 for better training
    BATCH_SIZE = 256
    LR = 1e-3
    MAX_TRAIN_VIDEOS = 120  # Subset for transformer
    
    # Feature engineering
    FPS_DEFAULT = 30.0  # Default frames per second
    
    # Phase 2: Feature extraction
    MAX_FRAMES = 6000  # Per video
    STRIDE = 2
    KEY_PARTS = ['nose', 'body_center', 'tail_base', 'ear_left', 'ear_right']
    N_FEATURES = len(KEY_PARTS) * 2
    
    # Phase 3: XGBoost
    N_SPLITS = 2
    XGB_PARAMS = {
        'n_estimators': 150,
        'learning_rate': 0.1,
        'max_depth': 5,
        'subsample': 0.8,
        'colsample_bytree': 0.7,
        'tree_method': 'hist',
        'random_state': 42,
        'n_jobs': 4,
        'verbosity': 0
    }
    
    DROP_PARTS = ['headpiece_bottombackleft', 'headpiece_bottombackright',
                  'headpiece_bottomfrontleft', 'headpiece_bottomfrontright',
                  'headpiece_topbackleft', 'headpiece_topbackright',
                  'headpiece_topfrontleft', 'headpiece_topfrontright',
                  'spine_1', 'spine_2', 'tail_middle_1', 'tail_middle_2', 'tail_midpoint']

cfg = Config()

In [None]:
train = pd.read_csv(cfg.TRAIN_PATH)
test = pd.read_csv(cfg.TEST_PATH)
train_labeled = train[~train['lab_id'].str.startswith('MABe22')].reset_index(drop=True)

print(f"Labeled: {len(train_labeled)}, Test: {len(test)}")

# Phase 1: Transformer

In [None]:
# Time Series Transformer (based on "A Transformer-based Framework for Multivariate Time Series")
# Reference: Zerveas et al. 2020 - https://arxiv.org/abs/2010.02803

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        # x: [batch_size, seq_len, d_model]
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

class TSTransformerEncoder(nn.Module):
    """
    Time Series Transformer Encoder for Classification
    Based on the paper: "A Transformer-based Framework for Multivariate Time Series Representation Learning"
    """
    def __init__(self, n_features, d_model=128, n_heads=8, n_layers=3, 
                 dim_feedforward=256, dropout=0.1, max_len=1000):
        super().__init__()
        
        self.d_model = d_model
        
        self.project_inp = nn.Linear(n_features, d_model)
        
        self.pos_enc = PositionalEncoding(d_model, max_len=max_len, dropout=dropout)
        
        # Transformer encoder layers
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            activation='gelu',
            batch_first=True,
            norm_first=True  # Pre-LN for better training stability
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        
        self.output_layer = nn.Linear(d_model, 1)
        
        self.gap = nn.AdaptiveAvgPool1d(1)
        self.gmp = nn.AdaptiveMaxPool1d(1)
        
        self._init_weights()
    
    def _init_weights(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
    
    def forward(self, x, return_embedding=False):
        """
        Args:
            x: [batch_size, seq_len, n_features]
            return_embedding: if True, return embeddings instead of logits
        Returns:
            If return_embedding: [batch_size, d_model * 2] (concat of avg and max pooling)
            Else: [batch_size] logits
        """
        x = self.project_inp(x)  # [batch, seq_len, d_model]
        
        x = self.pos_enc(x)
        
        x = self.transformer_encoder(x)  # [batch, seq_len, d_model]
        
        if return_embedding:
            # Transpose for pooling: [batch, d_model, seq_len]
            x_t = x.transpose(1, 2)
            avg_pool = self.gap(x_t).squeeze(-1)  # [batch, d_model]
            max_pool = self.gmp(x_t).squeeze(-1)  # [batch, d_model]
            return torch.cat([avg_pool, max_pool], dim=1)  # [batch, d_model * 2]
        else:
            mid_idx = x.size(1) // 2
            x_mid = x[:, mid_idx, :]  # [batch, d_model]
            return self.output_layer(x_mid).squeeze(-1)  # [batch]

model = TSTransformerEncoder(
    n_features=cfg.N_FEATURES,
    d_model=128,
    n_heads=8,
    n_layers=3,
    dim_feedforward=512,
    dropout=0.1,
    max_len=cfg.SEQ_LEN * 2
).to(device)

print(f"Model: TSTransformerEncoder")
print(f"Params: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
def extract_pose(data, parts):
    feat = []
    for p in parts:
        if p in data.columns:
            feat.extend([data[p]['x'].values, data[p]['y'].values])
        else:
            feat.extend([np.zeros(len(data)), np.zeros(len(data))])
    arr = np.stack(feat, 1).astype(np.float32)
    
    arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
    
    mean = arr.mean(axis=0, keepdims=True)
    std = arr.std(axis=0, keepdims=True) + 1e-6
    arr = (arr - mean) / std
    
    return arr

def make_seqs(feat, lab, slen, stride):
    seqs, tgts = [], []
    for i in range(0, len(feat)-slen, stride):
        seqs.append(feat[i:i+slen])
        tgts.append(lab[i+slen//2])
    return np.array(seqs), np.array(tgts)

class SeqData(Dataset):
    def __init__(self, seqs, labs):
        self.seqs = torch.FloatTensor(seqs)
        self.labs = torch.FloatTensor(labs)
    def __len__(self):
        return len(self.seqs)
    def __getitem__(self, i):
        return self.seqs[i], self.labs[i]

In [None]:
# FPS-AWARE FEATURE ENGINEERING UTILITIES

def _scale(n_frames_at_30fps, fps, ref=30.0):
    """Scale window size based on FPS"""
    return max(1, int(round(n_frames_at_30fps * float(fps) / ref)))

def _scale_signed(n_frames_at_30fps, fps, ref=30.0):
    """Scale with sign preservation for offsets"""
    if n_frames_at_30fps == 0:
        return 0
    s = 1 if n_frames_at_30fps > 0 else -1
    mag = max(1, int(round(abs(n_frames_at_30fps) * float(fps) / ref)))
    return s * mag

def add_curvature_features(X, center_x, center_y, fps):
    """Add curvature and turn rate features"""
    vel_x = center_x.diff()
    vel_y = center_y.diff()
    acc_x = vel_x.diff()
    acc_y = vel_y.diff()

    cross_prod = vel_x * acc_y - vel_y * acc_x
    vel_mag = np.sqrt(vel_x**2 + vel_y**2)
    curvature = np.abs(cross_prod) / (vel_mag**3 + 1e-6)

    for w in [25, 50, 75]:
        ws = _scale(w, fps)
        X[f'curv_mean_{w}'] = curvature.rolling(ws, min_periods=max(1, ws // 5)).mean()

    angle = np.arctan2(vel_y, vel_x)
    angle_change = np.abs(angle.diff())
    w = 30
    ws = _scale(w, fps)
    X[f'turn_rate_{w}'] = angle_change.rolling(ws, min_periods=max(1, ws // 5)).sum()

    return X

def add_multiscale_features(X, center_x, center_y, fps):
    """Add multiscale speed analysis"""
    speed = np.sqrt(center_x.diff()**2 + center_y.diff()**2) * float(fps)

    scales = [20, 40, 60, 80]
    for scale in scales:
        ws = _scale(scale, fps)
        if len(speed) >= ws:
            X[f'sp_m{scale}'] = speed.rolling(ws, min_periods=max(1, ws // 4)).mean()
            X[f'sp_s{scale}'] = speed.rolling(ws, min_periods=max(1, ws // 4)).std()

    if len(scales) >= 2 and f'sp_m{scales[0]}' in X.columns and f'sp_m{scales[-1]}' in X.columns:
        X['sp_ratio'] = X[f'sp_m{scales[0]}'] / (X[f'sp_m{scales[-1]}'] + 1e-6)

    return X

def add_state_features(X, center_x, center_y, fps):
    """Add state-based features (speed binning)"""
    speed = np.sqrt(center_x.diff()**2 + center_y.diff()**2) * float(fps)
    w_ma = _scale(15, fps)
    speed_ma = speed.rolling(w_ma, min_periods=max(1, w_ma // 3)).mean()

    try:
        bins = [-np.inf, 0.5 * fps, 2.0 * fps, 5.0 * fps, np.inf]
        speed_states = pd.cut(speed_ma, bins=bins, labels=[0, 1, 2, 3]).astype(float)

        for window in [20, 40, 60, 80]:
            ws = _scale(window, fps)
            if len(speed_states) >= ws:
                for state in [0, 1, 2, 3]:
                    X[f's{state}_{window}'] = (
                        (speed_states == state).astype(float)
                        .rolling(ws, min_periods=max(1, ws // 5)).mean()
                    )
                state_changes = (speed_states != speed_states.shift(1)).astype(float)
                X[f'trans_{window}'] = state_changes.rolling(ws, min_periods=max(1, ws // 5)).sum()
    except Exception:
        pass

    return X

def add_longrange_features(X, center_x, center_y, fps):
    """Add long-range temporal features"""
    for window in [30, 60, 120]:
        ws = _scale(window, fps)
        if len(center_x) >= ws:
            X[f'x_ml{window}'] = center_x.rolling(ws, min_periods=max(5, ws // 6)).mean()
            X[f'y_ml{window}'] = center_y.rolling(ws, min_periods=max(5, ws // 6)).mean()

    for span in [30, 60, 120]:
        s = _scale(span, fps)
        X[f'x_e{span}'] = center_x.ewm(span=s, min_periods=1).mean()
        X[f'y_e{span}'] = center_y.ewm(span=s, min_periods=1).mean()

    speed = np.sqrt(center_x.diff()**2 + center_y.diff()**2) * float(fps)
    for window in [30, 60, 120]:
        ws = _scale(window, fps)
        if len(speed) >= ws:
            X[f'sp_pct{window}'] = speed.rolling(ws, min_periods=max(5, ws // 6)).rank(pct=True)

    return X

def add_interaction_features(X, mouse_pair, avail_A, avail_B, fps):
    """Add pair interaction features"""
    if 'body_center' not in avail_A or 'body_center' not in avail_B:
        return X

    rel_x = mouse_pair['A']['body_center']['x'] - mouse_pair['B']['body_center']['x']
    rel_y = mouse_pair['A']['body_center']['y'] - mouse_pair['B']['body_center']['y']
    rel_dist = np.sqrt(rel_x**2 + rel_y**2)

    A_vx = mouse_pair['A']['body_center']['x'].diff()
    A_vy = mouse_pair['A']['body_center']['y'].diff()
    B_vx = mouse_pair['B']['body_center']['x'].diff()
    B_vy = mouse_pair['B']['body_center']['y'].diff()

    A_lead = (A_vx * rel_x + A_vy * rel_y) / (np.sqrt(A_vx**2 + A_vy**2) * rel_dist + 1e-6)
    B_lead = (B_vx * (-rel_x) + B_vy * (-rel_y)) / (np.sqrt(B_vx**2 + B_vy**2) * rel_dist + 1e-6)

    for window in [30, 60]:
        ws = _scale(window, fps)
        X[f'A_ld{window}'] = A_lead.rolling(ws, min_periods=max(1, ws // 6)).mean()
        X[f'B_ld{window}'] = B_lead.rolling(ws, min_periods=max(1, ws // 6)).mean()

    approach = -rel_dist.diff()
    chase = approach * B_lead
    w = 30
    ws = _scale(w, fps)
    X[f'chase_{w}'] = chase.rolling(ws, min_periods=max(1, ws // 6)).mean()

    for window in [60, 120]:
        ws = _scale(window, fps)
        A_sp = np.sqrt(A_vx**2 + A_vy**2)
        B_sp = np.sqrt(B_vx**2 + B_vy**2)
        X[f'sp_cor{window}'] = A_sp.rolling(ws, min_periods=max(1, ws // 6)).corr(B_sp)

    return X

In [None]:
# ENHANCED FEATURE EXTRACTION FUNCTIONS

def enhanced_hand_feat(data, fps=30.0):
    """
    Extract comprehensive handcrafted features with FPS awareness
    
    IMPORTANT: Always produces the same number of features regardless of missing body parts.
    Missing parts are filled with zeros to ensure consistent feature dimensions.
    
    Features: FIXED dimension (~120-150 features)
    - Basic speed (15, 30 frame windows)
    - Rolling statistics (5, 15, 30, 60 frame windows)
    - Distance pairs between all body parts
    - Body shape (elongation, body angle)
    - Speed with lags
    - Curvature and turn rate
    - Multiscale speed analysis
    - State-based features
    - Long-range temporal features
    - Nose-tail features
    - Ear features
    """
    n_frames = len(data)
    X = pd.DataFrame(index=data.index)
    available_parts = data.columns.get_level_values(0).unique()
    
    # Helper to get part data or zeros
    def get_part(part_name):
        if part_name in available_parts:
            return data[part_name]['x'], data[part_name]['y']
        else:
            return pd.Series(0.0, index=data.index), pd.Series(0.0, index=data.index)

    # Body center features (ALWAYS added, zeros if missing)
    cx, cy = get_part('body_center')
    
    # Basic speed features
    speed = np.sqrt(cx.diff()**2 + cy.diff()**2) * fps
    for w in [15, 30]:
        ws = _scale(w, fps)
        X[f's{w}'] = speed.rolling(ws, min_periods=1, center=True).mean()

    # Rolling statistics
    for w in [5, 15, 30, 60]:
        ws = _scale(w, fps)
        roll_params = dict(min_periods=1, center=True)
        X[f'cx_m{w}'] = cx.rolling(ws, **roll_params).mean()
        X[f'cy_m{w}'] = cy.rolling(ws, **roll_params).mean()
        X[f'cx_s{w}'] = cx.rolling(ws, **roll_params).std()
        X[f'cy_s{w}'] = cy.rolling(ws, **roll_params).std()
        X[f'x_rng{w}'] = cx.rolling(ws, **roll_params).max() - cx.rolling(ws, **roll_params).min()
        X[f'y_rng{w}'] = cy.rolling(ws, **roll_params).max() - cy.rolling(ws, **roll_params).min()

        # Displacement and activity
        X[f'disp{w}'] = np.sqrt(
            cx.diff().rolling(ws, min_periods=1).sum()**2 +
            cy.diff().rolling(ws, min_periods=1).sum()**2
        )
        X[f'act{w}'] = np.sqrt(
            cx.diff().rolling(ws, min_periods=1).var() +
            cy.diff().rolling(ws, min_periods=1).var()
        )

    # Advanced features
    X = add_curvature_features(X, cx, cy, fps)
    X = add_multiscale_features(X, cx, cy, fps)
    X = add_state_features(X, cx, cy, fps)
    X = add_longrange_features(X, cx, cy, fps)

    # Distance pairs between ALL body parts (ALWAYS added)
    for p1, p2 in itertools.combinations(cfg.KEY_PARTS, 2):
        p1x, p1y = get_part(p1)
        p2x, p2y = get_part(p2)
        X[f"{p1}+{p2}"] = np.sqrt((p1x - p2x)**2 + (p1y - p2y)**2)

    # Body shape features (ALWAYS added)
    nose_x, nose_y = get_part('nose')
    tail_x, tail_y = get_part('tail_base')
    
    v1x = nose_x - cx
    v1y = nose_y - cy
    v2x = tail_x - cx
    v2y = tail_y - cy
    X['body_ang'] = (v1x * v2x + v1y * v2y) / (
        np.sqrt(v1x**2 + v1y**2) * np.sqrt(v2x**2 + v2y**2) + 1e-6
    )

    # Elongation (ALWAYS added)
    nose_tail_dist = np.sqrt((nose_x - tail_x)**2 + (nose_y - tail_y)**2)
    ear_l_x, ear_l_y = get_part('ear_left')
    ear_r_x, ear_r_y = get_part('ear_right')
    ear_dist = np.sqrt((ear_l_x - ear_r_x)**2 + (ear_l_y - ear_r_y)**2)
    X['elong'] = nose_tail_dist / (ear_dist + 1e-6)

    # Speed with lags (ALWAYS added)
    lag = _scale(10, fps)
    X['sp_lf'] = np.sqrt((ear_l_x - ear_l_x.shift(lag))**2 + 
                         (ear_l_y - ear_l_y.shift(lag))**2)
    X['sp_rt'] = np.sqrt((ear_r_x - ear_r_x.shift(lag))**2 + 
                         (ear_r_y - ear_r_y.shift(lag))**2)
    X['sp_lf2'] = np.sqrt((ear_l_x - tail_x.shift(lag))**2 + 
                          (ear_l_y - tail_y.shift(lag))**2)
    X['sp_rt2'] = np.sqrt((ear_r_x - tail_x.shift(lag))**2 + 
                          (ear_r_y - tail_y.shift(lag))**2)

    # Nose-tail features with multiple lags (ALWAYS added)
    nt_dist = nose_tail_dist
    for lag in [10, 20, 40]:
        l = _scale(lag, fps)
        X[f'nt_lg{lag}'] = nt_dist.shift(l)
        X[f'nt_df{lag}'] = nt_dist - nt_dist.shift(l)

    # Ear features (ALWAYS added)
    ear_d = ear_dist
    for off in [-30, -20, -10, 10, 20, 30]:
        o = _scale_signed(off, fps)
        X[f'ear_o{off}'] = ear_d.shift(-o)
    w = _scale(30, fps)
    X['ear_con'] = ear_d.rolling(w, min_periods=1, center=True).std() / \
                   (ear_d.rolling(w, min_periods=1, center=True).mean() + 1e-6)

    return X.fillna(0).astype(np.float32)


def extract_pair_features(pvid, agent_id, target_id, fps=30.0):
    """
    Extract features for mouse pair interactions
    
    Includes: ~100-150 pair-specific features
    - Cross-mouse distances for all body part pairs
    - Relative orientation
    - Approach/retreat dynamics
    - Distance categories (very close, close, medium, far)
    - Distance statistics over multiple windows
    - Interaction intensity
    - Velocity alignment and coordination
    - Leader/follower dynamics
    - Chase behavior
    - Nose-to-nose features
    - Velocity alignment at multiple offsets
    """
    if agent_id not in pvid.columns.get_level_values('mouse_id') or \
       target_id not in pvid.columns.get_level_values('mouse_id'):
        return None

    agent_data = pvid.loc[:, agent_id]
    target_data = pvid.loc[:, target_id]

    X = pd.DataFrame(index=pvid.index)

    # Inter-mouse distances for all body part pairs
    for p1 in cfg.KEY_PARTS:
        for p2 in cfg.KEY_PARTS:
            if p1 in agent_data.columns and p2 in target_data.columns:
                X[f"12+{p1}+{p2}"] = np.sqrt(
                    (agent_data[p1]['x'] - target_data[p2]['x'])**2 +
                    (agent_data[p1]['y'] - target_data[p2]['y'])**2
                )

    # Relative orientation
    if all(p in agent_data.columns for p in ['nose', 'tail_base']) and \
       all(p in target_data.columns for p in ['nose', 'tail_base']):
        dir_A = agent_data['nose'] - agent_data['tail_base']
        dir_B = target_data['nose'] - target_data['tail_base']
        X['rel_ori'] = (dir_A['x'] * dir_B['x'] + dir_A['y'] * dir_B['y']) / (
            np.sqrt(dir_A['x']**2 + dir_A['y']**2) *
            np.sqrt(dir_B['x']**2 + dir_B['y']**2) + 1e-6
        )

    # Approach/retreat
    if 'nose' in agent_data.columns and 'nose' in target_data.columns:
        current_dist = np.sqrt(
            (agent_data['nose']['x'] - target_data['nose']['x'])**2 +
            (agent_data['nose']['y'] - target_data['nose']['y'])**2
        )
        lag = _scale(10, fps)
        past_dist = current_dist.shift(lag)
        X['appr'] = past_dist - current_dist  # Positive = approaching

    # Distance statistics and categories
    if 'body_center' in agent_data.columns and 'body_center' in target_data.columns:
        center_dist = np.sqrt(
            (agent_data['body_center']['x'] - target_data['body_center']['x'])**2 +
            (agent_data['body_center']['y'] - target_data['body_center']['y'])**2
        )

        # Distance categories
        X['v_cls'] = (center_dist < 5.0).astype(float)
        X['cls'] = ((center_dist >= 5.0) & (center_dist < 15.0)).astype(float)
        X['med'] = ((center_dist >= 15.0) & (center_dist < 30.0)).astype(float)
        X['far'] = (center_dist >= 30.0).astype(float)

        # Distance rolling statistics
        for w in [5, 15, 30, 60]:
            ws = _scale(w, fps)
            roll_params = dict(min_periods=1, center=True)
            X[f'd_m{w}'] = center_dist.rolling(ws, **roll_params).mean()
            X[f'd_s{w}'] = center_dist.rolling(ws, **roll_params).std()
            X[f'd_mn{w}'] = center_dist.rolling(ws, **roll_params).min()
            X[f'd_mx{w}'] = center_dist.rolling(ws, **roll_params).max()

            # Interaction intensity
            d_var = center_dist.rolling(ws, **roll_params).var()
            X[f'int{w}'] = 1 / (1 + d_var)

            # Coordination (velocity dot product)
            Axd = agent_data['body_center']['x'].diff()
            Ayd = agent_data['body_center']['y'].diff()
            Bxd = target_data['body_center']['x'].diff()
            Byd = target_data['body_center']['y'].diff()
            coord = Axd * Bxd + Ayd * Byd
            X[f'co_m{w}'] = coord.rolling(ws, **roll_params).mean()
            X[f'co_s{w}'] = coord.rolling(ws, **roll_params).std()

        # Add interaction features (chase, leader/follower)
        X = add_interaction_features(X, {'A': agent_data, 'B': target_data},
                                    agent_data.columns.get_level_values(0).unique(),
                                    target_data.columns.get_level_values(0).unique(),
                                    fps)

    # Nose-to-nose features
    if 'nose' in agent_data.columns and 'nose' in target_data.columns:
        nn = np.sqrt((agent_data['nose']['x'] - target_data['nose']['x'])**2 +
                     (agent_data['nose']['y'] - target_data['nose']['y'])**2)
        for lag in [10, 20, 40]:
            l = _scale(lag, fps)
            X[f'nn_lg{lag}'] = nn.shift(l)
            X[f'nn_ch{lag}'] = nn - nn.shift(l)
            is_cl = (nn < 10.0).astype(float)
            X[f'cl_ps{lag}'] = is_cl.rolling(l, min_periods=1).mean()

    # Velocity alignment at multiple offsets
    if 'body_center' in agent_data.columns and 'body_center' in target_data.columns:
        Avx = agent_data['body_center']['x'].diff()
        Avy = agent_data['body_center']['y'].diff()
        Bvx = target_data['body_center']['x'].diff()
        Bvy = target_data['body_center']['y'].diff()
        val = (Avx * Bvx + Avy * Bvy) / (np.sqrt(Avx**2 + Avy**2) * np.sqrt(Bvx**2 + Bvy**2) + 1e-6)

        for off in [-30, -20, -10, 0, 10, 20, 30]:
            o = _scale_signed(off, fps)
            X[f'va_{off}'] = val.shift(-o)

        w = _scale(30, fps)
        center_dist_sq = (agent_data['body_center']['x'] - target_data['body_center']['x'])**2 + \
                         (agent_data['body_center']['y'] - target_data['body_center']['y'])**2
        X['int_con'] = center_dist_sq.rolling(w, min_periods=1, center=True).std() / \
                       (center_dist_sq.rolling(w, min_periods=1, center=True).mean() + 1e-6)

    return X.fillna(0).astype(np.float32)

In [None]:
print("\n" + "="*80)
print("PHASE 1: TRANSFORMER TRAINING")
print("="*80)
t1 = time.time()

all_seq, all_lab = [], []
sample = train_labeled.sample(n=min(cfg.MAX_TRAIN_VIDEOS, len(train_labeled)), random_state=42)

for _, row in tqdm(sample.iterrows(), total=len(sample), desc="Loading"):
    if type(row.behaviors_labeled) != str:
        continue
    try:
        vid = pd.read_parquet(f"{cfg.TRAIN_TRACKING_DIR}/{row.lab_id}/{row.video_id}.parquet")
        if len(vid.bodypart.unique()) > 5:
            vid = vid[~vid.bodypart.isin(cfg.DROP_PARTS)]
        pvid = vid.pivot(columns=['mouse_id','bodypart'], index='video_frame', values=['x','y'])
        pvid = pvid.reorder_levels([1,2,0], axis=1).T.sort_index().T / row.pix_per_cm_approx
        
        annot = pd.read_parquet(f"{cfg.TRAIN_ANNOTATION_DIR}/{row.lab_id}/{row.video_id}.parquet")
        
        if 1 not in pvid.columns.get_level_values('mouse_id'):
            continue
        md = pvid.loc[:, 1]
        
        if len(md) > cfg.MAX_FRAMES:
            md = md.iloc[::cfg.STRIDE]
        
        feat = extract_pose(md, cfg.KEY_PARTS)
        lab = np.zeros(len(md))
        
        for _, a in annot[(annot.agent_id==1) & (annot.target_id==1)].iterrows():
            s, e = a.start_frame, a.stop_frame
            if len(md) > cfg.MAX_FRAMES:
                s, e = s//cfg.STRIDE, e//cfg.STRIDE
            if s < len(lab) and e <= len(lab):
                lab[s:e] = 1.0
        
        seqs, tgts = make_seqs(feat, lab, cfg.SEQ_LEN, cfg.SEQ_LEN//2)
        if len(seqs) > 0:
            all_seq.append(seqs)
            all_lab.append(tgts)
        del vid, pvid
        gc.collect()
    except:
        continue

X = np.concatenate(all_seq)
y = np.concatenate(all_lab)
print(f"\nSequences: {len(X):,}, Pos: {y.mean():.2%}")

ds = SeqData(X, y)
loader = DataLoader(ds, batch_size=cfg.BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)

pos_ratio = y.mean()
del all_seq, all_lab, X, y
gc.collect()

opt = torch.optim.AdamW(model.parameters(), lr=cfg.LR, weight_decay=0.01)

pos_weight = torch.tensor([(1 - pos_ratio) / pos_ratio]).to(device)
crit = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
scaler = torch.cuda.amp.GradScaler() if device.type == 'cuda' else None

print(f"Positive class weight: {pos_weight.item():.2f}x")

for ep in range(cfg.TRANSFORMER_EPOCHS):
    model.train()
    loss_sum = 0
    correct = 0
    total = 0
    pos_pred = 0
    pos_true = 0
    
    for bx, by in tqdm(loader, desc=f"Epoch {ep+1}"):
        bx, by = bx.to(device), by.to(device)
        opt.zero_grad()
        if scaler:
            with torch.cuda.amp.autocast():
                out = model(bx, return_embedding=False)
                loss = crit(out, by)
            scaler.scale(loss).backward()
            scaler.unscale_(opt)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(opt)
            scaler.update()
        else:
            out = model(bx, return_embedding=False)
            loss = crit(out, by)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            opt.step()
        
        loss_sum += loss.item()
        # Track accuracy and predictions
        preds = (torch.sigmoid(out) > 0.5).float()
        correct += (preds == by).sum().item()
        total += by.size(0)
        pos_pred += preds.sum().item()
        pos_true += by.sum().item()
    
    acc = correct / total
    pred_rate = pos_pred / total
    true_rate = pos_true / total
    print(f"Epoch {ep+1} Loss: {loss_sum/len(loader):.4f}, Acc: {acc:.4f}, "
          f"PredPos: {pred_rate:.2%}, TruePos: {true_rate:.2%}")

torch.save({
    'model': model.state_dict(),
    'd_model': model.d_model,
    'model_type': 'TSTransformerEncoder'
}, cfg.TRANSFORMER_PATH)
model.eval()
print(f"\nPhase 1: {(time.time()-t1)/60:.1f}min")
del loader, ds
gc.collect()
torch.cuda.empty_cache()

# Phase 2: Extract Features

In [None]:
print("\n" + "="*80)
print("PHASE 2: FEATURE EXTRACTION (Enhanced with FPS-Aware Features)")
print("="*80)
t2 = time.time()

# Load transformer
ckpt = torch.load(cfg.TRANSFORMER_PATH)
model = TSTransformerEncoder(
    n_features=cfg.N_FEATURES,
    d_model=ckpt['d_model'],
    n_heads=8,
    n_layers=3,
    dim_feedforward=512,
    dropout=0.1,
    max_len=cfg.SEQ_LEN * 2
).to(device)
model.load_state_dict(ckpt['model'])
model.eval()

# Build FPS lookup from metadata
fps_lookup = {}
for _, row in train_labeled.iterrows():
    if 'frames_per_second' in train_labeled.columns and pd.notnull(row.get('frames_per_second')):
        fps_lookup[row.video_id] = float(row['frames_per_second'])

print(f"FPS lookup: {len(fps_lookup)} videos with FPS metadata (default: {cfg.FPS_DEFAULT} FPS)")

@torch.no_grad()
def get_emb(data, model, slen):
    feat = extract_pose(data, cfg.KEY_PARTS)
    if len(feat) < slen:
        return None
    seqs = [feat[i:i+slen] for i in range(0, len(feat)-slen+1, slen//2)]
    seqs_t = torch.FloatTensor(seqs).to(device)
    embs = []
    for i in range(0, len(seqs_t), 64):
        embs.append(model(seqs_t[i:i+64], return_embedding=True).cpu().numpy())
    embs = np.concatenate(embs)
    frame_emb = np.zeros((len(feat), embs.shape[1]))
    for i, e in enumerate(embs):
        s, end = i*(slen//2), i*(slen//2)+slen
        frame_emb[s:end] = e
    return frame_emb

def proc_video(row, mode='train', expected_dim=None, fps_lookup=None):
    tdir = cfg.TRAIN_TRACKING_DIR if mode == 'train' else cfg.TEST_TRACKING_DIR
    try:
        vid = pd.read_parquet(f"{tdir}/{row.lab_id}/{row.video_id}.parquet")
        if len(vid.bodypart.unique()) > 5:
            vid = vid[~vid.bodypart.isin(cfg.DROP_PARTS)]
        pvid = vid.pivot(columns=['mouse_id','bodypart'], index='video_frame', values=['x','y'])
        pvid = pvid.reorder_levels([1,2,0], axis=1).T.sort_index().T / row.pix_per_cm_approx
        
        fps = cfg.FPS_DEFAULT
        if fps_lookup and row.video_id in fps_lookup:
            fps = fps_lookup[row.video_id]
        
        if 1 not in pvid.columns.get_level_values('mouse_id'):
            return None
        md = pvid.loc[:, 1]
        
        downsampled = len(md) > cfg.MAX_FRAMES
        if downsampled:
            md = md.iloc[::cfg.STRIDE].reset_index(drop=True)
        
        emb = get_emb(md, model, cfg.SEQ_LEN)
        if emb is None:
            return None
        
        hand = enhanced_hand_feat(md, fps=fps)
        combined = np.concatenate([emb, hand.values], axis=1)
        
        if expected_dim is not None:
            if combined.shape[1] < expected_dim:
                padding = np.zeros((combined.shape[0], expected_dim - combined.shape[1]))
                combined = np.concatenate([combined, padding], axis=1)
            elif combined.shape[1] > expected_dim:
                combined = combined[:, :expected_dim]
        
        if mode == 'train':
            annot = pd.read_parquet(f"{cfg.TRAIN_ANNOTATION_DIR}/{row.lab_id}/{row.video_id}.parquet")
            behav = json.loads(row.behaviors_labeled)
            behav = [b.replace("'","").split(',') for b in behav if '1,self' in b or '1,1' in b]
            actions = [b[2] for b in behav if b[0]=='mouse1' and b[1] in ['self','mouse1']]
            
            if len(actions) == 0:
                return None
            
            labels = pd.DataFrame(0.0, columns=actions, index=range(len(md)))
            
            for _, a in annot[(annot.agent_id==1) & (annot.target_id==1)].iterrows():
                if a.action in actions:
                    s, e = a.start_frame, a.stop_frame
                    
                    if downsampled:
                        s, e = s // cfg.STRIDE, e // cfg.STRIDE
                    
                    s = max(0, min(s, len(labels)))
                    e = max(0, min(e, len(labels)))
                    
                    if s < e and s < len(labels):
                        labels.loc[s:e-1, a.action] = 1.0
            
            return {'X': combined, 'y': labels, 'vid': row.video_id}
        return {'X': combined, 'vid': row.video_id}
    except Exception as ex:
        return None

print("Extracting training features with FPS-aware feature engineering...")
train_data = []
total_pos_count = 0
videos_with_labels = 0
feature_dims_list = []

for _, row in tqdm(train_labeled.iterrows(), total=len(train_labeled)):
    if type(row.behaviors_labeled) != str:
        continue
    res = proc_video(row, 'train', fps_lookup=fps_lookup)
    if res:
        train_data.append(res)
        feature_dims_list.append(res['X'].shape[1])
        pos_in_video = res['y'].sum().sum()
        if pos_in_video > 0:
            videos_with_labels += 1
        total_pos_count += pos_in_video

with open(cfg.FEATURES_PATH, 'wb') as f:
    pickle.dump(train_data, f)

if len(feature_dims_list) > 0:
    feature_dim = max(set(feature_dims_list), key=feature_dims_list.count)
    print(f"\nFeature dimension: {feature_dim} (256 transformer + ~{feature_dim-256} handcrafted)")
print(f"Features saved: {len(train_data)} videos")
print(f"Videos with positive labels: {videos_with_labels}/{len(train_data)}")
print(f"Total positive labels: {int(total_pos_count)}")
print(f"Phase 2: {(time.time()-t2)/60:.1f}min")
del model
gc.collect()
torch.cuda.empty_cache()

# Phase 3: XGBoost

In [None]:
print("\n" + "="*80)
print("PHASE 3: XGBOOST TRAINING")
print("="*80)
t3 = time.time()

with open(cfg.FEATURES_PATH, 'rb') as f:
    data = pickle.load(f)

all_actions = set()
for d in data:
    all_actions.update(d['y'].columns)
print(f"Actions: {sorted(all_actions)}")

feature_dims = [d['X'].shape[1] for d in data]
unique_dims = set(feature_dims)
if len(unique_dims) > 1:
    print(f"ERROR: Inconsistent feature dimensions found: {unique_dims}")
    print(f"Feature dimension counts: {[(dim, feature_dims.count(dim)) for dim in unique_dims]}")
    raise ValueError("Feature dimensions are inconsistent! Check enhanced_hand_feat() function.")

expected_dim = feature_dims[0]
print(f"Feature dimension: {expected_dim} (256 transformer + {expected_dim-256} handcrafted)")

models = {}
thresholds = {}

for action in sorted(all_actions):
    Xs, ys, vids = [], [], []
    for d in data:
        if action in d['y'].columns:
            Xs.append(d['X'])
            ys.append(d['y'][action].values)
            vids.extend([d['vid']]*len(d['X']))
    
    if len(Xs) == 0:
        continue
    
    X = np.vstack(Xs)
    y = np.concatenate(ys)
    vids = np.array(vids)
    
    if np.any(np.isnan(X)) or np.any(np.isinf(X)):
        print(f"{action}: WARNING - NaN/Inf detected in features, replacing with 0")
        X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
    
    pos_count = y.sum()
    pos_rate = y.mean()
    
    if pos_count < 10:
        print(f"{action}: Skip (only {int(pos_count)} positives)")
        continue
    
    print(f"{action}: {len(X):,} samples, {int(pos_count)} pos ({pos_rate:.2%})")
    
    cv = StratifiedGroupKFold(n_splits=cfg.N_SPLITS)
    oof = np.zeros(len(X))
    folds = []
    
    xgb_params = cfg.XGB_PARAMS.copy()
    xgb_params['scale_pos_weight'] = (1 - pos_rate) / pos_rate
    
    for fold_idx, (train_i, val_i) in enumerate(cv.split(X, y, vids)):
        m = XGBClassifier(**xgb_params)
        m.fit(X[train_i], y[train_i])
        oof[val_i] = m.predict_proba(X[val_i])[:, 1]
        folds.append(m)
    
    oof_min, oof_max, oof_mean = oof.min(), oof.max(), oof.mean()
    oof_unique = len(np.unique(oof))
    
    if oof_unique < 10:
        print(f"  WARNING: Only {oof_unique} unique OOF predictions!")
    
    if oof_max - oof_min < 0.01:
        print(f"  WARNING: OOF predictions have very low variance [{oof_min:.4f}, {oof_max:.4f}]")
    
    best_thr, best_f1 = 0.5, 0
    for thr in np.arange(0.05, 0.95, 0.01):
        preds = (oof >= thr).astype(int)
        f1 = f1_score(y, preds, zero_division=0)
        if f1 > best_f1:
            best_f1, best_thr = f1, thr
    
    models[action] = folds
    thresholds[action] = best_thr
    
    final_preds = (oof >= best_thr).astype(int)
    tp = ((final_preds == 1) & (y == 1)).sum()
    fp = ((final_preds == 1) & (y == 0)).sum()
    fn = ((final_preds == 0) & (y == 1)).sum()
    
    print(f"  F1={best_f1:.4f}, thr={best_thr:.3f}, OOF=[{oof_min:.3f},{oof_max:.3f}], "
          f"TP={tp}, FP={fp}, FN={fn}")

import pickle as pkl
with open('models.pkl', 'wb') as f:
    pkl.dump({'models': models, 'thresholds': thresholds, 'expected_dim': expected_dim}, f)

print(f"\nPhase 3: {(time.time()-t3)/60:.1f}min")

# Phase 4: Inference

In [None]:
print("\n" + "="*80)
print("PHASE 4: INFERENCE")
print("="*80)
t4 = time.time()

# Load transformer
ckpt = torch.load(cfg.TRANSFORMER_PATH)
model = TSTransformerEncoder(
    n_features=cfg.N_FEATURES,
    d_model=ckpt['d_model'],
    n_heads=8,
    n_layers=3,
    dim_feedforward=512,
    dropout=0.1,
    max_len=cfg.SEQ_LEN * 2
).to(device)
model.load_state_dict(ckpt['model'])
model.eval()

import pickle as pkl
with open('models.pkl', 'rb') as f:
    saved = pkl.load(f)
    models = saved['models']
    thresholds = saved['thresholds']
    expected_dim = saved['expected_dim']

print(f"Loaded models for {len(models)} actions, expected feature dim: {expected_dim}")

submissions = []

for _, row in tqdm(test.iterrows(), total=len(test), desc="Test"):
    if type(row.behaviors_labeled) != str:
        continue
    
    behav = json.loads(row.behaviors_labeled)
    behav_parsed = [b.replace("'","").split(',') for b in behav]
    
    required_behaviors = []
    for b in behav_parsed:
        if len(b) >= 3:
            agent = b[0]
            target = b[1]
            action = b[2]
            required_behaviors.append((agent, target, action))
    
    res = proc_video(row, 'test', expected_dim=expected_dim)
    
    if res is None:
        for agent_id, target_id, action in required_behaviors:
            submissions.append({
                'video_id': row.video_id,
                'agent_id': agent_id,
                'target_id': target_id,
                'action': action,
                'start_frame': 0,
                'stop_frame': 1
            })
        continue
    
    X_test = res['X']
    n_frames = len(X_test)
    
    agent_target_groups = {}
    for agent_id, target_id, action in required_behaviors:
        key = (agent_id, target_id)
        if key not in agent_target_groups:
            agent_target_groups[key] = []
        agent_target_groups[key].append(action)
    
    for (agent_id, target_id), actions in agent_target_groups.items():
        action_probs = {}
        
        for action in actions:
            if action in models:
                preds = [m.predict_proba(X_test)[:, 1] for m in models[action]]
                prob = np.mean(preds, axis=0)
                action_probs[action] = prob
            else:
                action_probs[action] = np.zeros(n_frames)
        
        frame_actions = []
        for i in range(n_frames):
            best_action = None
            best_prob = 0.0
            
            for action in actions:
                if action in action_probs:
                    prob = action_probs[action][i]
                    if action in thresholds and prob >= thresholds[action]:
                        if prob > best_prob:
                            best_prob = prob
                            best_action = action
            
            frame_actions.append(best_action)
        
        if len(frame_actions) > 0:
            current_action = frame_actions[0]
            start_frame = 0
            
            for i in range(1, len(frame_actions)):
                if frame_actions[i] != current_action:
                    if current_action is not None:
                        orig_start = start_frame * cfg.STRIDE if n_frames < 6000 else start_frame
                        orig_stop = i * cfg.STRIDE if n_frames < 6000 else i
                        
                        submissions.append({
                            'video_id': row.video_id,
                            'agent_id': agent_id,
                            'target_id': target_id,
                            'action': current_action,
                            'start_frame': orig_start,
                            'stop_frame': orig_stop
                        })
                    
                    current_action = frame_actions[i]
                    start_frame = i
            
            if current_action is not None:
                orig_start = start_frame * cfg.STRIDE if n_frames < 6000 else start_frame
                orig_stop = n_frames * cfg.STRIDE if n_frames < 6000 else n_frames
                
                submissions.append({
                    'video_id': row.video_id,
                    'agent_id': agent_id,
                    'target_id': target_id,
                    'action': current_action,
                    'start_frame': orig_start,
                    'stop_frame': orig_stop
                })

all_required = set()
for _, row in test.iterrows():
    if type(row.behaviors_labeled) != str:
        continue
    behav = json.loads(row.behaviors_labeled)
    behav_parsed = [b.replace("'","").split(',') for b in behav]
    for b in behav_parsed:
        if len(b) >= 3:
            all_required.add((row.video_id, b[0], b[1], b[2]))

submitted = set()
for sub in submissions:
    submitted.add((sub['video_id'], sub['agent_id'], sub['target_id'], sub['action']))

for vid, agent, target, action in all_required:
    if (vid, agent, target, action) not in submitted:
        submissions.append({
            'video_id': vid,
            'agent_id': agent,
            'target_id': target,
            'action': action,
            'start_frame': 0,
            'stop_frame': 1
        })

if len(submissions) > 0:
    sub = pd.DataFrame(submissions)
else:
    sub = pd.DataFrame([{
        'video_id': test.iloc[0].video_id,
        'agent_id': 'mouse1',
        'target_id': 'mouse2',
        'action': 'sniff',
        'start_frame': 0,
        'stop_frame': 1
    }])

sub = sub.sort_values(['video_id', 'agent_id', 'target_id', 'action', 'start_frame']).reset_index(drop=True)
sub.index.name = 'row_id'
sub.to_csv('submission.csv')

print(f"\nSubmission: {len(sub)} rows")
print(f"Unique videos: {sub.video_id.nunique()}")
print(f"Unique behaviors: {len(sub.groupby(['agent_id', 'target_id', 'action']))}")
print(f"Phase 4: {(time.time()-t4)/60:.1f}min")
print(f"\nTOTAL TIME: {(time.time()-t1)/60:.1f}min")