# Module 1: The Bio-Physics Data Engine

In [1]:
import torch
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
import pandas as pd
import numpy as np
from pathlib import Path
import ast
import random

# --- CONFIGURATION ---
LAB_CONFIGS = {
    "AdaptableSnail":       {"thresh": 718.59, "window": 120, "pix_cm": 14.5},
    "BoisterousParrot":     {"thresh": 50.93,  "window": 292, "pix_cm": 5.5},
    "CRIM13":               {"thresh": 207.95, "window": 117, "pix_cm": 14.5},
    "CalMS21_supplemental": {"thresh": 206.05, "window": 196, "pix_cm": 18.3},
    "CalMS21_task1":        {"thresh": 154.32, "window": 140, "pix_cm": 18.3},
    "CalMS21_task2":        {"thresh": 177.51, "window": 122, "pix_cm": 18.3},
    "CautiousGiraffe":      {"thresh": 119.97, "window": 67,  "pix_cm": 21.0},
    "DeliriousFly":         {"thresh": 97.31,  "window": 172, "pix_cm": 16.0},
    "ElegantMink":          {"thresh": 88.58,  "window": 391, "pix_cm": 18.4},
    "GroovyShrew":          {"thresh": 254.45, "window": 115, "pix_cm": 11.3},
    "InvincibleJellyfish":  {"thresh": 249.33, "window": 158, "pix_cm": 32.0},
    "JovialSwallow":        {"thresh": 99.68,  "window": 62,  "pix_cm": 15.3},
    "LyricalHare":          {"thresh": 198.80, "window": 361, "pix_cm": 10.9},
    "NiftyGoldfinch":       {"thresh": 303.02, "window": 78,  "pix_cm": 13.5},
    "PleasantMeerkat":      {"thresh": 150.58, "window": 32,  "pix_cm": 15.8},
    "ReflectiveManatee":    {"thresh": 117.76, "window": 97,  "pix_cm": 15.0},
    "SparklingTapir":       {"thresh": 281.60, "window": 252, "pix_cm": 40.0},
    "TranquilPanther":      {"thresh": 133.98, "window": 105, "pix_cm": 12.3},
    "UppityFerret":         {"thresh": 228.77, "window": 55,  "pix_cm": 12.7},
    "DEFAULT":              {"thresh": 150.00, "window": 128, "pix_cm": 15.0}
}

ACTION_LIST = sorted([
    "allogroom", "approach", "attack", "attemptmount", "avoid", "biteobject",
    "chase", "chaseattack", "climb", "defend", "dig", "disengage", "dominance",
    "dominancegroom", "dominancemount", "ejaculate", "escape", "exploreobject",
    "flinch", "follow", "freeze", "genitalgroom", "huddle", "intromit", "mount",
    "rear", "reciprocalsniff", "rest", "run", "selfgroom", "shepherd", "sniff",
    "sniffbody", "sniffface", "sniffgenital", "submit", "tussle"
])
ACTION_TO_IDX = {a: i for i, a in enumerate(ACTION_LIST)}
NUM_CLASSES = len(ACTION_LIST)
BODY_PARTS = [
    "ear_left", "ear_right", "nose", "neck", "body_center",
    "lateral_left", "lateral_right", "hip_left", "hip_right",
    "tail_base", "tail_tip"
]
PART_TO_IDX = {p: i for i, p in enumerate(BODY_PARTS)}

