# Olfactrack: 2D Tracking Pipeline for Olfactometer Videos

## Overview

This Python pipeline enables 2D tracking of bumblebees in olfactometer videos. The notebook processes videos through five sequential stages, each handling a specific part of the workflow.

### Workflow Stages

1. **Import & Setup**: Import required libraries and configure the environment
2. **ROI Selection**: Define the Y-tube structure by selecting four key points
3. **Tracking**: Track bumblebee movement through the Y-tube apparatus
4. **Visualization**: Generate plots and animations of tracking results
5. **Analysis**: Calculate behavioral metrics and generate statistics

### Getting Started

1. Ensure all required libraries are installed (see dependencies below)
2. Configure the parameters at the beginning of each section
3. Run the cells in order from top to bottom
4. When prompted, mark the Y-tube structure by clicking on four points in each video frame

### Dependencies

This notebook requires the following libraries:
- OpenCV (`opencv-python`)
- NumPy
- Pandas
- Matplotlib
- Seaborn

You can install them with: 

pip install opencv-python numpy pandas matplotlib seaborn jupyter

### Outputs

The notebook generates several files during processing:
- `selected_points.csv`: Contains ROI coordinates for each video
- `processed_<video_name>.csv`: Tracking data with X,Y coordinates and behavioral metrics
- `results.csv`: Summary file with tracking metrics for all videos
- Various visualization plots and animations when requested

## 1. Import & Setup

This section imports all required libraries and configures the environment for the tracking pipeline. All functionality depends on these imports.

### Required Libraries:
- **OpenCV**: Video processing and computer vision
- **NumPy & Pandas**: Data processing and analysis
- **Matplotlib & Seaborn**: Data visualization
- **Concurrent.futures**: Parallel processing of videos

In [None]:
# Import necessary libraries
import cv2                 # OpenCV for video processing and computer vision
import os                  # Operating system utilities
import pandas as pd        # Data analysis and manipulation
import numpy as np         # Numerical computing
import ast                 # Abstract Syntax Tree module for parsing
import concurrent.futures  # Parallel processing
from concurrent.futures import ThreadPoolExecutor
import time                # Time utilities
import matplotlib.pyplot as plt  # Plotting
import seaborn as sns      # Enhanced visualization
import re                  # Regular expressions

# Jupyter-specific settings
%matplotlib inline
plt.rcParams['figure.max_open_warning'] = 50  # Avoid warnings when displaying many figures

## 2. ROI Selection

This section allows you to select Regions of Interest (ROIs) by manually clicking on specific points in each video. For each olfactometer video, you will need to mark four key points that define the Y-tube structure:

1. **Point 1**: Bottom entrance point (where the bumblebee enters the Y-tube)
2. **Point 2**: Y-junction center point (where the tube branches)
3. **Point 3**: Left arm endpoint (tip of the left branch)
4. **Point 4**: Right arm endpoint (tip of the right branch)

The notebook will show you the middle frame of each video and prompt you to click on these points. The coordinates are then saved to a CSV file for later use in the tracking phase.

### Process Flow:
1. Load each video file
2. Display the middle frame
3. User clicks to mark the 4 key points
4. Coordinates are saved to CSV
5. Process repeats for each video in the directory

### When selecting points:
- Click precisely on the key locations of the Y-tube
- Points are numbered as you click them
- Maintain consistent selection order across all videos
- Press ESC to cancel the selection process

In [None]:
# ===== ROI SELECTION CONFIGURATION =====
# These parameters control the region of interest selection process

# File settings
VIDEO_EXTENSIONS = ['.mp4']    # File extensions to process
OUTPUT_CSV = 'selected_points.csv'  # Name of the output CSV file

# UI settings
MARKER_SIZE = 5                # Size of the marker circle when selecting points
MARKER_COLOR = (0, 255, 0)     # Color of the marker (BGR format)
TEXT_COLOR = (0, 0, 255)       # Color of the text labels (BGR format)
FONT_SIZE = 0.7                # Size of the text labels
FONT_THICKNESS = 2             # Thickness of the text labels
WINDOW_NAME = "Select 4 points (click to mark)"  # Window title
# =======================================

import cv2
import os
import pandas as pd

def select_points_from_videos(video_directory):
    """
    Process all video files in a directory and allow users to select 4 points from each video's middle frame.
    Points are saved to a CSV file in the same directory.
    
    Args:
        video_directory (str): Path to the directory containing video files
    
    Returns:
        str: Path to the generated CSV file
    """
    # Path to save the CSV file in the same folder as videos
    csv_file_path = os.path.join(video_directory, OUTPUT_CSV)
    
    # List to store data for the DataFrame
    data_selected_points = []
    
    # List all files in the directory
    video_files = [f for f in os.listdir(video_directory) if f.endswith('.mp4')]
    print(f"Found {len(video_files)} video files to process")
    
    for video_file in video_files:
        points = process_single_video(video_directory, video_file)
        
        if points and len(points) == 4:
            # Store the points and video name in the data list
            data_selected_points.append({
                'Video File': video_file,
                'Point 1 (x, y)': points[0],
                'Point 2 (x, y)': points[1],
                'Point 3 (x, y)': points[2],
                'Point 4 (x, y)': points[3]
            })
            print(f"Selected points for {video_file}: {points}")
    
    # Create a DataFrame from the collected data
    df_selected_points = pd.DataFrame(data_selected_points)
    
    # Save the DataFrame to a CSV file
    df_selected_points.to_csv(csv_file_path, index=False)
    print(f"Process finished. All points have been saved to {csv_file_path}.")
    
    return csv_file_path

def process_single_video(video_directory, video_filename):
    """
    Process a single video file to select 4 points from its middle frame.
    
    Args:
        video_directory (str): Directory containing the video file
        video_filename (str): Name of the video file to process
    
    Returns:
        list: List of 4 (x, y) coordinate tuples if successful, empty list otherwise
    """
    video_path = os.path.join(video_directory, video_filename)
    print(f"Processing video: {video_filename}")
    
    # Initialize points list for this video
    points = []
    
    # Open the video
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print(f"Error: Cannot open video {video_filename}.")
        return []
    
    # Get total frames and compute the middle frame index
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    mid_frame_idx = total_frames // 2
    
    # Set the video position to the middle frame
    cap.set(cv2.CAP_PROP_POS_FRAMES, mid_frame_idx)
    
    # Read the middle frame
    ret, frame = cap.read()
    if not ret:
        print(f"Error: Could not read the middle frame of {video_filename}.")
        cap.release()
        return []
    
    # Display the middle frame for point selection
    cv2.imshow(WINDOW_NAME, frame)
    
    # Define the mouse callback function to capture points
    def click_event(event, x, y, flags, param):
        if event == cv2.EVENT_LBUTTONDOWN:
            # Append the clicked point (x, y) to the points list
            points.append((x, y))
            print(f"Point {len(points)} selected: ({x}, {y})")
            
            # Draw a small circle at the clicked point
            cv2.circle(param, (x, y), MARKER_SIZE, MARKER_COLOR, -1)
            # Add point number text
            cv2.putText(param, str(len(points)), (x+10, y-10), 
                       cv2.FONT_HERSHEY_SIMPLEX, FONT_SIZE, TEXT_COLOR, FONT_THICKNESS)
            cv2.imshow(WINDOW_NAME, param)
            
            # Stop after 4 points
            if len(points) == 4:
                print("4 points selected. Continuing to next video.")
                cv2.destroyAllWindows()
    
    # Set the mouse callback function to capture points
    cv2.setMouseCallback(WINDOW_NAME, click_event, param=frame)
    
    # Wait until 4 points are selected
    while len(points) < 4 and cv2.getWindowProperty(WINDOW_NAME, cv2.WND_PROP_VISIBLE) >= 1:
        key = cv2.waitKey(100)
        # Allow user to exit with ESC key
        if key == 27:  # ESC key
            print("Selection canceled by user")
            break
    
    cap.release()
    cv2.destroyAllWindows()
    
    return points

### ROI Selection Execution

Execute the ROI selection process for all videos in the specified directory. The results will be saved to `selected_points.csv`.

**Instructions:**
1. Set the video directory path in the cell below
2. Run the cell to start the ROI selection process
3. For each video, click on the 4 key points in the following order:
   - Point 1: Bottom entrance point
   - Point 2: Y-junction center point
   - Point 3: Left arm endpoint
   - Point 4: Right arm endpoint
4. Press ESC to cancel the selection for the current video

In [None]:
# ===== ROI SELECTION EXECUTION =====
# Set the directory path containing your olfactometer videos
VIDEO_DIRECTORY = r'C:\Users\YourUsername\Videos'  # Change this path to your video folder
# ==================================

if __name__ == "__main__":
    # Define the directory containing the video files
    video_dir = VIDEO_DIRECTORY
    
    # Uncomment the block below to use a directory selection dialog instead of hardcoded path
    """
    try:
        import tkinter as tk
        from tkinter import filedialog
        
        # Create and hide the Tkinter root window
        root = tk.Tk()
        root.withdraw()
        
        # Show directory selection dialog
        selected_dir = filedialog.askdirectory(title="Select directory containing videos")
        
        # Use the selected directory if one was chosen
        if selected_dir:
            video_dir = selected_dir
            print(f"Using selected directory: {video_dir}")
    except ImportError:
        print("Tkinter not available - using configured directory path")
    
    """
    if os.path.isdir(video_dir):
        select_points_from_videos(video_dir)
    else:
        print(f"Error: Directory '{video_dir}' not found or not accessible.")

## 3. Tracking

This section implements the core tracking functionality to detect and follow bumblebees as they move through the Y-tube. The tracking algorithm uses computer vision techniques including background subtraction, contour detection, and trajectory analysis.

### Tracking Process:

1. **ROI Loading**: Read the Y-tube coordinates from the CSV file
2. **Preprocessing**: 
   - Create masks for each Y-tube segment (bottom, left, right)
   - Apply background subtraction to isolate moving objects
   - Use morphological operations to clean up the image
3. **Detection**:
   - Find contours in the binary mask
   - Filter contours based on area and quantity
   - Compute centroid of detected bumblebee
4. **Tracking**:
   - Record X,Y coordinates frame by frame
   - Project points onto Y-tube segments
   - Clean and interpolate trajectory data
5. **Metrics Calculation**:
   - Compute velocity and displacement
   - Calculate segment proportions
   - Generate additional behavioral metrics

### Key Parameters:
- Detection thresholds for size and movement
- Background subtraction sensitivity
- Trajectory cleaning parameters
- Y-tube segment definitions

In [None]:
# ===== TRACKING CONFIGURATION =====
# These parameters control the bumblebee detection and tracking process

# Background subtraction parameters
HISTORY = 500                      # History length for background subtractor
VAR_THRESHOLD = 200                # Threshold for background subtractor
DETECT_SHADOWS = True              # Whether to detect shadows
BLUR_SIZE = 5                      # Size of median blur kernel

# Detection parameters
MAX_CONTOURS = 2                   # Maximum number of contours allowed
MIN_AREA = 30                      # Minimum contour area (pixels²)
MAX_AREA = 3000                    # Maximum contour area (pixels²)
MAX_DISTANCE = 200                 # Maximum distance between multiple contours

# Y-tube structure parameters
TUBE_THICKNESS = 60                # Y-tube thickness for mask creation

# Trajectory cleaning parameters
CLEAN_WINDOW_SIZE = 50             # Window size for cleaning coordinates
CLEAN_MIN_VALID = 10               # Minimum valid points in window
CLEAN_DISTANCE_THRESHOLD = 30      # Maximum allowed displacement between frames
MAX_ITERATIONS = 5                 # Maximum iterations for trajectory denoising
# ===================================

import os
import pandas as pd
import ast
import numpy as np
import cv2
import matplotlib.pyplot as plt
import concurrent.futures
from concurrent.futures import ThreadPoolExecutor
import time

def extract_coordinates(row):
    """
    Extract and convert coordinates from a DataFrame row to tuples.
    
    Args:
        row (pd.DataFrame): DataFrame row containing the coordinate points.
        
    Returns:
        tuple: Four coordinate tuples (point_1, point_2, point_3, point_4) or (None, None, None, None) if error.
    """
    try:
        point_1 = ast.literal_eval(row['Point 1 (x, y)'].values[0])
        point_2 = ast.literal_eval(row['Point 2 (x, y)'].values[0])
        point_3 = ast.literal_eval(row['Point 3 (x, y)'].values[0])
        point_4 = ast.literal_eval(row['Point 4 (x, y)'].values[0])
        return point_1, point_2, point_3, point_4
    except (ValueError, KeyError, IndexError) as e:
        print(f"Error parsing coordinates: {e}")
        return None, None, None, None

def clip_point(point, frame_width, frame_height):
    """
    Clip a point to ensure it lies within the frame boundaries.
    
    Args:
        point (tuple): (x, y) coordinates of the point.
        frame_width (int): Width of the frame.
        frame_height (int): Height of the frame.
        
    Returns:
        tuple: Clipped (x, y) coordinates.
    """
    x, y = point
    x = max(0, min(frame_width - 1, x))
    y = max(0, min(frame_height - 1, y))
    return x, y

def create_thick_segment(center, end, thickness, frame_width, frame_height):
    """
    Create a thick segment as a polygon around a line segment.
    
    Args:
        center (tuple): (x, y) coordinates of the segment's center point.
        end (tuple): (x, y) coordinates of the segment's end point.
        thickness (int): Thickness of the segment in pixels.
        frame_width (int): Width of the frame.
        frame_height (int): Height of the frame.
        
    Returns:
        np.ndarray: Array of polygon points defining the thick segment.
    """
    dx, dy = end[0] - center[0], end[1] - center[1]
    
    # Handle the case where center and end are the same point
    if dx == 0 and dy == 0:
        return np.array([center, center, center, center], dtype=np.int32)
        
    length = np.sqrt(dx**2 + dy**2)
    dx, dy = dx / length, dy / length
    perp_x, perp_y = -dy, dx
    
    # Create four corners of the polygon
    p1 = clip_point((int(center[0] + perp_x * thickness), int(center[1] + perp_y * thickness)), frame_width, frame_height)
    p2 = clip_point((int(center[0] - perp_x * thickness), int(center[1] - perp_y * thickness)), frame_width, frame_height)
    p3 = clip_point((int(end[0] - perp_x * thickness), int(end[1] - perp_y * thickness)), frame_width, frame_height)
    p4 = clip_point((int(end[0] + perp_x * thickness), int(end[1] + perp_y * thickness)), frame_width, frame_height)
    
    return np.array([p1, p2, p3, p4], dtype=np.int32)

