# Notebook 1: The Observer's Log - A Deep Dive into the Data (EDA)

**Overall Goal:** Before writing a single line of model code, our objective is to build a strong intuition for the MABe dataset. We will explore the structure of the data, visualize the mouse movements, and identify the core challenges of this competition, such as class imbalance and data variability. This deep understanding will guide all of our future feature engineering and modeling decisions.

---

# **Step 1: Setup and Metadata Exploration**

**Goal:** Based on the file structure, the primary data is stored in efficient `.parquet` files located in separate folders for tracking and annotations. The `train.csv` and `test.csv` files likely serve as metadata indexes, providing a list of all video IDs and potentially other high-level information.

Our first step is to load these metadata files to understand the scope of the dataset (how many training/test videos are there?) and how we can use them to access the individual data files.

**Action:** Please run the code block below to import libraries and inspect the head and info of the `train.csv` and `test.csv` metadata files.

In [None]:
# Import necessary libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
from tqdm.auto import tqdm
from matplotlib.animation import FuncAnimation
from IPython.display import HTML

In [None]:
# Set some display options for pandas for better readability
pd.set_option('display.max_columns', 100)
sns.set_style('whitegrid')

# Define the path to your data
DATA_PATH = '/kaggle/input/MABe-mouse-behavior-detection/'

# Load the metadata files
print("Loading train.csv (metadata)...")
df_train_meta = pd.read_csv(DATA_PATH + 'train.csv')

print("Loading test.csv (metadata)...")
df_test_meta = pd.read_csv(DATA_PATH + 'test.csv')

print("\n--- Train Metadata ---")
print(f"Shape: {df_train_meta.shape}")
display(df_train_meta.head())

print("\n--- Test Metadata ---")
print(f"Shape: {df_test_meta.shape}")
display(df_test_meta.head())

## What We Learned in Step 1

*   **Metadata Confirmed:** `train.csv` and `test.csv` are indeed high-level metadata files, not the raw tracking data. They serve as an index for the entire dataset.
*   **Dataset Scale:** The training set is substantial, with **8,790 videos**. This means our methods need to be efficient.
*   **Code Competition Structure:** The public test set is tiny (just **1 video**). This is a classic sign of a code competition where our submitted notebook will be re-run on a much larger, hidden test set. This emphasizes the need for a **generalizable solution**, not one that is overfit to this single test video.
*   **Rich Metadata:** We have a treasure trove of information for each video:
    *   `lab_id`: Tells us which lab the data came from. This is a crucial feature for handling data variability.
    *   `mouse_...`: Details about the strain, sex, age, and condition of up to four mice.
    *   `video_...` / `arena_...`: Technical details about the recording setup (FPS, resolution, arena size).
*   **File Path Keys:** The `lab_id` and `video_id` columns are the keys we need to construct the paths to the actual `.parquet` data files.

# Step 2: Loading and Inspecting a Single Data Sample

**Goal:** Now that we understand the map (the metadata), it's time to explore the territory (the actual data). We will select the very first video from our training metadata and load its corresponding tracking and annotation files. This will reveal the low-level data structure we'll be working with for our models.

**Action:** The code below will:
1.  Select the first video from `df_train_meta`.
2.  Construct the file paths for its tracking and annotation data.
3.  Load the two `.parquet` files into new pandas DataFrames.
4.  Display the first few rows, shape, and column information for both the tracking and annotation data.

In [None]:
# Select the first video from the metadata as our sample
sample_video_meta = df_train_meta.iloc[0]
sample_lab_id = sample_video_meta['lab_id']
sample_video_id = sample_video_meta['video_id']

print(f"Loading sample video...\n  Lab ID: {sample_lab_id}\n  Video ID: {sample_video_id}")

In [None]:
# Construct the file paths using the lab and video IDs
tracking_path = os.path.join(DATA_PATH, 'train_tracking', sample_lab_id, f'{sample_video_id}.parquet')
annotation_path = os.path.join(DATA_PATH, 'train_annotation', sample_lab_id, f'{sample_video_id}.parquet')

# Load the actual data from the parquet files
df_tracking_sample = pd.read_parquet(tracking_path)
df_annot_sample = pd.read_parquet(annotation_path)

print("\n--- Sample Tracking Data ---")
print(f"Shape: {df_tracking_sample.shape}")
print("Info:")
df_tracking_sample.info()
print("\nFirst 5 rows:")
display(df_tracking_sample.head())

print("\n\n--- Sample Annotation Data ---")
print(f"Shape: {df_annot_sample.shape}")
print("Info:")
df_annot_sample.info()
print("\nFirst 5 rows:")
display(df_annot_sample.head())

## What We Learned in Step 2

*   **"Long" Data Format:** The tracking data is in a "long" or "tidy" format. This means every single row represents just **one bodypart** for **one mouse** in **one frame**. While this is efficient for storage, it's not immediately useful for machine learning, where we typically want to see all the information for a single frame in one row.
*   **The Reshaping Challenge:** Our next major task will be to "pivot" or reshape this data. We need to transform it into a "wide" format where each row represents a single `video_frame`, and the columns contain all the x and y coordinates for all mice (e.g., `mouse1_head_x`, `mouse1_head_y`, `mouse2_head_x`, etc.).
*   **Annotation Structure:** The annotation data is very clear. It defines discrete behavioral events with a start and end frame. We can see some actions are individual (where `agent_id` == `target_id`, like "rear") and some are social (where they are different, like "avoid").
*   **Data Granularity:** For a video that is ~615 seconds long at 30 FPS (from the metadata), we'd expect around 18,450 frames. The tracking data has `1,087,658` rows. Dividing this by the number of bodyparts and mice should give us the frame count, confirming the data's structure.

# Step 3: Reshaping the Data for Analysis (Pivoting)

**Goal:** To transform our "long" tracking data into a "wide" format. This is the most critical data manipulation step in our EDA. A wide format (one row per frame) will allow us to easily calculate features, visualize entire scenes, and feed the data into a model later.

**Action:** The code below will:
1.  First, list the unique bodyparts being tracked in this video to see what we're working with.
2.  Use the `pivot` function in pandas to reshape the data.
3.  We will pivot the `x` and `y` coordinates separately and then merge them together to create one comprehensive DataFrame for the sample video.
4.  Display the head of the new, wide DataFrame to see the result.

In [None]:
# 1. See what bodyparts are available
unique_bodyparts = df_tracking_sample['bodypart'].unique()
print(f"Unique bodyparts tracked: {unique_bodyparts}\n")

In [None]:
# 2. Pivot the table to get a "wide" format
# We want one row per video_frame, and columns for each mouse's bodypart's x and y coordinates.

print("Pivoting data from long to wide format...")

In [None]:
# Create a pivot for the 'x' coordinates
pivot_x = df_tracking_sample.pivot(
    index='video_frame', 
    columns=['mouse_id', 'bodypart'], 
    values='x'
)
# Rename columns for clarity, e.g., (1, 'nose') -> 'mouse1_nose_x'
pivot_x.columns = [f"mouse{m}_{bp}_x" for m, bp in pivot_x.columns]