class BioPhysicsDataset(Dataset):
    def __init__(self, data_root, mode='train', video_ids=None):
        self.root = Path(data_root)
        self.mode = mode
        # Directory logic
        self.tracking_dir = self.root / f"{mode}_tracking"
        self.annot_dir = self.root / f"{mode}_annotation"
        
        # Load Metadata
        self.metadata = pd.read_csv(self.root / f"{mode}.csv")
        
        # Filter Video IDs (e.g., for train/val split)
        if video_ids is not None:
            self.metadata = self.metadata[self.metadata['video_id'].astype(str).isin(video_ids)]
        
        # Build samples from metadata DIRECTLY (Skip strict file check to avoid crash)
        self.samples = []
        for _, row in self.metadata.iterrows():
            self.samples.append({
                'video_id': str(row['video_id']),
                'lab_id': row['lab_id']
            })
            
        # Hardcoded Window
        self.local_window = 256
        self.max_global_tokens = 2048

        # Pre-scan for sampling
        self.action_windows = []
        if self.mode == 'train':
            self._scan_actions_safe()

    def _scan_actions_safe(self):
        # We try to find files. If not found, we skip optimization, but DO NOT CRASH.
        count = 0
        print("Scanning subset of annotations for sampling...")
        for i, s in enumerate(self.samples):
            if i > 500: break # Quick partial scan
            p = self.annot_dir / s['lab_id'] / f"{s['video_id']}.parquet"
            if p.exists():
                try:
                    df = pd.read_parquet(p)
                    # Find centers
                    df = df[df['action'].isin(ACTION_TO_IDX)]
                    if not df.empty:
                        for c in ((df['start_frame'] + df['stop_frame']) // 2).values:
                            self.action_windows.append((i, int(c)))
                            count += 1
                except: pass
        if count == 0:
            print("Warning: No actions scanned. Falling back to random sampling.")

    def _fix_teleport(self, pos):
        # pos: [T, 11, 2]
        T, N, _ = pos.shape
        # Identify holes
        missing = (np.abs(pos).sum(axis=2) < 1e-6)
        cleaned = pos.copy()
        for n in range(N):
            m = missing[:, n]
            if np.any(m) and not np.all(m):
                valid_t = np.where(~m)[0]
                missing_t = np.where(m)[0]
                cleaned[missing_t, n, 0] = np.interp(missing_t, valid_t, pos[valid_t, n, 0])
                cleaned[missing_t, n, 1] = np.interp(missing_t, valid_t, pos[valid_t, n, 1])
        return cleaned

    def _geo_feats(self, pos, other, pix_cm):
        # Simple geometric extractor
        # Normalize
        pos = pos / pix_cm
        other = other / pix_cm
        
        # Align
        origin = pos[:, 9:10, :] # Tail base
        centered = pos - origin
        other_centered = other - origin
        
        # Velocity
        vel = np.diff(centered, axis=0, prepend=centered[0:1])
        speed = np.sqrt((vel**2).sum(axis=-1))
        
        # Relation
        dist = np.sqrt(((pos - other)**2).sum(axis=-1))
        
        # Pack to 16
        # [Pos X, Pos Y, Vel X, Vel Y, Speed, Rel_Dist] + Pads
        feat = np.stack([
            centered[...,0], centered[...,1],
            vel[...,0], vel[...,1],
            speed, dist,
            np.zeros_like(speed), np.zeros_like(speed), # 7-8
            np.zeros_like(speed), np.zeros_like(speed),
            np.zeros_like(speed), np.zeros_like(speed),
            np.zeros_like(speed), np.zeros_like(speed),
            np.zeros_like(speed), np.zeros_like(speed),
        ], axis=-1)
        
        return feat.astype(np.float32)

    def _load(self, idx, center=None):
        sample = self.samples[idx]
        lab = sample['lab_id']
        conf = LAB_CONFIGS.get(lab, LAB_CONFIGS['DEFAULT'])
        
        # Try Loading Track
        raw_m1, raw_m2 = np.zeros((1,11,2)), np.zeros((1,11,2))
        
        fpath = self.tracking_dir / lab / f"{sample['video_id']}.parquet"
        
        # Load Success?
        success = False
        if fpath.exists():
            try:
                df = pd.read_parquet(fpath)
                mids = df['mouse_id'].unique()
                L = len(df)
                
                # Expand buffer
                raw_m1 = np.zeros((L, 11, 2), dtype=np.float32)
                raw_m2 = np.zeros((L, 11, 2), dtype=np.float32)
                
                m1_id = mids[0]
                m2_id = mids[1] if len(mids) > 1 else m1_id
                
                # Check Bodypart column
                if 'bodypart' in df.columns:
                    for i, bp in enumerate(BODY_PARTS):
                        d1 = df[(df['mouse_id']==m1_id) & (df['bodypart']==bp)][['x','y']].values
                        if len(d1)>0: raw_m1[:len(d1), i] = d1
                        
                        d2 = df[(df['mouse_id']==m2_id) & (df['bodypart']==bp)][['x','y']].values
                        if len(d2)>0: raw_m2[:len(d2), i] = d2
                else:
                    # Wide format check
                    for col in df.columns:
                        if "mouse1" in col:
                            # simplified parsing
                            pass 
                success = True
            except: pass
        
        if not success:
            # DUMMY DATA TO PREVENT CRASH
            # Returns a single frame of zeros
            L = self.local_window
            raw_m1 = np.zeros((L, 11, 2), dtype=np.float32)
            raw_m2 = np.zeros((L, 11, 2), dtype=np.float32)

        # 1. Teleport Fix
        raw_m1 = self._fix_teleport(raw_m1)
        raw_m2 = self._fix_teleport(raw_m2)
        
        # 2. Window
        seq_len = len(raw_m1)
        if center is None: center = random.randint(0, seq_len)
        s = max(0, min(center - self.local_window//2, seq_len - self.local_window))
        e = min(s + self.local_window, seq_len)
        
        idx_slice = np.arange(s, e)
        
        # 3. Features
        feats = self._geo_feats(raw_m1[idx_slice], raw_m2[idx_slice], conf['pix_cm'])
        
        # 4. Targets
        target = torch.zeros((self.local_window, NUM_CLASSES), dtype=torch.float32)
        weights = torch.zeros(self.local_window, dtype=torch.float32)
        
        # Pad
        if len(feats) < self.local_window:
            pad_n = self.local_window - len(feats)
            pad_f = np.zeros((pad_n, 11, 16), dtype=np.float32)
            feats = np.concatenate([feats, pad_f], axis=0)
            # Weights stay 0 at end
            weights[:len(idx_slice)] = 1.0
        else:
            weights[:] = 1.0

        if self.mode == 'train':
            ap = self.annot_dir / lab / f"{sample['video_id']}.parquet"
            if ap.exists():
                try:
                    adf = pd.read_parquet(ap)
                    for _, row in adf.iterrows():
                        if row['action'] in ACTION_TO_IDX:
                            st, et = int(row['start_frame'])-s, int(row['stop_frame'])-s
                            st, et = max(0, st), min(self.local_window, et)
                            if st < et: target[st:et, ACTION_TO_IDX[row['action']]] = 1.0
                except: pass
        
        lab_idx = list(LAB_CONFIGS.keys()).index(lab) if lab in LAB_CONFIGS else 0
        return torch.tensor(feats), torch.tensor(feats), target, weights, lab_idx

    def __getitem__(self, idx):
        if self.mode=='train' and random.random() < 0.9 and len(self.action_windows)>0:
            i, c = self.action_windows[random.randint(0, len(self.action_windows)-1)]
            return self._load(i, c)
        return self._load(idx)
    
    def __len__(self): return len(self.samples)

def pad_collate_dual(batch):
    gx, lx, t, w, lid = zip(*batch)
    return torch.stack(gx), torch.stack(lx), torch.stack(t), torch.stack(w), torch.tensor(lid)

# Module 2: The Morphological & Interaction Core.

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# ==============================================================================
# 1. CANONICAL GRAPH ADAPTER (Signal Refinement)
# ==============================================================================
class CanonicalGraphAdapter(nn.Module):
    # INPUT: [B, T, 11, 16] (Geometric Features)
    def __init__(self, input_nodes=11, canonical_nodes=11, feat_dim=16, num_labs=20):
        super().__init__()

        # Learnable Projection Matrix: (NumLabs, 11, 11)
        # Learns to map tracking artifacts to a canonical topology per lab
        self.projection = nn.Parameter(torch.eye(input_nodes).unsqueeze(0).repeat(num_labs, 1, 1))
        
        # Identity initialization with slight noise
        self.projection.data += torch.randn_like(self.projection) * 0.01

        # Lab-Specific Bias (Correction for systematic sensor offset)
        self.bias = nn.Parameter(torch.zeros(num_labs, 1, canonical_nodes, feat_dim))

        # Refinement MLP (Cleans physics calculations)
        self.refine = nn.Sequential(
            nn.Linear(feat_dim, feat_dim * 2),
            nn.LayerNorm(feat_dim * 2),
            nn.GELU(),
            nn.Linear(feat_dim * 2, feat_dim)
        )

    def forward(self, x, lab_idx):
        # x: (Batch, Time, 11, 16)
        # lab_idx: (Batch)
        b, t, n, f = x.shape

        # 1. Fetch Weights
        W = self.projection[lab_idx] # (B, 11, 11)
        B = self.bias[lab_idx]       # (B, 1, 11, 16)

        # 2. Graph Projection (Node Mixing)
        # We process all time-steps in parallel by flattening B*T
        x_flat = x.view(-1, n, f) # (B*T, 11, 16)
        
        # Prepare Projection Matrix: Expand to T, then view as (B*T, 11, 11)
        W_flat = W.unsqueeze(1).repeat(1, t, 1, 1).view(-1, n, n)

        # Apply Graph Projection: nodes^T * W
        # (B*T, 16, 11) @ (B*T, 11, 11) -> (B*T, 16, 11)
        x_t = x_flat.transpose(1, 2) 
        out = torch.bmm(x_t, W_flat) 

        # 3. Reshape Back & Apply Physics Refinement
        out = out.transpose(1, 2).view(b, t, n, f)
        out = out + B # Apply Bias
        out = self.refine(out)

        return out # (Batch, Time, 11, 16)

# ==============================================================================
# 2. SOCIAL INTERACTION BLOCK (Updated for Geo-Features)
# ==============================================================================
class SocialInteractionBlock(nn.Module):
    def __init__(self, node_dim=16, hidden_dim=64):
        super().__init__()

        # Relational MLP
        # Takes the pre-calc geometric relations from Module 1
        # [Rel_X, Rel_Y, Rel_Dist] + [Speed_Self, Speed_Other] (Derived)
        self.relational_mlp = nn.Sequential(
            nn.Linear(5, 32),
            nn.GELU(),
            nn.Linear(32, 16)
        )

        self.fusion = nn.Linear(node_dim * 2 + 16, hidden_dim)

    def forward(self, agent_canon, target_canon):
        # Input: [B, T, 11, 16] (Normalized Egocentric Features)
        
        # New Feature Map (Module 1):
        # 0: PosX, 1: PosY (Self)
        # 2: VelX, 3: VelY
        # 4: Neighbor PosX, 5: Neighbor PosY (Explicit Relation)
        # 6: Neighbor Dist
        
        # We extract Interaction Context from Node 0 (Body/Nose or Main Axis)
        # or aggregate across nodes. Here we take the mean interaction 
        # features across all nodes for stability.
        
        # 1. Extract Interaction Features (Ch 4, 5, 6)
        # Shape: [B, T, 3] (Mean over nodes)
        interaction_raw = agent_canon[..., 4:7].mean(dim=2) 
        
        # 2. Extract Dynamic Differences
        # Speed is typically computed in loader, but let's take velocity diffs (Ch 2,3)
        # Vel Self (Ch 2,3)
        vel_self = agent_canon[..., 2:4].mean(dim=2) 
        # Vel Other (Inferred/Proxy via target tensor)
        vel_targ = target_canon[..., 2:4].mean(dim=2)
        
        speed_diff = torch.norm(vel_self - vel_targ, dim=-1, keepdim=True)
        dot_prod = (vel_self * vel_targ).sum(dim=-1, keepdim=True)
        
        # Combine: [Ix, Iy, Dist, SpeedDiff, VelDot] -> 5 Dims
        rel_feats = torch.cat([interaction_raw, speed_diff, dot_prod], dim=-1)
        
        # Embed
        rel_embed = self.relational_mlp(rel_feats) # [B, T, 16]

        return agent_canon, target_canon, rel_embed

# ==============================================================================
# WRAPPER: MORPHOLOGICAL INTERACTION CORE
# ==============================================================================
class MorphologicalInteractionCore(nn.Module):
    def __init__(self, num_labs=20):
        super().__init__()
        # Standard input 11 canonical nodes
        self.adapter = CanonicalGraphAdapter(input_nodes=11, canonical_nodes=11, num_labs=num_labs)
        self.interaction = SocialInteractionBlock()

        # Fusion: (11 nodes * 16 features * 2 agents) + 16 relation = 368
        self.frame_fusion = nn.Linear(368, 128)

    def forward(self, agent_x, target_x, lab_idx):
        # 1. Adapt Topology (Refine Physics/Geometry)
        a_c = self.adapter(agent_x, lab_idx)
        t_c = self.adapter(target_x, lab_idx)

        # 2. Compute Social Relations
        # This uses the specific relative features baked into Module 1
        _, _, rel_embed = self.interaction(a_c, t_c)

        # 3. Flatten for Transformer Input
        b, t, n, f = a_c.shape
        a_flat = a_c.view(b, t, -1)
        t_flat = t_c.view(b, t, -1)

        # 4. Dense Fusion
        # Fuses Self(A) + Self(B) + Relationship
        combined = torch.cat([a_flat, t_flat, rel_embed], dim=-1) # [B, T, 368]
        out = self.frame_fusion(combined) # [B, T, 128]

        # Returns: 
        # out -> The Fused Token (used for Global Context / Temporal processing)
        # a_c, t_c -> The Canonical Skeletons (used for Physics Gating in Mod 5)
        return out, a_c, t_c

# Module 3: The Split-Stream Interaction Block


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SplitStreamInteractionBlock(nn.Module):
    def __init__(self, node_dim=16, hidden_dim=128):
        super(SplitStreamInteractionBlock, self).__init__()

        # ----------------------------------------------------------------------
        # BRANCH A: SELF-BEHAVIOR STREAM (The "Me" Branch)
        # ----------------------------------------------------------------------
        # Focus: Posture, Grooming, Rearing, Running.
        # Input: Strictly LIMITED to the first 4 channels of the Agent (Pos X/Y, Vel, Speed).
        # We explicitly block Neighbor information (Channels 4+) from this stream
        # to prevent "Soft Leaks" (the Self branch learning Pair behaviors).
        self.self_input_size = 11 * 4 # 11 Nodes * 4 Feats (Pos/Vel)
        
        self.self_projector = nn.Sequential(
            nn.Linear(self.self_input_size, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        # ----------------------------------------------------------------------
        # BRANCH B: PAIR-BEHAVIOR STREAM (The "Us" Branch)
        # ----------------------------------------------------------------------
        # Focus: Interaction, Distance, Chasing, Fight.
        # Input: Agent (Full) + Target (Full) + Interaction Token.
        
        # 1. Relational Engine
        # The new Module 1 (Geo Features) pre-calculates distance/rel_pos in Ch 4-6.
        # We extract this directly rather than re-calculating on the fly.
        self.relational_mlp = nn.Sequential(
            nn.Linear(3, 32), # [Rel_X, Rel_Y, Rel_Dist] averaged over nodes
            nn.GELU(),
            nn.Linear(32, 32)
        )

        # 2. Fusion Layer
        # Agent (176) + Target (176) + Rel (32) + Roles (2)
        full_node_dim = 11 * node_dim # 176
        pair_input_dim = (full_node_dim * 2) + 32 + 2
        
        self.pair_projector = nn.Sequential(
            nn.Linear(pair_input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        # Role Tokens (Solves "Multi-Agent Roles")
        # [1, 0] = "I am Acting", [0, 1] = "I am Receiving"
        self.role_embedding = nn.Parameter(torch.tensor([[1.0, 0.0], [0.0, 1.0]]))

    def forward(self, agent_c, target_c):
        """
        agent_c:  [Batch, Time, 11, 16] (Canonical Skeleton w/ Geo Features)
        target_c: [Batch, Time, 11, 16] 
        """
        batch, time, nodes, feat = agent_c.shape

        # ----------------------------------------------------------
        # 1. PROCESS SELF STREAM (Strict Slicing)
        # ----------------------------------------------------------
        # Only take Channels 0,1,2,3 (Pos, Vel). 
        # Channels 4+ contain Neighbor Relative info -> BLOCKED.
        agent_proprioception = agent_c[..., 0:4] # [B, T, 11, 4]
        agent_flat_self = agent_proprioception.contiguous().view(batch, time, -1)
        
        self_feat = self.self_projector(agent_flat_self) # [B, T, 128]

        # ----------------------------------------------------------
        # 2. PROCESS PAIR STREAM (Full Context)
        # ----------------------------------------------------------
        # Flatten full skeletons
        agent_flat_full = agent_c.view(batch, time, -1)
        target_flat_full = target_c.view(batch, time, -1)
        
        # Extract Relational Data baked into Module 1 output
        # Channels: 4 (Neighbor X), 5 (Neighbor Y), 6 (Dist)
        # We assume mean interaction across nodes represents the body-level interaction
        rel_feats = agent_c[..., 4:7].mean(dim=2) # [B, T, 3]
        
        # Embed Relation
        rel_embed = self.relational_mlp(rel_feats) # [B, T, 32]

        # Add Role Tokens (Broadcasting Agent Role [1,0])
        role_token = self.role_embedding[0].view(1, 1, 2).expand(batch, time, 2)

        # Fuse Pair Features
        # Concatenate: Agent(Full) + Target(Full) + Relation + Role
        pair_input = torch.cat([
            agent_flat_full, 
            target_flat_full, 
            rel_embed, 
            role_token
        ], dim=-1)
        
        pair_feat = self.pair_projector(pair_input) # [B, T, 128]

        return self_feat, pair_feat

# Module 4: The Local-Global Chronos Encoder

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class LocalGlobalChronosEncoder(nn.Module):
    def __init__(self, input_dim=128, hidden_dim=128):
        super(LocalGlobalChronosEncoder, self).__init__()

        # ======================================================================
        # 1. GLOBAL CONTEXT STREAM (The "Narrative" Memory)
        # ======================================================================
        # Processes the 1 FPS Global Pair Features.
        # Captures long-term states (e.g., "Dominance established 10 mins ago").
        self.global_proj = nn.Linear(input_dim, hidden_dim)
        self.pos_encoder = PositionalEncoding(hidden_dim, max_len=5000)

        global_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=4,
            dim_feedforward=512,
            batch_first=True,
            dropout=0.1,
            activation="gelu"
        )
        self.global_transformer = nn.TransformerEncoder(global_layer, num_layers=2)

        # ======================================================================
        # 2. LOCAL SELF STREAM (The "Me" Branch) - DEEP TCN
        # ======================================================================
        # Updated: Receptive Field ~2.0 seconds (64 frames)
        self.self_tcn = nn.Sequential(
            # Frame Level (d=1)
            nn.Conv1d(input_dim, hidden_dim, kernel_size=3, padding=1),
            nn.BatchNorm1d(hidden_dim),
            nn.GELU(),
            
            # Short Range (d=2)
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=2, dilation=2),
            nn.BatchNorm1d(hidden_dim),
            nn.GELU(),
            
            # Medium Range (d=4)
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=4, dilation=4),
            nn.BatchNorm1d(hidden_dim),
            nn.GELU(),

            # Long Range (d=8) -> +16 frames context
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=8, dilation=8),
            nn.BatchNorm1d(hidden_dim),
            nn.GELU(),

            # Very Long Range (d=16) -> +32 frames context (TOTAL ~64)
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=16, dilation=16),
            nn.BatchNorm1d(hidden_dim),
            nn.GELU()
        )
        
        # Cross-Attention to Global 
        self.self_attn = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=4, batch_first=True)
        self.self_norm = nn.LayerNorm(hidden_dim)

        # ======================================================================
        # 3. LOCAL PAIR STREAM (The "Us" Branch) - DEEP TCN
        # ======================================================================
        self.pair_tcn = nn.Sequential(
            # Frame Level
            nn.Conv1d(input_dim, hidden_dim, kernel_size=3, padding=1),
            nn.BatchNorm1d(hidden_dim),
            nn.GELU(),
            
            # Short
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=2, dilation=2),
            nn.BatchNorm1d(hidden_dim),
            nn.GELU(),
            
            # Medium
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=4, dilation=4),
            nn.BatchNorm1d(hidden_dim),
            nn.GELU(),
            
            # Long (Interaction Buildup)
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=8, dilation=8),
            nn.BatchNorm1d(hidden_dim),
            nn.GELU(),
            
            # Very Long (Sustained Aggression/Chase)
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=16, dilation=16),
            nn.BatchNorm1d(hidden_dim),
            nn.GELU()
        )
        
        # Cross-Attention to Global
        self.pair_attn = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=4, batch_first=True)
        self.pair_norm = nn.LayerNorm(hidden_dim)

    def forward(self, global_feat, local_self, local_pair):
        """
        global_feat: [Batch, T_g, 128] 
        local_self:  [Batch, T_l, 128] 
        local_pair:  [Batch, T_l, 128] 
        """

        # --- A. Build Global Memory Bank ---
        g_emb = self.global_proj(global_feat)
        g_emb = self.pos_encoder(g_emb)
        global_memory = self.global_transformer(g_emb) # [B, T_g, 128]

        # --- B. Process Local Self Stream ---
        # 1. TCN 
        s_in = local_self.permute(0, 2, 1) # [B, C, T]
        s_tcn = self.self_tcn(s_in).permute(0, 2, 1) # [B, T, C]

        # 2. Cross-Attention
        # Query: Local TCN, Key/Value: Global Memory
        s_ctx, _ = self.self_attn(query=s_tcn, key=global_memory, value=global_memory)
        self_out = self.self_norm(s_tcn + s_ctx) 

        # --- C. Process Local Pair Stream ---
        # 1. TCN
        p_in = local_pair.permute(0, 2, 1)
        p_tcn = self.pair_tcn(p_in).permute(0, 2, 1)

        # 2. Cross-Attention
        p_ctx, _ = self.pair_attn(query=p_tcn, key=global_memory, value=global_memory)
        pair_out = self.pair_norm(p_tcn + p_ctx)

        return self_out, pair_out

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=10000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        L = x.size(1)
        if L > self.pe.size(0):
            return x + self.pe[:self.pe.size(0), :].repeat(math.ceil(L/self.pe.size(0)), 1)[:L, :]
        return x + self.pe[:L, :]