def clean_coordinates(coord_list, window_size=CLEAN_WINDOW_SIZE, min_valid=CLEAN_MIN_VALID):
    """
    Cleans the coordinates by removing isolated valid points surrounded by mostly None values.
    
    Args:
        coord_list (list): A list of coordinates [[frame, x, y], ...].
        window_size (int): Size of the window to check for density.
        min_valid (int): Minimum number of non-None points required to keep data.
        
    Returns:
        list: The cleaned list of coordinates.
    """
    half_window = window_size // 2
    coord_list = coord_list.copy()  # Make a copy to avoid modifying the original list
    
    for frame in range(half_window, len(coord_list) - half_window):
        # Extract the window around the current frame
        window = coord_list[frame - half_window:frame + half_window + 1]
        
        # Count valid (non-None) data points in the window
        valid_count = sum(1 for entry in window if entry[1] is not None and entry[2] is not None)
        
        # If valid data points are less than the threshold, discard them
        if valid_count < min_valid:
            coord_list[frame][1] = None
            coord_list[frame][2] = None
    
    return coord_list

In [None]:
def point_to_segment_distance(px, py, x1, y1, x2, y2):
    """
    Calculate the distance from a point to a line segment and the projection onto that segment.
    
    Args:
        px (float): x-coordinate of the point.
        py (float): y-coordinate of the point.
        x1 (float): x-coordinate of the segment's start.
        y1 (float): y-coordinate of the segment's start.
        x2 (float): x-coordinate of the segment's end.
        y2 (float): y-coordinate of the segment's end.
        
    Returns:
        tuple: (distance, proportion, (proj_x, proj_y)) where:
            - distance is the perpendicular distance to the segment
            - proportion is the position along the segment (0 at start, 1 at end)
            - (proj_x, proj_y) are the coordinates of the projection point
    """
    dx, dy = x2 - x1, y2 - y1
    
    # Handle degenerate case where segment is a point
    if dx == 0 and dy == 0:
        return np.sqrt((px - x1)**2 + (py - y1)**2), 0, (x1, y1)
    
    dot = (px - x1) * dx + (py - y1) * dy
    length_sq = dx * dx + dy * dy
    proportion = dot / length_sq if length_sq != 0 else 0
    proportion = max(0, min(1, proportion))  # Clamp to [0,1]
    
    # Calculate projection point
    proj_x = x1 + proportion * dx
    proj_y = y1 + proportion * dy
    
    # Calculate distance
    distance = np.sqrt((px - proj_x)**2 + (py - proj_y)**2)
    
    return distance, proportion, (proj_x, proj_y)

def find_closest_segment_with_projection(row, px_col, py_col, segments):
    """
    Find the closest segment to a given point and calculate projection details.
    
    Args:
        row (pd.Series): A row of the DataFrame.
        px_col (str): Column name for x-coordinates.
        py_col (str): Column name for y-coordinates.
        segments (list): List of segments as ((x1, y1), (x2, y2), segment_id).
        
    Returns:
        tuple: ((distance, proportion, (proj_x, proj_y)), segment_id) of the closest segment.
    """
    # Skip calculation if coordinates are None/NaN
    if pd.isna(row[px_col]) or pd.isna(row[py_col]):
        # Return a default value with NaNs
        return ((np.nan, np.nan, (np.nan, np.nan)), "")
    
    px, py = row[px_col], row[py_col]
    
    # Calculate distances and projections for all segments
    distances_and_projections = [
        (point_to_segment_distance(px, py, x1, y1, x2, y2), seg_id)
        for (x1, y1), (x2, y2), seg_id in segments
    ]
    
    # Return the entry with minimum distance
    return min(distances_and_projections, key=lambda x: x[0][0])

def denoise_trajectory(df, distance_threshold, max_iterations=MAX_ITERATIONS):
    """
    Denoise a trajectory by removing and interpolating points with abnormal instantaneous displacement.
    
    Args:
        df (pd.DataFrame): DataFrame containing trajectory data.
        distance_threshold (float): Maximum allowed instantaneous displacement.
        max_iterations (int): Maximum number of denoising iterations.
        
    Returns:
        pd.DataFrame: Denoised DataFrame.
    """
    iteration = 0
    
    while iteration < max_iterations:
        # Calculate instantaneous distances
        df["Instantaneous_Distance"] = np.sqrt(
            np.diff(df["X_raw_projected"], prepend=np.nan) ** 2 +
            np.diff(df["Y_raw_projected"], prepend=np.nan) ** 2
        )

        # Identify points to remove (distance > threshold)
        outlier_indices = df[df["Instantaneous_Distance"] > distance_threshold].index

        # If no outliers, break the loop
        if outlier_indices.empty:
            break

        # Include surrounding points for removal (window of 5 points centered on outlier)
        indices_to_remove = set()
        for idx in outlier_indices:
            for offset in range(-2, 3):  # -2, -1, 0, 1, 2
                if 0 <= idx + offset < len(df):
                    indices_to_remove.add(idx + offset)

        # Remove the points by setting them to NaN
        df.loc[sorted(indices_to_remove), ["X_raw_projected", "Y_raw_projected"]] = np.nan

        # Interpolate missing values
        df["X_raw_projected"] = df["X_raw_projected"].interpolate(method="linear")
        df["Y_raw_projected"] = df["Y_raw_projected"].interpolate(method="linear")

        iteration += 1

    # Recalculate instantaneous distances after final cleanup
    df["Instantaneous_Distance"] = np.sqrt(
        np.diff(df["X_raw_projected"], prepend=np.nan) ** 2 +
        np.diff(df["Y_raw_projected"], prepend=np.nan) ** 2
    )
    
    return df

In [None]:
def process_video(file_name, folder_path, df):
    """
    Process a video file to detect bumblebees and calculate their trajectories in the Y-tube.
    This function performs the following steps:
    1. Extracts Y-tube coordinates from metadata
    2. Creates a mask for the Y-tube branches
    3. Applies background subtraction to isolate moving objects
    4. Detects contours and filters them based on size and number
    5. Tracks the position of the bumblebee through the Y-tube
    
    Args:
        file_name (str): Name of the video file.
        folder_path (str): Path to the folder containing the video.
        df (pd.DataFrame): DataFrame containing point metadata and coordinates.
        
    Returns:
        tuple: (DataFrame with tracking results, list of segments)
    """
    # Filter the row for the current video
    row = df[df['Video File'] == file_name]
    if row.empty:
        print(f"No data found for {file_name}.")
        return pd.DataFrame(columns=["Frame", "X_raw", "Y_raw"]), []

    # Extract coordinates that define the Y-tube structure
    point_1, point_2, point_3, point_4 = extract_coordinates(row)
    if None in (point_1, point_2, point_3, point_4):
        return pd.DataFrame(columns=["Frame", "X_raw", "Y_raw"]), []

    x1, y1 = point_1  # Bottom point
    x2, y2 = point_2  # Center point
    x3, y3 = point_3  # Left point
    x4, y4 = point_4  # Right point

    # Define the three segments of interest (center-to-point, segment label)
    segments = [
        ((x2, y2), (x1, y1), "bottom"),
        ((x2, y2), (x3, y3), "left"),
        ((x2, y2), (x4, y4), "right")
    ]
    
    # Open the video
    cap = cv2.VideoCapture(os.path.join(folder_path, file_name))
    if not cap.isOpened():
        print(f"Error: Video file {file_name} could not be opened.")
        return pd.DataFrame(columns=["Frame", "X_raw", "Y_raw"]), []

    # Initialize background subtractor
    fgbg = cv2.createBackgroundSubtractorMOG2(history=HISTORY, varThreshold=VAR_THRESHOLD, detectShadows=DETECT_SHADOWS)
    coordinates = []

    # Read first frame to get dimensions
    ret, frame = cap.read()
    if not ret:
        print("Error: Unable to read the first frame.")
        cap.release()
        return pd.DataFrame(columns=["Frame", "X_raw", "Y_raw"]), []

    frame_height, frame_width = frame.shape[:2]

    # Create mask for the Y-tube structure
    mask = np.zeros((frame_height, frame_width), dtype=np.uint8)
    for center, end, segment_id in segments:
        poly = create_thick_segment(center, end, thickness=TUBE_THICKNESS, frame_width=frame_width, frame_height=frame_height)
        cv2.fillPoly(mask, [poly], 255)

    # Reset to beginning of video
    cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
    
    # Process each frame
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        frame_number = int(cap.get(cv2.CAP_PROP_POS_FRAMES))

        # Apply background subtraction
        fgmask = fgbg.apply(frame)
        fgmask = cv2.medianBlur(fgmask, BLUR_SIZE)

        # Apply morphological operations to clean up the mask
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
        fgmask = cv2.morphologyEx(fgmask, cv2.MORPH_OPEN, kernel)
        fgmask = cv2.morphologyEx(fgmask, cv2.MORPH_CLOSE, kernel)

        # Apply the Y-tube mask to focus only on our region of interest
        fgmask = cv2.bitwise_and(fgmask, mask)
        
        # Apply median blur again to smooth the result
        fgmask = cv2.medianBlur(fgmask, BLUR_SIZE)

        # Threshold to create binary mask
        _, binary_mask = cv2.threshold(fgmask, 50, 255, cv2.THRESH_BINARY)

        # Find contours in the binary mask
        contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        # Filter 1: Based on number of contours
        if len(contours) > MAX_CONTOURS:
            coordinates.append([frame_number, None, None])
            continue

        # Filter 2: Based on contour area
        if contours:
            largest_contour_area = max(cv2.contourArea(c) for c in contours)
            if largest_contour_area > MAX_AREA or largest_contour_area < MIN_AREA:
                coordinates.append([frame_number, None, None])
                continue

        # Filter 3: Based on distance between contour centroids
        contour_centroids = []
        for contour in contours:
            M = cv2.moments(contour)
            if M["m00"] != 0:
                cX = int(M["m10"] / M["m00"])
                cY = int(M["m01"] / M["m00"])
                contour_centroids.append((cX, cY))

        # Check if multiple centroids are too far apart
        too_far_apart = False
        for i, (x1, y1) in enumerate(contour_centroids):
            for j, (x2, y2) in enumerate(contour_centroids):
                if i != j:
                    dist = np.sqrt((x2 - x1)**2 + (y2 - y1)**2)
                    if dist > MAX_DISTANCE:
                        too_far_apart = True
                        break
            if too_far_apart:
                break

        if too_far_apart:
            coordinates.append([frame_number, None, None])
            continue

        # If all filters pass, use the largest contour centroid
        black_point = None
        if contours:
            sorted_contours = sorted(contours, key=cv2.contourArea, reverse=True)
            largest_contour = sorted_contours[0]
            M = cv2.moments(largest_contour)
            if M["m00"] != 0:
                cX = int(M["m10"] / M["m00"])
                cY = int(M["m01"] / M["m00"])
                black_point = (cX, cY)

        # Record coordinates (or None if no valid point found)
        coordinates.append([frame_number, None if black_point is None else black_point[0], 
                           None if black_point is None else black_point[1]])

    # Release video resources
    cap.release()

    # Clean coordinates to remove isolated detections
    cleaned_coordinates = clean_coordinates(coordinates)

    # Convert to DataFrame
    return pd.DataFrame(cleaned_coordinates, columns=["Frame", "X_raw", "Y_raw"]), segments

In [None]:
def visualize_background_subtraction(file_name, folder_path, df):
    """
    Visualize the background subtraction and object detection process in real-time.
    Shows the binary mask resulting from background subtraction and morphological operations.
    This helps in understanding how well the object (bumblebee) is being detected.
    
    Args:
        file_name (str): Name of the video file.
        folder_path (str): Path to the folder containing the video.
        df (pd.DataFrame): DataFrame containing point metadata.
        
    Returns:
        tuple: Empty DataFrame and segments list (just for consistency with process_video).
    """
    # Filter the row for the current video
    row = df[df['Video File'] == file_name]
    if row.empty:
        print(f"No data found for {file_name}.")
        return pd.DataFrame(columns=["Frame", "X_raw", "Y_raw"]), []

    # Extract coordinates
    point_1, point_2, point_3, point_4 = extract_coordinates(row)
    if None in (point_1, point_2, point_3, point_4):
        return pd.DataFrame(columns=["Frame", "X_raw", "Y_raw"]), []

    # Define segments
    x1, y1 = point_1
    x2, y2 = point_2
    x3, y3 = point_3
    x4, y4 = point_4

    segments = [
        ((x2, y2), (x1, y1), "bottom"),
        ((x2, y2), (x3, y3), "left"),
        ((x2, y2), (x4, y4), "right")
    ]
    
    # Open the video
    cap = cv2.VideoCapture(os.path.join(folder_path, file_name))
    if not cap.isOpened():
        print(f"Error: Video file {file_name} could not be opened.")
        return pd.DataFrame(columns=["Frame", "X_raw", "Y_raw"]), []

    # Initialize background subtractor
    fgbg = cv2.createBackgroundSubtractorMOG2(history=300, varThreshold=200, detectShadows=True)

    # Read first frame to get dimensions
    ret, frame = cap.read()
    if not ret:
        print("Error: Unable to read the first frame.")
        cap.release()
        return pd.DataFrame(columns=["Frame", "X_raw", "Y_raw"]), []

    frame_height, frame_width = frame.shape[:2]

    # Create mask for the Y-tube
    mask = np.zeros((frame_height, frame_width), dtype=np.uint8)
    for center, end, segment_id in segments:
        poly = create_thick_segment(center, end, thickness=TUBE_THICKNESS, frame_width=frame_width, frame_height=frame_height)
        cv2.fillPoly(mask, [poly], 255)

    # Reset to beginning of video
    cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
    
    # Process each frame
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        # Apply background subtraction
        fgmask = fgbg.apply(frame)
        fgmask = cv2.medianBlur(fgmask, 5)

        # Apply morphological operations
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
        fgmask = cv2.morphologyEx(fgmask, cv2.MORPH_OPEN, kernel)
        fgmask = cv2.morphologyEx(fgmask, cv2.MORPH_CLOSE, kernel)

        # Apply the mask
        masked_fgmask = cv2.bitwise_and(fgmask, mask)
        masked_fgmask = cv2.medianBlur(masked_fgmask, 5)

        # Create binary mask for visualization
        _, binary_mask = cv2.threshold(masked_fgmask, 50, 255, cv2.THRESH_BINARY)

        # Display the binary mask
        cv2.imshow('Background Subtraction Result', binary_mask)
        
        # Wait for 'q' key to quit
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    # Release resources
    cap.release()
    cv2.destroyAllWindows()

    return pd.DataFrame(columns=["Frame", "X_raw", "Y_raw"]), []

