# MABe Mouse Behavior Detection - Starter Notebook

This notebook provides a starting point for exploring the MABe dataset, training models, and preparing submissions.

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import yaml
from pathlib import Path

# Set style
sns.set_style('whitegrid')
%matplotlib inline

## 1. Load Data

In [None]:
# Data paths
data_dir = Path('/vol' if Path('/vol').exists() else 'data')
train_dir = data_dir / 'train'
val_dir = data_dir / 'val'

# List data files
train_files = sorted(list(train_dir.glob('*.npy')))
print(f'Training files: {len(train_files)}')

# Load a sample
sample_data = np.load(train_files[0])
print(f'Sample shape: {sample_data.shape}')
print(f'Features per frame: {sample_data.shape[1]}')

## 2. Visualize Keypoints

In [None]:
# Reshape to keypoints
num_mice = 2
num_keypoints = 7
keypoints = sample_data.reshape(-1, num_mice, num_keypoints, 2)

# Plot first frame
frame_idx = 0
plt.figure(figsize=(10, 8))

for mouse_idx in range(num_mice):
    kpts = keypoints[frame_idx, mouse_idx]
    plt.scatter(kpts[:, 0], kpts[:, 1], label=f'Mouse {mouse_idx}', s=100)
    
    # Connect keypoints
    plt.plot(kpts[:, 0], kpts[:, 1], alpha=0.5)

plt.legend()
plt.title(f'Frame {frame_idx} - Mouse Keypoints')
plt.xlabel('X coordinate')
plt.ylabel('Y coordinate')
plt.axis('equal')
plt.show()

## 3. Analyze Temporal Dynamics

In [None]:
# Plot inter-mouse distance over time
mouse0_centroid = keypoints[:, 0, :, :].mean(axis=1)
mouse1_centroid = keypoints[:, 1, :, :].mean(axis=1)
inter_dist = np.linalg.norm(mouse0_centroid - mouse1_centroid, axis=1)

plt.figure(figsize=(12, 4))
plt.plot(inter_dist)
plt.title('Inter-Mouse Distance Over Time')
plt.xlabel('Frame')
plt.ylabel('Distance')
plt.grid(True)
plt.show()

print(f'Mean distance: {inter_dist.mean():.3f}')
print(f'Std distance: {inter_dist.std():.3f}')

## 4. Feature Engineering

In [None]:
import sys
sys.path.insert(0, '/vol/code' if Path('/vol/code').exists() else 'src')

from data.feature_engineering import MouseFeatureEngineer

# Create feature engineer
feature_engineer = MouseFeatureEngineer(num_mice=2, num_keypoints=7)

# Extract features
features = feature_engineer.extract_all_features(
    sample_data,
    include_pca=False,
    include_temporal=True
)

print(f'Extracted features shape: {features.shape}')
print(f'Number of features: {features.shape[1]}')

## 5. Load and Test Model

In [None]:
from models.advanced_models import build_advanced_model

# Load config
config_path = '/vol/code/configs/config_advanced.yaml' if Path('/vol/code').exists() else 'configs/config_advanced.yaml'
with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

config['input_dim'] = features.shape[1]

# Build model
model = build_advanced_model(config)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)

print(f'Model: {config["model_type"]}')
print(f'Device: {device}')
print(f'Parameters: {sum(p.numel() for p in model.parameters()):,}')

## 6. Test Inference

In [None]:
# Prepare test input
test_sequence = torch.FloatTensor(features[:100]).unsqueeze(0).to(device)
print(f'Test sequence shape: {test_sequence.shape}')

# Run inference
model.eval()
with torch.no_grad():
    logits = model(test_sequence)
    predictions = torch.argmax(logits, dim=-1)

print(f'Output shape: {logits.shape}')
print(f'Predictions shape: {predictions.shape}')
print(f'Sample predictions: {predictions[0, :10].cpu().numpy()}')

## 7. Visualize Predictions

In [None]:
# Behavior classes
behavior_classes = {0: 'other', 1: 'close_investigation', 2: 'mount', 3: 'attack'}

# Plot predictions
pred_np = predictions[0].cpu().numpy()

plt.figure(figsize=(14, 4))
plt.plot(pred_np, marker='o', markersize=3, linestyle='-', alpha=0.7)
plt.title('Predicted Behavior Over Time')
plt.xlabel('Frame')
plt.ylabel('Behavior Class')
plt.yticks(range(4), [behavior_classes[i] for i in range(4)])
plt.grid(True, alpha=0.3)
plt.show()

# Count behavior occurrences
unique, counts = np.unique(pred_np, return_counts=True)
for behavior_id, count in zip(unique, counts):
    print(f'{behavior_classes[behavior_id]}: {count} frames ({count/len(pred_np)*100:.1f}%)')