# Module 5: The Multi-Task Logic Head

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiTaskLogicHead(nn.Module):
    def __init__(self, input_dim=128, num_labs=20):
        super(MultiTaskLogicHead, self).__init__()

        # 1. DOMAIN EMBEDDING
        self.lab_embedding = nn.Embedding(num_labs, 32)

        # 2. FEATURE EXPANSION
        fusion_dim = input_dim + 32
        expanded_dim = 256 

        # 3. HEAD A: SELF BEHAVIORS
        self.self_classifier = nn.Sequential(
            nn.Linear(fusion_dim, expanded_dim),
            nn.LayerNorm(expanded_dim),
            nn.GELU(),
            nn.Linear(expanded_dim, 11) 
        )

        # 4. HEAD B: PAIR BEHAVIORS
        self.pair_classifier = nn.Sequential(
            nn.Linear(fusion_dim, expanded_dim),
            nn.LayerNorm(expanded_dim),
            nn.GELU(),
            nn.Linear(expanded_dim, 26) 
        )

        # 5. CENTER REGRESSOR
        self.center_regressor = nn.Sequential(
            nn.Linear(fusion_dim, 64),
            nn.GELU(),
            nn.Linear(64, 1) 
        )

        # 6. PHYSICS LOGIC GATE 
        self.gate_control = nn.Linear(1, 1)
        
        # --- CRITICAL FIX: INITIALIZATION ---
        with torch.no_grad():
            # Force Gate Open (start unbiased)
            self.gate_control.bias.fill_(2.0)
            
            # FORCE CLASSIFIERS TO PREDICT BACKGROUND (Prob ~0.01)
            # The final Linear layer is at index [3] of Sequential
            # Logits = -4.59 -> Sigmoid(-4.59) = 0.01
            nn.init.constant_(self.self_classifier[3].bias, -4.59)
            nn.init.constant_(self.pair_classifier[3].bias, -4.59)
            
            # Start Center Regression at 0.5 (Midpoint)
            nn.init.constant_(self.center_regressor[2].bias, 0.0)

    def forward(self, self_feat, pair_feat, lab_idx, agent_c, target_c):
        """
        self_probs: [B, T, 11] (0.0 - 1.0)
        pair_probs: [B, T, 26] (0.0 - 1.0)
        """
        batch, time, _ = self_feat.shape

        # A. Context
        lab_context = self.lab_embedding(lab_idx).unsqueeze(1).expand(-1, time, -1)
        self_input = torch.cat([self_feat, lab_context], dim=-1)
        pair_input = torch.cat([pair_feat, lab_context], dim=-1)

        # B. Raw Logits
        self_logits = self.self_classifier(self_input) 
        pair_logits = self.pair_classifier(pair_input) 
        
        # Center Score
        center_score = torch.sigmoid(self.center_regressor(pair_input))

        # C. Physics Gate
        # Dist Logic
        a_pos = agent_c[:, :, 0, :2]
        t_pos = target_c[:, :, 0, :2]
        dist = torch.norm(a_pos - t_pos, dim=-1, keepdim=True) 

        gate = torch.sigmoid(self.gate_control(dist))

        # D. Activation
        self_probs = torch.sigmoid(self_logits)
        
        # Combine Pair Logits with Gate
        pair_probs = torch.sigmoid(pair_logits) * gate

        return self_probs, pair_probs, center_score

# Module 6: Final Assembly (EthoSwarmNet V3)

In [6]:
import torch
import torch.nn as nn

# ==============================================================================
# BEHAVIOR DEFINITIONS (For Output Stitching)
# ==============================================================================
# All 37 actions sorted alphabetically (Competition Standard)
ACTION_LIST = sorted([
    "allogroom", "approach", "attack", "attemptmount", "avoid", "biteobject",
    "chase", "chaseattack", "climb", "defend", "dig", "disengage", "dominance",
    "dominancegroom", "dominancemount", "ejaculate", "escape", "exploreobject",
    "flinch", "follow", "freeze", "genitalgroom", "huddle", "intromit", "mount",
    "rear", "reciprocalsniff", "rest", "run", "selfgroom", "shepherd", "sniff",
    "sniffbody", "sniffface", "sniffgenital", "submit", "tussle"
])

# Subset: 11 Self Behaviors (Agent only)
SELF_BEHAVIORS = sorted([
    "biteobject", "climb", "dig", "exploreobject", "freeze", "genitalgroom",
    "huddle", "rear", "rest", "run", "selfgroom"
])

# Subset: 26 Pair Behaviors (Agent + Target)
PAIR_BEHAVIORS = sorted([
    "allogroom", "approach", "attack", "attemptmount", "avoid", "chase",
    "chaseattack", "defend", "disengage", "dominance", "dominancegroom",
    "dominancemount", "ejaculate", "escape", "flinch", "follow", "intromit",
    "mount", "reciprocalsniff", "shepherd", "sniff", "sniffbody", "sniffface",
    "sniffgenital", "submit", "tussle"
])