In [None]:
def visualize_contour_detection(file_name, folder_path, df):
    """
    Visualize the contour detection process in real-time, showing detected objects.
    Draws the detected bumblebee contours, their centroids, and the Y-tube segments.
    This helps in debugging the tracking process and assessing the quality of contour detection.
    
    Args:
        file_name (str): Name of the video file.
        folder_path (str): Path to the folder containing the video.
        df (pd.DataFrame): DataFrame containing point metadata.
        
    Returns:
        tuple: Empty DataFrame and segments list (for consistency with process_video).
    """
    # Filter the row for the current video
    row = df[df['Video File'] == file_name]
    if row.empty:
        print(f"No data found for {file_name}.")
        return pd.DataFrame(columns=["Frame", "X_raw", "Y_raw"]), []

    # Extract coordinates
    point_1, point_2, point_3, point_4 = extract_coordinates(row)
    if None in (point_1, point_2, point_3, point_4):
        return pd.DataFrame(columns=["Frame", "X_raw", "Y_raw"]), []

    # Define segments
    x1, y1 = point_1
    x2, y2 = point_2
    x3, y3 = point_3
    x4, y4 = point_4

    segments = [
        ((x2, y2), (x1, y1), "bottom"),
        ((x2, y2), (x3, y3), "left"),
        ((x2, y2), (x4, y4), "right")
    ]
    
    # Open the video
    cap = cv2.VideoCapture(os.path.join(folder_path, file_name))
    if not cap.isOpened():
        print(f"Error: Video file {file_name} could not be opened.")
        return pd.DataFrame(columns=["Frame", "X_raw", "Y_raw"]), []

    # Initialize background subtractor
    fgbg = cv2.createBackgroundSubtractorMOG2(history=300, varThreshold=200, detectShadows=True)

    # Read first frame to get dimensions
    ret, frame = cap.read()
    if not ret:
        print("Error: Unable to read the first frame.")
        cap.release()
        return pd.DataFrame(columns=["Frame", "X_raw", "Y_raw"]), []

    frame_height, frame_width = frame.shape[:2]

    # Create mask for the Y-tube
    mask = np.zeros((frame_height, frame_width), dtype=np.uint8)
    for center, end, segment_id in segments:
        poly = create_thick_segment(center, end, thickness=TUBE_THICKNESS, frame_width=frame_width, frame_height=frame_height)
        cv2.fillPoly(mask, [poly], 255)

    # Reset to beginning of video
    cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
    
    # Process each frame
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        # Apply background subtraction
        fgmask = fgbg.apply(frame)
        fgmask = cv2.medianBlur(fgmask, 5)

        # Apply morphological operations
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
        fgmask = cv2.morphologyEx(fgmask, cv2.MORPH_OPEN, kernel)
        fgmask = cv2.morphologyEx(fgmask, cv2.MORPH_CLOSE, kernel)

        # Apply the mask
        masked_fgmask = cv2.bitwise_and(fgmask, mask)
        masked_fgmask = cv2.medianBlur(masked_fgmask, 5)

        # Create binary mask
        _, binary_mask = cv2.threshold(masked_fgmask, 50, 255, cv2.THRESH_BINARY)

        # Find contours in the binary mask
        contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        # Create a visualization frame
        vis_frame = frame.copy()
        
        # Draw Y-tube segments
        for center, end, segment_id in segments:
            cv2.line(vis_frame, center, end, (0, 255, 0), 2)
            cv2.putText(vis_frame, segment_id, end, cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)

        # Draw detected contours with random colors
        for contour in contours:
            # Filter out tiny contours
            if cv2.contourArea(contour) > MIN_AREA:
                # Random color in BGR format
                color = (np.random.randint(0, 256), np.random.randint(0, 256), np.random.randint(0, 256))
                
                # Draw contour
                cv2.drawContours(vis_frame, [contour], -1, color, 2)
                
                # Draw centroid
                M = cv2.moments(contour)
                if M["m00"] != 0:
                    cX = int(M["m10"] / M["m00"])
                    cY = int(M["m01"] / M["m00"])
                    cv2.circle(vis_frame, (cX, cY), 5, color, -1)
                    cv2.putText(vis_frame, f"Area: {cv2.contourArea(contour):.1f}", (cX + 10, cY), 
                                cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)

        # Display the result
        cv2.imshow('Contour Detection', vis_frame)

        # Wait for 'q' key to quit
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    # Release resources
    cap.release()
    cv2.destroyAllWindows()

    return pd.DataFrame(columns=["Frame", "X_raw", "Y_raw"]), []

In [None]:

def process_video_parallel(file_name, folder_path, df_segments):
    """
    Process a video file with complete tracking workflow including interpolation and trajectory analysis.
    
    Args:
        file_name (str): Name of the video file.
        folder_path (str): Path to the folder containing the video.
        df_segments (pd.DataFrame): DataFrame containing point metadata.
        
    Returns:
        tuple: (file_name, output_path) or None if processing failed.
    """
    video_path = os.path.join(folder_path, file_name)
    print(f"Processing {file_name} in process {os.getpid()}")

    # Validate file existence and permissions
    if not os.path.exists(video_path):
        print(f"Error: File not found - {video_path}")
        return None
        
    if not os.access(video_path, os.R_OK):
        print(f"Error: No read permission for - {video_path}")
        return None

    # Try opening the video file
    video = cv2.VideoCapture(video_path)
    if not video.isOpened():
        print(f"Error: Could not open video file - {video_path}")
        return None

    print(f"Successfully opened video: {file_name}")
    video.release()  # Release video here, it will be reopened in process_video
    
    # ========== TRACKING PHASE ==========
    # Call tracking function
    df, segments = process_video(file_name, folder_path, df_segments)
    print(f"Basic tracking completed for {file_name}")
    
    # Check if tracking produced valid results
    if df.empty or 'X_raw' not in df.columns:
        print(f"Error: Tracking failed for {file_name}")
        return None
    
    # ========== INTERPOLATION PHASE ==========
    # Fill missing values using linear interpolation
    df["X_raw_filled"] = df["X_raw"].interpolate(method="linear", limit_direction="both")
    df["Y_raw_filled"] = df["Y_raw"].interpolate(method="linear", limit_direction="both")

    # Check if interpolation succeeded
    if df['X_raw_filled'].isna().any() or df['Y_raw_filled'].isna().any():
        print(f"Error: Interpolation failed for {file_name} - still have NaN values")
        return None

    # ========== PROJECTION PHASE ==========
    # Project points onto segments
    df['intermediate'] = df.apply(find_closest_segment_with_projection, 
                                axis=1, 
                                args=('X_raw_filled', 'Y_raw_filled', segments))
    
    # Extract projection information
    df['X_raw_projected'] = df['intermediate'].apply(lambda x: x[0][2][0])
    df['Y_raw_projected'] = df['intermediate'].apply(lambda x: x[0][2][1])
    
    # Remove intermediate column to keep dataframe clean
    df.drop(columns=['intermediate'], inplace=True)

    # ========== DENOISING PHASE ==========
    # Clean the trajectory by removing erratic movements
    df = denoise_trajectory(df, distance_threshold=30, max_iterations=3)
    
    # Project points again after denoising for more accurate segment assignment
    df['intermediate'] = df.apply(find_closest_segment_with_projection, 
                                axis=1, 
                                args=('X_raw_projected', 'Y_raw_projected', segments))
    
    # Extract final projection information
    df['X_projected'] = df['intermediate'].apply(lambda x: x[0][2][0])
    df['Y_projected'] = df['intermediate'].apply(lambda x: x[0][2][1])
    df['proportion'] = df['intermediate'].apply(lambda x: x[0][1])
    df['closest_segment'] = df['intermediate'].apply(lambda x: x[1])

    # Remove intermediate column
    df.drop(columns=['intermediate'], inplace=True)
    
    # ========== METRICS CALCULATION PHASE ==========
    # Compute displacement metrics
    df["delta_X"] = df["X_projected"].diff()
    df["delta_Y"] = df["Y_projected"].diff()
    
    # Compute velocity (spatial and proportional)
    df["velocity"] = np.sqrt(df["delta_X"]**2 + df["delta_Y"]**2)
    df["velocity_prop"] = np.sqrt(df["proportion"].diff()**2)
    
    # Compute cumulative distance
    df["cumulative_distance"] = df["velocity"].cumsum()

    # ========== OUTPUT PHASE ==========
    # Save processed DataFrame to CSV
    output_path = os.path.join(folder_path, f"processed_{file_name}.csv")
    df.to_csv(output_path, index=False)
    print(f"Processed trajectory data saved to {output_path}")

    return file_name, output_path  # Return processing result

### Tracking Execution

This cell executes the tracking process for all videos. The process runs in parallel threads to speed up execution. You can also visualize the background subtraction and contour detection in real-time to debug tracking issues.

**Process Steps:**
1. Load the CSV file with ROI coordinates
2. Find all video files in the specified folder
3. Optionally run visualization for the first video
4. Process all videos in parallel threads
5. Save trajectory data to CSV files

**Output Files:**
- `processed_<video_name>.csv`: CSV file containing tracking data for each video

**Visualization Options:**
- Set `VISUALIZE_BACKGROUND` to True to see background subtraction results
- Set `VISUALIZE_CONTOURS` to True to see contour detection

In [None]:
# ===== TRACKING EXECUTION CONFIGURATION =====
# Modify these parameters to configure the tracking execution

# Directory settings
FOLDER_PATH = r'C:\Users\YourUsername\Videos'  # Path to video directory
CSV_POINTS_FILE = 'selected_points.csv'                                # Filename of points data

# Processing settings
MAX_WORKERS = 8                                                        # Number of parallel processes

# Visualization options
VISUALIZE_BACKGROUND = False                                           # Set to True to visualize background subtraction
VISUALIZE_CONTOURS = True                                              # Set to True to visualize contour detection
# ==========================================

# Main processing cell for bumblebee tracking in olfactometer videos
import os
import time
import pandas as pd
from concurrent.futures import ThreadPoolExecutor

# Get a list of all .mp4 filenames in the folder
mp4_files = [f for f in os.listdir(FOLDER_PATH) if f.endswith('.mp4')]
print(f"Found {len(mp4_files)} video files to process")

# Load the selected_points.csv file containing Y-tube coordinates
csv_path = os.path.join(FOLDER_PATH, CSV_POINTS_FILE)
if not os.path.exists(csv_path):
    raise FileNotFoundError(f"Cannot find coordinate file: {csv_path}")

df_segments = pd.read_csv(csv_path)
print(f"Loaded coordinate data for {len(df_segments)} videos")

# Initialize result containers
df_results_list = []
processed_trajectory_list = []

# Define wrapper function for processing
def process_video_wrapper(f):
    return process_video_parallel(f, FOLDER_PATH, df_segments)

# Optional: Run visualizations if configured
if VISUALIZE_BACKGROUND or VISUALIZE_CONTOURS:
    print("Running visualizations for the first video...")
    first_video = mp4_files[0] if mp4_files else None
    
    if first_video:
        if VISUALIZE_BACKGROUND:
            print(f"Visualizing background subtraction for {first_video}")
            visualize_background_subtraction(first_video, FOLDER_PATH, df_segments)
            
        if VISUALIZE_CONTOURS:
            print(f"Visualizing contour detection for {first_video}")
            visualize_contour_detection(first_video, FOLDER_PATH, df_segments)
    else:
        print("No videos found for visualization")

# Process videos in parallel
print(f"Starting parallel processing with {MAX_WORKERS} workers...")
start_time_para = time.time()

with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
    results = list(executor.map(process_video_wrapper, mp4_files))
    
end_time_para = time.time()
para_duration = end_time_para - start_time_para

# Clean results (remove None values from failed processing)
valid_results = [res for res in results if res is not None]

# Results summary
print("=" * 50)
print("PROCESSING COMPLETED")
print("=" * 50)
print(f"Total videos processed: {len(valid_results)}/{len(mp4_files)}")
print(f"Parallel execution time: {para_duration:.2f} seconds")
print(f"Average time per video: {para_duration / len(mp4_files):.2f} seconds")

# Check for failed videos
if len(valid_results) < len(mp4_files):
    failed_videos = [f for f, res in zip(mp4_files, results) if res is None]
    print("\nFailed videos:")
    for vid in failed_videos:
        print(f" - {vid}")

print("\nProcessed video files can be found at:")
print(f" {FOLDER_PATH}\\processed_*.csv")

## 4. Trajectory Visualization

This section provides tools to visualize the tracking results in various formats. The visualizations help to understand bumblebee behavior and validate tracking quality.

### Types of Visualizations:
1. **Raw Trajectory Plots**: Show the spatial path with color-coded segments
2. **Time-series Plots**: Display movement metrics over time
3. **Heatmaps**: Visualize spatial density of bumblebee positions
4. **Animations**: Create dynamic visualizations of movement

### How to Use:
- Load processed trajectories using `load_processed_trajectories()`
- Generate individual plots with the specific plotting functions
- Use `visualize_all_trajectories()` to process multiple videos
- Set `save_plots=True` to save visualizations to disk

In [None]:
# ===== VISUALIZATION CONFIGURATION =====
# These parameters control the visualization options

# Plot settings
DEFAULT_FIGSIZE = (10, 8)          # Default figure size for plots
HEATMAP_FIGSIZE = (14, 12)         # Figure size for heatmaps
TIMELINE_FIGSIZE = (18, 6)         # Figure size for timeline plots
DPI = 300                          # Resolution for saved figures

# Heatmap settings
HEATMAP_RESOLUTION = 50            # Resolution of the heatmap grid
HEATMAP_SMOOTHING = 3.0            # Gaussian smoothing factor for heatmaps

# Animation settings
ANIMATION_FPS = 30                 # Frames per second for animations
ANIMATION_DPI = 100                # DPI for animations

# Color schemes
SEGMENT_COLORS = {                 # Colors for Y-tube segments
    'bottom': 'orange',
    'left': 'green',
    'right': 'purple'
}
# ======================================

import seaborn as sns
import matplotlib.pyplot as plt
import os
import pandas as pd
import numpy as np
import cv2
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.patches import Patch
from matplotlib.animation import FuncAnimation