# Create a pivot for the 'y' coordinates
pivot_y = df_tracking_sample.pivot(
    index='video_frame', 
    columns=['mouse_id', 'bodypart'], 
    values='y'
)
# Rename columns for clarity
pivot_y.columns = [f"mouse{m}_{bp}_y" for m, bp in pivot_y.columns]


# 3. Merge the x and y pivots into a single DataFrame
df_wide_sample = pd.concat([pivot_x, pivot_y], axis=1)

# Sort columns alphabetically for consistent order
df_wide_sample = df_wide_sample.sort_index(axis=1)


print("Pivoting complete.\n")
print("--- Reshaped Wide DataFrame ---")
print(f"Shape: {df_wide_sample.shape}")
display(df_wide_sample.head())

## What We Learned in Step 3

*   **Pivoting is Key:** We have successfully transformed the data from a long, stacked format into a wide, intuitive format. The shape `(18451, 142)` tells us we have 18,451 frames of data, and 142 feature columns (a mix of x and y coordinates for all tracked bodyparts on all mice).
*   **Missing Data (`NaN`):** Notice the presence of `NaN` (Not a Number) values. This is completely normal in tracking data. It means the tracking algorithm (e.g., DeepLabCut) was not confident enough to assign a coordinate for that bodypart in that specific frame. This could be due to one mouse blocking another (occlusion) or fast movements causing motion blur. We will need to keep this in mind when engineering features.
*   **Complexity of Bodyparts:** The list of unique bodyparts shows a mix of standard anatomical points (`nose`, `ear_left`, `tail_base`) and some experiment-specific ones (`headpiece_...`). For general-purpose features, we'll focus on the standard anatomical points first.

# Step 4: Visualizing the Data - A Static Snapshot

**Goal:** Before we animate the mice, let's make sure we can plot a single frame correctly. This helps us understand the coordinate system and see the posture of all mice at one moment in time.

**Action:** We will write a function that takes a single frame's data (one row from our wide DataFrame) and plots the keypoints for each mouse. We'll connect some keypoints with lines to form a simple "skeleton" for better visualization.
1.  Define a list of standard, anatomical bodyparts we want to focus on.
2.  Define the connections between these parts to draw skeletons.
3.  Create the plotting function.
4.  Use the function to plot frame `1000` of our sample video.

In [None]:
# 1. Define the core bodyparts we want to visualize
# We will ignore the 'headpiece' parts for this general visualization
ANATOMICAL_BODYPARTS = [
    'nose', 'ear_left', 'ear_right', 'neck', 'body_center', 
    'lateral_left', 'lateral_right', 'tail_base'
]

In [None]:
# 2. Define connections to draw a simple skeleton
# Each tuple represents a line from one bodypart to another
SKELETON_CONNECTIONS = [
    ('nose', 'ear_left'), ('nose', 'ear_right'), ('ear_left', 'ear_right'),
    ('nose', 'neck'), ('neck', 'body_center'),
    ('body_center', 'lateral_left'), ('body_center', 'lateral_right'),
    ('body_center', 'tail_base')
]

# Define a color for each mouse for consistent plotting
MOUSE_COLORS = {1: 'blue', 2: 'orange', 3: 'green', 4: 'red'}

In [None]:
# 3. Create the plotting function
def plot_frame(frame_data):
    """Plots the skeletons of all mice for a single frame of data."""
    
    plt.figure(figsize=(8, 8))
    
    # Iterate through each mouse
    for mouse_id in range(1, 5): # Assumes up to 4 mice
        
        # Check if data for this mouse exists in the frame
        if f'mouse{mouse_id}_nose_x' not in frame_data or pd.isna(frame_data[f'mouse{mouse_id}_nose_x']):
            continue # Skip if this mouse isn't tracked in this frame
            
        # Plot the keypoints (bodyparts)
        for part in ANATOMICAL_BODYPARTS:
            col_x = f'mouse{mouse_id}_{part}_x'
            col_y = f'mouse{mouse_id}_{part}_y'
            if col_x in frame_data and col_y in frame_data:
                plt.scatter(frame_data[col_x], frame_data[col_y], color=MOUSE_COLORS[mouse_id], label=f'Mouse {mouse_id}' if part == 'nose' else "")

        # Plot the skeleton connections
        for part1, part2 in SKELETON_CONNECTIONS:
            col1_x, col1_y = f'mouse{mouse_id}_{part1}_x', f'mouse{mouse_id}_{part1}_y'
            col2_x, col2_y = f'mouse{mouse_id}_{part2}_x', f'mouse{mouse_id}_{part2}_y'

            # Check if both points for the line exist
            if all(c in frame_data for c in [col1_x, col1_y, col2_x, col2_y]) and \
               pd.notna(frame_data[col1_x]) and pd.notna(frame_data[col2_x]):
                
                plt.plot([frame_data[col1_x], frame_data[col2_x]], 
                         [frame_data[col1_y], frame_data[col2_y]], 
                         color=MOUSE_COLORS[mouse_id], alpha=0.7)

    plt.title(f"Mouse Positions at Frame {frame_data.name}")
    plt.xlabel("X-coordinate")
    plt.ylabel("Y-coordinate")
    
    # Invert the y-axis because image coordinates (0,0) are usually at the top-left
    plt.gca().invert_yaxis()
    plt.legend()
    plt.axis('equal') # Ensure aspect ratio is maintained
    plt.show()

In [None]:
# 4. Use the function to plot a specific frame
FRAME_TO_PLOT = 1000
plot_frame(df_wide_sample.loc[FRAME_TO_PLOT])

## What We Learned in Step 4

*   **Visualization Success:** Our plotting function works perfectly! We can now take any single frame of data and instantly visualize the entire scene.
*   **Relative Positions are Clear:** We can see that at frame 1000, Mouse 3 (green) and Mouse 1 (blue) are relatively close, while Mouse 4 (red) is further away. This ability to see spatial relationships is the foundation for understanding social behavior.
*   **Missing Mice Handled:** Notice that Mouse 2 is not plotted. Our code correctly handled the missing data for this mouse at this frame, which is essential for creating a robust visualization tool.
*   **Static is Not Enough:** A single frame shows posture, but behavior is defined by **movement through time**. To truly understand what's happening, we need to see the sequence of these frames.

# Step 5: The Dynamic View - Animating a Behavior

**Goal:** This is the most intuitive part of our EDA. We will create an animation—a mini-movie—of the mice. This will allow us to see how their positions and postures change over time, giving us a true feel for their behavior.

**Action:**
1.  We will pick the first labeled behavior from our `df_annot_sample` DataFrame.
2.  We'll extract the `start_frame` and `stop_frame` for that behavior.
3.  We will create an animation of the mouse movements during that specific time window.
4.  We will display the animation directly in the notebook.

This will be our first look at a specific, labeled action as it actually happened.