class EthoSwarmNet(nn.Module):
    def __init__(self, num_classes=37, input_dim=128):
        super(EthoSwarmNet, self).__init__()

        # ----------------------------------------------------------------------
        # 1. Morphological Core (Module 2)
        # ----------------------------------------------------------------------
        self.morph_core = MorphologicalInteractionCore(num_labs=20)

        # ----------------------------------------------------------------------
        # 2. Split-Stream Block (Module 3)
        # ----------------------------------------------------------------------
        self.split_interaction = SplitStreamInteractionBlock(hidden_dim=128)

        # ----------------------------------------------------------------------
        # 3. Local-Global Chronos (Module 4)
        # ----------------------------------------------------------------------
        self.chronos = LocalGlobalChronosEncoder(input_dim=128, hidden_dim=128)

        # ----------------------------------------------------------------------
        # 4. Multi-Task Logic Head (Module 5)
        # ----------------------------------------------------------------------
        self.logic_head = MultiTaskLogicHead(
            input_dim=128,
            num_labs=20
        )

        # ----------------------------------------------------------------------
        # 5. Output Stitching Maps
        # ----------------------------------------------------------------------
        self.register_buffer('self_indices', self._get_indices(SELF_BEHAVIORS))
        self.register_buffer('pair_indices', self._get_indices(PAIR_BEHAVIORS))

    def _get_indices(self, subset_list):
        indices = []
        for beh in subset_list:
            try:
                indices.append(ACTION_LIST.index(beh))
            except ValueError:
                pass
        return torch.tensor(indices, dtype=torch.long)

    def forward(self, global_agent, global_target, local_agent, local_target, lab_idx):
        """
        The V3 Forward Pass:
        Global/Local Streams -> Topology -> Split -> Time -> Logic -> Stitch
        """

        # --- A. TOPOLOGY (Module 2) ---
        g_out, _, _ = self.morph_core(global_agent, global_target, lab_idx)
        _, l_ac, l_tc = self.morph_core(local_agent, local_target, lab_idx)

        # --- B. SPLIT-STREAM (Module 3) ---
        l_self, l_pair = self.split_interaction(l_ac, l_tc)

        # --- C. TIME & CONTEXT (Module 4) ---
        t_self, t_pair = self.chronos(g_out, l_self, l_pair)

        # --- D. LOGIC & PHYSICS (Module 5) ---
        # FIX: Now accepts 3 return values
        # center_score is the Regression Head output (0.0 to 1.0)
        p_self, p_pair, center_score = self.logic_head(t_self, t_pair, lab_idx, l_ac, l_tc)

        # --- E. OUTPUT STITCHING ---
        batch, time, _ = p_self.shape
        # Reconstruct [Batch, T, 37] for classification targets
        final_output = torch.zeros(batch, time, 37, device=p_self.device, dtype=p_self.dtype)

        final_output.index_copy_(2, self.self_indices, p_self)
        final_output.index_copy_(2, self.pair_indices, p_pair)
        
        # NOTE: For now, we are returning 'final_output' (37 classes) 
        # because the Training Loop expects [B, T, 37] matching targets.
        # The 'center_score' improves internal gradient flow via backprop on Module 5.
        # If you want to use Center Score explicitly in loss later, return it as tuple:
        # return final_output, center_score
        
        return final_output

# Module 7: The Training Loop & Validation

In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import numpy as np
import pandas as pd
import os
import glob
import json
import torch.nn.functional as F

# ==============================================================================
# UTILS & METRICS
# ==============================================================================
def load_lab_vocabulary(vocab_path, action_to_idx, num_classes, device):
    """
    Loads a boolean mask [20, 37] where 1.0 means the lab annotates that action.
    """
    # Default to "Allow All" if file missing
    if not os.path.exists(vocab_path):
        return torch.ones(25, 37).to(device)
        
    with open(vocab_path, 'r') as f:
        vocab = json.load(f)
    
    # Must sort keys to match Module 1 index order
    lab_names = sorted(list(LAB_CONFIGS.keys()))
    mask = torch.zeros(len(lab_names), num_classes).to(device)
    
    for i, name in enumerate(lab_names):
        if name in vocab:
            for a in vocab[name]:
                if a in action_to_idx: 
                    mask[i, action_to_idx[a]] = 1.0
        else:
            mask[i, :] = 1.0
    return mask

def get_batch_f1(probs_in, targets, batch_vocab_mask, temporal_weights):
    """
    FIXED: Removed torch.sigmoid(). Input 'probs_in' is already 0.0-1.0 from Model.
    """
    # 1. Binarize Predictions (probs are already 0-1)
    preds = (probs_in > 0.4).float() 
    
    # 2. Combine Masks
    valid_pixels = temporal_weights.unsqueeze(-1) * batch_vocab_mask.unsqueeze(1)
    
    # 3. Calculate F1 only on VALID pixels
    tp = (preds * targets * valid_pixels).sum()
    fp = (preds * (1-targets) * valid_pixels).sum()
    fn = ((1-preds) * targets * valid_pixels).sum()
    
    f1 = 2*tp / (2*tp + fp + fn + 1e-6)
    return f1.item()

# ==============================================================================
# LOSS FUNCTION
# ==============================================================================
class DualStreamMaskedLoss(nn.Module):
    def __init__(self, model_self_indices, model_pair_indices):
        super().__init__()
        self.self_idx = model_self_indices
        self.pair_idx = model_pair_indices

    def forward(self, model_output_probs, target, weight_mask, lab_vocab_mask):
        """
        FIXED: Removed torch.sigmoid(). 
        Model outputs probabilities (0-1) due to Physics Gate.
        """
        # Slice Output/Target
        # Inputs are ALREADY PROBABILITIES
        p_self = model_output_probs[:, :, self.self_idx]
        p_pair = model_output_probs[:, :, self.pair_idx]
        
        t_self = target[:, :, self.self_idx]
        t_pair = target[:, :, self.pair_idx]
        
        # Clamp for numerical stability (prevent log(0))
        p_self = torch.clamp(p_self, 1e-7, 1 - 1e-7)
        p_pair = torch.clamp(p_pair, 1e-7, 1 - 1e-7)
        
        # Slice Lab Masks for Batch
        m_self = lab_vocab_mask[:, self.self_idx].unsqueeze(1) # [B, 1, n_self]
        m_pair = lab_vocab_mask[:, self.pair_idx].unsqueeze(1) # [B, 1, n_pair]
        
        # Temporal Mask [B, T, 1]
        tm = weight_mask.unsqueeze(-1)
        
        # Compute Loss (Standard BCELoss, NOT WithLogits)
        l_self_raw = F.binary_cross_entropy(p_self, t_self, reduction='none')
        l_pair_raw = F.binary_cross_entropy(p_pair, t_pair, reduction='none')
        
        # Weighted Sum
        loss_s = (l_self_raw * m_self * tm).sum() / ((m_self * tm).sum() + 1e-6)
        loss_p = (l_pair_raw * m_pair * tm).sum() / ((m_pair * tm).sum() + 1e-6)
        
        return loss_s + loss_p

# ==============================================================================
# TRAINING CONTROLLER
# ==============================================================================
def train_ethoswarm_v3():
    # --- 1. SETUP & PATHS ---
    if 'mabe_mouse_behavior_detection_path' in globals():
        DATA_PATH = globals()['mabe_mouse_behavior_detection_path']
    elif os.path.exists('/kaggle/input/MABe-mouse-behavior-detection'):
        DATA_PATH = '/kaggle/input/MABe-mouse-behavior-detection'
    else: 
        print("Dataset not found."); return
    
    VOCAB_PATH = '/kaggle/input/mabe-metadata/results/lab_vocabulary.json'
    
    gpu_count = torch.cuda.device_count()
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    BATCH_SIZE = 8 * max(1, gpu_count)
    LEARNING_RATE = 3e-4 
    NUM_EPOCHS = 3 

    print(f"Start Training on {gpu_count} GPU(s) | Batch Size: {BATCH_SIZE}")

    # --- 2. DATA PREP (Strict Video Split) ---
    meta = pd.read_csv(f"{DATA_PATH}/train.csv")
    vids = meta['video_id'].astype(str).unique()
    np.random.shuffle(vids)
    
    split = int(len(vids) * 0.90)
    train_ids = vids[:split]
    val_ids = vids[split:]
    
    # Loaders - using Module 1 (Cached)
    train_ds = BioPhysicsDataset(DATA_PATH, 'train', video_ids=train_ids)
    val_ds = BioPhysicsDataset(DATA_PATH, 'train', video_ids=val_ids)
    
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=pad_collate_dual, num_workers=2)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, collate_fn=pad_collate_dual, num_workers=2)
    
    # --- 3. MODEL INITIALIZATION ---
    model = EthoSwarmNet(num_classes=NUM_CLASSES, input_dim=128)
    
    model.to(DEVICE)
    if gpu_count > 1:
        print(f"--> Activating Distributed Data Parallel on {gpu_count} GPUs")
        model = nn.DataParallel(model)

    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
    scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LEARNING_RATE, steps_per_epoch=len(train_loader), epochs=NUM_EPOCHS)
    
    # Load Masks
    lab_masks = load_lab_vocabulary(VOCAB_PATH, ACTION_TO_IDX, NUM_CLASSES, DEVICE)
    
    self_indices = [ACTION_TO_IDX[a] for a in sorted(
        ["biteobject", "climb", "dig", "exploreobject", "freeze", "genitalgroom", 
         "huddle", "rear", "rest", "run", "selfgroom"])]
    pair_indices = [ACTION_TO_IDX[a] for a in sorted(
        ["allogroom", "approach", "attack", "attemptmount", "avoid", "chase", 
         "chaseattack", "defend", "disengage", "dominance", "dominancegroom", 
         "dominancemount", "ejaculate", "escape", "flinch", "follow", "intromit", 
         "mount", "reciprocalsniff", "shepherd", "sniff", "sniffbody", "sniffface", 
         "sniffgenital", "submit", "tussle"])]
         
    loss_fn = DualStreamMaskedLoss(self_indices, pair_indices)

    # --- 4. EPOCH LOOP ---
    for epoch in range(NUM_EPOCHS):
        model.train()
        loop = tqdm(train_loader, desc=f"Ep {epoch+1}")
        
        run_loss = 0.0
        run_f1 = 0.0
        
        for i, batch in enumerate(loop):
            # Move 5 items to GPU
            gx, lx, tgt, weights, lid = [b.to(DEVICE) for b in batch]
            
            optimizer.zero_grad()
            
            # Forward 
            # Output is PROBABILITIES now (from Module 5+6)
            probs = model(gx, gx, lx, lx, lid)
            
            # Loss Calc 
            loss = loss_fn(probs, tgt, weights, lab_masks[lid])
            
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            
            # Metrics
            with torch.no_grad():
                # FIXED: Don't sigmoid again
                f1 = get_batch_f1(probs, tgt, lab_masks[lid], weights)
                
            run_loss = 0.9*run_loss + 0.1*loss.item() if i>0 else loss.item()
            run_f1 = 0.9*run_f1 + 0.1*f1 if i>0 else f1
            
            if i % 20 == 0:
                loop.set_postfix({'Loss': f"{run_loss:.4f}", 'F1': f"{run_f1:.3f}"})
        
        # Validation
        print("Validating...")
        model.eval()
        val_loss_sum = 0
        batches = 0
        with torch.no_grad():
            for batch in val_loader:
                gx, lx, tgt, weights, lid = [b.to(DEVICE) for b in batch]
                
                probs = model(gx, gx, lx, lx, lid)
                loss = loss_fn(probs, tgt, weights, lab_masks[lid])
                
                val_loss_sum += loss.item()
                batches += 1
                
        print(f"Val Loss: {val_loss_sum/batches:.4f}")
        
        state = model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict()
        torch.save(state, f"ethoswarm_v3_ep{epoch+1}.pth")