def count_peaks_by_segment(proportion_series, segments_series):
    """
    Count peaks in proportion values, categorized by segment.
    A peak is counted when proportion crosses a threshold in either direction.
    
    Args:
        proportion_series (pd.Series): Series of proportion values
        segments_series (pd.Series): Series of segment labels
        
    Returns:
        tuple: (count_left, count_right, time_list)
            - count_left: Number of peaks in the left segment
            - count_right: Number of peaks in the right segment
            - time_list: Frame indices where peaks occur
    """
    count_left = 0  # For negative peaks (left)
    count_right = 0  # For positive peaks (right)
    above_threshold_left = True  # To track when the proportion goes above 0.25 for negative peaks
    below_threshold_right = True  # To track when the proportion goes below -0.25 for positive peaks
    time_list = []
    
    for time, (proportion, segment) in enumerate(zip(proportion_series, segments_series)):
        if segment == "bottom" or segment == "no_choice":        
            above_threshold_left = True  
            below_threshold_right = True
            continue
                
        if segment == "left" or segment == "right":
            # Reset thresholds when crossing the middle range
            if proportion >= -0.25:
                above_threshold_left = True  # Allows negative peaks after crossing 0.25
            if proportion <= 0.25:
                below_threshold_right = True  # Allows positive peaks after crossing -0.25
        
            # Count right peaks (positive) after going below -0.25
            if proportion >= 0.75 and below_threshold_right:
                count_right += 1
                below_threshold_right = False  # Prevents double counting
                time_list.append(time)
        
            # Count left peaks (negative) after going above 0.25
            if proportion <= -0.75 and above_threshold_left:
                count_left += 1
                above_threshold_left = False  # Prevents double counting
                time_list.append(time)
    
    return count_left, count_right, time_list

def load_processed_trajectories(folder_path):
    """
    Load all processed trajectory CSV files from a specified folder.
    
    Args:
        folder_path (str): Path to the folder containing processed CSV files
        
    Returns:
        list: List of tuples (filename, dataframe) for each processed trajectory
    """
    # Find all processed CSV files
    tracking_csv_files = [file for file in os.listdir(folder_path) 
                         if file.endswith(".csv") and "processed_" in file]
    print(f"Found {len(tracking_csv_files)} processed trajectory files")
    
    # Load each file into a DataFrame
    processed_trajectory_list = []
    for csv_filename in tracking_csv_files:
        csv_path = os.path.join(folder_path, csv_filename)
        try:
            df_tracking = pd.read_csv(csv_path)
            # Convert 'left' segment proportions to negative for better visualization
            df_tracking.loc[df_tracking['closest_segment'] == 'left', 'proportion'] *= -1
            processed_trajectory_list.append([csv_filename, df_tracking])
            print(f"Loaded: {csv_filename}")
        except Exception as e:
            print(f"Error loading {csv_filename}: {e}")
    
    return processed_trajectory_list

In [None]:
def plot_raw_trajectory(filename, df, show_plot=True):
    """
    Plot the raw tracked coordinates colored by segment.
    
    Args:
        filename (str): Name of the file being visualized
        df (pd.DataFrame): DataFrame containing trajectory data
        show_plot (bool): Whether to display the plot immediately
        
    Returns:
        matplotlib.figure.Figure: The generated figure
    """
    fig, ax = plt.subplots(figsize=DEFAULT_FIGSIZE)
    
    # Create scatter plot with colors by segment
    sns.scatterplot(data=df, x="X_raw", y="Y_raw", hue='closest_segment', 
                   palette=SEGMENT_COLORS, s=20, alpha=0.7, ax=ax)
    
    # Draw lines to connect points and show trajectory path
    ax.plot(df["X_raw"], df["Y_raw"], color='gray', alpha=0.3, linewidth=0.5)
    
    # Add Y-tube structure from the first and last points in each segment
    segments = ['bottom', 'left', 'right']
    for segment in segments:
        segment_df = df[df['closest_segment'] == segment]
        if not segment_df.empty:
            # Draw a thicker line for the segment
            ax.plot(segment_df["X_raw"], segment_df["Y_raw"], 
                   linewidth=2, alpha=0.8, label=f"{segment} path")
    
    # Enhance plot appearance
    ax.set_title(f"Raw Tracking Trajectory: {filename}", fontsize=16)
    ax.set_xlabel("X Coordinate", fontsize=14)
    ax.set_ylabel("Y Coordinate", fontsize=14)
    ax.grid(True, alpha=0.3)
    ax.legend(title="Segment", fontsize=10)
    
    # Add note about coordinate system
    ax.text(0.02, 0.02, "Note: (0,0) is at top-left corner", 
           transform=ax.transAxes, fontsize=10, alpha=0.7,
           bbox=dict(facecolor='white', alpha=0.5))
    
    plt.tight_layout()
    
    if show_plot:
        plt.show()
    
    return fig

def plot_cumulative_distance(filename, df, show_plot=True):
    """
    Plot the cumulative distance traveled over time.
    
    Args:
        filename (str): Name of the file being visualized
        df (pd.DataFrame): DataFrame containing trajectory data
        show_plot (bool): Whether to display the plot immediately
        
    Returns:
        matplotlib.figure.Figure: The generated figure
    """
    fig, ax = plt.subplots(figsize=TIMELINE_FIGSIZE)
    
    # Plot cumulative distance as line
    ax.plot(df["Frame"], df["cumulative_distance"], label="Cumulative Distance", 
           color="green", linewidth=2)
    
    # Add segment transitions as vertical lines
    segment_changes = df['closest_segment'].ne(df['closest_segment'].shift()).cumsum()
    change_frames = df.groupby(segment_changes)['Frame'].first().iloc[1:]
    
    for frame in change_frames:
        ax.axvline(x=frame, color='gray', linestyle='--', alpha=0.5)
    
    # Add velocity information
    ax2 = ax.twinx()
    ax2.plot(df["Frame"], df["velocity"], label="Instantaneous Velocity", 
            color="red", alpha=0.3)
    ax2.set_ylabel("Velocity (pixels/frame)", color="red", fontsize=14)
    ax2.tick_params(axis='y', labelcolor="red")
    
    # Enhance plot appearance
    ax.set_title(f"Cumulative Distance Over Time: {filename}", fontsize=16)
    ax.set_xlabel("Frame", fontsize=14)
    ax.set_ylabel("Cumulative Distance (pixels)", fontsize=14)
    ax.grid(True, alpha=0.3)
    
    # Combine legends
    lines1, labels1 = ax.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax.legend(lines1 + lines2, labels1 + labels2, loc='upper left')
    
    plt.tight_layout()
    
    if show_plot:
        plt.show()
    
    return fig

In [None]:
def plot_proportion_by_segment(filename, df, show_plot=True, highlight_peaks=True):
    """
    Plot proportion values over time, colored by segment, with threshold lines.
    
    Args:
        filename (str): Name of the file being visualized
        df (pd.DataFrame): DataFrame containing trajectory data
        show_plot (bool): Whether to display the plot immediately
        highlight_peaks (bool): Whether to highlight detected peaks
        
    Returns:
        matplotlib.figure.Figure: The generated figure
    """
    # Create segment-specific DataFrames
    df_bottom = df[df['closest_segment'] == "bottom"]
    df_left = df[df['closest_segment'] == "left"]
    df_right = df[df['closest_segment'] == "right"]
    
    fig, ax = plt.subplots(figsize=TIMELINE_FIGSIZE)
    
    # Plot proportion values by segment
    sns.scatterplot(x=df_left["Frame"], y=df_left["proportion"], 
                   label="Left", color=SEGMENT_COLORS["left"], ax=ax, alpha=0.6)
    sns.scatterplot(x=df_right["Frame"], y=df_right["proportion"], 
                   label="Right", color=SEGMENT_COLORS["right"], ax=ax, alpha=0.6)
    sns.scatterplot(x=df_bottom["Frame"], y=df_bottom["proportion"], 
                   label="Bottom", color=SEGMENT_COLORS["bottom"], ax=ax, alpha=0.6)
    
    # Connect points with lines
    ax.plot(df["Frame"], df["proportion"], color='gray', alpha=0.2, linewidth=0.5)
    
    # Add threshold lines
    ax.axhline(y=0.75, color='red', linestyle='--', linewidth=1.5)
    ax.axhline(y=0.25, color='cyan', linestyle='--', linewidth=1.5)
    ax.axhline(y=-0.75, color='red', linestyle='--', linewidth=1.5)
    ax.axhline(y=-0.25, color='cyan', linestyle='--', linewidth=1.5)
    
    # Highlight threshold zones
    ax.axhspan(0.75, 1.0, alpha=0.1, color='red')
    ax.axhspan(0.25, 0.75, alpha=0.1, color='cyan')
    ax.axhspan(-0.25, 0.25, alpha=0.1, color='gray')
    ax.axhspan(-0.75, -0.25, alpha=0.1, color='cyan')
    ax.axhspan(-1.0, -0.75, alpha=0.1, color='red')
    
    # Add labels for threshold regions
    ax.text(df["Frame"].max() * 1.01, 0.9, "Peak Right", fontsize=10, va='center')
    ax.text(df["Frame"].max() * 1.01, 0.5, "Transition Zone", fontsize=10, va='center')
    ax.text(df["Frame"].max() * 1.01, 0.0, "No Choice Zone", fontsize=10, va='center')
    ax.text(df["Frame"].max() * 1.01, -0.5, "Transition Zone", fontsize=10, va='center')
    ax.text(df["Frame"].max() * 1.01, -0.9, "Peak Left", fontsize=10, va='center')
    
    # Highlight peaks if requested
    if highlight_peaks:
        count_left, count_right, peak_frames = count_peaks_by_segment(
            df['proportion'], df['closest_segment'])
        
        if peak_frames:
            peak_proportions = df.loc[peak_frames, 'proportion']
            ax.scatter(peak_frames, peak_proportions, color='red', s=100, 
                      zorder=5, label=f"Peaks (L:{count_left}, R:{count_right})")
    
    # Enhance plot appearance
    ax.set_title(f"Proportion Values by Segment: {filename}", fontsize=16)
    ax.set_xlabel("Frame", fontsize=14)
    ax.set_ylabel("Proportion", fontsize=14)
    ax.set_ylim(-1.1, 1.1)
    ax.grid(True, alpha=0.3)
    ax.legend(fontsize=12)
    
    plt.tight_layout()
    
    if show_plot:
        plt.show()
    
    return fig

In [None]:
def plot_heatmap_trajectory(filename, df, show_plot=True, resolution=HEATMAP_RESOLUTION, smoothing_factor=HEATMAP_SMOOTHING):
    """
    Create a smoother heatmap showing where the bumblebee spent most time,
    with a blue-to-red colormap and faded trajectories.
    
    Args:
        filename (str): Name of the file being visualized
        df (pd.DataFrame): DataFrame containing trajectory data
        show_plot (bool): Whether to display the plot immediately
        resolution (int): Resolution of the heatmap grid (lower = smoother)
        smoothing_factor (float): Gaussian smoothing sigma (higher = smoother)
        
    Returns:
        matplotlib.figure.Figure: The generated figure
    """
    from scipy.ndimage import gaussian_filter
    
    # Filter out NaN values
    valid_df = df.dropna(subset=['X_raw', 'Y_raw'])
    
    if valid_df.empty:
        print(f"No valid trajectory data for {filename}")
        return None
    
    # Determine the boundaries of the plot
    x_min, x_max = valid_df['X_raw'].min(), valid_df['X_raw'].max()
    y_min, y_max = valid_df['Y_raw'].min(), valid_df['Y_raw'].max()
    
    # Add some margin
    margin = 0.1
    x_range = x_max - x_min
    y_range = y_max - y_min
    x_min -= margin * x_range
    x_max += margin * x_range
    y_min -= margin * y_range
    y_max += margin * y_range
    
    # Create a 2D histogram with lower resolution for smoother appearance
    hist, x_edges, y_edges = np.histogram2d(
        valid_df['X_raw'], valid_df['Y_raw'], 
        bins=[resolution, resolution], 
        range=[[x_min, x_max], [y_min, y_max]]
    )
    
    # Apply Gaussian smoothing to make the heatmap smoother
    hist_smooth = gaussian_filter(hist, sigma=smoothing_factor)
    
    # Apply logarithmic transformation to better visualize the distribution
    hist_smooth = np.log1p(hist_smooth)  # log(1+x) to handle zeros
    
    # Create figure with black background for better contrast
    fig, ax = plt.subplots(figsize=HEATMAP_FIGSIZE, facecolor='black')
    ax.set_facecolor('black')
    
    # Create blue-to-red colormap
    cmap = plt.cm.coolwarm
    
    # Plot the heatmap
    im = ax.imshow(hist_smooth.T, extent=[x_min, x_max, y_max, y_min], 
                  origin='upper', cmap=cmap, interpolation='gaussian',
                  aspect='auto', alpha=0.9)
    
    # Add the raw trajectory as very faint lines
    ax.plot(valid_df['X_raw'], valid_df['Y_raw'], 
           color='white', alpha=0.1, linewidth=0.5)
    
    # Highlight segment paths with very light transparency
    segment_colors = {'bottom': 'yellow', 'left': 'lime', 'right': 'magenta'}
    
    for segment, color in segment_colors.items():
        segment_df = valid_df[valid_df['closest_segment'] == segment]
        if not segment_df.empty:
            ax.plot(segment_df['X_raw'], segment_df['Y_raw'], 
                   color=color, alpha=0.2, linewidth=1.0, label=segment)
    
    # Add colorbar with custom formatting
    cbar = plt.colorbar(im, ax=ax, pad=0.01)
    cbar.set_label('Visitation Frequency', fontsize=14, color='white')
    cbar.ax.yaxis.set_tick_params(color='white')
    cbar.outline.set_edgecolor('white')
    plt.setp(plt.getp(cbar.ax, 'yticklabels'), color='white')
    
    # Enhance plot appearance
    ax.set_title(f"Bumblebee Visitation Heatmap: {filename}", fontsize=18, color='white')
    ax.set_xlabel("X Coordinate", fontsize=14, color='white')
    ax.set_ylabel("Y Coordinate", fontsize=14, color='white')
    ax.tick_params(colors='white')
    for spine in ax.spines.values():
        spine.set_edgecolor('white')
    
    # Create custom legend with white text
    legend = ax.legend(framealpha=0.7, facecolor='black', edgecolor='white', fontsize=12)
    for text in legend.get_texts():
        text.set_color('white')
    
    plt.tight_layout()
    
    if show_plot:
        plt.show()
    
    return fig