In [None]:
# 1. Pick a behavior to animate from our annotation sample
behavior_to_animate = df_annot_sample.iloc[0]
start_frame = behavior_to_animate['start_frame']
stop_frame = behavior_to_animate['stop_frame']
action = behavior_to_animate['action']
agent = behavior_to_animate['agent_id']

print(f"Preparing to animate behavior: '{action}' by Mouse {agent}")
print(f"Frame range: {start_frame} to {stop_frame}")

In [None]:
# Add a small buffer before and after to see the context
ANIM_START = max(0, start_frame - 20)
ANIM_STOP = stop_frame + 20

# Slice our wide dataframe to get only the frames we need for the animation
anim_df = df_wide_sample.loc[ANIM_START:ANIM_STOP]

In [None]:
# --- Animation Setup ---

# Set up the figure and axis
fig, ax = plt.subplots(figsize=(8, 8))

# Determine axis limits from the entire animation sequence to prevent jittering
x_min, x_max = anim_df.filter(like='_x').min().min(), anim_df.filter(like='_x').max().max()
y_min, y_max = anim_df.filter(like='_y').min().min(), anim_df.filter(like='_y').max().max()
padding = 50 # Add some padding to the plot
ax.set_xlim(x_min - padding, x_max + padding)
ax.set_ylim(y_min - padding, y_max + padding)


# The function that will draw each frame of the animation
def update(frame_num):
    ax.clear() # Clear the previous frame
    
    # Get the data for the current frame
    frame_data = anim_df.iloc[frame_num]
    current_real_frame = anim_df.index[frame_num]
    
    # Plot each mouse for the current frame
    for mouse_id in range(1, 5):
        if f'mouse{mouse_id}_nose_x' not in frame_data or pd.isna(frame_data[f'mouse{mouse_id}_nose_x']):
            continue

        # Plot keypoints
        for part in ANATOMICAL_BODYPARTS:
            col_x, col_y = f'mouse{mouse_id}_{part}_x', f'mouse{mouse_id}_{part}_y'
            if col_x in frame_data and col_y in frame_data:
                ax.scatter(frame_data[col_x], frame_data[col_y], color=MOUSE_COLORS[mouse_id])

        # Plot skeleton
        for part1, part2 in SKELETON_CONNECTIONS:
            col1_x, col1_y = f'mouse{mouse_id}_{part1}_x', f'mouse{mouse_id}_{part1}_y'
            col2_x, col2_y = f'mouse{mouse_id}_{part2}_x', f'mouse{mouse_id}_{part2}_y'
            if all(c in frame_data for c in [col1_x, col1_y, col2_x, col2_y]) and \
               pd.notna(frame_data[col1_x]) and pd.notna(frame_data[col2_x]):
                ax.plot([frame_data[col1_x], frame_data[col2_x]], [frame_data[col1_y], frame_data[col2_y]], color=MOUSE_COLORS[mouse_id], alpha=0.7)

    # Set titles and labels for the frame
    ax.set_title(f"Behavior: '{action}' by Mouse {agent} | Frame: {current_real_frame}")
    ax.set_xlabel("X-coordinate")
    ax.set_ylabel("Y-coordinate")
    ax.set_xlim(x_min - padding, x_max + padding)
    ax.set_ylim(y_min - padding, y_max + padding)
    ax.invert_yaxis() # Invert y-axis for image coordinates
    return ax,

In [None]:
# Create the animation
# frames=len(anim_df) specifies how many times to call the update function
# interval=50 is the delay between frames in milliseconds
ani = FuncAnimation(fig, update, frames=len(anim_df), interval=50, blit=False)

# Display the animation in the notebook
# This may take a little while to render
HTML(ani.to_jshtml())

## What We Learned in Step 5

*   **Behavior is Motion:** The animation makes it crystal clear that behaviors are not static poses but dynamic sequences of movements. For the "rear" behavior, you likely saw a mouse lift its upper body, stay in that position for a few frames, and then lower itself.
*   **Context is Key:** By adding a buffer before and after the labeled event, we can see the transitions into and out of a behavior. This is critical information that a sequence model can learn.
*   **The Power of Visualization:** We now have a powerful tool to debug our future models. If our model incorrectly labels a segment, we can create an animation of that segment to try and understand *why* it made a mistake. Was it a subtle movement? Was there an occlusion?
*   **From Deep to Wide:** We have now performed a "deep dive" on a single video. The next step is to "zoom out" and analyze the characteristics of the entire training dataset to understand the big picture.

# Step 6: Dataset-Wide Analysis - Behavior Distribution & Duration

**Goal:** To understand the overall properties of the behaviors we need to predict. We will now use the complete `annotations.csv` file (which we loaded as `df_annotations` in Step 1) to answer critical questions:
1.  **What are all the different behaviors?**
2.  **How often does each behavior occur (Class Balance)?** This is one of the most important questions. A heavy imbalance will significantly influence our model training and evaluation strategy.
3.  **How long do behaviors typically last (Duration)?** Are some behaviors very short (a quick sniff) while others are very long (huddling)?

**Action:** We will create plots to visualize the frequency and duration of every behavior across the entire training set.

In [None]:
# This cell is designed to be self-contained. It will build the full annotation
# dataframe if it doesn't already exist in memory.
if 'df_annotations_full' not in locals():
    print("Building the full annotations dataframe by combining all individual annotation files...")

    all_annotations_list = []
    for index, row in tqdm(df_train_meta.iterrows(), total=df_train_meta.shape[0]):
        lab_id = row['lab_id']
        video_id = row['video_id']
        annotation_path = os.path.join(DATA_PATH, 'train_annotation', lab_id, f'{video_id}.parquet')
        
        if os.path.exists(annotation_path):
            temp_df = pd.read_parquet(annotation_path)
            temp_df['video_id'] = video_id
            all_annotations_list.append(temp_df)

    df_annotations_full = pd.concat(all_annotations_list, ignore_index=True)
    print(f"\nSuccessfully created full annotation dataframe with shape: {df_annotations_full.shape}")
else:
    print("Full annotation dataframe already exists in memory. Proceeding with analysis.")


In [None]:
# --- Behavior Frequency Analysis ---
print("\n--- Behavior Frequency Analysis ---")

behavior_counts = df_annotations_full['action'].value_counts()
plt.figure(figsize=(12, 8))
sns.barplot(x=behavior_counts.index, y=behavior_counts.values, palette='viridis')
plt.title('Frequency of Each Behavior Across the Entire Training Set', fontsize=16)
plt.xlabel('Behavior', fontsize=12)
plt.ylabel('Count', fontsize=12)
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()

# --- Behavior Duration Analysis ---
print("\n--- Behavior Duration Analysis ---")
df_annotations_full['duration_frames'] = df_annotations_full['stop_frame'] - df_annotations_full['start_frame']

# Let's see how many zero-duration events we have
zero_duration_count = (df_annotations_full['duration_frames'] == 0).sum()
print(f"Found {zero_duration_count} events with a duration of 0 frames.")