if __name__ == '__main__':
#     train_ethoswarm_v3()
      pass

# Module 8 inference

In [8]:
# # ==============================================================================
# # MODULE 8: FINAL INFERENCE (CALIBRATED THRESHOLDS)
# # ==============================================================================
# import os
# import gc
# import glob
# import ast
# import numpy as np
# import pandas as pd
# import torch
# from torch.utils.data import Dataset, DataLoader
# from tqdm.auto import tqdm
# from scipy.ndimage import median_filter
# from pathlib import Path

# # --- 1. CONFIGURATION ---
# try:
#     DEVICE
# except NameError:
#     DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# TEST_CSV = '/kaggle/input/MABe-mouse-behavior-detection/test.csv'
# TEST_TRACKING = '/kaggle/input/MABe-mouse-behavior-detection/test_tracking'
# SUBMISSION_PATH = 'submission.csv'

# def find_weights(filename='ethoswarm_v3_ep3.pth'):
#     files = glob.glob(f"/kaggle/input/**/{filename}", recursive=True)
#     if files: return files[0]
#     return filename 

# WEIGHTS_PATH = find_weights('ethoswarm_v3_ep3.pth')
# INFERENCE_CHUNK_SIZE = 4000
# GPU_BATCH_SIZE = 8
# NUM_WORKERS = 0 

# # --- 2. CALIBRATED THRESHOLDS ---
# # Derived from Forensic Report: Threshold = ~0.4 * Max_Observed_Conf
# TH = {
#     "attack": 0.12,       # Max was 0.30
#     "rear": 0.15,         # Max was 0.46
#     "sniff": 0.05,        # Max was 0.13
#     "approach": 0.05,     # Max was 0.14
#     "mount": 0.06,        # Max was 0.14
#     "intromit": 0.06,     # Max was 0.14
#     "chase": 0.03,        # Max was 0.07 -> LOWERED
#     "escape": 0.01,       # Max was 0.015 -> LOWERED
#     "submit": 0.015,      # Max was 0.036 -> LOWERED (Critical Fix)
#     "avoid": 0.04,        # Max was 0.10
#     "biteobject": 0.02,   # Max was 0.04
#     "climb": 0.02,        # Max was 0.05
#     "dominance": 0.04,    # Max was 0.09
#     "tussle": 0.01        # Max was 0.012
# }
# DEF_TH = 0.02 # Catch-all for very weak signals

# BODY_PARTS = ["ear_left", "ear_right", "nose", "neck", "body_center","lateral_left", "lateral_right", "hip_left", "hip_right","tail_base", "tail_tip"]

# ACTION_LIST = sorted([
#     "allogroom", "approach", "attack", "attemptmount", "avoid", "biteobject",
#     "chase", "chaseattack", "climb", "defend", "dig", "disengage", "dominance",
#     "dominancegroom", "dominancemount", "ejaculate", "escape", "exploreobject",
#     "flinch", "follow", "freeze", "genitalgroom", "huddle", "intromit", "mount",
#     "rear", "reciprocalsniff", "rest", "run", "selfgroom", "shepherd", "sniff",
#     "sniffbody", "sniffface", "sniffgenital", "submit", "tussle"
# ])

# # Separation for Masking
# SELF_BEHAVIORS = sorted(["biteobject", "climb", "dig", "exploreobject", "freeze", "genitalgroom", "huddle", "rear", "rest", "run", "selfgroom"])
# PAIR_BEHAVIORS = sorted(list(set(ACTION_LIST) - set(SELF_BEHAVIORS)))
# SELF_IDXS = [ACTION_LIST.index(a) for a in SELF_BEHAVIORS]
# PAIR_IDXS = [ACTION_LIST.index(a) for a in PAIR_BEHAVIORS]

# LAB_CONFIGS = {
#     "AdaptableSnail":       {"thresh": 718.59, "window": 120, "pix_cm": 14.5},
#     "DEFAULT":              {"thresh": 150.00, "window": 128, "pix_cm": 15.0}
# }
# LAB_NAME_TO_IDX = {name: i for i, name in enumerate(sorted(LAB_CONFIGS.keys()))}

# # ==============================================================================
# # 8.1 PARSER
# # ==============================================================================
# def parse_test_tasks(test_csv_path):
#     if not os.path.exists(test_csv_path): return []
#     df = pd.read_csv(test_csv_path)
#     tasks = []
    
#     print(f"Scanning {len(df)} videos for tasks...")
#     for i, row in df.iterrows():
#         vid = str(row['video_id'])
#         lab = row['lab_id']
#         raw = row.get('behaviors_labeled', None)
#         try:
#             if pd.isna(raw):
#                 tasks.append({'vid': vid, 'lab': lab, 'ag': 'mouse1', 'tg': 'mouse2'})
#                 tasks.append({'vid': vid, 'lab': lab, 'ag': 'mouse2', 'tg': 'mouse1'})
#                 continue
            
#             labels = ast.literal_eval(raw)
#             seen = set()
#             for item in labels:
#                 if isinstance(item, str): parts = item.split(',')
#                 else: parts = item 
                
#                 if len(parts) >= 2:
#                     ag = parts[0].strip()
#                     tg = parts[1].strip()
#                     if tg.lower() == 'self' or tg == '': tg = 'self'
                    
#                     if (ag, tg) not in seen:
#                         tasks.append({'vid': vid, 'lab': lab, 'ag': ag, 'tg': tg})
#                         seen.add((ag, tg))
#         except:
#             tasks.append({'vid': vid, 'lab': lab, 'ag': 'mouse1', 'tg': 'mouse2'})
            
#     print(f"Total Unique Tasks: {len(tasks)}")
#     return tasks

# # ==============================================================================
# # 8.2 DATASET
# # ==============================================================================
# class TaskInferenceDataset(Dataset):
#     def __init__(self, tasks, tracking_dir):
#         self.samples = tasks
#         self.tracking_dir = Path(tracking_dir)
#         self.bp_idx = {b:i for i,b in enumerate(BODY_PARTS)}
#         self.ALIASES = {'hip_left': ['lateral_left', 'tail_base'], 'hip_right': ['lateral_right', 'tail_base'], 'neck': ['headpiece_bottombackright', 'nose'], 'body_center': ['tail_base']}

#     def __len__(self): return len(self.samples)

#     def _fix_teleport(self, pos):
#         T, N, _ = pos.shape
#         missing = (np.abs(pos).sum(axis=2) < 1e-6)
#         cleaned = pos.copy()
#         cols = np.where(missing.any(axis=0))[0]
#         if len(cols) > 0:
#             t = np.arange(T)
#             for n in cols:
#                 mask = missing[:, n]
#                 if np.any(~mask):
#                     cleaned[mask, n, 0] = np.interp(t[mask], t[~mask], pos[~mask, n, 0])
#                     cleaned[mask, n, 1] = np.interp(t[mask], t[~mask], pos[~mask, n, 1])
#         return cleaned

#     def _geo_features(self, m1, m2, lab):
#         # NORMALIZED SCALE
#         conf = LAB_CONFIGS.get(lab, LAB_CONFIGS['DEFAULT'])
#         px = float(conf['pix_cm']) + 1e-6
#         m1 = m1.astype(np.float32) / px
#         m2 = m2.astype(np.float32) / px
        
#         origin = m1[:, 9:10]; vec = m1[:, 3:4] - origin
#         ang = np.arctan2(vec[..., 0], vec[..., 1]).flatten()
#         c, s = np.cos(ang), np.sin(ang)
        
#         def rotate(p):
#             p_c = p - origin
#             px = p_c[...,0] * c[:,None] - p_c[...,1] * s[:,None]
#             py = p_c[...,0] * s[:,None] + p_c[...,1] * c[:,None]
#             return np.stack([px, py], axis=-1)

#         # Rotate
#         r_m1 = rotate(m1)
#         r_m2 = rotate(m2) 
        
#         rx, ry = r_m1[..., 0], r_m1[..., 1]
#         ox, oy = r_m2[..., 0], r_m2[..., 1]
        
#         vel_x = np.diff(rx, axis=0, prepend=rx[0:1])
#         vel_y = np.diff(ry, axis=0, prepend=ry[0:1])
#         spd = np.sqrt(vel_x**2 + vel_y**2)
#         dist = np.sqrt((rx - ox)**2 + (ry - oy)**2)
        
#         feats = np.stack([rx, ry, vel_x, vel_y, spd, dist, ox, oy], axis=-1)
#         L, N, _ = feats.shape
#         padded = np.zeros((L, N, 16), dtype=np.float32)
#         padded[..., :8] = feats
        
#         padded = np.nan_to_num(padded, nan=0.0)
#         padded = np.clip(padded, -30.0, 30.0) 
        
#         return padded

#     def __getitem__(self, idx):
#         job = self.samples[idx]
#         job_out = job.copy()
        
