This notebook demonstrates how to visualize the mouse pose tracking data.

## üéØ Goal

The primary goal of this code is to create a dynamic animation that shows the movement of a mouse's different body parts over a series of video frames. This helps in understanding the raw tracking data and verifying its quality.

## üìù Process

1.  **Setup**: We first import the necessary libraries and define the path to the dataset. We specify a `VIDEO_ID` to select a single video for analysis.
2.  **Data Loading**:
      * The tracking data (x, y coordinates for each body part at each frame) is loaded from a `.parquet` file.
      * The corresponding video metadata (like resolution) is loaded from the `train.csv` file.
3.  **Data Preparation**:
      * To make the animation process faster, we select a subset of the total frames (`N_FRAMES_TO_ANIMATE`).
      * We assign a unique color to each body part using the `rainbow` colormap. This makes it easy to distinguish different parts in the animation.
4.  **Animation**:
      * We use `matplotlib.animation.FuncAnimation` to generate the animation.
      * A plot is initialized with the correct dimensions based on the video's resolution.
      * An `update` function is defined to redraw the positions of all body parts for each new frame.
      * A legend is added to the plot to show the mapping between colors and body parts.
5.  **Display**: The final animation is rendered as an interactive HTML object directly within the notebook, complete with play and pause controls.

This visualization is a crucial first step for exploring the dataset, debugging tracking issues, and developing features for behavior classification models.

In [None]:
# Import necessary libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
from pathlib import Path

# ===================================================================
# 1. Set data paths and target video ID
# ===================================================================
VIDEO_ID = "44566106"
LAB = "AdaptableSnail"

# --- Automatically find the root path of the dataset ---
CANDIDATE_BASES = ["/kaggle/input/MABe-mouse-behavior-detection/"]
BASE = None
for b in CANDIDATE_BASES:
    p = Path(b)
    if (p / "train_tracking").exists():
        BASE = p
        break
if BASE is None:
    raise FileNotFoundError("Dataset path not found. Please update CANDIDATE_BASES.")

TRACKING_FILE = BASE / "train_tracking" / LAB / f"{VIDEO_ID}.parquet"
META_FILE = BASE / "train.csv"

print(f"Tracking file: {TRACKING_FILE}")
print(f"Meta file: {META_FILE}")

# ===================================================================
# 2. Load and preprocess data
# ===================================================================
# --- Load tracking data (coordinates) ---
df = pd.read_parquet(TRACKING_FILE)

# --- Load metadata (video information) ---
meta_df = pd.read_csv(META_FILE)
meta_df["video_id"] = meta_df["video_id"].astype(str)
video_meta = meta_df[meta_df["video_id"] == VIDEO_ID].iloc[0]

# --- Limit the number of frames for the animation (processing all frames can be very time-consuming) ---
N_FRAMES_TO_ANIMATE = 500  
df_sub = df[df['video_frame'] < N_FRAMES_TO_ANIMATE].copy()

# --- Assign colors to each body part ---
bodyparts = sorted(df_sub['bodypart'].unique())
# Use the 'rainbow' colormap to generate colors based on the number of body parts
colors = cm.rainbow(np.linspace(0, 1, len(bodyparts)))
# Create a dictionary to map body part names to colors
part_to_color = dict(zip(bodyparts, colors))
# Add color information to each row
df_sub['color'] = df_sub['bodypart'].map(part_to_color)

print(f"\nAnimating first {N_FRAMES_TO_ANIMATE} frames...")
print(f"{len(bodyparts)} body parts will be colored.")

# ===================================================================
# 3. Prepare the animation
# ===================================================================
# --- Create the plotting area ---
fig, ax = plt.subplots(figsize=(10, 8))

# --- Set the plot limits to match the video resolution ---
width, height = video_meta['video_width_pix'], video_meta['video_height_pix']
ax.set_xlim(0, width)
ax.set_ylim(height, 0)  # Invert the y-axis to set the origin to the top-left corner
ax.set_title(f"Mice Pose - video_id: {VIDEO_ID} (Frame: 0)")
ax.set_xlabel("x (pixels)")
ax.set_ylabel("y (pixels)")
ax.set_aspect('equal', adjustable='box') # Set the aspect ratio to 1:1
fig.tight_layout()

# --- Initialize the plot (initially empty) ---
# s=20 sets the size of the points
scatter = ax.scatter([], [], s=20, alpha=0.8)

# --- Create a legend (to show which color corresponds to which body part) ---
for part, color in part_to_color.items():
    ax.scatter([], [], c=[color], label=part, s=30) # Create the legend using dummy plots
ax.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0.)


# ===================================================================
# 4. Define the animation update logic
# ===================================================================
def update(frame_num):
    """Function called for each frame"""
    # Extract data for the current frame
    current_frame_data = df_sub[df_sub['video_frame'] == frame_num]
    
    if not current_frame_data.empty:
        # Update the coordinates and colors of the scatter plot
        scatter.set_offsets(current_frame_data[['x', 'y']].values)
        scatter.set_color(current_frame_data['color'].values)
        
    # Update the title
    ax.set_title(f"Mice Pose - video_id: {VIDEO_ID} (Frame: {frame_num})")
    
    # Return the updated plot object
    return scatter,

# ===================================================================
# 5. Generate and display the animation
# ===================================================================
# --- Create the animation object ---
# interval=33 corresponds to approximately 30fps (1000ms / 30fps)
# blit=True is an optimization for faster redrawing
ani = FuncAnimation(fig, update, frames=range(N_FRAMES_TO_ANIMATE), 
                    interval=33, blit=True)

# --- Prevent the static plot from being displayed twice ---
plt.close(fig)

# --- Display the animation in the Jupyter Notebook ---
# Using to_jshtml() displays the animation with play/pause controls
HTML(ani.to_jshtml())