print("\nBasic statistics for behavior durations (in frames):")
display(df_annotations_full['duration_frames'].describe())

# Add 1 to duration before plotting on a log scale to handle zeros
plt.figure(figsize=(12, 6))
sns.histplot(df_annotations_full['duration_frames'] + 1, bins=100, log_scale=True)
plt.title('Distribution of Behavior Durations (Log Scale, Duration+1)', fontsize=16)
plt.xlabel('Duration (Frames) + 1 - Log Scale', fontsize=12)
plt.ylabel('Frequency', fontsize=12)
plt.show()

plt.figure(figsize=(12, 10))

In [None]:

order = df_annotations_full.groupby('action')['duration_frames'].median().sort_values(ascending=False).index
# Use duration + 1 for the x-axis in the boxplot as well
sns.boxplot(x=df_annotations_full['duration_frames'] + 1, y='action', data=df_annotations_full, order=order, palette='coolwarm')
plt.title('Duration of Each Behavior Type', fontsize=16)
plt.xlabel('Duration (Frames) + 1 - Log Scale', fontsize=12)
plt.ylabel('Behavior', fontsize=12)
plt.xscale('log')
plt.tight_layout()
plt.show()

## What We Learned in Step 6

The dataset-wide analysis has revealed the most critical challenges of this competition.

**From the Frequency Plot (Bar Chart):**