#         try:
#             p = self.tracking_dir / job['lab'] / f"{job['vid']}.parquet"
#             if not p.exists(): return torch.zeros(1), 0, job_out

#             df = pd.read_parquet(p)
#             max_f = int(df.video_frame.max() + 1) if 'video_frame' in df else len(df)
#             limit = min(max_f, 3000000)
            
#             uids = df.mouse_id.unique()
#             def resolve(req):
#                 if req == 'self': return None
#                 for u in uids:
#                     if str(u) == str(req): return u
#                     if f"mouse{u}" == str(req): return u
#                 return uids[0]

#             real_ag = resolve(job['ag'])
#             real_tg = resolve(job['tg'])
            
#             if real_tg is None:
#                 real_tg = real_ag
#                 job_out['is_self'] = True
#             else:
#                 job_out['is_self'] = False
                
#             job_out['sub_ag'] = f"mouse{real_ag}" if str(real_ag).isdigit() else str(real_ag)
#             job_out['sub_tg'] = 'self' if job_out['is_self'] else (f"mouse{real_tg}" if str(real_tg).isdigit() else str(real_tg))

#             raw_ag = np.zeros((limit, 11, 2), dtype=np.float32)
#             raw_tg = np.zeros_like(raw_ag)
#             avail = set(df['bodypart'].unique())
            
#             def extract(mid, dest):
#                 m_rows = df[df.mouse_id == mid]
#                 groups = dict(tuple(m_rows.groupby('bodypart')))
#                 for bp, idx in self.bp_idx.items():
#                     src = bp
#                     if src not in avail and bp in self.ALIASES:
#                         for a in self.ALIASES[bp]:
#                             if a in avail: src = a; break
#                     if src in groups:
#                         g = groups[src]
#                         f = g.video_frame.values.astype(int) if 'video_frame' in g else np.arange(len(g))
#                         v = f < limit
#                         dest[f[v], idx] = g[['x','y']].values[v]

#             extract(real_ag, raw_ag)
#             extract(real_tg, raw_tg)

#             raw_ag = self._fix_teleport(raw_ag)
#             raw_tg = self._fix_teleport(raw_tg)
            
#             for m in [raw_ag, raw_tg]:
#                 tail_base = m[:, 9]
#                 for i in [3, 4, 7, 8]: m[:, i] = tail_base if m[:, i].sum() == 0 else m[:, i]

#             feats = self._geo_features(raw_ag, raw_tg, job['lab'])
#             lid = LAB_NAME_TO_IDX.get(job['lab'], 0)
            
#             return torch.tensor(feats), lid, job_out

#         except: return torch.zeros(1), 0, job_out

# def collate_fn(b): return b[0]

# # ==============================================================================
# # 8.3 MAIN EXECUTION
# # ==============================================================================
# if __name__ == "__main__":
#     print(f"--- STARTING MODULE 8: FINAL INFERENCE (CALIBRATED) ---")
    
#     if not os.path.exists(TEST_CSV):
#         pd.DataFrame(columns=['row_id','video_id','action']).to_csv(SUBMISSION_PATH, index=False); exit()

#     tasks = parse_test_tasks(TEST_CSV)
#     ds = TaskInferenceDataset(tasks, TEST_TRACKING)
#     loader = DataLoader(ds, batch_size=1, collate_fn=collate_fn, num_workers=NUM_WORKERS)
    
#     try: model = EthoSwarmNet(37, 128).to(DEVICE)
#     except: print("ERROR: EthoSwarmNet not found."); exit()

#     if os.path.exists(WEIGHTS_PATH):
#         try:
#             model.load_state_dict(torch.load(WEIGHTS_PATH, map_location=DEVICE), strict=False)
#             print(f"Weights loaded from {WEIGHTS_PATH}")
#         except: pass
#     model.eval()
    
#     first_pass = True
#     with open(SUBMISSION_PATH, 'w') as f:
#         f.write("row_id,video_id,agent_id,target_id,action,start_frame,stop_frame\n")
#         row_id = 0
        
#         with torch.no_grad():
#             for feat, lid, job in tqdm(loader, desc="Inferring"):
#                 if feat.dim() < 3 or feat.shape[0] < 5: continue
                
#                 T = feat.shape[0]
#                 rem = T % INFERENCE_CHUNK_SIZE
#                 if rem > 0:
#                     pad = INFERENCE_CHUNK_SIZE - rem
#                     feat = torch.cat([feat, feat[-1:].repeat(pad,1,1)], 0)
                
#                 c_f = feat.view(-1, INFERENCE_CHUNK_SIZE, 11, 16)
#                 preds_list = []
#                 for i in range(len(c_f)):
#                     ba = c_f[i:i+1].to(DEVICE)
#                     l_idx = torch.tensor([lid], device=DEVICE)
#                     out = model(ba, ba, ba, ba, l_idx)
#                     preds_list.append(out.cpu())
                
#                 # --- RAW PROBABILITIES ---
#                 probs = torch.cat(preds_list, 1).flatten(0, 1)[:T].numpy()
#                 row_max = probs.max(axis=1, keepdims=True) + 1e-6
#                 probs_norm = probs / row_max 
                
#                 # --- DEBUG SNAPSHOT ---
#                 if first_pass:
#                     print(f"\n--- DEBUG SNAPSHOT ({job['vid']}) ---")
#                     print(f"Task: {job['sub_ag']} -> {job['sub_tg']}")
#                     mid = T // 2
#                     top3 = np.argsort(probs_norm[mid])[-3:][::-1]
#                     print(f"Pred (Norm): {[f'{ACTION_LIST[i]}({probs_norm[mid,i]:.3f})' for i in top3]}")
#                     first_pass = False
                
#                 aid, tid = job['sub_ag'], job['sub_tg']
                
#                 # --- STRICT MASKING ---
#                 if job['is_self']:
#                     probs_norm[:, PAIR_IDXS] = 0.0 
#                 else:
#                     probs_norm[:, SELF_IDXS] = 0.0
                
#                 # --- TOP-K GATING ---
#                 topk_vals = np.sort(probs_norm, axis=1)[:, -3:] 
#                 min_k = topk_vals[:, 0:1] 
#                 mask_k = (probs_norm >= min_k).astype(np.float32)
#                 probs_norm *= mask_k

#                 for cls_idx, act in enumerate(ACTION_LIST):
#                     trace = probs_norm[:, cls_idx]
#                     if trace.max() < 0.01: continue
                    
#                     # Use Calibrated Threshold
#                     thresh = TH.get(act, DEF_TH)
                    
#                     smoothed = median_filter(trace, size=7)
#                     binary = (smoothed > thresh).astype(np.int8)
                    
#                     if binary.sum() == 0: continue
#                     diffs = np.diff(np.concatenate(([0], binary, [0])))
#                     starts = np.where(diffs == 1)[0]
#                     ends = np.where(diffs == -1)[0]
                    
#                     for s, e in zip(starts, ends):
#                         if (e - s) < 2: continue
#                         f.write(f"{row_id},{job['vid']},{aid},{tid},{act},{s},{e}\n")
#                         row_id += 1
                
#                 del feat, c_f, probs, preds_list
    
#     print(f"Success! Generated {row_id} rows.")
#     print(f"Submission saved to {SUBMISSION_PATH}")

In [9]:
# # ==============================================================================
# # MODULE 9: FORENSIC VALIDATION (SCALE 1.0 ONLY)
# # ==============================================================================
# import pandas as pd
# import numpy as np
# import torch
# import glob
# import os
# import random
# from torch.utils.data import Dataset, DataLoader
# from pathlib import Path

# # --- CONFIGURATION ---
# LAB_ID = 'AdaptableSnail'
# MODEL_PATH = '/kaggle/input/mabe-separated/ethoswarm_v3_ep3.pth'
# TRACKING_PATH = '/kaggle/input/MABe-mouse-behavior-detection/train_tracking'
# ANNOTATION_PATH = '/kaggle/input/MABe-mouse-behavior-detection/train_annotation'
# SCALE = 1.0 # Locked per your request

# # --- CONSTANTS ---
# ACTION_LIST = sorted([
#     "allogroom", "approach", "attack", "attemptmount", "avoid", "biteobject",
#     "chase", "chaseattack", "climb", "defend", "dig", "disengage", "dominance",
#     "dominancegroom", "dominancemount", "ejaculate", "escape", "exploreobject",
#     "flinch", "follow", "freeze", "genitalgroom", "huddle", "intromit", "mount",
#     "rear", "reciprocalsniff", "rest", "run", "selfgroom", "shepherd", "sniff",
#     "sniffbody", "sniffface", "sniffgenital", "submit", "tussle"
# ])
# BODY_PARTS = ["ear_left", "ear_right", "nose", "neck", "body_center","lateral_left", "lateral_right", "hip_left", "hip_right","tail_base", "tail_tip"]
# LAB_CONFIGS = {"AdaptableSnail": {"thresh": 718.59, "window": 120, "pix_cm": 14.5}, "DEFAULT": {"thresh": 150.0, "window": 128, "pix_cm": 15.0}}

# # --- DATASET ---
# class ValidationDataset(Dataset):
#     def __init__(self, vid_id, lab_id, tracking_dir, agent_id, target_id):
#         self.vid = vid_id; self.lab = lab_id; self.tracking_dir = Path(tracking_dir)
#         self.agent_id = agent_id; self.target_id = target_id
#         self.p_path = self.tracking_dir / self.lab / f"{self.vid}.parquet"
#         self.ALIASES = {'hip_left': ['lateral_left', 'tail_base'], 'hip_right': ['lateral_right', 'tail_base'], 'neck': ['headpiece_bottombackright', 'nose'], 'body_center': ['tail_base']}
#         self.bp_idx = {b:i for i,b in enumerate(BODY_PARTS)}

#     def _geo_features(self, m1, m2, lab):
#         conf = LAB_CONFIGS.get(lab, LAB_CONFIGS['DEFAULT'])
#         px = float(conf['pix_cm']) + 1e-6
#         m1 = m1 / px; m2 = m2 / px
        
#         origin = m1[:, 9:10]; vec = m1[:, 3:4] - origin
#         ang = np.arctan2(vec[..., 0], vec[..., 1]).flatten(); c, s = np.cos(ang), np.sin(ang)
#         def rotate(p):
#             px = p[...,0] * c[:,None] - p[...,1] * s[:,None]
#             py = p[...,0] * s[:,None] + p[...,1] * c[:,None]
#             return np.stack([px, py], axis=-1)