In [None]:
def create_trajectory_animation(filename, df, output_file=None, fps=ANIMATION_FPS, dpi=ANIMATION_DPI):
    """
    Create an animation of the bumblebee's trajectory over time.
    
    Args:
        filename (str): Name of the file being visualized
        df (pd.DataFrame): DataFrame containing trajectory data
        output_file (str, optional): Path to save the animation (if None, animation is displayed)
        fps (int): Frames per second for the animation
        dpi (int): Resolution of the output animation
        
    Returns:
        matplotlib.animation.FuncAnimation: The animation object
    """
    # Filter out NaN values
    valid_df = df.dropna(subset=['X_raw', 'Y_raw'])
    
    if valid_df.empty:
        print(f"No valid trajectory data for {filename}")
        return None
    
    # Create figure and axis
    fig, ax = plt.subplots(figsize=DEFAULT_FIGSIZE)
    
    # Determine axis limits
    x_min, x_max = valid_df['X_raw'].min(), valid_df['X_raw'].max()
    y_min, y_max = valid_df['Y_raw'].min(), valid_df['Y_raw'].max()
    
    # Add margins
    margin = 0.1
    x_range = x_max - x_min
    y_range = y_max - y_min
    ax.set_xlim(x_min - margin * x_range, x_max + margin * x_range)
    ax.set_ylim(y_max + margin * y_range, y_min - margin * y_range)  # Inverted for image coordinates
    
    # Prepare segment colors
    segment_colors = SEGMENT_COLORS
    
    # Plot the Y-tube structure (all segments)
    for segment, color in segment_colors.items():
        segment_df = valid_df[valid_df['closest_segment'] == segment]
        if not segment_df.empty:
            ax.plot(segment_df['X_raw'], segment_df['Y_raw'], 
                   color=color, alpha=0.2, linewidth=3, label=f"{segment} path")
    
    # Animation elements
    trail_line, = ax.plot([], [], 'gray', alpha=0.5, linewidth=1)
    bee_point, = ax.plot([], [], 'ro', markersize=8)
    timer_text = ax.text(0.02, 0.96, '', transform=ax.transAxes, fontsize=12)
    segment_text = ax.text(0.02, 0.91, '', transform=ax.transAxes, fontsize=12)
    
    # Create legend for segments
    legend_elements = [
        Patch(facecolor=color, edgecolor='black', alpha=0.6, label=segment)
        for segment, color in segment_colors.items()
    ]
    ax.legend(handles=legend_elements, loc='upper right')
    
    # Set up plot details
    ax.set_title(f"Bumblebee Trajectory Animation: {filename}", fontsize=16)
    ax.set_xlabel("X Coordinate", fontsize=14)
    ax.set_ylabel("Y Coordinate", fontsize=14)
    ax.grid(True, alpha=0.3)
    
    # Subsample frames for smoother animation
    # Use every Nth frame for animation
    N = max(1, len(valid_df) // 300)  # Aim for around 300 frames
    sampled_df = valid_df.iloc[::N].reset_index(drop=True)
    
    def init():
        """Initialize the animation"""
        trail_line.set_data([], [])
        bee_point.set_data([], [])
        timer_text.set_text('')
        segment_text.set_text('')
        return trail_line, bee_point, timer_text, segment_text
    
    def update(frame):
        """Update the animation for each frame"""
        current_frame = sampled_df.iloc[:frame+1]
        
        # Update trail
        trail_line.set_data(current_frame['X_raw'], current_frame['Y_raw'])
        
        # Update current position
        curr_pos = current_frame.iloc[-1]
        bee_point.set_data([curr_pos['X_raw']], [curr_pos['Y_raw']])
        
        # Update text information
        timer_text.set_text(f"Frame: {curr_pos['Frame']}")
        segment_text.set_text(f"Segment: {curr_pos['closest_segment']}")
        
        # Set color based on current segment
        segment = curr_pos['closest_segment']
        bee_point.set_color(segment_colors.get(segment, 'red'))
        
        return trail_line, bee_point, timer_text, segment_text
    
    # Create animation
    ani = FuncAnimation(fig, update, frames=len(sampled_df),
                       init_func=init, blit=True, interval=1000/fps)
    
    # Save or display animation
    plt.tight_layout()
    
    if output_file:
        ani.save(output_file, writer='pillow', fps=fps, dpi=dpi)
        print(f"Animation saved to {output_file}")
    else:
        plt.show()
    
    return ani

In [None]:
def visualize_all_trajectories(folder_path, output_folder=None, 
                              save_plots=False, save_animations=False):
    """
    Process and visualize all trajectory files in the specified folder.
    
    Args:
        folder_path (str): Path to the folder containing processed CSV files
        output_folder (str, optional): Path to save visualizations (defaults to folder_path/visualizations)
        save_plots (bool): Whether to save plots as PNG files
        save_animations (bool): Whether to save animations as GIF files
        
    Returns:
        dict: Dictionary of visualization results
    """
    # Create output folder if needed
    if save_plots or save_animations:
        if output_folder is None:
            output_folder = os.path.join(folder_path, 'visualizations')
        
        if not os.path.exists(output_folder):
            os.makedirs(output_folder)
        print(f"Visualizations will be saved to: {output_folder}")
    
    # Load all trajectory files
    trajectory_list = load_processed_trajectories(folder_path)
    if not trajectory_list:
        print("No trajectory files found")
        return {}
    
    # Process each trajectory
    results = {}
    for filename, df in trajectory_list:
        print(f"\nProcessing: {filename}")
        trajectory_name = filename.replace('processed_', '').replace('.csv', '')
        result_dict = {}
        
        # Create basic name for output files
        base_filename = os.path.join(output_folder, trajectory_name) if output_folder else None
        
        # Create visualizations
        try:
            # Raw trajectory plot
            fig1 = plot_raw_trajectory(trajectory_name, df, show_plot=not save_plots)
            result_dict['raw_trajectory'] = fig1
            if save_plots and base_filename:
                fig1.savefig(f"{base_filename}_raw_trajectory.png", dpi=DPI)
                plt.close(fig1)
            
            # Cumulative distance plot
            fig2 = plot_cumulative_distance(trajectory_name, df, show_plot=not save_plots)
            result_dict['cumulative_distance'] = fig2
            if save_plots and base_filename:
                fig2.savefig(f"{base_filename}_cumulative_distance.png", dpi=DPI)
                plt.close(fig2)
            
            # Proportion by segment plot
            fig3 = plot_proportion_by_segment(trajectory_name, df, show_plot=not save_plots)
            result_dict['proportion_plot'] = fig3
            if save_plots and base_filename:
                fig3.savefig(f"{base_filename}_proportion.png", dpi=DPI)
                plt.close(fig3)
            
            # Heatmap trajectory
            fig4 = plot_heatmap_trajectory(trajectory_name, df, show_plot=not save_plots)
            result_dict['heatmap'] = fig4
            if save_plots and base_filename and fig4:
                fig4.savefig(f"{base_filename}_heatmap.png", dpi=DPI)
                plt.close(fig4)
            
            # Animation (only if requested)
            if save_animations and base_filename:
                ani = create_trajectory_animation(
                    trajectory_name, df, 
                    output_file=f"{base_filename}_animation.gif",
                    fps=10, dpi=100
                )
                result_dict['animation'] = ani
            
            results[trajectory_name] = result_dict
            print(f"Completed visualizations for: {trajectory_name}")
            
        except Exception as e:
            print(f"Error processing {trajectory_name}: {e}")
    
    print(f"\nCompleted visualization of {len(results)} trajectory files")
    return results

### Visualization Configuration and Run

This cell creates visualizations of the tracking results. It can generate various plots and animations to help understand the bumblebee's behavior.

#### Key Parameters:
- `VISUALIZATION_FOLDER`: Directory containing the processed CSV files
- `OUTPUT_FOLDER`: Where to save visualizations (None = auto-create subfolder)
- `SAVE_PLOTS`: Whether to save plot images to disk
- `SAVE_ANIMATIONS`: Whether to create and save animations (GIF files)
- `INTERACTIVE_DISPLAY`: Whether to display plots in the notebook
- `VISUALIZE_ALL`: Process all trajectories or just specific ones
- `SPECIFIC_FILES`: List of specific files to process (if not VISUALIZE_ALL)

#### Instructions:
1. Make sure you've completed the tracking step for all videos
2. Configure the visualization parameters
3. Run this cell to generate visualizations
4. View the plots in the notebook or check the output folder

#### Visualization Types:
- Raw trajectory plots (spatial path with color-coded segments)
- Cumulative distance plots (distance over time)
- Proportion plots (time spent in each segment)
- Heatmaps (spatial density of bumblebee positions)
- Animations (dynamic visualization of movement, if enabled)

#### Output:
If saving is enabled, files will be named `<video_name>_<plot_type>.png` or `<video_name>_animation.gif`

In [None]:
# ===== VISUALIZATION EXECUTION =====
# Modify these parameters for visualization settings

# Directory and file settings
VISUALIZATION_FOLDER = r'C:\Users\YourUsername\Videos'  # Path to processed files
OUTPUT_FOLDER = None                                  # Output folder (None = auto-create)

# Output options
SAVE_PLOTS = True                                   # Whether to save plots to disk
SAVE_ANIMATIONS = True                               # Whether to create and save animations
INTERACTIVE_DISPLAY = False                            # Whether to display plots interactively

# Visualization subset
VISUALIZE_ALL = True                                  # Process all trajectories
SPECIFIC_FILES = []                                   # List of specific files to process
# =====================================

# Example usage
if __name__ == "__main__":
    folder_path = VISUALIZATION_FOLDER
    
    print(f"Starting visualization of trajectories in {folder_path}")
    
    if VISUALIZE_ALL:
        # Option 1: Visualize all trajectories
        results = visualize_all_trajectories(
            folder_path, 
            output_folder=OUTPUT_FOLDER,
            save_plots=SAVE_PLOTS, 
            save_animations=SAVE_ANIMATIONS
        )
        
        print(f"Generated {len(results)} visualization sets")
    else:
        # Option 2: Process specific trajectory files
        processed_trajectories = load_processed_trajectories(folder_path)
        
        if SPECIFIC_FILES:
            # Filter to specific files if requested
            processed_trajectories = [(filename, df) for filename, df in processed_trajectories 
                                     if any(spec_file in filename for spec_file in SPECIFIC_FILES)]
            
        # Process each trajectory file
        for filename, df in processed_trajectories:
            print(f"Visualizing {filename}")
            
            # Create individual plots
            raw_fig = plot_raw_trajectory(filename, df, show_plot=INTERACTIVE_DISPLAY)
            proportion_fig = plot_proportion_by_segment(filename, df, show_plot=INTERACTIVE_DISPLAY)
            
            # Save plots if requested
            if SAVE_PLOTS and OUTPUT_FOLDER:
                output_path = os.path.join(OUTPUT_FOLDER if OUTPUT_FOLDER else folder_path, 
                                          f"{filename.replace('.csv', '')}")
                raw_fig.savefig(f"{output_path}_raw.png", dpi=DPI)
                proportion_fig.savefig(f"{output_path}_proportion.png", dpi=DPI)
                print(f"Saved plots to {output_path}_*.png")

## 5. Analysis & Metrics Computation

This section calculates behavioral metrics from the tracking data and generates summary statistics. The analysis provides insights into bumblebee behavior in the Y-tube, such as preferences for different odors, movement patterns, and decision-making dynamics.

### Analysis Process:
1. Load tracking files for each video
2. Calculate metrics such as segment preferences and velocity
3. Merge with experimental metadata
4. Generate summary statistics by experimental conditions
5. Create visualization of results

### Key Metrics:
- Time and proportion spent in each Y-tube arm
- Number and frequency of movements between arms
- Velocity and acceleration patterns
- First and final arm choices

The results are saved to a CSV file for further analysis in statistical software.

In [None]:
# ===== ANALYSIS CONFIGURATION =====
# Modify these parameters to configure the analysis process

# Directory and file settings
ANALYSIS_FOLDER = r'C:\Users\YourUsername\Videos'  # Path to processed files
METADATA_FILE = 'metadata_phtalates.csv'                        # Name of the metadata file 
OUTPUT_FOLDER = None                                  # Output folder (None = auto-create)
OUTPUT_FILE = 'results.csv'                           # Name of the output results file

# Filtering parameters
MIN_DURATION = 6.0                                    # Minimum video duration (minutes)
MAX_DURATION = 12.0                                   # Maximum video duration (minutes)
MAX_BOTTOM_RATIO = 0.7                                # Maximum proportion in bottom segment

# Analysis settings
FPS = 10                                              # Frames per second (for time calculations)
GROUP_COLUMN = 'MODA'                            # Main column for grouping results
GENERATE_PLOTS = True                                 # Whether to generate summary plots
SAVE_RESULTS = True                                  # Whether to save summary plots
# ====================================

# Main cell for analyzing insect olfactometer experiment data

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.ndimage import gaussian_filter

print(f"Starting analysis of olfactometer data in {ANALYSIS_FOLDER}")

# Set output folder
if OUTPUT_FOLDER is None:
    OUTPUT_FOLDER = os.path.join(ANALYSIS_FOLDER, 'analysis')
if not os.path.exists(OUTPUT_FOLDER):
    os.makedirs(OUTPUT_FOLDER)
    print(f"Created output directory: {OUTPUT_FOLDER}")


def remove_consecutive_duplicates(lst):
    """
    Remove consecutive duplicate items from a list while preserving order.
    
    Args:
        lst (list): List of items to process
        
    Returns:
        list: List with consecutive duplicates removed
    """
    if not lst:
        return []
    
    result = [lst[0]]
    for elem in lst[1:]:
        if elem != result[-1]:  # Only add if different from the last element
            result.append(elem)
    return result


def count_transition_patterns(segment_list):
    """
    Count different transition patterns between segments.
    
    Args:
        segment_list (list): List of segment labels
        
    Returns:
        dict: Dictionary with counts of different transition patterns
    """
    patterns = {
        "left_no_choice_right": 0,
        "right_no_choice_left": 0,
        "right_no_choice_right": 0,
        "left_no_choice_left": 0,
        "bottom_no_choice_right": 0,
        "bottom_no_choice_left": 0
    }

    for i in range(len(segment_list) - 2):
        # Check for each transition pattern
        if segment_list[i] == "left" and segment_list[i+1] == "no_choice" and segment_list[i+2] == "right":
            patterns["left_no_choice_right"] += 1
        elif segment_list[i] == "right" and segment_list[i+1] == "no_choice" and segment_list[i+2] == "left":
            patterns["right_no_choice_left"] += 1
        elif segment_list[i] == "right" and segment_list[i+1] == "no_choice" and segment_list[i+2] == "right":
            patterns["right_no_choice_right"] += 1
        elif segment_list[i] == "left" and segment_list[i+1] == "no_choice" and segment_list[i+2] == "left":
            patterns["left_no_choice_left"] += 1
        elif segment_list[i] == "bottom" and segment_list[i+1] == "no_choice" and segment_list[i+2] == "right":
            patterns["bottom_no_choice_right"] += 1
        elif segment_list[i] == "bottom" and segment_list[i+1] == "no_choice" and segment_list[i+2] == "left":
            patterns["bottom_no_choice_left"] += 1

    return patterns


def find_first_and_last_choice(segment_list):
    """
    Find the first and last choice (left or right) in a sequence of segments.
    
    Args:
        segment_list (list): List of segment labels
        
    Returns:
        tuple: (first_choice, last_choice) - The first and last "left" or "right" segments
    """
    first_choice = None
    last_choice = None

    for item in segment_list:
        if item in ("left", "right"):
            if first_choice is None:
                first_choice = item.capitalize()  # Store the first occurrence
            last_choice = item.capitalize()  # Keep updating to get the last occurrence

    return first_choice, last_choice


def preprocess_trajectory_data(df, proportion_threshold=0.25):
    """
    Preprocess trajectory data by adjusting proportion values and marking no_choice zones.
    
    Args:
        df (pd.DataFrame): DataFrame containing trajectory data
        proportion_threshold (float): Threshold below which segments are marked as "no_choice"
        
    Returns:
        pd.DataFrame: Preprocessed DataFrame
    """
    # Create a copy to avoid modifying the original
    df_processed = df.copy()
    
    # Convert left segment proportions to negative values
    df_processed.loc[df_processed['closest_segment'] == 'left', 'proportion'] *= -1
    
    # Calculate velocity in proportion space
    df_processed["velocity_prop"] = np.sqrt(df_processed["proportion"].diff()**2)
    
    # Mark low-proportion regions as "no_choice"
    df_processed.loc[abs(df_processed["proportion"]) < proportion_threshold, "closest_segment"] = "no_choice"
    
    return df_processed

In [None]:
def analyze_trajectory(df, file_name, threshold=0.75):
    """
    Analyze the insect trajectory data and compute various behavioral metrics.
    
    Args:
        df (pd.DataFrame): DataFrame containing trajectory data
        file_name (str): Name of the file being analyzed
        threshold (float): Threshold for considering high proportion values
        
    Returns:
        pd.DataFrame: Single-row DataFrame with calculated metrics
    """
    # Validate required columns
    required_columns = {'closest_segment', 'proportion', 'velocity', 'cumulative_distance', 'X_projected'}
    if not required_columns.issubset(df.columns):
        raise ValueError(f"DataFrame must contain the following columns: {required_columns}")
    
    # Initialize metrics dictionary
    metrics = {}
    
    # Calculate segment-specific metrics
    for segment in ['right', 'left']:
        # Time spent in each segment
        time_segment = df['closest_segment'].eq(segment).sum()
        metrics[f"time_{segment}"] = time_segment

        # Sum and average proportion above threshold
        segment_high_prop = df.loc[
            (df['closest_segment'] == segment) & (abs(df['proportion']) > threshold),
            'proportion'
        ]
        sum_proportion = segment_high_prop.sum()
        avg_proportion = sum_proportion / time_segment if time_segment > 0 else None
        
        metrics[f"sum_proportion_{segment}"] = sum_proportion
        metrics[f"avg_proportion_{segment}"] = avg_proportion

        # Count frames above threshold
        time_above_threshold = len(segment_high_prop)
        metrics[f"time_above_threshold_{segment}"] = time_above_threshold

        # Cumulative distance and velocity metrics
        cumulative_distance = df.loc[df['closest_segment'] == segment, 'velocity_prop'].sum()
        
        # Calculate velocity in transition zones (proportion between 0.25 and 0.5)
        velocity = df.loc[
            (df['closest_segment'] == segment) & 
            (df['proportion'].abs().between(0.25, 0.5)),
            'velocity_prop'
        ].mean()
        
        # Handle NaN values
        if pd.isna(velocity):
            velocity = 0
            
        metrics[f"cumulative_distance_{segment}"] = cumulative_distance
        metrics[f"velocity_{segment}"] = velocity

    # Calculate overall metrics
    total_time = metrics["time_right"] + metrics["time_left"]
    total_time_above_threshold = metrics["time_above_threshold_right"] + metrics["time_above_threshold_left"]
    
    # Calculate proportion of time spent in each segment
    metrics["prop_right"] = metrics["time_right"] / total_time if total_time > 0 else None
    metrics["prop_left"] = metrics["time_left"] / total_time if total_time > 0 else None
    
    # Calculate distance and velocity metrics
    total_cumulative_distance = metrics["cumulative_distance_left"] + metrics["cumulative_distance_right"]
    
    # Average velocity in branches (proportion between 0.25 and 0.75)
    avg_velocity_branches = df.loc[
        ((df['closest_segment'] == "right") | (df['closest_segment'] == "left")) &
        (df['proportion'].abs().between(0.25, 0.75)),
        'velocity_prop'
    ].mean()
    
    # Overall average velocity in all branches
    avg_velocity_total = df.loc[
        ((df['closest_segment'] == "right") | (df['closest_segment'] == "left")),
        'velocity_prop'
    ].mean()
    
    metrics["cumulative_distance"] = total_cumulative_distance
    metrics["average_velocity_in_branch"] = avg_velocity_branches
    metrics["overall_average_velocity"] = avg_velocity_total

    # Calculate proportional distance metrics
    if total_cumulative_distance > 0:
        metrics['prop_cumulative_distance_left'] = metrics["cumulative_distance_left"] / total_cumulative_distance
        metrics['prop_cumulative_distance_right'] = metrics["cumulative_distance_right"] / total_cumulative_distance
    else:
        metrics['prop_cumulative_distance_left'] = None
        metrics['prop_cumulative_distance_right'] = None
    
    # Calculate proportion metrics
    total_sum_proportion = abs(metrics["sum_proportion_right"]) + abs(metrics["sum_proportion_left"])
    metrics["sum_proportion_total"] = total_sum_proportion
    
    if total_sum_proportion != 0:
        metrics["prop_proportion_left"] = abs(metrics["sum_proportion_left"]) / total_sum_proportion
        metrics["prop_proportion_right"] = abs(metrics["sum_proportion_right"]) / total_sum_proportion

    else:
        metrics["prop_proportion_left"] = None
        metrics["prop_proportion_right"] = None
   
    
    # Calculate peak-related metrics
    bout_left, bout_right, peak_times = count_peaks_by_segment(df['proportion'], df['closest_segment'])
    
    # Analyze segment transitions
    segment_sequence = remove_consecutive_duplicates(df['closest_segment'].tolist())
    transition_patterns = count_transition_patterns(segment_sequence)
    first_choice_lr, last_choice_lr = find_first_and_last_choice(segment_sequence)
    
    # Store choice-related metrics
    metrics["first_choice_lr"] = first_choice_lr
    metrics["last_choice_lr"] = last_choice_lr
    
    # Store transition pattern counts
    metrics["change_count"] = transition_patterns["right_no_choice_left"] + transition_patterns["left_no_choice_right"]
    metrics["baf_left"] = transition_patterns["left_no_choice_left"] + transition_patterns["bottom_no_choice_left"]
    metrics["baf_right"] = transition_patterns["right_no_choice_right"] + transition_patterns["bottom_no_choice_right"]
    
    # Calculate back-and-forth indices
    total_baf = metrics["baf_right"] + metrics["baf_left"]

    # Store peak counts
    metrics["bout_left"] = bout_left
    metrics["bout_right"] = bout_right
    
    # Calculate peak-related indices
    total_bout = bout_right + bout_left
    if total_bout != 0:
        metrics["prop_bout_left"] = bout_left / total_bout
        metrics["prop_bout_right"] = bout_right / total_bout
    else:
        metrics["prop_bout_left"] = None
        metrics["prop_bout_right"] = None
    
    # Print summary information
    print(f"{file_name}: Left prop = {metrics['prop_left']}, Right prop = {metrics['prop_right']}")
    
    # Add file identifier to results
    metrics["ID"] = file_name
    
    # Return results as a single-row DataFrame
    return pd.DataFrame([metrics])

In [None]:
def load_and_filter_csv_files(folder_path, pattern="processed_*.mp4.csv", 
                              min_duration=6.0, max_duration=12.0, max_bottom_ratio=0.7):
    """
    Load and filter processed CSV trajectory files based on specified criteria.
    
    Args:
        folder_path (str): Path to folder containing CSV files
        pattern (str): Pattern to match filenames
        min_duration (float): Minimum duration in minutes (at 10fps)
        max_duration (float): Maximum duration in minutes (at 10fps)
        max_bottom_ratio (float): Maximum ratio of time spent in bottom segment
        
    Returns:
        tuple: (list of valid filenames, list of valid DataFrames, list of recording lengths)
    """
    # Get all CSV files with the specified pattern
    tracking_csv_files = [file for file in os.listdir(folder_path) 
                        if file.endswith(".csv") and "processed_" in file]
    
    if not tracking_csv_files:
        print(f"No tracking CSV files found in {folder_path}")
        return [], [], []
    
    print(f"Found {len(tracking_csv_files)} tracking files")
    
    # Initialize result lists
    valid_files = []
    valid_dfs = []
    recording_lengths = []
    
    # Process each file
    for csv_filename in tracking_csv_files:
        csv_path = os.path.join(folder_path, csv_filename)
        
        try:
            df_tracking = pd.read_csv(csv_path)
            
            # Filter based on proportion of time in bottom segment
            length_bottom = len(df_tracking.loc[df_tracking['closest_segment'] == 'bottom'])
            length_of_recording = len(df_tracking)
            bottom_ratio = length_bottom / length_of_recording if length_of_recording > 0 else 1.0
            
            # Convert frame count to minutes (assuming 10fps)
            minutes = length_of_recording / (60 * 10)
            
            # Check if file meets criteria
            if bottom_ratio > max_bottom_ratio:
                print(f"Skipping {csv_filename}: Too much time in bottom segment ({bottom_ratio:.1%})")
                continue
                
            if minutes < min_duration or minutes > max_duration:
                print(f"Skipping {csv_filename}: Duration ({minutes:.1f} min) outside range")
                continue
            
            # Extract file name without extension for later use
            pattern = r'processed_(.+)\.mp4'
            match = re.search(pattern, csv_filename)
            
            if match:
                file_name = match.group(1)  # Extract the content between "processed_" and ".mp4"
                
                # Preprocess the data
                df_processed = preprocess_trajectory_data(df_tracking)
                
                # Add to valid lists
                valid_files.append(file_name)
                valid_dfs.append(df_processed)
                recording_lengths.append(length_of_recording)
                
                print(f"Processed {file_name}: {minutes:.1f} min, {bottom_ratio:.1%} bottom")
            else:
                print(f"Skipping {csv_filename}: Filename does not match expected pattern")
                
        except Exception as e:
            print(f"Error processing {csv_filename}: {e}")
    
    print(f"Successfully loaded {len(valid_files)} valid tracking files")
    return valid_files, valid_dfs, recording_lengths

In [None]:
def make_correspondance_df(df):
    """
    Convert left/right metrics to odor/control metrics based on which 
    side contained the odor treatment.
    
    Args:
        df (pd.DataFrame): DataFrame with metrics and 'Cote_Odor' column
        
    Returns:
        pd.DataFrame: DataFrame with metrics reorganized by treatment
    """
    # Create a copy to avoid modifying the original
    df_result = df.copy()
    
    
    # Mapping dictionaries for each orientation
    left_odor_mapping = {
        'time_left': 'time_odor',
        'prop_left': 'prop_odor',
        'sum_proportion_left': 'sum_proportion_odor',
        'avg_proportion_left': 'avg_proportion_odor',
        'time_right': 'time_control',
        'prop_right': 'prop_control',
        'sum_proportion_right': 'sum_proportion_control',
        'avg_proportion_right': 'avg_proportion_control',
        'time_above_threshold_left': 'time_odor_above_threshold',
        'time_above_threshold_right': 'time_control_above_threshold',
        'cumulative_distance_left': 'cumulative_distance_odor',
        'cumulative_distance_right': 'cumulative_distance_control',
        'prop_cumulative_distance_left': 'prop_cumulative_distance_odor',
        'velocity_left': 'average_velocity_odor',
        'velocity_right': 'average_velocity_control',
        'bout_left': 'bout_odor',
        'bout_right': 'bout_control',
        'prop_bout_left': 'prop_bout_odor',
        'prop_bout_right': 'prop_bout_control',
        'prop_proportion_left': 'prop_proportion_odor',
        'prop_proportion_right': 'prop_proportion_control',
        'prop_cumulative_distance_left': 'prop_cumultive_distance_odor',
    }
    
    right_odor_mapping = {
        'time_right': 'time_odor',
        'prop_right': 'prop_odor',
        'sum_proportion_right': 'sum_proportion_odor',
        'avg_proportion_right': 'avg_proportion_odor',
        'time_left': 'time_control',
        'prop_left': 'prop_control',
        'sum_proportion_left': 'sum_proportion_control',
        'avg_proportion_left': 'avg_proportion_control',
        'time_above_threshold_right': 'time_odor_above_threshold',
        'time_above_threshold_left': 'time_control_above_threshold',
        'cumulative_distance_right': 'cumulative_distance_odor',
        'cumulative_distance_left': 'cumulative_distance_control',
        'prop_cumulative_distance_right': 'prop_cumulative_distance_odor',
        'velocity_right': 'average_velocity_odor',
        'velocity_left': 'average_velocity_control',
        'bout_right': 'bout_odor',
        'bout_left': 'bout_control',
        'prop_bout_right': 'prop_bout_odor',
        'prop_bout_left': 'prop_bout_control',
        'prop_proportion_right': 'prop_proportion_odor',
        'prop_proportion_left': 'prop_proportion_control',
        'prop_cumulative_distance_right': 'prop_cumultive_distance_odor',
    }
    
    # Process each row
    for index, row in df.iterrows():
        try:
            if row.get('Cote_Odor') == 'Left':
                # Map metrics based on left odor
                for source, target in left_odor_mapping.items():
                    if source in row and pd.notna(row[source]):
                        df_result.at[index, target] = row[source]
                
                
                # First and last choice
                if row['first_choice_lr'] == "Left":  # Left
                    df_result.at[index, 'first_choice'] = "Odor"  # Chose odor
                else:
                    df_result.at[index, 'first_choice'] = "Control"  # Did not choose odor
                    
                if row['last_choice_lr'] == "Left":  # Left
                    df_result.at[index, 'last_choice'] = "Odor"  # Ended on odor
                else:
                    df_result.at[index, 'last_choice'] = "Control"  # Did not end on odor
                    
            elif row.get('Cote_Odor') == 'Right':
                # Map metrics based on right odor
                for source, target in right_odor_mapping.items():
                    if source in row and pd.notna(row[source]):
                        df_result.at[index, target] = row[source]
                
           
           
                
                # First and last choice
                if row['first_choice_lr'] == "Right":  # Right
                    df_result.at[index, 'first_choice'] = "Odor"  # Chose odor
                else:
                    df_result.at[index, 'first_choice'] = "Control" # Did not choose odor
                    
                if row['last_choice_lr'] == "Right":  # Right
                    df_result.at[index, 'last_choice'] = "Odor" # Ended on odor
                else:
                    df_result.at[index, 'last_choice'] = "Control" # Did not end on odor
                    
            else:
                print(f"Warning: Unknown Cote_Odor value '{row.get('Cote_Odor')}' for index {index}")
                
        except Exception as e:
            print(f"Error processing row {index}: {e}")
    
    return df_result

In [None]:
def generate_summary_statistics(df, group_columns=None, metric_columns=None):
    """
    Generate summary statistics for behavioral metrics.
    
    Args:
        df (pd.DataFrame): DataFrame with behavioral metrics
        group_columns (list): Columns to group by (e.g., ['Dose', 'Treatment', Genotype','Condition'])
        metric_columns (list): Metrics to summarize (if None, use common behavior metrics)
        
    Returns:
        pd.DataFrame: Summary statistics DataFrame
    """
    # Default grouping if none provided
    if group_columns is None:
        # Try to find common grouping columns
        potential_groups = ['Dose', 'Treatment', 'Hive', 'Cov','Genotype','Condition']
        group_columns = [col for col in potential_groups if col in df.columns]
        
        if not group_columns:
            print("No grouping columns found. Using all data as one group.")
            group_columns = []
    
    # Default metrics if none provided
    if metric_columns is None:
        # Common behavior metrics to analyze
        metric_columns = [
        'prop_odor', 'prop_bout_odor', 'prop_proportion_odor',
        'prop_left','prop_bout_left','prop_proportion_left',
        'average_velocity_in_branch','overall_average_velocity'
        ]
    
    # Available metrics
    available_metrics = [col for col in metric_columns if col in df.columns]
    
    if not available_metrics:
        print("Warning: None of the specified metrics are in the DataFrame.")
        return None
    
    # Generate summary statistics
    summary_stats = []
    
    # Handle the case with no grouping
    if not group_columns:
        stat_dict = {'Group': 'All Data', 'Count': len(df)}
        
        # Calculate statistics for each metric
        for metric in available_metrics:
            data = df[metric].dropna()
            if len(data) > 0:
                stat_dict[f'{metric}_mean'] = data.mean()
                stat_dict[f'{metric}_median'] = data.median()
                stat_dict[f'{metric}_std'] = data.std()
                stat_dict[f'{metric}_sem'] = data.sem()
                stat_dict[f'{metric}_n'] = len(data)
                
        summary_stats.append(stat_dict)
    else:
        # Group data and calculate statistics
        grouped = df.groupby(group_columns)
        
        for name, group in grouped:
            # Handle different group types
            if isinstance(name, tuple):
                group_name = ' - '.join([str(x) for x in name])
                stat_dict = {group_col: val for group_col, val in zip(group_columns, name)}
            else:
                group_name = str(name)
                stat_dict = {group_columns[0]: name}
                
            stat_dict['Group'] = group_name
            stat_dict['Count'] = len(group)
            
            # Calculate statistics for each metric
            for metric in available_metrics:
                data = group[metric].dropna()
                if len(data) > 0:
                    stat_dict[f'{metric}_mean'] = data.mean()
                    stat_dict[f'{metric}_median'] = data.median()
                    stat_dict[f'{metric}_std'] = data.std()
                    stat_dict[f'{metric}_sem'] = data.sem()
                    stat_dict[f'{metric}_n'] = len(data)
                    
            summary_stats.append(stat_dict)
    
    # Convert to DataFrame
    summary_df = pd.DataFrame(summary_stats)
    
    return summary_df

In [None]:

def visualize_behavioral_metrics(df, metrics=None, group_col=None, title=None, 
                                output_file=None, figsize=(12, 8)):
    """
    Create a set of visualizations for behavioral metrics from olfactometer experiments.
    
    Args:
        df (pd.DataFrame): DataFrame with behavioral metrics
        metrics (list): List of metrics to visualize (if None, use key metrics)
        group_col (str): Column to use for grouping (e.g., 'Treatment', 'DOSE')
        title (str): Title prefix for the plots
        output_file (str): Path to save the figure (if None, display instead)
        figsize (tuple): Figure size (width, height) in inches
        
    Returns:
        matplotlib.figure.Figure: The generated figure
    """
    # Default metrics if none provided
    if metrics is None:
        # Try to find key metrics
        potential_metrics = [
        'prop_odor', 'prop_bout_odor', 'prop_proportion_odor',
        'prop_left','prop_bout_left','prop_proportion_left',
        'average_velocity_in_branch','overall_average_velocity'
        ]
        metrics = [m for m in potential_metrics if m in df.columns]
        
        if not metrics:
            print("No suitable metrics found for visualization.")
            return None
    
    # Filter to available metrics
    metrics = [m for m in metrics if m in df.columns]
    
    # Detect group column if not specified
    if group_col is None:
        potential_groups = ['Dose', 'Treatment', 'Hive', 'Cov','Genotype','Condition']
        for col in potential_groups:
            if col in df.columns and df[col].nunique() > 1:
                group_col = col
                break
    
    # Set up figure
    n_metrics = len(metrics)
    n_cols = min(3, n_metrics)
    n_rows = (n_metrics + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
    if n_metrics == 1:
        axes = np.array([axes])
    axes = axes.flatten()
    
    # Create a plot for each metric
    for i, metric in enumerate(metrics):
        ax = axes[i]
        
        if group_col:
            # Group data
            sns.boxplot(x=group_col, y=metric, data=df, ax=ax)
            
            # Add individual points
            sns.stripplot(x=group_col, y=metric, data=df, 
                        size=4, color='black', alpha=0.5, ax=ax)
            
            # Rotate x-labels if many groups
            if df[group_col].nunique() > 4:
                plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
        else:
            # Single group, just show distribution
            sns.boxplot(y=metric, data=df, ax=ax)
            sns.stripplot(y=metric, data=df, size=4, color='black', alpha=0.5, ax=ax)

        # Add reference line at 0.5 for proportion metrics
        if metric.startswith('prop_') :
            ax.axhline(y=0.5, color='red', linestyle='--', alpha=0.5)
        
        # Add title and labels
        metric_name = metric.replace('_', ' ').title()
        ax.set_title(metric_name)
        ax.set_ylabel(metric_name)
        ax.grid(True, alpha=0.3)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.tick_params(axis='both', which='major', labelsize=12)
    
    # Hide unused subplots
    for j in range(i+1, len(axes)):
        axes[j].set_visible(False)
    
    # Add overall title
    if title:
        fig.suptitle(title, fontsize=16, y=1.02)
    else:
        fig.suptitle('Behavioral Metrics Summary', fontsize=16, y=1.02)
    
    # Adjust layout
    plt.tight_layout()
    
    # Save or display
    if output_file:
        plt.savefig(output_file, bbox_inches='tight', dpi=300)
        print(f"Figure saved to {output_file}")
    
    return fig


def create_pairplot(df, vars=None, hue=None, figsize=(18, 18), title=None):
    """
    Create a pairplot (scatter plot matrix) of multiple behavioral metrics.
    
    Args:
        df (pd.DataFrame): DataFrame with behavioral metrics
        vars (list): List of variables to include in the pairplot
        hue (str): Column to use for coloring points (e.g., 'Treatment')
        figsize (tuple): Figure size
        title (str): Title for the plot
        
    Returns:
        seaborn.PairGrid: The generated pairplot
    """
    # Default variables if none specified
    if vars is None:
        potential_vars = [
        'prop_odor', 'prop_bout_odor', 'prop_proportion_odor',
        'prop_left','prop_bout_left','prop_proportion_left',
        'average_velocity_in_branch','overall_average_velocity'
        ]
        vars = [v for v in potential_vars if v in df.columns]
        
    # Filter to available variables
    vars = [v for v in vars if v in df.columns]
    
    if len(vars) < 2:
        print("Not enough variables for pairplot")
        return None
    
    # Create pairplot with fill instead of shade (to fix the deprecation warning)
    g = sns.pairplot(df, vars=vars, hue=hue, 
                   height=figsize[0]//len(vars),
                   diag_kind='kde', 
                   plot_kws={'alpha': 0.6, 's': 80, 'edgecolor': 'w'},
                   diag_kws={'fill': True})  # Using fill instead of shade to avoid deprecation warning
    
    # Add title if provided
    if title:
        g.fig.suptitle(title, fontsize=16, y=1.02)
    
    # Format axis labels
    for ax in g.axes.flatten():
        ax.set_xlabel(ax.get_xlabel().replace('_', ' ').title())
        ax.set_ylabel(ax.get_ylabel().replace('_', ' ').title())
    
    # Adjust layout
    g.fig.tight_layout()
    
    return g


def create_heatmap_by_treatment(df, metrics=None, treatment_col=None, figsize=(15, 10)):
    """
    Create a heatmap comparing metrics across different treatment groups.
    
    Args:
        df (pd.DataFrame): DataFrame with behavioral metrics
        metrics (list): List of metrics to include
        treatment_col (str): Column containing treatment groups
        figsize (tuple): Figure size
        
    Returns:
        matplotlib.figure.Figure: The generated figure
    """
    # Default metrics if none provided
    if metrics is None:
        potential_metrics = [
        'prop_odor', 'prop_bout_odor', 'prop_proportion_odor',
        'prop_left','prop_bout_left','prop_proportion_left',
        'average_velocity_in_branch','overall_average_velocity'
        ]
        metrics = [m for m in potential_metrics if m in df.columns]
    
    # Filter to available metrics
    metrics = [m for m in metrics if m in df.columns]
    
    if len(metrics) < 2:
        print("Not enough metrics for heatmap")
        return None
    
    # Detect treatment column if not specified
    if treatment_col is None:
        for col in ['Dose', 'Treatment', 'Hive', 'Cov','Genotype','Condition']:
            if col in df.columns and df[col].nunique() > 1:
                treatment_col = col
                break
    
    if treatment_col is None or treatment_col not in df.columns:
        print("No suitable treatment column found")
        return None
    
    # Aggregate data by treatment
    grouped_data = df.groupby(treatment_col)[metrics].mean()
    
    # Convert data for better visualization
    for col in grouped_data.columns:
        # Center proportion metrics around 0.5
        if col.startswith('prop_') and 'index' not in col:
            grouped_data[col] = grouped_data[col] - 0.5
    
    # Create figure
    fig, ax = plt.subplots(figsize=figsize)
    
    # Create heatmap
    heatmap = sns.heatmap(grouped_data, annot=True, cmap='coolwarm', center=0,
                        linewidths=.5, fmt='.2f', cbar=True, ax=ax)
    
    # Format axis labels
    ax.set_title(f'Behavioral Metrics by {treatment_col}', fontsize=16)
    
    # Format y-tick labels if they're long
    if max(len(str(x)) for x in grouped_data.index) > 10:
        plt.yticks(rotation=0)
    
    # Rename metrics for display
    ax.set_xticklabels([x.replace('_', ' ').title() for x in grouped_data.columns], 
                     rotation=45, ha='right')
    
    plt.tight_layout()
    
    return fig


def create_radar_chart(df, metrics=None, group_col=None, figsize=(12, 10)):
    """
    Create a radar chart comparing metrics across different treatment groups.
    
    Args:
        df (pd.DataFrame): DataFrame with behavioral metrics
        metrics (list): List of metrics to include
        group_col (str): Column for grouping
        figsize (tuple): Figure size
        
    Returns:
        matplotlib.figure.Figure: The generated figure
    """
    # Default metrics if none provided
    if metrics is None:
        potential_metrics = [
        'prop_odor', 'prop_bout_odor', 'prop_proportion_odor',
        'prop_left','prop_bout_left','prop_proportion_left',
        'average_velocity_in_branch','overall_average_velocity'
        ]
        metrics = [m for m in potential_metrics if m in df.columns]
    
    # Filter to available metrics
    metrics = [m for m in metrics if m in df.columns]
    
    if len(metrics) < 3:
        print("At least 3 metrics are needed for a radar chart")
        return None
    
    # Detect group column if not specified
    if group_col is None:
        for col in ['Dose', 'Treatment', 'Hive', 'Cov','Genotype','Condition']:
            if col in df.columns and df[col].nunique() > 1 and df[col].nunique() <= 6:
                group_col = col
                break
    
    if group_col is None or group_col not in df.columns:
        print("No suitable grouping column found")
        return None
    
    # Group data
    grouped_data = df.groupby(group_col)[metrics].mean()
    
    # Normalize data between 0 and 1 for radar chart
    normalized_data = pd.DataFrame(index=grouped_data.index)
    
    for col in metrics:
        # For proportion metrics already between 0-1, use as is
        if col.startswith('prop_') and not col.endswith('index'):
            normalized_data[col] = grouped_data[col]
        # For other metrics, normalize min-max
        else:
            min_val = grouped_data[col].min()
            max_val = grouped_data[col].max()
            if max_val > min_val:
                normalized_data[col] = (grouped_data[col] - min_val) / (max_val - min_val)
            else:
                normalized_data[col] = 0.5
    
    # Set up figure
    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111, polar=True)
    
    # Number of metrics
    N = len(metrics)
    
    # Compute angle for each metric
    angles = [n / N * 2 * np.pi for n in range(N)]
    angles += angles[:1]  # Close the loop
    
    # Set up radar chart
    ax.set_theta_offset(np.pi / 2)
    ax.set_theta_direction(-1)
    
    # Set up axis labels
    plt.xticks(angles[:-1], [m.replace('_', ' ').title() for m in metrics], 
              fontsize=12)
    
    # Draw axis lines
    ax.set_rlabel_position(0)
    plt.yticks([0.25, 0.5, 0.75], ["0.25", "0.5", "0.75"], 
              color="grey", size=10)
    plt.ylim(0, 1)
    
    # Plot each group
    for i, group in enumerate(normalized_data.index):
        values = normalized_data.loc[group].values.flatten().tolist()
        values += values[:1]  # Close the loop
        
        # Plot data
        ax.plot(angles, values, linewidth=2, linestyle='solid', 
               label=str(group))
        ax.fill(angles, values, alpha=0.1)
    
    # Add legend
    plt.legend(loc='upper right', bbox_to_anchor=(0.1, 0.1))
    
    # Add title
    plt.title(f'Behavioral Metrics by {group_col}', size=16, y=1.1)
    
    return fig


def create_correlation_matrix(df, metrics=None, method='pearson', cmap='coolwarm', 
                             figsize=(10, 8), title=None):
    """
    Create a correlation matrix visualization for behavioral metrics.
    
    Args:
        df (pd.DataFrame): DataFrame with behavioral metrics
        metrics (list): List of metrics to include in correlation (if None, use numeric columns)
        method (str): Correlation method ('pearson', 'spearman', or 'kendall')
        cmap (str): Colormap for the heatmap
        figsize (tuple): Figure size
        title (str): Plot title
        
    Returns:
        matplotlib.figure.Figure: The generated figure
    """
    # Select metrics if not specified
    if metrics is None:
        # Use numeric columns excluding IDs and metadata
        exclude_cols = ['ID', 'Frame', 'X_raw', 'Y_raw', 'X_projected', 'Y_projected']
        metrics = [col for col in df.select_dtypes(include=[np.number]).columns 
                  if not any(ex in col for ex in exclude_cols)]
    else:
        # Filter to available metrics
        metrics = [m for m in metrics if m in df.columns]
    
    # Check if we have enough metrics
    if len(metrics) < 2:
        print("Not enough metrics for correlation analysis.")
        return None
    
    # Calculate correlation matrix
    corr_matrix = df[metrics].corr(method=method)
    
    # Create figure
    fig, ax = plt.subplots(figsize=figsize)
    
    # Create heatmap
    mask = np.triu(np.ones_like(corr_matrix, dtype=bool))  # Mask upper triangle
    sns.heatmap(corr_matrix, mask=mask, cmap=cmap, vmin=-1, vmax=1, 
               center=0, square=True, linewidths=.5, annot=True, fmt='.2f',
               cbar_kws={"shrink": .7}, ax=ax)
    
    # Set title
    if title:
        ax.set_title(title, fontsize=14, pad=20)
    else:
        ax.set_title(f'{method.capitalize()} Correlation Matrix', fontsize=14, pad=20)
    
    plt.tight_layout()
    
    return fig


def create_comparison_dashboard(df, group_col=None, output_folder=None):
    """
    Create a comprehensive dashboard of visualizations for comparing groups.
    
    Args:
        df (pd.DataFrame): DataFrame with behavioral metrics
        group_col (str): Column to use for grouping
        output_folder (str): Folder to save visualizations
        
    Returns:
        dict: Dictionary of generated figures
    """
    # Detect group column if not specified
    if group_col is None:
        potential_groups = ['Dose', 'Treatment', 'Hive', 'Cov','Genotype','Condition']
        for col in potential_groups:
            if col in df.columns and df[col].nunique() > 1:
                group_col = col
                break
    
    if group_col is None or group_col not in df.columns:
        print("No suitable grouping column found. Cannot create comparison dashboard.")
        return None
    
    # Create output folder if specified and doesn't exist
    if output_folder and not os.path.exists(output_folder):
        os.makedirs(output_folder)
    
    # Dictionary to store figures
    figures = {}
    
    # 1. Key behavioral metrics visualization
    key_metrics = [
        'prop_odor', 'prop_bout_odor', 'prop_proportion_odor',
        'prop_left','prop_bout_left','prop_proportion_left',
        'average_velocity_in_branch','overall_average_velocity'
    ]
    
    # Filter to available metrics
    key_metrics = [m for m in key_metrics if m in df.columns]
    
    if key_metrics:
        fig1 = visualize_behavioral_metrics(
            df, metrics=key_metrics, group_col=group_col,
            title=f"Key Behavioral Metrics by {group_col}",
            figsize = (24,24)
        )
        figures['key_metrics'] = fig1

    # 2. Heatmap by treatment
    heatmap_metrics = [
        'prop_odor', 'prop_bout_odor', 'prop_proportion_odor',
        'prop_left','prop_bout_left','prop_proportion_left',
        'average_velocity_in_branch','overall_average_velocity'
    ]
    
    # Filter to available metrics
    heatmap_metrics = [m for m in heatmap_metrics if m in df.columns]
    
    if heatmap_metrics:
        fig2 = create_heatmap_by_treatment(
            df, metrics=heatmap_metrics, treatment_col=group_col,
            figsize=(14, 8)
        )
        figures['heatmap'] = fig2
        
    
    # 3. Radar chart
    radar_metrics = [
        'prop_odor', 'prop_bout_odor', 'prop_proportion_odor',
        'prop_left','prop_bout_left','prop_proportion_left',
        'average_velocity_in_branch','overall_average_velocity'
    ]
    
    # Filter to available metrics
    radar_metrics = [m for m in radar_metrics if m in df.columns]
    
    if len(radar_metrics) >= 3:
        fig3 = create_radar_chart(
            df, metrics=radar_metrics, group_col=group_col,
            figsize=(10, 10)
        )
        figures['radar'] = fig3
    
    # 4. Pairplot of key metrics
    pairplot_metrics = [
                'prop_odor', 'prop_bout_odor', 'prop_proportion_odor',
        'prop_left','prop_bout_left','prop_proportion_left',
        'average_velocity_in_branch','overall_average_velocity'
    ]
    
    # Filter to available metrics
    pairplot_metrics = [m for m in pairplot_metrics if m in df.columns]
    
    if len(pairplot_metrics) >= 2:
        g =  create_pairplot(
    final_df, vars=pairplot_metrics, hue=group_cols[0],
    figsize=(24, 24),  # Much larger pairplot
    title=f"Relationships Between Key Metrics by {group_cols[0]}"
)
        figures['pairplot'] = g
    
    # 5. Correlation matrix
    correlation_metrics = [
        'prop_odor', 'prop_bout_odor', 'prop_proportion_odor',
        'prop_left','prop_bout_left','prop_proportion_left',
        'average_velocity_in_branch','overall_average_velocity'
    ]
    
    # Filter to available metrics
    correlation_metrics = [m for m in correlation_metrics if m in df.columns]
    
    if len(correlation_metrics) >= 3:
        fig5 = create_correlation_matrix(
            df, metrics=correlation_metrics,
            title="Correlation Between Behavioral Metrics",
            figsize=(12, 10)
        )
        figures['correlation'] = fig5
    
    return figures

### Analysis Configuration and Run

This cell analyzes the tracking data to extract behavioral metrics and generate summary statistics. It calculates various metrics related to bumblebee preferences, movement patterns, and decision-making in the Y-tube.

#### Key Parameters:
- `ANALYSIS_FOLDER`: Directory containing the processed CSV files
- `METADATA_FILE`: Filename of the experimental metadata CSV
- `OUTPUT_FOLDER`: Where to save results (None = auto-create subfolder)
- `OUTPUT_FILE`: Filename for the results CSV
- `MIN_DURATION`, `MAX_DURATION`: Filter videos by duration (minutes)
- `MAX_BOTTOM_RATIO`: Maximum allowable time in bottom segment
- `FPS`: Frames per second (for time calculations)
- `GROUP_COLUMN`: Main column for grouping in summary statistics
- `GENERATE_PLOTS`: Whether to create summary plots

#### Instructions:
1. Make sure you've completed the tracking step for all videos
2. Ensure your metadata file is in the correct format and location
3. Configure the analysis parameters
4. Run this cell to perform the analysis
5. Review the summary statistics in the output
6. Check the output folder for the results files

#### Analysis Process:
1. Load and filter tracking files based on criteria
2. Calculate metrics for each trajectory
3. Merge with experimental metadata (e.g., odor treatments)
4. Convert left/right metrics to odor/control metrics
5. Generate summary statistics by experimental conditions
6. Save results to CSV files

#### Output Files:
- `results.csv`: Complete results for all videos
- `summary_statistics.csv`: Statistical summary by experimental groups

In [None]:
# Step 1: Load and filter tracking files
print("\n== STEP 1: Loading and filtering tracking files ==")
valid_files, valid_dfs, recording_lengths = load_and_filter_csv_files(
    ANALYSIS_FOLDER,
    min_duration=MIN_DURATION,       # Minimum recording duration in minutes
    max_duration=MAX_DURATION,       # Maximum recording duration in minutes
    max_bottom_ratio=MAX_BOTTOM_RATIO    # Maximum proportion of time in bottom segment
)

print(f"Successfully loaded {len(valid_files)} valid tracking files")
if not valid_files:
    print("No valid files found. Analysis cannot continue.")
else:
    # Step 2: Analyze each trajectory
    print("\n== STEP 2: Analyzing trajectories ==")
    df_results_list = []
    
    for file_name, df in zip(valid_files, valid_dfs):
        try:
            # Analyze the trajectory
            result_df = analyze_trajectory(df, file_name)
            df_results_list.append(result_df)
            print(f"Analyzed {file_name}")
        except Exception as e:
            print(f"Error analyzing {file_name}: {e}")
    
    # Combine all results
    if df_results_list:
        df_results = pd.concat(df_results_list, ignore_index=True)
        df_results["length_of_recording"] = recording_lengths
        print(f"Generated metrics for {len(df_results)} trajectories")
        
        # Step 3: Load and merge with metadata
        print("\n== STEP 3: Merging with metadata ==")
        metadata_path = os.path.join(ANALYSIS_FOLDER, METADATA_FILE)
        
        if os.path.exists(metadata_path):
            try:
                # Detect delimiter
                with open(metadata_path, 'r') as f:
                    first_line = f.readline().strip()
                    delimiter = ';' if ';' in first_line else ','
                
                df_metadata = pd.read_csv(metadata_path, delimiter=delimiter)
                
                # Ensure ID columns are string type for proper joining
                df_metadata['ID'] = df_metadata['ID'].astype(str)
                df_results['ID'] = df_results['ID'].astype(str)
                
                # Merge the results with metadata
                merged_df = pd.merge(df_metadata, df_results, on='ID', how='inner')
                
                # Report merge statistics
                print(f"Merged with metadata: {len(merged_df)} matched entries")
                unmatched_metadata = df_metadata[~df_metadata['ID'].isin(df_results['ID'])]
                unmatched_results = df_results[~df_results['ID'].isin(df_metadata['ID'])]
                
                if not unmatched_metadata.empty:
                    print(f"Warning: {len(unmatched_metadata)} entries in metadata not found in results")
                
                if not unmatched_results.empty:
                    print(f"Warning: {len(unmatched_results)} entries in results not found in metadata")
            
            except Exception as e:
                print(f"Error processing metadata: {e}")
                merged_df = df_results
                
        else:
            print(f"Metadata file {METADATA_FILE} not found. Using results without metadata.")
            merged_df = df_results
        
        # Step 4: Convert left/right metrics to odor/control
        print("\n== STEP 4: Converting left/right metrics to odor/control ==")
        
        if 'Cote_Odor' in merged_df.columns:
            final_df = make_correspondance_df(merged_df)
            print("Successfully converted metrics to odor/control format")
        else:
            print("Warning: 'Cote_Odor' column not found. Using original metrics.")
            final_df = merged_df
        
        # Step 5: Save results
        if True:
            final_csv_path = os.path.join(OUTPUT_FOLDER, OUTPUT_FILE)
            final_df.to_csv(final_csv_path, index=False, sep=';')
            print(f"Final results saved to {final_csv_path}")
        
        # Step 6: Generate and display summary statistics
        print("\n== STEP 6: Generating summary statistics ==")
        
        # Determine grouping columns
        group_cols = []
        if GROUP_COLUMN in final_df.columns and final_df[GROUP_COLUMN].nunique() > 1:
            group_cols = [GROUP_COLUMN]
            print(f"Using '{GROUP_COLUMN}' for grouping")
        else:
            # Try to find alternative grouping columns
            for col in ['Dose', 'Treatment', 'Hive', 'Cov','Genotype','Condition']:
                if col in final_df.columns and final_df[col].nunique() > 1:
                    group_cols = [col]
                    print(f"Using '{col}' for grouping")
                    break
        
        # Generate summary statistics
        summary_stats = generate_summary_statistics(final_df, group_columns=group_cols)
        
        if summary_stats is not None:
            # Display summary
            print("\nSummary Statistics:")
            print("=" * 80)
            # Display only the most important columns
            display_cols = group_cols + ['Count'] if group_cols else ['Group', 'Count']
            key_metrics = ['prop_odor', 'prop_bout_odor', 'prop_proportion_odor',
            'average_velocity_in_branch','overall_average_velocity']
            
            for metric in key_metrics:
                mean_col = f"{metric}_mean"
                sem_col = f"{metric}_sem"
                if mean_col in summary_stats.columns and sem_col in summary_stats.columns:
                    display_cols.extend([mean_col, sem_col])
            
            # Check if display columns exist in the DataFrame
            display_cols = [col for col in display_cols if col in summary_stats.columns]
            
            # Print the summary
            if display_cols:
                print(summary_stats[display_cols].to_string(index=False))
            else:
                print(summary_stats.to_string(index=False))
            
            # Save summary statistics
            summary_path = os.path.join(OUTPUT_FOLDER, 'summary_statistics.csv')
            summary_stats.to_csv(summary_path, index=False, sep=';')
            print(f"\nSummary statistics saved to {summary_path}")
        
        # Step 7: Generate visualizations
        if GENERATE_PLOTS and len(valid_files) > 1:
            print("\n== STEP 7: Generating visualizations ==")
            
            # Create a visualization dashboard with multiple plot types
            if group_cols:
                group_col = group_cols[0]
                dashboard = create_comparison_dashboard(
                    final_df, group_col=group_col, output_folder=None  # Don't save in the function
                )
                
                if dashboard:
                    print(f"Created visualization dashboard with {len(dashboard)} plot types")
                    
                    # Save the plots manually 
                    if SAVE_RESULTS:
                        for plot_name, fig in dashboard.items():
                            try:
                                output_path = os.path.join(OUTPUT_FOLDER, f"{plot_name}.png")
                                if hasattr(fig, 'fig'):  # For PairGrid objects
                                    fig.savefig(output_path, bbox_inches='tight', dpi=400, format='png')
                                else:  # For Figure objects
                                    fig.savefig(output_path, bbox_inches='tight', dpi=400, format='png')
                                print(f"Saved high-quality {plot_name} plot to {output_path}")
                            except Exception as e:
                                print(f"Error saving {plot_name} plot: {e}")
                                        
                    # Show the plots
                    for plot_name, fig in dashboard.items():
                        try:
                            if hasattr(fig, 'fig'):  # For PairGrid objects
                                plt.figure(fig.fig.number)
                            elif hasattr(fig, 'number'):  # For Figure objects
                                plt.figure(fig.number)
                            plt.show()
                        except Exception as e:
                            print(f"Error displaying {plot_name} plot: {e}")
    else:
        print(f"No trajectory data was successfully analyzed. Check for errors above. Path: {FOLDER_PATH}")