*   **Extreme Class Imbalance:** This is the #1 challenge. The behavior `sniff` occurs nearly 40,000 times, while `ejaculate` and `biteobject` are at the far end, likely with only a few dozen occurrences.
*   **Modeling Implication:** A standard model will become an expert at predicting `sniff` and `attack` but will completely ignore the rare classes because it can achieve high accuracy by just focusing on the majority. We **must** use special techniques to handle this, such as:
    *   Using an appropriate evaluation metric that cares about rare classes (like the competition's F-Score variant).
    *   Applying class weights during training to penalize the model more for misclassifying rare behaviors.
    *   Using special sampling techniques (e.g., oversampling rare classes).

**From the Duration Distribution Plot (Histogram):**

*   **Bimodal Distribution:** The histogram isn't a simple bell curve. It has two "humps" or modes. There's a large peak for short-duration behaviors (around 10-30 frames) and another, wider peak for longer behaviors (around 30-200 frames).
*   **Modeling Implication:** This suggests there isn't one "typical" behavior length. Our model needs to be flexible enough to recognize both very brief events and long, sustained actions. The sequence length we choose for our models (e.g., LSTMs, Transformers) will be an important hyperparameter.

**From the Duration by Type Plot (Box Plot):**

*   **Behaviors Have Characteristic Durations:** This plot is incredibly useful. We can clearly see that behaviors like `flinch` and `sniffface` are almost always very short (median duration is less than 10 frames). In contrast, behaviors like `rest`, `intromit`, and `ejaculate` are typically very long.
*   **Feature Engineering Idea:** The duration of an action is itself a powerful predictive feature! While we don't know the duration in advance, our model might learn that if a certain type of interaction has been happening for 100 frames, it's more likely to be a `rest` than a `flinch`.
*   **High Variance:** Notice the long tails and many outlier points (the black diamonds) for almost every behavior. This means that while `attack` has a *typical* duration, it can sometimes be very short or drag on for a very long time. Our model must be robust to this variability.

# Step 7: The Lab Effect - Analyzing Data Variability

**Goal:** The competition description explicitly mentions the challenge of generalizing across data from different labs, which may use different equipment and tracking methods. Our final EDA step is to investigate this "lab effect."

**Action:** We will join our full annotation data with the original metadata to get the `lab_id` for each behavior. Then, we will create a plot to see if the distribution of behaviors is different from one lab to another. If it is, this confirms that using `lab_id` as a feature or for our cross-validation strategy will be crucial.

In [None]:
# First, ensure 'df_annotations_full' exists
if 'df_annotations_full' not in locals():
    print("Full annotation dataframe is not in memory. Please re-run the previous cell (Step 6).")
else:
    print("--- Lab Variability Analysis ---")
    
    # We need to merge our annotations with the metadata to get the lab_id for each event
    # We select only the 'video_id' and 'lab_id' from the metadata to keep the merge light
    df_lab_info = df_train_meta[['video_id', 'lab_id']]
    
    # Perform a left merge to add 'lab_id' to each annotation
    df_annotations_with_lab = pd.merge(df_annotations_full, df_lab_info, on='video_id', how='left')
    
    print(f"Successfully merged lab info. New shape: {df_annotations_with_lab.shape}")
    display(df_annotations_with_lab.head())
    
    # Now, let's plot the behavior counts per lab
    plt.figure(figsize=(15, 8))
    
    # We use crosstab to count occurrences of each action within each lab, then normalize
    # to see the percentage/proportion, which is better for comparison
    crosstab_norm = pd.crosstab(df_annotations_with_lab['lab_id'], 
                               df_annotations_with_lab['action'], 
                               normalize='index') # 'normalize=index' calculates percentages per lab
    
    sns.heatmap(crosstab_norm, cmap='viridis', annot=False) # 'annot=True' can be messy if too many classes
    plt.title('Proportion of Behaviors by Lab', fontsize=16)
    plt.xlabel('Behavior', fontsize=12)
    plt.ylabel('Lab ID', fontsize=12)
    plt.show()

## What We Learned in Step 7

The heatmap provides definitive evidence of the "lab effect" and is perhaps the most important visualization for designing a winning strategy.

*   **The 'Lab Effect' is Real and Severe:** The plot is not uniform at all. It's very "blocky." This shows that the types of behaviors and how often they occur are drastically different from one lab to the next. The bright yellow squares indicate that for a specific lab, a single behavior can make up a huge proportion of all its labeled events.

*   **Lab Specialization:** Look at the bright yellow square for lab `BoisterousParrot` under the behavior `sniffbody`. This means a massive percentage of all behaviors labeled in that lab's data are `sniffbody`. Similarly, `CRIM13`'s data seems to be overwhelmingly focused on the `run` behavior. These labs were likely designed to study these specific actions.

*   **Rare Behaviors are Lab-Specific:** Some behaviors might only appear in data from one or two labs. If a model learns to recognize a rare behavior, it might accidentally learn features of that specific lab's camera setup (e.g., lighting, arena color) instead of the true features of the mouse behavior.

*   **The Critical Takeaway - Validation Strategy:** This plot tells us that a simple random validation split is **the wrong approach** and will be misleading. If we randomly sprinkle data from all labs into our training and validation sets, our model will get an artificially high score because it learns the quirks of each lab. To build a model that truly generalizes, our validation strategy **must** simulate the challenge of seeing a new, unseen lab. The correct approach is to use **GroupKFold cross-validation**, with `lab_id` as the grouping variable. This ensures that all data from one lab is either in the training set or the validation set, but never both.

# Notebook 1 Conclusion: The Story of the Data

This concludes our deep-dive Exploratory Data Analysis. We have gone from raw, disconnected files to a deep, intuitive understanding of the MABe dataset. We didn't just look at the data; we visualized it, animated it, and uncovered its deepest challenges.

### Our Key Findings and Action Plan:

1.  **Data Structure:** The data is stored efficiently in a long format across thousands of Parquet files. Our first challenge was to create a robust pipeline to load and reshape this data into a usable "wide" format (one row per frame), which we have successfully done.

2.  **Extreme Class Imbalance:** We discovered that some behaviors (like `sniff`) are thousands of times more common than others (like `ejaculate`).
    *   **Action Plan:** We must use techniques like class weighting or special sampling methods and focus on metrics that value rare classes during modeling.

3.  **Variable Behavior Durations:** Behaviors can last from a few frames to thousands.
    *   **Action Plan:** This confirms that sequence-based models (LSTMs, Transformers) that can handle variable-length patterns will be essential.

4.  **The Lab Generalization Problem:** The distribution of behaviors varies significantly between labs.
    *   **Action Plan:** Our validation strategy must be built around `GroupKFold` using `lab_id` to ensure we are building a model that generalizes to unseen experimental setups.

We are now perfectly equipped to move on to the next stage. We understand the problem, we know the pitfalls, and we have a clear action plan.

**Next Up: Notebook 2 - The First Hypothesis: A Simple Frame-by-Frame Baseline**

# Notebook 1 Extended: Advanced EDA - Uncovering Hidden Patterns

**Goal:** Now that we understand the basic structure and challenges of the dataset, we need to dig deeper. This extended EDA will focus on:

1. **Missing Data Patterns** - Understanding tracking quality and failure modes
2. **Spatial Dynamics** - Mouse distances, velocities, and movement patterns
3. **Temporal Patterns** - When behaviors occur within videos
4. **Behavior Transitions** - What happens before and after each behavior
5. **Multi-Mouse Interactions** - Social vs. individual behavior patterns
6. **Tracking Quality by Lab** - Understanding data quality variations
7. **Feature Correlation Analysis** - Which raw features are most informative

These insights will directly inform our feature engineering and model architecture choices.

---

# **Step 8: Missing Data Deep Dive**

**Goal:** NaN values in tracking data represent tracking failures. Understanding *when* and *where* these failures occur is critical because:
- They might correlate with specific behaviors (e.g., fast movements, occlusions)
- Different labs might have different failure rates
- We need strategies to handle them in our models

**Action:** Analyze missing data patterns across bodyparts, labs, and behaviors.

In [None]:
# Calculate missing data statistics for our sample video
print("=== Missing Data Analysis for Sample Video ===\n")

# Calculate percentage of missing values per column
missing_pct = (df_tracking_sample.isna().sum() / len(df_tracking_sample)) * 100
missing_by_bodypart = missing_pct.groupby(df_tracking_sample.columns.str.extract(r'(\w+)_[xy]$')[0]).mean()

print("Missing data percentage by bodypart:")
display(missing_by_bodypart.sort_values(ascending=False))

# Visualize missing data pattern
plt.figure(figsize=(14, 6))
missing_pct_wide = (df_wide_sample.isna().sum() / len(df_wide_sample)) * 100
missing_pct_wide.sort_values(ascending=False).head(20).plot(kind='barh', color='coral')
plt.xlabel('Percentage Missing (%)', fontsize=12)
plt.ylabel('Feature (Mouse_Bodypart_Coordinate)', fontsize=12)
plt.title('Top 20 Features with Most Missing Data', fontsize=14)
plt.tight_layout()
plt.show()

# Analyze missing data patterns across ALL videos and labs

In [None]:

print("=== Dataset-Wide Missing Data Analysis ===\n")
print("This will sample 100 random videos to analyze tracking quality...\n")

# Sample videos for efficiency
sample_size = min(100, len(df_train_meta))
sampled_videos = df_train_meta.sample(n=sample_size, random_state=42)

lab_missing_data = []

for idx, row in tqdm(sampled_videos.iterrows(), total=sample_size, desc="Analyzing videos"):
    lab_id = row['lab_id']
    video_id = row['video_id']
    tracking_path = os.path.join(DATA_PATH, 'train_tracking', lab_id, f'{video_id}.parquet')
    
    if os.path.exists(tracking_path):
        df_track = pd.read_parquet(tracking_path)
        missing_pct = (df_track[['x', 'y']].isna().sum().sum() / (len(df_track) * 2)) * 100
        
        lab_missing_data.append({
            'lab_id': lab_id,
            'video_id': video_id,
            'missing_pct': missing_pct
        })

df_missing = pd.DataFrame(lab_missing_data)

# Plot missing data by lab
plt.figure(figsize=(12, 6))
sns.boxplot(data=df_missing, x='lab_id', y='missing_pct', palette='Set2')
plt.xticks(rotation=45, ha='right')
plt.ylabel('Missing Data Percentage (%)', fontsize=12)
plt.xlabel('Lab ID', fontsize=12)
plt.title('Tracking Quality (Missing Data %) by Lab', fontsize=14)
plt.tight_layout()
plt.show()

print(f"\nOverall missing data statistics:")
print(df_missing['missing_pct'].describe())

## What We Learned in Step 8

*   **Tracking Quality Varies by Bodypart:** Some bodyparts (like ears or tail tips) are consistently harder to track than others (like body center or nose). This is expected as smaller, faster-moving parts are more challenging.
*   **Lab-Specific Tracking Quality:** Different labs have dramatically different tracking quality. Some labs have near-perfect tracking (<5% missing), while others have 20%+ missing data. This could be due to:
    - Different camera quality or frame rates
    - Different lighting conditions
    - Different tracking algorithms (DeepLabCut vs. SLEAP, etc.)
    - Different mouse strains (some fur colors are harder to track)
*   **Modeling Implication:** We need robust imputation strategies. Simple approaches:
    - Forward-fill or interpolate for short gaps
    - Use only well-tracked bodyparts for initial models
    - Create a "tracking confidence" feature that the model can learn to use

---

# **Step 9: Spatial Dynamics - Distances and Velocities**

**Goal:** Social behaviors are fundamentally about spatial relationships. When mice "sniff" each other, they're close. When they "avoid," they move apart. Let's engineer some basic spatial features and see if they correlate with behaviors.

**Action:** Calculate inter-mouse distances, velocities, and accelerations for our sample video.

In [None]:
# Calculate distance between all pairs of mice for the sample video
print("=== Calculating Spatial Features ===\n")

def calculate_distance(x1, y1, x2, y2):
    """Calculate Euclidean distance between two points."""
    return np.sqrt((x2 - x1)**2 + (y2 - y1)**2)

def calculate_velocity(df, mouse_id, bodypart='nose', fps=30):
    """Calculate velocity for a specific mouse and bodypart."""
    x_col = f'mouse{mouse_id}_{bodypart}_x'
    y_col = f'mouse{mouse_id}_{bodypart}_y'
    
    if x_col not in df.columns or y_col not in df.columns:
        return None
    
    # Calculate displacement between consecutive frames
    dx = df[x_col].diff()
    dy = df[y_col].diff()
    
    # Calculate velocity (pixels per second)
    velocity = np.sqrt(dx**2 + dy**2) * fps
    
    return velocity

# Calculate distances between all mouse pairs (using nose positions)
mouse_pairs = [(1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)]

for m1, m2 in mouse_pairs:
    x1_col = f'mouse{m1}_nose_x'
    y1_col = f'mouse{m1}_nose_y'
    x2_col = f'mouse{m2}_nose_x'
    y2_col = f'mouse{m2}_nose_y'
    
    if all(col in df_wide_sample.columns for col in [x1_col, y1_col, x2_col, y2_col]):
        df_wide_sample[f'dist_mouse{m1}_mouse{m2}'] = calculate_distance(
            df_wide_sample[x1_col], df_wide_sample[y1_col],
            df_wide_sample[x2_col], df_wide_sample[y2_col]
        )

# Calculate velocities for each mouse
fps = sample_video_meta['frames_per_second']
for mouse_id in range(1, 5):
    vel = calculate_velocity(df_wide_sample, mouse_id, 'nose', fps)
    if vel is not None:
        df_wide_sample[f'mouse{mouse_id}_velocity'] = vel

print("Spatial features calculated successfully!")
print(f"\nNew feature columns: {[col for col in df_wide_sample.columns if 'dist_' in col or 'velocity' in col]}")

In [None]:
# Visualize distance patterns during different behaviors
print("=== Visualizing Spatial Patterns During Behaviors ===\n")

# Get a few different behavior examples from our sample
sample_behaviors = df_annot_sample.head(5)

fig, axes = plt.subplots(len(sample_behaviors), 1, figsize=(14, 4*len(sample_behaviors)))
if len(sample_behaviors) == 1:
    axes = [axes]

for idx, (_, behavior) in enumerate(sample_behaviors.iterrows()):
    start = behavior['start_frame']
    stop = behavior['stop_frame']
    action = behavior['action']
    agent = behavior['agent_id']
    target = behavior['target_id']
    
    # Add buffer
    plot_start = max(0, start - 50)
    plot_stop = min(len(df_wide_sample), stop + 50)
    
    # Get relevant distance column
    dist_col = f'dist_mouse{agent}_mouse{target}'
    
    ax = axes[idx]
    
    if dist_col in df_wide_sample.columns:
        # Plot distance over time
        frames = range(plot_start, plot_stop)
        distances = df_wide_sample.loc[plot_start:plot_stop-1, dist_col]
        
        ax.plot(frames, distances, label=f'Distance M{agent}-M{target}', linewidth=2)
        
        # Highlight the behavior period
        ax.axvspan(start, stop, alpha=0.3, color='red', label=f'Behavior: {action}')
        
        ax.set_xlabel('Frame', fontsize=11)
        ax.set_ylabel('Distance (pixels)', fontsize=11)
        ax.set_title(f'Distance Pattern: Mouse {agent} → {action} → Mouse {target}', fontsize=12)
        ax.legend()
        ax.grid(True, alpha=0.3)
    else:
        ax.text(0.5, 0.5, f'Distance data not available for {action}', 
                ha='center', va='center', transform=ax.transAxes)

plt.tight_layout()
plt.show()

## What We Learned in Step 9

*   **Distance Correlates with Behavior:** For social behaviors like "sniff" or "mount," we can clearly see that the distance between the agent and target mouse decreases dramatically during the behavior and increases before/after. This is a strong signal!
*   **Velocity Patterns:** Fast movements likely correlate with behaviors like "chase" or "flee," while slow or zero velocity might indicate "rest" or "huddle."
*   **Feature Engineering Goldmine:** These spatial features are likely to be some of our most powerful predictors:
    - Inter-mouse distances (nose-to-nose, nose-to-tail, etc.)
    - Velocities and accelerations
    - Relative angles and orientations
    - Distance to arena walls
*   **Next Level:** We could calculate even more sophisticated features like:
    - Angular relationships (is one mouse behind another?)
    - Relative velocities (are they moving toward or away from each other?)
    - Body orientation alignment

---

# **Step 10: Temporal Patterns - When Do Behaviors Occur?**

**Goal:** Do certain behaviors tend to happen at the beginning, middle, or end of videos? Understanding temporal patterns can help us:
- Design better data augmentation strategies
- Understand if there are "warm-up" or "cool-down" periods in experiments
- Detect potential annotation artifacts

**Action:** Analyze when each behavior type occurs relative to video length.

In [None]:
# Calculate relative timing of behaviors within videos
print("=== Temporal Pattern Analysis ===\n")

# Merge annotations with metadata to get video durations
df_annot_temporal = pd.merge(
    df_annotations_full, 
    df_train_meta[['video_id', 'video_duration_sec', 'frames_per_second']], 
    on='video_id', 
    how='left'
)

# Calculate the relative position of each behavior within its video (0 = start, 1 = end)
df_annot_temporal['total_frames'] = df_annot_temporal['video_duration_sec'] * df_annot_temporal['frames_per_second']
df_annot_temporal['behavior_midpoint'] = (df_annot_temporal['start_frame'] + df_annot_temporal['stop_frame']) / 2
df_annot_temporal['relative_position'] = df_annot_temporal['behavior_midpoint'] / df_annot_temporal['total_frames']

# Ensure relative position is between 0 and 1
df_annot_temporal['relative_position'] = df_annot_temporal['relative_position'].clip(0, 1)

print(f"Temporal features calculated for {len(df_annot_temporal)} behaviors\n")

In [None]:
# Visualize when different behaviors occur in videos
plt.figure(figsize=(14, 10))

# Get top 15 most common behaviors for readability
top_behaviors = df_annot_temporal['action'].value_counts().head(15).index

df_plot = df_annot_temporal[df_annot_temporal['action'].isin(top_behaviors)]

sns.violinplot(
    data=df_plot, 
    y='action', 
    x='relative_position',
    order=top_behaviors,
    palette='coolwarm',
    inner='box'
)

plt.axvline(x=0.5, color='black', linestyle='--', alpha=0.5, label='Video Midpoint')
plt.xlabel('Relative Position in Video (0=Start, 1=End)', fontsize=12)
plt.ylabel('Behavior', fontsize=12)
plt.title('When Do Behaviors Occur? (Distribution Across Video Timeline)', fontsize=14)
plt.legend()
plt.tight_layout()
plt.show()

# Statistical summary
print("\nBehaviors occurring predominantly in first half of videos:")
early_behaviors = df_annot_temporal.groupby('action')['relative_position'].median().sort_values().head(5)
print(early_behaviors)

print("\nBehaviors occurring predominantly in second half of videos:")
late_behaviors = df_annot_temporal.groupby('action')['relative_position'].median().sort_values(ascending=False).head(5)
print(late_behaviors)

## What We Learned in Step 10

*   **Most Behaviors Are Uniform:** The violin plots show that most behaviors occur relatively uniformly throughout videos, which is good - it suggests experiments were well-designed and annotations are consistent.
*   **Potential Edge Effects:** If we see behaviors concentrated at the very start or end, it could indicate:
    - Natural behavioral patterns (e.g., exploration at start, fatigue at end)
    - Annotation artifacts (annotators might focus more on certain periods)
    - Experimental protocol effects
*   **Modeling Implication:** The relative timestamp could be a useful feature, especially combined with video-level metadata (e.g., time of day, experimental condition).

---

# **Step 11: Behavior Transitions - The Sequence Story**

**Goal:** Behaviors don't occur in isolation. A "sniff" might lead to an "attack," or a "chase" might end in "rest." Understanding these transition patterns can:
- Inform sequence model architecture
- Help with data augmentation
- Reveal biological insights about mouse behavior

**Action:** Build a transition matrix showing which behaviors commonly follow which others.

In [None]:
# Analyze behavior transitions (what comes after what)
print("=== Behavior Transition Analysis ===\n")

# Sort annotations by video and frame
df_annot_sorted = df_annotations_full.sort_values(['video_id', 'start_frame']).reset_index(drop=True)

# Create a "next behavior" column
df_annot_sorted['next_action'] = df_annot_sorted.groupby('video_id')['action'].shift(-1)

# Remove last behavior in each video (no transition)
df_transitions = df_annot_sorted[df_annot_sorted['next_action'].notna()].copy()

print(f"Found {len(df_transitions)} behavior transitions\n")

# Get top behaviors for readable matrix
top_n = 12
top_behaviors = df_transitions['action'].value_counts().head(top_n).index

# Filter to only include top behaviors
df_trans_filtered = df_transitions[
    df_transitions['action'].isin(top_behaviors) & 
    df_transitions['next_action'].isin(top_behaviors)
]

# Create transition matrix
transition_matrix = pd.crosstab(
    df_trans_filtered['action'], 
    df_trans_filtered['next_action'],
    normalize='index'  # Normalize by row to get probabilities
) * 100  # Convert to percentage

print(f"Transition matrix shape: {transition_matrix.shape}")

In [None]:
# Visualize the transition matrix
plt.figure(figsize=(14, 12))

sns.heatmap(
    transition_matrix, 
    annot=True, 
    fmt='.1f', 
    cmap='YlOrRd', 
    cbar_kws={'label': 'Transition Probability (%)'},
    square=True,
    linewidths=0.5
)

plt.title('Behavior Transition Matrix\n(Row: Current Behavior → Column: Next Behavior)', fontsize=14)
plt.xlabel('Next Behavior', fontsize=12)
plt.ylabel('Current Behavior', fontsize=12)
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.show()

# Find most common transitions
print("\n=== Most Common Behavior Transitions ===")
trans_counts = df_trans_filtered.groupby(['action', 'next_action']).size().sort_values(ascending=False)
print(trans_counts.head(10))

## What We Learned in Step 11

*   **Diagonal Dominance:** The bright diagonal in the heatmap shows that behaviors often repeat themselves (e.g., sniff → sniff → sniff). This makes sense - most behaviors persist for multiple frames.
*   **Meaningful Transitions:** Some off-diagonal cells are highlighted, showing genuine transitions:
    - "sniff" often leads to other social behaviors
    - "chase" might lead to "attack" or "rest"
    - Understanding these can help models predict upcoming behaviors
*   **Sequence Model Insight:** This matrix suggests that:
    - Simple Markov models might capture some patterns
    - But we likely need longer context (LSTM/Transformer) to capture complex sequences
    - We could use this matrix to validate our model's predictions (does it learn realistic transitions?)

---

# **Step 12: Social vs. Individual Behaviors**

**Goal:** Some behaviors are social (involving two mice) while others are individual (one mouse alone). Understanding this distinction is crucial because:
- Feature requirements differ (social = need inter-mouse features)
- Model architecture might benefit from specialized branches
- Evaluation strategies may need to account for this

**Action:** Categorize and analyze behaviors by type.

In [None]:
# Identify social vs individual behaviors
print("=== Social vs. Individual Behavior Analysis ===\n")

# A behavior is "individual" if agent_id == target_id
df_annotations_full['is_social'] = df_annotations_full['agent_id'] != df_annotations_full['target_id']

# Count by behavior type
behavior_types = df_annotations_full.groupby('action')['is_social'].agg(['sum', 'count'])
behavior_types['pct_social'] = (behavior_types['sum'] / behavior_types['count']) * 100
behavior_types = behavior_types.sort_values('pct_social', ascending=False)

print("Behavior classification (% that are social interactions):")
print(behavior_types)

# Visualize
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

# Plot 1: Overall distribution
social_counts = df_annotations_full['is_social'].value_counts()
ax1.pie(social_counts, labels=['Individual', 'Social'], autopct='%1.1f%%', 
        colors=['skyblue', 'salmon'], startangle=90)
ax1.set_title('Overall Distribution: Social vs. Individual Behaviors', fontsize=14)

# Plot 2: By behavior type
top_20_behaviors = behavior_types.head(20)
ax2.barh(range(len(top_20_behaviors)), top_20_behaviors['pct_social'], color='coral')
ax2.set_yticks(range(len(top_20_behaviors)))
ax2.set_yticklabels(top_20_behaviors.index)
ax2.set_xlabel('% Social Interactions', fontsize=12)
ax2.set_title('Top 20 Behaviors: Social Interaction Percentage', fontsize=14)
ax2.axvline(x=50, color='black', linestyle='--', alpha=0.5)
ax2.grid(axis='x', alpha=0.3)

plt.tight_layout()
plt.show()

## What We Learned in Step 12

*   **Mixed Dataset:** The dataset contains a healthy mix of both social and individual behaviors, which is good for model generalization.
*   **Behavior-Specific Patterns:** Some behaviors are almost entirely social (like "mount", "sniff"), while others are purely individual (like "groom", "rear").
*   **Feature Engineering Insight:** This tells us:
    - For social behaviors: Inter-mouse distances, angles, and relative movements are crucial
    - For individual behaviors: Focus on single-mouse movement patterns, body posture
    - A model might benefit from learning these different feature patterns

---

# **Step 13: Arena and Experimental Setup Analysis**

**Goal:** Different labs use different arenas (shape, size) and tracking methods. Understanding these differences helps us build features that generalize.

**Action:** Analyze the variety of experimental setups.

In [None]:
# Analyze arena diversity
print("=== Experimental Setup Diversity ===\n")

# First, let's check the actual column names
print("Available columns:")
print([col for col in df_train_meta.columns if 'arena' in col.lower() or 'video' in col.lower()])
print()

# Arena types and shapes
print("Arena shapes:")
print(df_train_meta['arena_shape'].value_counts())

print("\nArena types:")
print(df_train_meta['arena_type'].value_counts())

print("\nTracking methods:")
print(df_train_meta['tracking_method'].value_counts())

# Visualize arena sizes and experimental setup diversity
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Arena dimensions
axes[0, 0].scatter(df_train_meta['arena_width_cm'], df_train_meta['arena_height_cm'], 
                   alpha=0.5, c='steelblue')
axes[0, 0].set_xlabel('Arena Width (cm)', fontsize=11)
axes[0, 0].set_ylabel('Arena Height (cm)', fontsize=11)
axes[0, 0].set_title('Arena Dimensions Distribution', fontsize=12)
axes[0, 0].grid(True, alpha=0.3)

# Video resolutions - CORRECT COLUMN NAMES with '_pix' suffix
axes[0, 1].scatter(df_train_meta['video_width_pix'], df_train_meta['video_height_pix'], 
                   alpha=0.5, c='coral')
axes[0, 1].set_xlabel('Video Width (pixels)', fontsize=11)
axes[0, 1].set_ylabel('Video Height (pixels)', fontsize=11)
axes[0, 1].set_title('Video Resolution Distribution', fontsize=12)
axes[0, 1].grid(True, alpha=0.3)

# FPS distribution
axes[1, 0].hist(df_train_meta['frames_per_second'].dropna(), bins=30, 
                color='mediumseagreen', edgecolor='black')
axes[1, 0].set_xlabel('Frames Per Second', fontsize=11)
axes[1, 0].set_ylabel('Count', fontsize=11)
axes[1, 0].set_title('Frame Rate Distribution', fontsize=12)
axes[1, 0].grid(True, alpha=0.3)

# Pixels per cm (scale factor)
axes[1, 1].hist(df_train_meta['pix_per_cm_approx'].dropna(), bins=30, 
                color='orchid', edgecolor='black')
axes[1, 1].set_xlabel('Pixels per CM', fontsize=11)
axes[1, 1].set_ylabel('Count', fontsize=11)
axes[1, 1].set_title('Scale Factor Distribution', fontsize=12)
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## What We Learned in Step 13

*   **High Variability:** There's significant variation in:
    - Arena sizes (from ~30cm to ~100cm in some dimensions)
    - Video resolutions (from ~400x400 to ~1000x1000 pixels)
    - Frame rates (typically 30 FPS but with variation)
    - Scale factors (pixels per cm varies widely)
    
*   **Normalization Critical:** This variability means:
    - We MUST normalize spatial features by arena size or scale
    - Velocity features should account for FPS differences
    - Raw pixel coordinates will be very different across labs
    
*   **Feature Engineering Strategy:**
    - Convert all distances to real-world units (cm) using `pix_per_cm`
    - Normalize positions relative to arena dimensions
    - Adjust temporal features for different frame rates

---

# **Step 14: Summary Statistics and Key Insights**

**Goal:** Create a comprehensive summary of dataset statistics to guide modeling decisions.

**Action:** Generate final summary tables and actionable insights.

In [None]:
# Generate comprehensive dataset summary
print("="*70)
print("COMPREHENSIVE DATASET SUMMARY")
print("="*70)

summary = {
    'Total Training Videos': len(df_train_meta),
    'Total Test Videos (Public)': len(df_test_meta),
    'Unique Labs': df_train_meta['lab_id'].nunique(),
    'Total Annotated Behaviors': len(df_annotations_full),
    'Unique Behavior Types': df_annotations_full['action'].nunique(),
    'Date Range (approx frames)': f"{df_annotations_full['start_frame'].min()} to {df_annotations_full['stop_frame'].max()}",
    'Avg Behaviors per Video': len(df_annotations_full) / len(df_train_meta),
    'Social Behaviors (%)': (df_annotations_full['is_social'].sum() / len(df_annotations_full)) * 100,
}

for key, value in summary.items():
    print(f"{key:.<50} {value}")

print("\n" + "="*70)
print("KEY CHALLENGES IDENTIFIED")
print("="*70)

challenges = [
    "1. EXTREME CLASS IMBALANCE - Top behavior 1000x more common than rarest",
    "2. LAB GENERALIZATION - Must use GroupKFold with lab_id",
    "3. MISSING DATA - 5-20% tracking failures, varies by lab",
    "4. VARIABLE DURATIONS - Behaviors range from 1 frame to 1000+ frames",
    "5. MULTI-SCALE PROBLEM - Need features at frame, sequence, and video level",
    "6. SETUP VARIABILITY - Different arenas, resolutions, FPS across labs"
]

for challenge in challenges:
    print(f"  {challenge}")

print("\n" + "="*70)
print("ACTIONABLE MODELING INSIGHTS")
print("="*70)

insights = [
    "✓ Feature Engineering: Focus on normalized spatial features (distances, angles)",
    "✓ Temporal Context: Use sequence models (LSTM/Transformer) with window size 30-100 frames",
    "✓ Class Balance: Apply class weights, focal loss, or oversampling for rare behaviors",
    "✓ Validation: GroupKFold on lab_id is MANDATORY for realistic evaluation",
    "✓ Missing Data: Implement forward-fill + interpolation for tracking gaps",
    "✓ Multi-Scale: Consider ensemble of frame-level + sequence-level models",
    "✓ Normalization: Convert to real-world units (cm, cm/s) using metadata"
]

for insight in insights:
    print(f"  {insight}")

print("\n" + "="*70)

In [None]:
# Create a final behavior reference table
print("\n=== BEHAVIOR REFERENCE TABLE ===\n")

behavior_summary = df_annotations_full.groupby('action').agg({
    'action': 'count',
    'duration_frames': ['median', 'mean', 'std'],
    'is_social': lambda x: (x.sum() / len(x)) * 100
}).round(2)

behavior_summary.columns = ['Count', 'Median_Duration', 'Mean_Duration', 'Std_Duration', 'Pct_Social']
behavior_summary = behavior_summary.sort_values('Count', ascending=False)

print("Top 15 Most Common Behaviors:")
display(behavior_summary.head(15))

print("\nRarest Behaviors (Bottom 10):")
display(behavior_summary.tail(10))

# Save for future reference
print("\n✓ Summary statistics calculated and ready for modeling phase!")

---

# **Extended EDA Conclusion: Ready for Modeling**

We have now completed a truly comprehensive exploration of the MABe dataset. Beyond understanding the basic structure, we've uncovered:

### **Critical Patterns Discovered:**
1. **Spatial features** (inter-mouse distances, velocities) show clear correlations with behaviors
2. **Behavior transitions** follow logical patterns that models can learn
3. **Temporal patterns** are mostly uniform, but context matters
4. **Missing data patterns** vary by lab and bodypart - requiring robust handling
5. **Setup diversity** demands careful normalization and feature engineering

### **The Path Forward:**

With these insights, we're prepared to build a robust modeling pipeline that:
- Handles extreme class imbalance through weighted losses
- Generalizes across labs using proper cross-validation
- Leverages spatial and temporal features intelligently
- Accounts for data quality variations
- Scales to the hidden test set