#         rot_m1 = rotate(m1 - origin)
#         rot_m2 = rotate(m2 - origin)
#         v1 = np.diff(rot_m1, axis=0, prepend=rot_m1[0:1])
#         s1 = np.linalg.norm(v1, axis=-1, keepdims=True)
#         v2 = np.diff(rot_m2, axis=0, prepend=rot_m2[0:1])
#         s2 = np.linalg.norm(v2, axis=-1, keepdims=True)
#         dist = np.linalg.norm(rot_m2, axis=-1, keepdims=True)
        
#         lid = 0; lab_token = np.full_like(s1, lid); zeros = np.zeros_like(s1)
#         # RAW ALIGNMENT
#         feat = np.concatenate([rot_m1, v1, rot_m2, v2, s1, s2, dist, lab_token, zeros, zeros, zeros, zeros], axis=-1)
#         return feat.astype(np.float32)

#     def __getitem__(self, idx):
#         if not self.p_path.exists(): return None
#         df = pd.read_parquet(self.p_path)
#         limit = min(len(df), 150000) 
#         raw_m1 = np.zeros((limit, 11, 2), dtype=np.float32)
#         raw_m2 = np.zeros_like(raw_m1)
#         avail = set(df.bodypart.unique())
#         uids = df.mouse_id.unique()
        
#         def resolve_id(req_id):
#             if req_id in uids: return req_id
#             for u in uids:
#                 if str(u) == str(req_id): return u
#                 if f"mouse{u}" == str(req_id): return u
#             return uids[0]

#         real_ag = resolve_id(self.agent_id)
#         real_tg = resolve_id(self.target_id)

#         def load_mouse(mid, dest):
#             m_rows = df[df.mouse_id == mid]
#             groups = dict(tuple(m_rows.groupby('bodypart')))
#             for bp, idx in self.bp_idx.items():
#                 src = bp
#                 if src not in avail and bp in self.ALIASES:
#                     for a in self.ALIASES[bp]:
#                         if a in avail: src = a; break
#                 if src in groups:
#                     g = groups[src]
#                     f = g.video_frame.values.astype(int) if 'video_frame' in g else np.arange(len(g))
#                     v = f < limit
#                     dest[f[v], idx] = g[['x','y']].values[v]

#         load_mouse(real_ag, raw_m1)
#         load_mouse(real_tg, raw_m2)
#         for m in [raw_m1, raw_m2]:
#             tail_base = m[:, 9]
#             for i in [3, 4, 7, 8]: 
#                 if m[:, i].sum() == 0: m[:, i] = tail_base

#         f_ag = self._geo_features(raw_m1, raw_m2, self.lab)
#         f_tg = self._geo_features(raw_m2, raw_m1, self.lab)
#         return torch.tensor(f_ag), torch.tensor(f_tg), 0

#     def __len__(self): return 1

# # --- MULTI-FILE SCANNER ---
# def get_diverse_events():
#     print("Scanning for events...")
#     annotation_files = glob.glob(f"{ANNOTATION_PATH}/{LAB_ID}/*.parquet")
#     random.shuffle(annotation_files)
#     events = []
#     TARGETS = ['rear', 'chase', 'attack', 'sniff'] # Focused list
#     found = {t:0 for t in TARGETS}
    
#     # Deep scan (100 files max)
#     for p in annotation_files[:100]:
#         if all(v >= 3 for v in found.values()): break 
#         try:
#             df = pd.read_parquet(p)
#             vid_id = Path(p).stem
#             for act in TARGETS:
#                 if found[act] >= 3: continue 
#                 rows = df[df['action'] == act]
#                 if not rows.empty:
#                     r = rows.iloc[0]
#                     events.append({
#                         'video_id': vid_id, 'action': act,
#                         'frame': (r['start_frame'] + r['stop_frame']) // 2,
#                         'ag': r['agent_id'], 'tg': r['target_id']
#                     })
#                     found[act] += 1
#         except: continue
#     print(f"Stats: {found}")
#     return events

# # --- RUNNER ---
# def run_comparison():
#     print(f"--- FORENSIC VALIDATION (SCALE {SCALE}) ---")
#     EVENTS = get_diverse_events()
    
#     DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#     try: model = EthoSwarmNet(37, 128).to(DEVICE)
#     except: print("Model class not found."); return

#     if os.path.exists(MODEL_PATH):
#         model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE), strict=False)
#     else: print("❌ Weights NOT found."); return
#     model.eval()
    
#     print(f"\n{'VIDEO':<12} | {'ACTION':<10} | {'PREDICTION (Top 3)':<50} | {'INPUT STATS (Speed/Dist)'}")
#     print("-" * 110)
    
#     for evt in EVENTS:
#         ds = ValidationDataset(evt['video_id'], LAB_ID, TRACKING_PATH, evt['ag'], evt['tg'])
#         item = ds[0]
#         if item is None: continue
#         fa_raw, ft_raw, lid = item
        
#         # Apply Scale 1.0 (Raw)
#         fa = (fa_raw / SCALE).to(DEVICE)
#         ft = (ft_raw / SCALE).to(DEVICE)
#         fa = torch.clamp(fa, -5.0, 5.0)
#         ft = torch.clamp(ft, -5.0, 5.0)
        
#         mid = evt['frame']
#         start = max(0, mid - 50); end = min(fa.shape[0], mid + 50)
#         if start >= end: continue
        
#         l_bat = torch.tensor([lid], device=DEVICE)
        
#         with torch.no_grad():
#             ba = fa[start:end].unsqueeze(0)
#             bt = ft[start:end].unsqueeze(0)
#             lbl = l_bat.repeat(ba.shape[0])
            
#             out = model(ba, bt, ba, bt, lbl)
#             mid_idx = min(50, out.shape[1]-1)
#             probs = out[0, mid_idx].float().cpu().numpy()
            
#             # --- DIAGNOSTICS ---
#             # Extract Speed (Ch 8) and Dist (Ch 10) at midpoint
#             # Shape is [1, T, 11, 16], we want average over nodes
#             curr_spd = ba[0, mid_idx, :, 8].mean().item()
#             curr_dist = ba[0, mid_idx, :, 10].mean().item()
#             stats = f"Spd:{curr_spd:.2f} Dist:{curr_dist:.2f}"
            
#             # Top 3
#             top3_idx = np.argsort(probs)[-3:][::-1]
#             top3 = [f"{ACTION_LIST[i]}({probs[i]:.2f})" for i in top3_idx]
            
#             mark = "✅" if evt['action'] in [ACTION_LIST[i] for i in top3_idx] else "❌"
            
#             print(f"{evt['video_id'][:10]:<12} | {evt['action']:<10} | {mark} {', '.join(top3):<47} | {stats}")

# if __name__ == '__main__':
#     run_comparison()

In [10]:
# ==============================================================================
# MODULE 8: FINAL INFERENCE (WINNER-TAKES-ALL)
# ==============================================================================
import os
import gc
import glob
import ast
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm
from scipy.ndimage import median_filter
from pathlib import Path

# --- 1. CONFIGURATION ---
try:
    DEVICE
except NameError:
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

TEST_CSV = '/kaggle/input/MABe-mouse-behavior-detection/test.csv'
TEST_TRACKING = '/kaggle/input/MABe-mouse-behavior-detection/test_tracking'
SUBMISSION_PATH = 'submission.csv'

def find_weights(filename='ethoswarm_v3_ep3.pth'):
    files = glob.glob(f"/kaggle/input/**/{filename}", recursive=True)
    if files: return files[0]
    return filename 

WEIGHTS_PATH = find_weights('ethoswarm_v3_ep3.pth')
INFERENCE_CHUNK_SIZE = 4000
GPU_BATCH_SIZE = 8
NUM_WORKERS = 0 

# --- 2. THRESHOLDS (Slightly Higher to Reduce Noise) ---
TH = {
    "sniff": 0.08,        "approach": 0.08,
    "rear": 0.20,         "chase": 0.10,
    "attack": 0.20,       "mount": 0.20,
    "escape": 0.10,       "avoid": 0.10,
    "climb": 0.15,        "biteobject": 0.15,
    "sniffface": 0.08,    "sniffgenital": 0.08
}
DEF_TH = 0.05

BODY_PARTS = ["ear_left", "ear_right", "nose", "neck", "body_center","lateral_left", "lateral_right", "hip_left", "hip_right","tail_base", "tail_tip"]

ACTION_LIST = sorted([
    "allogroom", "approach", "attack", "attemptmount", "avoid", "biteobject",
    "chase", "chaseattack", "climb", "defend", "dig", "disengage", "dominance",
    "dominancegroom", "dominancemount", "ejaculate", "escape", "exploreobject",
    "flinch", "follow", "freeze", "genitalgroom", "huddle", "intromit", "mount",
    "rear", "reciprocalsniff", "rest", "run", "selfgroom", "shepherd", "sniff",
    "sniffbody", "sniffface", "sniffgenital", "submit", "tussle"
])

# Separation for Masking
SELF_BEHAVIORS = sorted(["biteobject", "climb", "dig", "exploreobject", "freeze", "genitalgroom", "huddle", "rear", "rest", "run", "selfgroom"])
PAIR_BEHAVIORS = sorted(list(set(ACTION_LIST) - set(SELF_BEHAVIORS)))
SELF_IDXS = [ACTION_LIST.index(a) for a in SELF_BEHAVIORS]
PAIR_IDXS = [ACTION_LIST.index(a) for a in PAIR_BEHAVIORS]

LAB_NAME_TO_IDX = {name: i for i, name in enumerate(sorted(LAB_CONFIGS.keys()))}

# ==============================================================================
# 8.1 PARSER
# ==============================================================================
def parse_test_tasks(test_csv_path):
    if not os.path.exists(test_csv_path): return []
    df = pd.read_csv(test_csv_path)
    tasks = []
    
    print(f"Scanning {len(df)} videos for tasks...")
    for i, row in df.iterrows():
        vid = str(row['video_id'])
        lab = row['lab_id']
        raw = row.get('behaviors_labeled', None)
        try:
            if pd.isna(raw):
                tasks.append({'vid': vid, 'lab': lab, 'ag': 'mouse1', 'tg': 'mouse2'})
                tasks.append({'vid': vid, 'lab': lab, 'ag': 'mouse2', 'tg': 'mouse1'})
                continue
            
            labels = ast.literal_eval(raw)
            seen = set()
            for item in labels:
                if isinstance(item, str): parts = item.split(',')
                else: parts = item 
                
                if len(parts) >= 2:
                    ag = parts[0].strip()
                    tg = parts[1].strip()
                    if tg.lower() == 'self' or tg == '': tg = 'self'
                    
                    if (ag, tg) not in seen:
                        tasks.append({'vid': vid, 'lab': lab, 'ag': ag, 'tg': tg})
                        seen.add((ag, tg))
        except:
            tasks.append({'vid': vid, 'lab': lab, 'ag': 'mouse1', 'tg': 'mouse2'})
            
    print(f"Total Unique Tasks: {len(tasks)}")
    return tasks

# ==============================================================================
# 8.2 DATASET
# ==============================================================================
class TaskInferenceDataset(Dataset):
    def __init__(self, tasks, tracking_dir):
        self.samples = tasks
        self.tracking_dir = Path(tracking_dir)
        self.bp_idx = {b:i for i,b in enumerate(BODY_PARTS)}
        self.ALIASES = {'hip_left': ['lateral_left', 'tail_base'], 'hip_right': ['lateral_right', 'tail_base'], 'neck': ['headpiece_bottombackright', 'nose'], 'body_center': ['tail_base']}

    def __len__(self): return len(self.samples)

    def _fix_teleport(self, pos):
        T, N, _ = pos.shape
        missing = (np.abs(pos).sum(axis=2) < 1e-6)
        cleaned = pos.copy()
        cols = np.where(missing.any(axis=0))[0]
        if len(cols) > 0:
            t = np.arange(T)
            for n in cols:
                mask = missing[:, n]
                if np.any(~mask):
                    cleaned[mask, n, 0] = np.interp(t[mask], t[~mask], pos[~mask, n, 0])
                    cleaned[mask, n, 1] = np.interp(t[mask], t[~mask], pos[~mask, n, 1])
        return cleaned

    def _geo_features(self, m1, m2, lab):
        # NORMALIZED SCALE
        conf = LAB_CONFIGS.get(lab, LAB_CONFIGS['DEFAULT'])
        px = float(conf['pix_cm']) + 1e-6
        m1 = m1.astype(np.float32) / px
        m2 = m2.astype(np.float32) / px
        
        # Align (NO ROTATION to match Training)
        origin = m1[:, 9:10] # Tail base
        centered = m1 - origin
        other_centered = m2 - origin
        
        # Velocity
        vel = np.diff(centered, axis=0, prepend=centered[0:1])
        speed = np.sqrt((vel**2).sum(axis=-1))
        
        # Relation
        dist = np.sqrt(((m1 - m2)**2).sum(axis=-1))
        
        # Pack to 16 matching Training
        # [Pos X, Pos Y, Vel X, Vel Y, Speed, Rel_Dist] + Pads
        feats = np.stack([
            centered[...,0], centered[...,1],
            vel[...,0], vel[...,1],
            speed, dist
        ], axis=-1)
        
        L, N, _ = m1.shape
        padded = np.zeros((L, N, 16), dtype=np.float32)
        padded[..., :6] = feats
        
        padded = np.nan_to_num(padded, nan=0.0)
        padded = np.clip(padded, -30.0, 30.0) 
        
        return padded

    def __getitem__(self, idx):
        job = self.samples[idx]
        job_out = job.copy()
        
        try:
            p = self.tracking_dir / job['lab'] / f"{job['vid']}.parquet"
            if not p.exists(): return torch.zeros(1), 0, job_out

            df = pd.read_parquet(p)
            max_f = int(df.video_frame.max() + 1) if 'video_frame' in df else len(df)
            limit = min(max_f, 3000000)
            
            uids = df.mouse_id.unique()
            def resolve(req):
                if req == 'self': return None
                for u in uids:
                    if str(u) == str(req): return u
                    if f"mouse{u}" == str(req): return u
                return uids[0]

            real_ag = resolve(job['ag'])
            real_tg = resolve(job['tg'])
            
            if real_tg is None:
                real_tg = real_ag
                job_out['is_self'] = True
            else:
                job_out['is_self'] = False
                
            job_out['sub_ag'] = f"mouse{real_ag}" if str(real_ag).isdigit() else str(real_ag)
            job_out['sub_tg'] = 'self' if job_out['is_self'] else (f"mouse{real_tg}" if str(real_tg).isdigit() else str(real_tg))

            raw_ag = np.zeros((limit, 11, 2), dtype=np.float32)
            raw_tg = np.zeros_like(raw_ag)
            avail = set(df['bodypart'].unique())
            
            def extract(mid, dest):
                m_rows = df[df.mouse_id == mid]
                groups = dict(tuple(m_rows.groupby('bodypart')))
                for bp, idx in self.bp_idx.items():
                    src = bp
                    if src not in avail and bp in self.ALIASES:
                        for a in self.ALIASES[bp]:
                            if a in avail: src = a; break
                    if src in groups:
                        g = groups[src]
                        f = g.video_frame.values.astype(int) if 'video_frame' in g else np.arange(len(g))
                        v = f < limit
                        dest[f[v], idx] = g[['x','y']].values[v]

            extract(real_ag, raw_ag)
            extract(real_tg, raw_tg)

            raw_ag = self._fix_teleport(raw_ag)
            raw_tg = self._fix_teleport(raw_tg)
            for m in [raw_ag, raw_tg]:
                tail_base = m[:, 9]
                for i in [3, 4, 7, 8]: m[:, i] = tail_base if m[:, i].sum() == 0 else m[:, i]

            feats = self._geo_features(raw_ag, raw_tg, job['lab'])
            lid = LAB_NAME_TO_IDX.get(job['lab'], 0)
            return torch.tensor(feats), lid, job_out

        except: return torch.zeros(1), 0, job_out

def collate_fn(b): return b[0]

# ==============================================================================
# 8.3 MAIN EXECUTION
# ==============================================================================
if __name__ == "__main__":
    print(f"--- STARTING MODULE 8: FINAL INFERENCE (WINNER-TAKES-ALL) ---")
    
    if not os.path.exists(TEST_CSV):
        pd.DataFrame(columns=['row_id','video_id','action']).to_csv(SUBMISSION_PATH, index=False); exit()

    tasks = parse_test_tasks(TEST_CSV)
    ds = TaskInferenceDataset(tasks, TEST_TRACKING)
    loader = DataLoader(ds, batch_size=1, collate_fn=collate_fn, num_workers=NUM_WORKERS)
    
    try: model = EthoSwarmNet(37, 128).to(DEVICE)
    except: print("ERROR: EthoSwarmNet not found."); exit()

    if os.path.exists(WEIGHTS_PATH):
        try:
            model.load_state_dict(torch.load(WEIGHTS_PATH, map_location=DEVICE), strict=False)
            print(f"Weights loaded from {WEIGHTS_PATH}")
        except: pass
    model.eval()
    
    first_pass = True
    with open(SUBMISSION_PATH, 'w') as f:
        f.write("row_id,video_id,agent_id,target_id,action,start_frame,stop_frame\n")
        row_id = 0
        
        with torch.no_grad():
            for feat, lid, job in tqdm(loader, desc="Inferring"):
                if feat.dim() < 3 or feat.shape[0] < 5: continue
                
                T = feat.shape[0]
                rem = T % INFERENCE_CHUNK_SIZE
                if rem > 0:
                    pad = INFERENCE_CHUNK_SIZE - rem
                    feat = torch.cat([feat, feat[-1:].repeat(pad,1,1)], 0)
                
                c_f = feat.view(-1, INFERENCE_CHUNK_SIZE, 11, 16)
                preds_list = []
                for i in range(len(c_f)):
                    ba = c_f[i:i+1].to(DEVICE)
                    l_idx = torch.tensor([lid], device=DEVICE)
                    out = model(ba, ba, ba, ba, l_idx)
                    preds_list.append(out.cpu())
                
                # Raw Probs
                probs = torch.cat(preds_list, 1).flatten(0, 1)[:T].numpy()
                
                # 1. Strict Masking
                if job['is_self']:
                    probs[:, PAIR_IDXS] = 0.0 
                else:
                    probs[:, SELF_IDXS] = 0.0
                
                # 2. TOP-1 GATING (The Fix for Concurrency)
                # Keep only the single highest probability per frame
                topk_vals = np.sort(probs, axis=1)[:, -1:] 
                mask_k = (probs >= topk_vals).astype(np.float32)
                probs *= mask_k
                
                # 3. Debug
                if first_pass:
                    print(f"\n--- DEBUG SNAPSHOT ({job['vid']}) ---")
                    print(f"Task: {job['sub_ag']} -> {job['sub_tg']}")
                    mid = T // 2
                    top = np.argmax(probs[mid])
                    print(f"Frame {mid} Winner: {ACTION_LIST[top]} ({probs[mid, top]:.3f})")
                    first_pass = False

                aid, tid = job['sub_ag'], job['sub_tg']
                
                for cls_idx, act in enumerate(ACTION_LIST):
                    trace = probs[:, cls_idx]
                    if trace.max() < 0.01: continue
                    
                    thresh = TH.get(act, DEF_TH)
                    
                    smoothed = median_filter(trace, size=7)
                    binary = (smoothed > thresh).astype(np.int8)
                    
                    if binary.sum() == 0: continue
                    diffs = np.diff(np.concatenate(([0], binary, [0])))
                    starts = np.where(diffs == 1)[0]
                    ends = np.where(diffs == -1)[0]
                    
                    for s, e in zip(starts, ends):
                        if (e - s) < 2: continue
                        f.write(f"{row_id},{job['vid']},{aid},{tid},{act},{s},{e}\n")
                        row_id += 1
                
                del feat, c_f, probs, preds_list
    
    print(f"Success! Generated {row_id} rows.")
    print(f"Submission saved to {SUBMISSION_PATH}")

--- STARTING MODULE 8: FINAL INFERENCE (WINNER-TAKES-ALL) ---
Scanning 1 videos for tasks...
Total Unique Tasks: 16
Weights loaded from /kaggle/input/mabe-separated/ethoswarm_v3_ep3.pth


Inferring:   0%|          | 0/16 [00:00<?, ?it/s]


--- DEBUG SNAPSHOT (438887472) ---
Task: mouse1 -> mouse2
Frame 9211 Winner: approach (0.073)
Success! Generated 506 rows.
Submission saved to submission.csv
