# Open field behavior analysis
Determines centroid (center-of-mass) of animal in open field arena videos, then analyzes centroid displacement and plots calculated metrics. Performs Gaussian Mix Modeling to determine thresholds for lingering vs progressing movement.

In [None]:
# Process videos

"""
Processes open-field videos by:
1) Loading and validating each video,
2) Thresholding and refining the largest contour to isolate the animal’s body,
3) Computing geometric properties (area, perimeter, aspect ratio, solidity, orientation),
4) Saving per-frame metrics (e.g., centroid coordinates) to a CSV,
5) Creating a side-by-side masked video with a highlighted centroid.

Expects video filenames of the form: 'AA###_AA##' (e.g. 'CK314_PD11-trim.mp4').

Args:
    folder_path (str): Path to the directory containing the video files.
    threshold_value (int or float): Threshold used for segmenting the subject.f
        Higher values make the mask more selective; lower values make it more inclusive.

Outputs:
    - For each valid video, a CSV file capturing frame-by-frame metrics (centroid, area, perimeter, etc.).
    - A corresponding MP4 video with the original frames side-by-side with the segmented mask.
"""


import cv2
import numpy as np
import pandas as pd
import re
from pathlib import Path
import datetime
import os
import gc

TIMESTAMP = datetime.datetime.now().strftime("%Y-%m-%d")

class OpenFieldVideoProcessor:
    """
    Loads a video file, extracts metadata, and provides processing methods.
    Expects filenames in the format: AA###_AA## (e.g. CK314_PD11-trim.mp4).
    """
    def __init__(self, video_path: str):
        self.file_path = Path(video_path)
        self.animal_id = self.extract_animal_id()
        self.day_id = self.extract_day_id()
        
        cap = cv2.VideoCapture(str(self.file_path))
        if not cap.isOpened():
            raise ValueError(f"Cannot open video file: {self.file_path}")
        self.frame_rate = cap.get(cv2.CAP_PROP_FPS)
        self.num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        self.shape = (
            int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
            int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        )
        cap.release()

    def load_video(self) -> cv2.VideoCapture:
        cap = cv2.VideoCapture(str(self.file_path))
        if not cap.isOpened():
            raise ValueError(f"Cannot open video file: {self.file_path}")
        return cap

    def extract_animal_id(self) -> str:
        pattern = re.compile(r'^[A-Za-z]{2}\d{3}')
        match = pattern.search(self.file_path.stem)
        if match:
            return match.group(0)
        raise ValueError(f"File name {self.file_path.name} does not match expected animal ID format.")

    def extract_day_id(self) -> str:
        pattern = re.compile(r'_[A-Za-z]{2}\d{2}')
        match = pattern.search(self.file_path.stem)
        if match:
            return match.group(0)[1:]
        raise ValueError(f"File name {self.file_path.name} does not match expected day ID format.")

def create_output_directory(input_folder, base_name="analysis"):
    parent_dir = Path(input_folder).resolve().parent
    output_dir = parent_dir / f"{base_name}_{TIMESTAMP}"
    os.makedirs(output_dir, exist_ok=True)
    return output_dir

def process_video(video_path, threshold_value):
    """
    Processes the video and returns a dict with keys:
    'frames', 'masks', 'centroids', 'mask_area', 'mask_perimeter',
    'mask_aspect_ratio', 'mask_solidity', 'mask_orientation'
    """
    video_obj = OpenFieldVideoProcessor(video_path)
    cap = video_obj.load_video()
    frames, masks, centroids = [], [], []
    mask_area_list = []
    mask_perimeter_list = []
    mask_aspect_ratio_list = []
    mask_solidity_list = []
    mask_orientation_list = []

    while True:
        ret, frame = cap.read()
        if not ret:
            break
        frames.append(frame)
        gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        blurred = cv2.GaussianBlur(gray, (11, 11), 0)  # Default is 5, 5
        _, mask = cv2.threshold(blurred, threshold_value, 255, cv2.THRESH_BINARY_INV)
        kernel = np.ones((3, 3), np.uint8)
        mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)

        contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        if contours:
            # Fill the largest contour
            largest = max(contours, key=cv2.contourArea)
            refined = np.zeros_like(mask)
            cv2.drawContours(refined, [largest], -1, 255, thickness=cv2.FILLED)
            
            # Apply distance transform to remove the thin tail
            dist = cv2.distanceTransform(refined, cv2.DIST_L2, 5)
            max_val = dist.max()
            
            # After applying distance transform and thresholding
            _, body_mask = cv2.threshold(dist, 0.1 * max_val, 255, cv2.THRESH_BINARY)
            body_mask = np.uint8(body_mask)

            # Find contours in the new body mask
            body_contours, _ = cv2.findContours(body_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

            # Create a new mask with only the largest body contour
            if body_contours:
                largest_body = max(body_contours, key=cv2.contourArea)
                final_mask = np.zeros_like(body_mask)
                cv2.drawContours(final_mask, [largest_body], -1, 255, thickness=cv2.FILLED)
                mask = final_mask
            else:
                mask = body_mask  # Fallback if no contours found

            # Recompute metrics using contours from the new mask
            new_contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            if new_contours:
                largest_body = max(new_contours, key=cv2.contourArea)
                area = cv2.contourArea(largest_body)
                perimeter = cv2.arcLength(largest_body, True)
                x_rect, y_rect, w_rect, h_rect = cv2.boundingRect(largest_body)
                aspect_ratio = float(w_rect) / h_rect if h_rect > 0 else np.nan
                hull = cv2.convexHull(largest_body)
                hull_area = cv2.contourArea(hull)
                solidity = area / hull_area if hull_area > 0 else np.nan
                if len(largest_body) >= 5:
                    ellipse = cv2.fitEllipse(largest_body)
                    orientation = ellipse[2]
                else:
                    orientation = np.nan
            else:
                area = cv2.contourArea(largest)
                perimeter = cv2.arcLength(largest, True)
                x_rect, y_rect, w_rect, h_rect = cv2.boundingRect(largest)
                aspect_ratio = float(w_rect) / h_rect if h_rect > 0 else np.nan
                hull = cv2.convexHull(largest)
                hull_area = cv2.contourArea(hull)
                solidity = area / hull_area if hull_area > 0 else np.nan
                if len(largest) >= 5:
                    ellipse = cv2.fitEllipse(largest)
                    orientation = ellipse[2]
                else:
                    orientation = np.nan
        else:
            area = np.nan
            perimeter = np.nan
            aspect_ratio = np.nan
            solidity = np.nan
            orientation = np.nan

        mask_area_list.append(area)
        mask_perimeter_list.append(perimeter)
        mask_aspect_ratio_list.append(aspect_ratio)
        mask_solidity_list.append(solidity)
        mask_orientation_list.append(orientation)

        # Compute centroid using moments on the final mask
        M = cv2.moments(mask)
        if M['m00'] != 0:
            cx = M['m10'] / M['m00']
            cy = M['m01'] / M['m00']
            centroids.append((cx, cy))
        else:
            centroids.append(None)

        masks.append(mask)

    cap.release()
    return {
        "frames": frames,
        "masks": masks,
        "centroids": centroids,
        "mask_area": mask_area_list,
        "mask_perimeter": mask_perimeter_list,
        "mask_aspect_ratio": mask_aspect_ratio_list,
        "mask_solidity": mask_solidity_list,
        "mask_orientation": mask_orientation_list
    }, video_obj

def create_masked_video(data, output_path, start_frame=0, end_frame=None, fps=30):
    """
    Creates a side-by-side masked video showing original frames and their binary masks with centroid overlay.
    """
    if not data:
        print("No data provided.")
        return

    frames = data['frames']
    masks = data['masks']
    centroids = data['centroids']
    
    total_frames = len(frames)
    if end_frame is None or end_frame > total_frames:
        end_frame = total_frames

    if start_frame < 0 or start_frame >= total_frames or start_frame >= end_frame:
        print("Invalid frame range.")
        return

    height, width, _ = frames[0].shape
    composite_width = width * 2
    
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(str(output_path), fourcc, fps, (composite_width, height))
    
    for frame, mask, centroid in zip(frames[start_frame:end_frame], masks[start_frame:end_frame], centroids[start_frame:end_frame]):
        frame_copy = frame.copy()
        mask_bgr = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)

        if centroid is not None:
            cv2.circle(frame_copy, (int(centroid[0]), int(centroid[1])), 10, (0, 0, 255), -1)  # cv.2circle expects integer coordinates
            cv2.circle(mask_bgr, (int(centroid[0]), int(centroid[1])), 10, (0, 0, 255), -1)
        
        composite = np.hstack((frame_copy, mask_bgr))
        out.write(composite)
    
    out.release()
    print(f"Masked video saved to: {output_path}")

def main(folder_path, threshold_value):
    folder = Path(folder_path)
    output_dir = create_output_directory(folder, base_name="openfield_analysis_mask")
    masked_videos_dir = output_dir / "masked_videos"
    os.makedirs(masked_videos_dir, exist_ok=True)
    
    video_pattern = re.compile(r'^[A-Za-z]{2}\d{3}_[A-Za-z]{2}\d{2}')
    
    for video_file in folder.iterdir():
        if video_file.is_file() and video_file.suffix.lower() in ['.mp4', '.avi'] and video_pattern.search(video_file.stem):
            try:
                data, video_obj = process_video(str(video_file), threshold_value)
            except ValueError as e:
                print(e)
                continue
            
            frames = list(range(len(data["centroids"])))
            real_times = [f / video_obj.frame_rate for f in frames]
            
            df = pd.DataFrame({
                "frame": frames,
                "real_time_s": real_times,
                "centroid_x": [c[0] if c is not None else np.nan for c in data["centroids"]],
                "centroid_y": [c[1] if c is not None else np.nan for c in data["centroids"]],
                "mask_area": data["mask_area"],
                "mask_perimeter": data["mask_perimeter"],
                "mask_aspect_ratio": data["mask_aspect_ratio"],
                "mask_solidity": data["mask_solidity"],
                "mask_orientation": data["mask_orientation"],
                "animal_id": video_obj.animal_id,
                "day_id": video_obj.day_id
            })
            csv_filename = f"{video_obj.animal_id}_{video_obj.day_id}.csv"
            df.to_csv(output_dir / csv_filename, index=False)
            print(f"{csv_filename} saved to {output_dir}")
            
            masked_video_path = masked_videos_dir / f"{video_obj.animal_id}_{video_obj.day_id}_masked.mp4"
            create_masked_video(
                data=data,
                output_path=masked_video_path,
                start_frame=0,
                end_frame=None,
                fps=int(video_obj.frame_rate)
            )
            
            del data, df
            gc.collect()
    
    print("Per-frame CSVs and masked videos saved in:", output_dir)

if __name__ == "__main__":
    folder_path = r"D:\CK3_open_field\videos"  # Directory containing open field videos
    main(folder_path, threshold_value=75)


In [None]:
# Analyze the centroid data and compute summary metrics

"""
Reads centroid tracking CSVs from open-field experiments, computes:
- Displacement, velocity, acceleration, angle of movement
- Mean velocity, mean acceleration, total distance
- Percent-of-baseline distance (if day ID == "PD00" is available)
- Conversion to centimeters using PX_PER_CM

It saves a processed CSV for each original file with "_processed.csv" appended,
and aggregates all summaries into "summary_metrics.csv".

Args:
    analysis_folder (str): Path to the folder containing centroid CSV files.

Outputs:
    - Per-file processed CSV files: Each original CSV gains columns for displacement,
      velocity, acceleration, angle, and is saved with a "_processed.csv" suffix.
    - A summary_metrics.csv containing aggregate statistics for each file, including
      total distance, mean velocity, mean acceleration, and percent-of-baseline
      distances (both in pixels and centimeters).
"""


import pandas as pd
import numpy as np
import os
import re
from pathlib import Path


def compute_summary_metrics(df):
    times = df["real_time_s"].values
    dt = np.median(np.diff(times)) if len(times) > 1 else 1.0
    frame_rate = 1 / dt if dt != 0 else 1.0

    # Displacement from original centroid
    displacement = [np.nan]
    for i in range(1, len(df)):
        x0, y0 = df.loc[i - 1, ["centroid_x", "centroid_y"]]
        x1, y1 = df.loc[i, ["centroid_x", "centroid_y"]]
        if pd.isna(x0) or pd.isna(y0) or pd.isna(x1) or pd.isna(y1):
            displacement.append(np.nan)
        else:
            dx = x1 - x0
            dy = y1 - y0
            displacement.append(np.sqrt(dx**2 + dy**2))
    displacement = pd.Series(displacement).interpolate().tolist()

    velocity = [np.nan] + [displacement[i] / dt if not np.isnan(displacement[i]) else np.nan 
                            for i in range(1, len(displacement))]

    acceleration = [np.nan]
    for i in range(1, len(velocity)):
        if np.isnan(velocity[i]) or np.isnan(velocity[i - 1]):
            acceleration.append(np.nan)
        else:
            acceleration.append((velocity[i] - velocity[i - 1]) / dt)

    angles = [np.nan]
    for i in range(1, len(df)):
        x0, y0 = df.loc[i - 1, ["centroid_x", "centroid_y"]]
        x1, y1 = df.loc[i, ["centroid_x", "centroid_y"]]
        if pd.isna(x0) or pd.isna(y0) or pd.isna(x1) or pd.isna(y1):
            angles.append(np.nan)
        else:
            dx = x1 - x0
            dy = y1 - y0
            angles.append(np.degrees(np.arctan2(dy, dx)))

    df["displacement_px"] = displacement
    df["velocity_px_s"] = velocity
    df["acceleration_px_s2"] = acceleration
    df["angle_deg"] = angles

    total_distance = np.nansum(displacement)
    mean_velocity = np.nanmean(velocity)
    mean_acceleration = np.nanmean(acceleration)

    return {
        "total_distance_px": total_distance,
        "mean_velocity_px_s": mean_velocity,
        "mean_acceleration_px_s2": mean_acceleration,
        "frame_rate": frame_rate
    }, df

def main(analysis_folder):
    folder = Path(analysis_folder)
    summary_list = []
    pattern = re.compile(r"^[A-Za-z]{2}\d{3}_PD\d{2}\.csv$")

    csv_files = [f for f in folder.iterdir() if f.is_file() and pattern.match(f.name)]
    for csv_file in csv_files:
        df = pd.read_csv(csv_file)
        metrics, df_processed = compute_summary_metrics(df)
        summary_list.append({
            "animal_id": df_processed["animal_id"].iloc[0],
            "day_id": df_processed["day_id"].iloc[0],
            "total_distance_px": metrics["total_distance_px"],
            "mean_velocity_px_s": metrics["mean_velocity_px_s"],
            "mean_acceleration_px_s2": metrics["mean_acceleration_px_s2"]
        })

        processed_path = folder / (csv_file.stem + "_processed.csv")
        df_processed.to_csv(processed_path, index=False)

    if summary_list:
        summary_df = pd.DataFrame(summary_list)
        # Baseline computed from PD00 for pixels.
        baseline_dict = summary_df[summary_df["day_id"] == "PD00"].set_index("animal_id")["total_distance_px"].to_dict()
        summary_df["percent_baseline_distance"] = summary_df.apply(
            lambda row: (row["total_distance_px"] / baseline_dict[row["animal_id"]] * 100)
            if row["animal_id"] in baseline_dict and baseline_dict[row["animal_id"]] != 0 else np.nan,
            axis=1
        )
        # Compute total distance in cm.
        summary_df["total_distance_cm"] = summary_df["total_distance_px"] / PX_PER_CM
        
        # Compute baseline in cm.
        baseline_dict_cm = {animal: distance_px / PX_PER_CM for animal, distance_px in baseline_dict.items()}
        summary_df["total_distance_cm_pct_baseline"] = summary_df.apply(
            lambda row: (row["total_distance_cm"] / baseline_dict_cm[row["animal_id"]] * 100)
            if row["animal_id"] in baseline_dict_cm and baseline_dict_cm[row["animal_id"]] != 0 else np.nan,
            axis=1
        )

        summary_df.to_csv(folder / "summary_metrics.csv", index=False)

    print("Processing complete. Processed CSVs and summary saved in:", folder)


PX_PER_CM = 16.8144

if __name__ == "__main__":
    input_folder = r"D:\CK3_open_field\openfield_analysis_mask_2025-03-31"
    main(input_folder)



In [None]:
# Plotting functions for the processed CSV files

"""
Generates visualizations for processed open-field tracking data, including:
1. Trajectory plots with a user-specified arena rectangle (bounding box),
2. Velocity over time (raw vs. smoothed, in px/s and cm/s),
3. Per-frame metrics plots (mask area, perimeter, etc., with px→cm conversions),
4. Bar charts and line plots summarizing total distance, velocity, and acceleration,
   along with % of baseline metrics when available.

Flow:
    - Reads any "*_processed.csv" files in `analysis_folder`.
    - Creates an "analysis_plots" folder for storing individual plots (e.g., trajectories, velocity, etc.).
    - If "summary_metrics.csv" is present, generates additional summary plots and baseline comparisons.

Args:
    analysis_folder (str): Directory containing processed per-frame CSVs and an optional "summary_metrics.csv".

Outputs:
    - Trajectory PNGs for each file with valid centroid data (includes a bounding box defined by `arena_rect`),
    - Velocity plots (px/s, cm/s, raw vs. smoothed),
    - Line plots for mask metrics (area, perimeter, aspect ratio, solidity, orientation),
    - Summary bar/line plots comparing total distance, velocity, acceleration, and baseline percentages across days/animals.
"""

import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
import os
from matplotlib.patches import Rectangle
from scipy.ndimage import gaussian_filter1d


def plot_velocity_comparison(df, output_dir):
    if "velocity_px_s" not in df.columns or "velocity_px_s_smooth" not in df.columns:
        return
    plt.figure(figsize=(12, 6), dpi=300)
    plt.plot(df["real_time_s"], df["velocity_px_s"], linestyle='-', label='Raw Velocity (px/s)')
    plt.plot(df["real_time_s"], df["velocity_px_s_smooth"], linestyle='-', label='Smoothed Velocity (px/s)')
    plt.title(f"Velocity Comparison: {df.animal_id.iloc[0]} {df.day_id.iloc[0]}")
    plt.xlabel("Time (s)")
    plt.ylabel("Velocity (px/s)")
    plt.legend(loc='best')
    out_path = output_dir / f"{df.animal_id.iloc[0]}_{df.day_id.iloc[0]}_velocity_comparison.png"
    plt.tight_layout()
    plt.savefig(out_path, dpi=300)
    plt.close()

def plot_velocity_comparison_cm(df, output_dir):
    if "velocity_px_s" not in df.columns or "velocity_px_s_smooth" not in df.columns:
        return
    plt.figure(figsize=(12, 6), dpi=300)
    plt.plot(df["real_time_s"], df["velocity_px_s"] / PX_PER_CM, linestyle='-', label='Raw Velocity (cm/s)')
    plt.plot(df["real_time_s"], df["velocity_px_s_smooth"] / PX_PER_CM, linestyle='-', label='Smoothed Velocity (cm/s)')
    plt.title(f"Velocity Comparison (cm/s): {df.animal_id.iloc[0]} {df.day_id.iloc[0]}")
    plt.xlabel("Time (s)")
    plt.ylabel("Velocity (cm/s)")
    plt.legend(loc='best')
    out_path = output_dir / f"{df.animal_id.iloc[0]}_{df.day_id.iloc[0]}_velocity_comparison_cm.png"
    plt.tight_layout()
    plt.savefig(out_path, dpi=300)
    plt.close()

def plot_trajectory(df, output_dir, arena_rect=(25, 25, 650, 650)):
    valid = df.dropna(subset=["centroid_x", "centroid_y"])
    x = valid["centroid_x"].values
    y = valid["centroid_y"].values

    fig, ax = plt.subplots(figsize=(6, 6), dpi=600)
    ax.add_patch(Rectangle((arena_rect[0], arena_rect[1]),
                           arena_rect[2], arena_rect[3],
                           edgecolor='black', facecolor='none', linewidth=2))
    ax.plot(x, y, color='#ff7f0e', linewidth=1)  # Original blue color: #1E88E5, Original orange color: #ff7f0e, New blue: #9AAACE, New orange: #EAA973
    ax.scatter(x[0], y[0], color='black', s=40, zorder=3)
    ax.set_xlim(arena_rect[0], arena_rect[0] + arena_rect[2])
    ax.set_ylim(arena_rect[1], arena_rect[1] + arena_rect[3])
    ax.invert_yaxis()
    ax.set_aspect('equal')
    ax.axis('off')

    animal_id = str(df.animal_id.iloc[0])
    day_id = str(df.day_id.iloc[0])
    out_path = output_dir / f"{animal_id}_{day_id}_trajectory.png"
    plt.savefig(out_path, dpi=600, bbox_inches='tight', pad_inches=0.05)
    plt.close()


def plot_velocity(df, output_dir):
    if "velocity_px_s" not in df.columns:
        return
    plt.figure(figsize=(12, 6), dpi=300)
    plt.plot(df["real_time_s"], df["velocity_px_s"], marker='o', linestyle='-')
    plt.title(f"Velocity Over Time: {df.animal_id.iloc[0]} {df.day_id.iloc[0]}")
    plt.xlabel("Time (s)")
    plt.ylabel("Velocity (px/s)")
    out_path = output_dir / f"{df.animal_id.iloc[0]}_{df.day_id.iloc[0]}_velocity_px_s.png"
    plt.savefig(out_path, dpi=300)
    plt.close()

def plot_velocity_cm(df, output_dir):
    if "velocity_px_s" not in df.columns:
        return
    plt.figure(figsize=(12, 6), dpi=300)
    plt.plot(df["real_time_s"], df["velocity_px_s"] / PX_PER_CM, marker='o', linestyle='-')
    plt.title(f"Velocity Over Time (cm/s): {df.animal_id.iloc[0]} {df.day_id.iloc[0]}")
    plt.xlabel("Time (s)")
    plt.ylabel("Velocity (cm/s)")
    out_path = output_dir / f"{df.animal_id.iloc[0]}_{df.day_id.iloc[0]}_velocity_cm_s.png"
    plt.savefig(out_path, dpi=300)
    plt.close()

def plot_line_parameter(df, parameter, title, ylabel, output_dir, animal_id, day_id):
    plt.figure(figsize=(12, 6), dpi=300)
    plt.plot(df["real_time_s"], df[parameter], marker='o', linestyle='-')
    plt.title(title)
    plt.xlabel("Time (s)")
    plt.ylabel(ylabel)
    out_path = output_dir / f"{animal_id}_{day_id}_{parameter}.png"
    plt.tight_layout()
    plt.savefig(out_path, dpi=300)
    plt.close()

def plot_summary_metrics(summary_df, output_dir):
    summary_df["id"] = summary_df["animal_id"] + "_" + summary_df["day_id"]
    metrics = [
        ("total_distance_px", "Total Distance (px)"),
        ("mean_velocity_px_s", "Mean Velocity (px/s)"),
        ("mean_acceleration_px_s2", "Mean Acceleration (px/s²)")
    ]
    for col, title in metrics:
        plt.figure(figsize=(12, 6), dpi=300)
        plt.bar(summary_df["id"], summary_df[col])
        plt.title(title)
        plt.ylabel(title)
        plt.xlabel("Video ID")
        plt.xticks(rotation=45, ha="right")
        out_path = output_dir / f"summary_{col}.png"
        plt.tight_layout()
        plt.savefig(out_path, dpi=300)
        plt.close()

def plot_summary_metrics_cm(summary_df, output_dir):
    summary_df["id"] = summary_df["animal_id"] + "_" + summary_df["day_id"]
    metrics = [
        ("total_distance_cm", "Total Distance (cm)"),
        ("mean_velocity_cm_s", "Mean Velocity (cm/s)"),
        ("mean_acceleration_cm_s2", "Mean Acceleration (cm/s²)")
    ]
    for col, title in metrics:
        plt.figure(figsize=(12, 6), dpi=300)
        plt.bar(summary_df["id"], summary_df[col])
        plt.title(title)
        plt.ylabel(title)
        plt.xlabel("Video ID")
        plt.xticks(rotation=45, ha="right")
        out_path = output_dir / f"summary_{col}.png"
        plt.tight_layout()
        plt.savefig(out_path, dpi=300)
        plt.close()

def plot_percent_baseline(summary_df, output_dir):
    plt.figure(figsize=(12, 6), dpi=300)
    for animal_id, group in summary_df.groupby("animal_id"):
        sorted_group = group.sort_values("day_id")
        plt.plot(sorted_group["day_id"], sorted_group["percent_baseline_distance"],
                 marker='o', label=animal_id)
    plt.axhline(100, color='gray', linestyle='--', linewidth=1)
    plt.title("Total Distance as % of Baseline (PD00)")
    plt.ylabel("Percent of Baseline Distance (%)")
    plt.xlabel("Day ID")
    plt.legend(title="Animal ID", loc='best')
    plt.tight_layout()
    out_path = output_dir / "percent_baseline_total_distance.png"
    plt.savefig(out_path, dpi=300)
    plt.close()

def plot_raw_total_distance(summary_df, output_dir):
    plt.figure(figsize=(12, 6), dpi=300)
    for animal_id, group in summary_df.groupby("animal_id"):
        sorted_group = group.sort_values("day_id")
        plt.plot(sorted_group["day_id"], sorted_group["total_distance_px"],
                 marker='o', label=animal_id)
    plt.title("Raw Total Distance over Days")
    plt.ylabel("Total Distance (px)")
    plt.xlabel("Day ID")
    plt.legend(title="Animal ID", loc='best')
    plt.tight_layout()
    out_path = output_dir / "raw_total_distance.png"
    plt.savefig(out_path, dpi=300)
    plt.close()

def plot_raw_total_distance_cm(summary_df, output_dir):
    plt.figure(figsize=(12, 6), dpi=300)
    for animal_id, group in summary_df.groupby("animal_id"):
        sorted_group = group.sort_values("day_id")
        plt.plot(sorted_group["day_id"], sorted_group["total_distance_cm"],
                 marker='o', label=animal_id)
    plt.title("Raw Total Distance over Days (cm)")
    plt.ylabel("Total Distance (cm)")
    plt.xlabel("Day ID")
    plt.legend(title="Animal ID", loc='best')
    plt.tight_layout()
    out_path = output_dir / "raw_total_distance_cm.png"
    plt.savefig(out_path, dpi=300)
    plt.close()

def main(analysis_folder):
    folder = Path(analysis_folder)
    plots_dir = folder / "analysis_plots"
    os.makedirs(plots_dir, exist_ok=True)
    
    processed_files = list(folder.glob("**/*_processed.csv"))
    for proc_file in processed_files:
        df = pd.read_csv(proc_file)
        animal_id = df["animal_id"].iloc[0]
        day_id = df["day_id"].iloc[0]
        animal_folder = plots_dir / animal_id
        os.makedirs(animal_folder, exist_ok=True)
        
        plot_trajectory(df, animal_folder)
        plot_velocity(df, animal_folder)
        plot_velocity_cm(df, animal_folder)
        plot_velocity_comparison(df, animal_folder)
        plot_velocity_comparison_cm(df, animal_folder)
        
        parameters = [
            ("mask_area", "Mask Area Over Time", "Area_px2"),
            ("mask_perimeter", "Mask Perimeter Over Time", "Perimeter_px"),
            ("mask_aspect_ratio", "Mask Aspect Ratio Over Time", "Aspect_Ratio"),
            ("mask_solidity", "Mask Solidity Over Time", "Solidity"),
            ("mask_orientation", "Mask Orientation Over Time", "Orientation_deg")
        ]
        for param, title, ylabel in parameters:
            if param in df.columns:
                plot_line_parameter(df, param, title, ylabel, animal_folder, animal_id, day_id)
                if param == "mask_area":
                    # Convert area from px^2 to cm^2.
                    new_param = param + "_cm2"
                    df[new_param] = df[param] / (PX_PER_CM ** 2)
                    new_title = title.replace("Area", "Area (cm²)")
                    plot_line_parameter(df, new_param, new_title, "Area (cm²)", animal_folder, animal_id, day_id)
                elif param == "mask_perimeter":
                    new_param = param + "_cm"
                    df[new_param] = df[param] / PX_PER_CM
                    new_title = title.replace("Perimeter", "Perimeter (cm)")
                    plot_line_parameter(df, new_param, new_title, "Perimeter (cm)", animal_folder, animal_id, day_id)
    
    summary_path = folder / "summary_metrics.csv"
    if summary_path.exists():
        summary_df = pd.read_csv(summary_path)
        # Add cm conversions for summary metrics.
        summary_df["total_distance_cm"] = summary_df["total_distance_px"] / PX_PER_CM
        summary_df["mean_velocity_cm_s"] = summary_df["mean_velocity_px_s"] / PX_PER_CM
        summary_df["mean_acceleration_cm_s2"] = summary_df["mean_acceleration_px_s2"] / PX_PER_CM
        
        plot_summary_metrics(summary_df, plots_dir)
        plot_summary_metrics_cm(summary_df, plots_dir)
        plot_percent_baseline(summary_df, plots_dir)
        plot_raw_total_distance(summary_df, plots_dir)
        plot_raw_total_distance_cm(summary_df, plots_dir)

    print("Plotting complete. Plots saved in:", plots_dir)


INPUT_FOLDER = r"D:\CK3_open_field\openfield_analysis_mask_2025-03-31\temp_recolor"
PX_PER_CM = 16.8144

if __name__ == "__main__":
    main(analysis_folder=INPUT_FOLDER)



In [None]:
# Analyze time spent in the middle of an arena

import pandas as pd
import numpy as np
import re
from pathlib import Path
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

def process_file(file, arena_width, arena_height, output_dir, middle_pct):
    """
    Read CSV with 'centroid_x','centroid_y'; compute:
      - in_middle: inside central square (middle_pct of arena)
      - cross_into_middle: entry events (outside→inside), first frame excluded
      - in_middle_post_cross: in_middle after ≥1 cross_into_middle event
    Returns augmented df and summary metrics.
    """
    df = pd.read_csv(file)
    df.columns = df.columns.str.strip()

    # define center bounds
    m = middle_pct / 100
    x_min, x_max = arena_width * (1 - m) / 2, arena_width * (1 + m) / 2
    y_min, y_max = arena_height * (1 - m) / 2, arena_height * (1 + m) / 2

    # boolean flags
    df['in_middle'] = (
        df['centroid_x'].between(x_min, x_max) &
        df['centroid_y'].between(y_min, y_max)
    )
    df['cross_into_middle'] = df['in_middle'] & ~df['in_middle'].shift(fill_value=False)
    df.loc[df.index[0], 'cross_into_middle'] = False

    # post‐cross flag: true from first entry onward (includes crossing frame)
    df['in_middle_post_cross'] = df['in_middle'] & (df['cross_into_middle'].cumsum() >= 1)

    # save processed CSV
    out_csv = output_dir / f"{file.stem}_middle.csv"
    df.to_csv(out_csv, index=False)

    # summary metrics
    pct_time = df['in_middle'].mean() * 100
    pct_time_post = df['in_middle_post_cross'].mean() * 100
    n_cross = df['cross_into_middle'].sum()

    # extract IDs
    parts = file.stem.split('_')
    animal_id = parts[0] if len(parts) >= 2 else "unknown"
    day_id = parts[1] if len(parts) >= 2 else "unknown"

    # annotate df for plotting if needed
    df['animal_id'], df['day_id'] = animal_id, day_id

    return df, animal_id, day_id, pct_time, pct_time_post, n_cross

def plot_trajectory(df, output_dir, arena_width, arena_height, animal_id, day_id, middle_pct):
    """
    Plot trajectory with arena border and dashed center square.
    """
    valid = df.dropna(subset=["centroid_x","centroid_y"])
    x, y = valid["centroid_x"], valid["centroid_y"]

    fig, ax = plt.subplots(figsize=(6,6), dpi=600)
    # arena border (25px margin)
    ax.add_patch(Rectangle((25,25), arena_width-50, arena_height-50,
                           edgecolor='black', facecolor='none', linewidth=2))
    # center square
    m = middle_pct/100
    cx, cy = arena_width*(1-m)/2, arena_height*(1-m)/2
    ax.add_patch(Rectangle((cx,cy), arena_width*m, arena_height*m,
                           edgecolor='black', facecolor='none', linestyle='--', linewidth=2))
    ax.plot(x, y, linewidth=1)
    ax.scatter(x.iat[0], y.iat[0], color='black', s=40, zorder=3)
    ax.set_xlim(25, arena_width-25)
    ax.set_ylim(25, arena_height-25)
    ax.invert_yaxis()
    ax.set_aspect('equal')
    ax.axis('off')
    plt.savefig(output_dir / f"{animal_id}_{day_id}_trajectory.png", dpi=600,
                bbox_inches='tight', pad_inches=0.05)
    plt.close(fig)

def plot_metric_line(summary_df, metric, ylabel, title, out_path):
    """
    Line plot of metric vs. days post-stroke, per animal.
    """
    plt.figure(figsize=(10,6))
    for ani, grp in summary_df.groupby('animal_id'):
        g = grp.copy()
        g['day_num'] = g['day_id'].str.extract(r'PD(\d{2})').astype(int)
        g = g.sort_values('day_num')
        plt.plot(g['day_num'], g[metric], marker='o', label=ani)
    if metric.startswith('norm_'):
        plt.axhline(100, linestyle='--', color='grey')
    # x‐axis labels
    days = summary_df[['day_id']].drop_duplicates()
    days['day_num'] = days['day_id'].str.extract(r'PD(\d{2})').astype(int)
    days = days.sort_values('day_num')
    plt.xticks(days['day_num'], days['day_id'])
    plt.xlabel("Days Post Stroke")
    plt.ylabel(ylabel)
    plt.title(title)
    plt.legend()
    plt.tight_layout()
    plt.savefig(out_path, dpi=300)
    plt.close()

def process_all_files(analysis_folder, arena_width, arena_height, middle_pct):
    folder = Path(analysis_folder)
    out_root = folder / "time_spent_middle_analysis"
    out_root.mkdir(exist_ok=True, parents=True)
    plots_dir = out_root / "individual_plots"
    plots_dir.mkdir(exist_ok=True, parents=True)

    pattern = re.compile(r"^[A-Za-z]+\d+_PD\d{2}\.csv$")
    files = [f for f in folder.iterdir() if f.is_file() and pattern.match(f.name)]

    summary = []
    for file in files:
        df, ani, day, pct, pct_post, n_cross = process_file(
            file, arena_width, arena_height, out_root, middle_pct
        )
        summary.append({
            'file': file.name,
            'animal_id': ani,
            'day_id': day,
            'pct_time_in_middle': pct,
            'pct_time_in_mid_post_cross': pct_post,
            'n_crossings': n_cross
        })
        plot_trajectory(df, plots_dir, arena_width, arena_height, ani, day, middle_pct)

    return pd.DataFrame(summary), out_root

def main(analysis_folder, arena_width=700, arena_height=700, middle_pct=50):
    summary_df, out_root = process_all_files(
        analysis_folder, arena_width, arena_height, middle_pct
    )
    # normalize metrics to PD00 per animal
    for col in ['pct_time_in_middle', 'pct_time_in_mid_post_cross', 'n_crossings']:
        norm_col = f"norm_{col}"
        summary_df[norm_col] = np.nan
        for ani, grp in summary_df.groupby('animal_id'):
            base = grp[grp['day_id']=='PD00']
            if not base.empty and base.iloc[0][col]:
                factor = 100 / base.iloc[0][col]
                idx = summary_df['animal_id']==ani
                summary_df.loc[idx, norm_col] = summary_df.loc[idx, col] * factor

    # save summary
    summary_df.to_csv(out_root / "middle_analysis_summary.csv", index=False)

    # existing plots
    plot_metric_line(summary_df, 'pct_time_in_middle',
                     'Percentage Time in Middle',
                     'Time Spent in Middle per Animal',
                     out_root / "raw_pct_time_in_middle.png")
    plot_metric_line(summary_df, 'n_crossings',
                     'Number of Crossings',
                     'Crossings into Middle per Animal',
                     out_root / "raw_n_crossings.png")
    plot_metric_line(summary_df, 'norm_pct_time_in_middle',
                     'Normalized % Time in Middle',
                     'Normalized Time in Middle per Animal',
                     out_root / "norm_pct_time_in_middle.png")
    plot_metric_line(summary_df, 'norm_n_crossings',
                     'Normalized Number of Crossings',
                     'Normalized Crossings into Middle per Animal',
                     out_root / "norm_n_crossings.png")
    plot_metric_line(summary_df, 'pct_time_in_mid_post_cross',
                     '% Time Post‐Cross in Middle',
                     'Time in Middle After First Re‐entry',
                     out_root / "raw_pct_time_post_cross.png")
    plot_metric_line(summary_df, 'norm_pct_time_in_mid_post_cross',
                     'Normalized % Post‐Cross Time in Middle',
                     'Normalized Post‐Cross Time in Middle',
                     out_root / "norm_pct_time_post_cross.png")

if __name__ == "__main__":
    ANALYSIS_FOLDER = r"D:\CK3_open_field\openfield_analysis_mask_2025-03-31\total_distance_analysis"
    main(ANALYSIS_FOLDER)


In [None]:
"""
Time Spent Moving Analysis

Automates Gaussian Mixture Model (GMM) analysis on baseline (PD00) files to establish
a velocity threshold distinguishing 'moving' vs. 'stationary' frames, then applies
that threshold to subsequent files, with optional removal of jump timepoints.

Args:
    analysis_folder (str): Directory with '*_PD##.csv' files containing centroid data.
    smoothing_n (int): Rolling window size for velocity smoothing (default 1 = no smoothing).

Outputs:
    - 'gmm_analysis_auto' directory:
      1. 'individual_plots': Histograms for each PD00 file showing the GMM fit,
      2. 'time_spent_moving': Per-day classification CSVs with 'is_moving' and 'is_removed_jump' flags,
         plus a summary of time spent moving (raw and normalized), velocity while moving,
         jump-removal-adjusted velocity metrics, and line plots.
      3. Plots of BIC scores, velocity PDFs, and line plots showing time spent moving both with and without jumps.
"""

import pandas as pd
import numpy as np
import re
from pathlib import Path
from sklearn.mixture import GaussianMixture
import matplotlib.pyplot as plt
from scipy.stats import norm
from scipy.optimize import brentq

# ------------------------------
# Global Variables & Jump Loading
# ------------------------------
ANALYSIS_FOLDER = r"D:\CK3_open_field\openfield_analysis_mask_2025-03-31\time_spent_moving_analysis\temp_velocity"
SMOOTHING_N     = 1       # Simple Moving Average. 1 = no smoothing
PX_PER_CM       = 16.8144 # Conversion factor from pixels to cm.
JUMP_FILE_PATH  = r"D:\CK3_open_field\openfield_analysis_mask_2025-03-31\jumping_timestamps_2025-05-08.csv"  # Set to None to disable

# load jump ranges {(animal, day): [(start,end), ...]}
jumps = {}
if JUMP_FILE_PATH:
    jf = pd.read_csv(JUMP_FILE_PATH)
    jf.columns = jf.columns.str.strip()
    for _, row in jf.dropna(subset=['frame_start','frame_end']).iterrows():
        key = (row['Animal'], row['Day'])
        jumps.setdefault(key, []).append((int(row['frame_start']), int(row['frame_end'])))

# ------------------------------
# GMM Utility Functions
# ------------------------------
def find_intersection(means, covars, weights):
    m0, m1 = means
    s0, s1 = np.sqrt(covars)
    w0, w1 = weights
    def diff(x):
        return w0 * norm.pdf(x, m0, s0) - w1 * norm.pdf(x, m1, s1)
    lower = max(0, min(m0 - 4*s0, m1 - 4*s1))
    upper = max(m0 + 4*s0, m1 + 4*s1)
    xs = np.linspace(lower, upper, 1000)
    diffs = diff(xs)
    changes = np.where(np.diff(np.sign(diffs)))[0]
    if not changes.size:
        return np.nan
    idx = changes[0]
    return brentq(diff, xs[idx], xs[idx+1])

def find_n_components_bic(data, max_components=5):
    bics, models = [], []
    for n in range(1, max_components+1):
        gmm = GaussianMixture(n_components=n, random_state=0).fit(data)
        bics.append(gmm.bic(data))
        models.append(gmm)
    x = np.arange(1, max_components+1)
    y = np.array(bics)
    vec = np.array([x[-1]-x[0], y[-1]-y[0]])
    norm_vec = vec/np.linalg.norm(vec)
    offsets = np.vstack((x-x[0], y-y[0])).T
    proj = np.outer(np.dot(offsets, norm_vec), norm_vec)
    dist = np.linalg.norm(offsets-proj, axis=1)
    best_idx = np.argmax(dist)
    return best_idx+1, models[best_idx], bics

def compute_velocity(df, smoothing_n):
    dx = np.diff(df['centroid_x'])
    dy = np.diff(df['centroid_y'])
    dt = np.diff(df['real_time_s'])
    velocity = np.sqrt(dx**2 + dy**2)/dt
    df['velocity_px_s'] = np.nan
    df.loc[1:, 'velocity_px_s'] = velocity
    df['smoothed_velocity_px_s'] = df['velocity_px_s'].rolling(window=smoothing_n, min_periods=1).mean()
    return df

# ------------------------------
# Plotting Functions
# ------------------------------
def plot_individual_histogram(res, plots_dir, smoothing_n):
    aid = res['base_name'].split('_')[0]
    valid = res['valid_velocities']
    gmm = res['gmm']; thr = res['threshold']; nc = res['n_components']
    vr = np.linspace(valid.min(), valid.max(), 1000).reshape(-1,1)
    pdf_tot = np.exp(gmm.score_samples(vr))
    w, m, c = gmm.weights_, gmm.means_.flatten(), gmm.covariances_.flatten()
    fig, ax = plt.subplots(figsize=(8,5))
    ax.hist(valid, bins=100, density=True, alpha=0.5)
    ax.plot(vr, pdf_tot, color='black')
    for i in range(nc):
        ax.plot(vr, w[i]*norm.pdf(vr.flatten(), m[i], np.sqrt(c[i])), linestyle='--')
    ax.axvline(thr, color='red', linestyle=':')
    ax.set_title(f'{aid} (gmm-{nc}, smooth-{smoothing_n})')
    ax.set_xlabel('Smoothed Velocity (px/s)'); ax.set_ylabel('Density')
    ax.legend([f'Comp {i}' for i in range(nc)] + [f'Thr {thr:.2f}'])
    fig.tight_layout()
    fig.savefig(plots_dir / f"{res['base_name']}_gmm-{nc}_smooth-{smoothing_n}_hist.png", dpi=300)
    plt.close(fig)

def plot_bic_scores(pd00_res, out_root, smoothing_n, global_nc):
    fig, ax = plt.subplots(figsize=(8,6))
    for res in pd00_res.values():
        bics, nc = res['bics'], res['n_components']
        x = np.arange(1, len(bics)+1)
        ax.plot(x, bics, label=f"{res['base_name']} (best={nc})")
        ax.scatter(nc, bics[nc-1], color='black')
    ax.set_xticks(np.arange(1,6))
    ax.set_xlabel('Components'); ax.set_ylabel('BIC')
    ax.set_title(f'BIC Scores (gmm-{global_nc}, smooth-{smoothing_n})'); ax.legend()
    fig.tight_layout()
    fig.savefig(out_root / f"gmm-{global_nc}_smooth-{smoothing_n}_bic_scores_pd00.png", dpi=300)
    plt.close(fig)

def plot_distribution_pdf(pd00_res, out_root, smoothing_n, global_nc):
    fig, ax = plt.subplots(figsize=(10,6))
    for res in pd00_res.values():
        valid, gmm, thr = res['valid_velocities'], res['gmm'], res['threshold']
        vr = np.linspace(valid.min(), valid.max(), 1000).reshape(-1,1)
        ax.plot(vr, np.exp(gmm.score_samples(vr)), label=res['base_name'])
        ax.axvline(thr, linestyle='--', color='grey')
    ax.set_xlabel('Smoothed Velocity (px/s)'); ax.set_ylabel('Density')
    ax.set_title(f'GMM PDFs (gmm-{global_nc}, smooth-{smoothing_n})'); ax.legend()
    fig.tight_layout()
    fig.savefig(out_root / f"gmm-{global_nc}_smooth-{smoothing_n}_velocity_pdfs_pd00.png", dpi=300)
    plt.close(fig)

def plot_time_spent_line(df, folder, smoothing_n, global_nc):
    plt.figure(figsize=(10,6))
    for aid, grp in df.groupby('animal_id'):
        g = grp.copy(); g['dnum'] = g['day_id'].str.extract(r'PD(\d{2})').astype(int)
        g = g.sort_values('dnum')
        plt.plot(g['dnum'], g['pct_time_moving'], marker='o', label=aid)
    days = df[['day_id']].drop_duplicates(); days['dnum'] = days['day_id'].str.extract(r'PD(\d{2})').astype(int)
    days = days.sort_values('dnum')
    plt.xticks(days['dnum'], days['day_id']); plt.ylim(0, None)
    plt.xlabel('Days Post Stroke'); plt.ylabel('% Time Moving')
    plt.title(f'Time Moving (gmm-{global_nc}, smooth-{smoothing_n})'); plt.legend(); plt.tight_layout()
    plt.savefig(folder / f"gmm-{global_nc}_smooth-{smoothing_n}_time_spent_line.png", dpi=300)
    plt.close()

def plot_normalized_time_spent_line(df, folder, smoothing_n, global_nc):
    plt.figure(figsize=(10,6))
    for aid, grp in df.groupby('animal_id'):
        g = grp.copy(); g['dnum'] = g['day_id'].str.extract(r'PD(\d{2})').astype(int)
        g = g.sort_values('dnum')
        plt.plot(g['dnum'], g['pct_time_moving_pct-baseline'], marker='o', label=aid)
    plt.axhline(0, linestyle='--', color='grey')  # baseline at 0% change
    days = df[['day_id']].drop_duplicates(); days['dnum'] = days['day_id'].str.extract(r'PD(\d{2})').astype(int)
    days = days.sort_values('dnum')
    plt.xticks(days['dnum'], days['day_id']); plt.ylim(0, None)
    plt.xlabel('Days Post Stroke'); plt.ylabel('% Change Time Moving')
    plt.title(f'Normalized Time Moving (gmm-{global_nc}, smooth-{smoothing_n})'); plt.legend(); plt.tight_layout()
    plt.savefig(folder / f"gmm-{global_nc}_smooth-{smoothing_n}_time_spent_norm_line.png", dpi=300)
    plt.close()

def plot_velocity_while_moving_line(df, folder, smoothing_n, global_nc):
    plt.figure(figsize=(10,6))
    for aid, grp in df.groupby('animal_id'):
        g = grp.copy(); g['dnum'] = g['day_id'].str.extract(r'PD(\d{2})').astype(int)
        g = g.sort_values('dnum')
        plt.plot(g['dnum'], g['mean_velocity_while_moving_cm_s'], marker='o', label=aid)
    days = df[['day_id']].drop_duplicates(); days['dnum'] = days['day_id'].str.extract(r'PD(\d{2})').astype(int)
    days = days.sort_values('dnum')
    plt.xticks(days['dnum'], days['day_id']); plt.ylim(0, None)
    plt.xlabel('Days Post Stroke'); plt.ylabel('Mean Vel (cm/s)')
    plt.title(f'Velocity Moving (gmm-{global_nc}, smooth-{smoothing_n})'); plt.legend(); plt.tight_layout()
    plt.savefig(folder / f"gmm-{global_nc}_smooth-{smoothing_n}_vel_moving_line.png", dpi=300)
    plt.close()

def plot_velocity_while_moving_normalized_line(df, folder, smoothing_n, global_nc):
    plt.figure(figsize=(10,6))
    for aid, grp in df.groupby('animal_id'):
        g = grp.copy(); g['dnum'] = g['day_id'].str.extract(r'PD(\d{2})').astype(int)
        g = g.sort_values('dnum')
        plt.plot(g['dnum'], g['mean_velocity_while_moving_pct-baseline'], marker='o', label=aid)
    plt.axhline(0, linestyle='--', color='grey')
    days = df[['day_id']].drop_duplicates(); days['dnum'] = days['day_id'].str.extract(r'PD(\d{2})').astype(int)
    days = days.sort_values('dnum')
    plt.xticks(days['dnum'], days['day_id']); plt.ylim(0, None)
    plt.xlabel('Days Post Stroke'); plt.ylabel('% Change Vel')
    plt.title(f'Normalized Velocity Moving (gmm-{global_nc}, smooth-{smoothing_n})'); plt.legend(); plt.tight_layout()
    plt.savefig(folder / f"gmm-{global_nc}_smooth-{smoothing_n}_vel_moving_norm_line.png", dpi=300)
    plt.close()

def plot_velocity_nojumps_line(df, folder, smoothing_n, global_nc):
    plt.figure(figsize=(10,6))
    for aid, grp in df.groupby('animal_id'):
        g = grp.copy(); g['dnum'] = g['day_id'].str.extract(r'PD(\d{2})').astype(int)
        g = g.sort_values('dnum')
        plt.plot(g['dnum'], g['mean_velocity_while_moving_nojumps_cm_s'], marker='o', label=aid)
    days = df[['day_id']].drop_duplicates(); days['dnum'] = days['day_id'].str.extract(r'PD(\d{2})').astype(int)
    days = days.sort_values('dnum')
    plt.xticks(days['dnum'], days['day_id']); plt.ylim(0, None)
    plt.xlabel('Days Post Stroke'); plt.ylabel('Mean Vel No Jumps (cm/s)')
    plt.title(f'Velocity No Jumps (gmm-{global_nc}, smooth-{smoothing_n})'); plt.legend(); plt.tight_layout()
    plt.savefig(folder / f"gmm-{global_nc}_smooth-{smoothing_n}_vel_nojumps_line.png", dpi=300)
    plt.close()

def plot_velocity_nojumps_normalized_line(df, folder, smoothing_n, global_nc):
    plt.figure(figsize=(10,6))
    for aid, grp in df.groupby('animal_id'):
        g = grp.copy(); g['dnum'] = g['day_id'].str.extract(r'PD(\d{2})').astype(int)
        g = g.sort_values('dnum')
        plt.plot(g['dnum'], g['mean_velocity_while_moving_nojumps_pct-baseline'], marker='o', label=aid)
    plt.axhline(0, linestyle='--', color='grey')
    days = df[['day_id']].drop_duplicates(); days['dnum'] = days['day_id'].str.extract(r'PD(\d{2})').astype(int)
    days = days.sort_values('dnum')
    plt.xticks(days['dnum'], days['day_id']); plt.ylim(0, None)
    plt.xlabel('Days Post Stroke'); plt.ylabel('% Change Vel No Jumps')
    plt.title(f'Normalized Vel No Jumps (gmm-{global_nc}, smooth-{smoothing_n})'); plt.legend(); plt.tight_layout()
    plt.savefig(folder / f"gmm-{global_nc}_smooth-{smoothing_n}_vel_nojumps_norm_line.png", dpi=300)
    plt.close()

# ------------------------------
# PD00 Processing with Jump Removal
# ------------------------------
def process_pd00_file(file, smoothing_n):
    df = pd.read_csv(file)
    df.columns = df.columns.str.strip()
    df = compute_velocity(df, smoothing_n)
    aid, day = file.stem.split('_')
    # remove jumps before GMM
    for start, end in jumps.get((aid, day), []):
        df = df[~df['frame'].between(start, end)]
    valid = df['smoothed_velocity_px_s'].dropna().values.reshape(-1,1)
    if len(valid) < 2:
        return None
    nc, gmm, bics = find_n_components_bic(valid)
    w, m, c = gmm.weights_, gmm.means_.flatten(), gmm.covariances_.flatten()
    idx = np.argsort(m)
    thr = find_intersection(m[idx[:2]], c[idx[:2]], w[idx[:2]]) if nc>=2 else np.nan
    return {
        'base_name': file.stem,
        'df': df,
        'valid_velocities': valid,
        'n_components': nc,
        'gmm': gmm,
        'bics': bics,
        'threshold': thr
    }

def process_pd00_files(pd00_files, smoothing_n, plots_dir):
    results = {}
    for f in pd00_files:
        res = process_pd00_file(f, smoothing_n)
        if not res:
            continue
        results[res['base_name']] = res
        plot_individual_histogram(res, plots_dir, smoothing_n)
    return results

def save_pd00_summary_metrics(pd00_res, out_root, smoothing_n, global_nc):
    summary = []
    for key, res in pd00_res.items():
        aid, day = key.split('_')
        summary.append({
            'file': f"{key}.csv",
            'animal_id': aid,
            'day_id': day,
            'n_components': res['n_components'],
            'threshold': res['threshold']
        })
    if summary:
        pd.DataFrame(summary).to_csv(
            out_root / f"gmm-{global_nc}_smooth-{smoothing_n}_PD00_summary_metrics.csv",
            index=False
        )

# ------------------------------
# Main Analysis
# ------------------------------
def process_all_files(folder, smoothing_n, thresholds, time_dir):
    pattern = re.compile(r"^[A-Za-z]{2}\d{3}_PD\d{2}\.csv$")
    files = [f for f in folder.iterdir() if f.is_file() and pattern.match(f.name)]
    summary, base_pct, base_vel, base_vel_noj = [], {}, {}, {}

    for f in files:
        aid, day = f.stem.split('_')
        if aid not in thresholds:
            continue
        df = pd.read_csv(f); df.columns = df.columns.str.strip()
        df = compute_velocity(df, smoothing_n)
        df['is_removed_jump'] = False
        for start, end in jumps.get((aid, day), []):
            df.loc[df['frame'].between(start, end), 'is_removed_jump'] = True
        thr = thresholds[aid]['threshold']
        df['is_moving'] = df['smoothed_velocity_px_s'] > thr

        pct = df['is_moving'].mean() * 100
        mv = df.loc[df['is_moving'], 'smoothed_velocity_px_s']
        mv_cm = mv.mean()/PX_PER_CM if not mv.empty else np.nan
        mv_noj = df.loc[df['is_moving'] & ~df['is_removed_jump'], 'smoothed_velocity_px_s']
        mv_noj_cm = mv_noj.mean()/PX_PER_CM if not mv_noj.empty else np.nan

        if day == "PD00":
            base_pct[aid] = pct
            base_vel[aid] = mv_cm
            base_vel_noj[aid] = mv_noj_cm

        df.to_csv(
            time_dir / f"{f.stem}_gmm-{thresholds[aid]['n_components']}_smooth-{smoothing_n}.csv",
            index=False
        )

        summary.append({
            'animal_id': aid,
            'day_id': day,
            'pct_time_moving': pct,
            'pct_time_moving_pct-baseline': np.nan,
            'mean_velocity_while_moving_cm_s': mv_cm,
            'mean_velocity_while_moving_pct-baseline': np.nan,
            'mean_velocity_while_moving_nojumps_cm_s': mv_noj_cm,
            'mean_velocity_while_moving_nojumps_pct-baseline': np.nan,
            'n_components': thresholds[aid]['n_components'],
            'smoothing_n': smoothing_n
        })

    for e in summary:
        p0 = base_pct.get(e['animal_id'], np.nan)
        e['pct_time_moving_pct-baseline'] = (e['pct_time_moving']/p0*100) if p0 else np.nan
        v0 = base_vel.get(e['animal_id'], np.nan)
        e['mean_velocity_while_moving_pct-baseline'] = (e['mean_velocity_while_moving_cm_s']/v0*100) if v0 else np.nan
        vnj0 = base_vel_noj.get(e['animal_id'], np.nan)
        e['mean_velocity_while_moving_nojumps_pct-baseline'] = (e['mean_velocity_while_moving_nojumps_cm_s']/vnj0*100) if vnj0 else np.nan

    df_sum = pd.DataFrame(summary)[[
        'animal_id','day_id',
        'pct_time_moving','pct_time_moving_pct-baseline',
        'mean_velocity_while_moving_cm_s','mean_velocity_while_moving_pct-baseline',
        'mean_velocity_while_moving_nojumps_cm_s','mean_velocity_while_moving_nojumps_pct-baseline',
        'n_components','smoothing_n'
    ]]
    return df_sum

def main(analysis_folder, smoothing_n=SMOOTHING_N):
    folder = Path(analysis_folder)
    out = folder / "gmm_analysis_auto"
    ip = out / "individual_plots"
    tp = out / "time_spent_moving"
    ip.mkdir(parents=True, exist_ok=True)
    tp.mkdir(parents=True, exist_ok=True)

    # PD00 analysis
    pd00 = [f for f in folder.iterdir() if f.name.endswith("_PD00.csv")]
    pd00_res = process_pd00_files(pd00, smoothing_n, ip)
    if not pd00_res:
        print("No valid PD00 files found."); return

    # use first PD00's n_components for naming
    global_nc = list(pd00_res.values())[0]['n_components']  # Note this only works if all PD00 files have the same n_components

    save_pd00_summary_metrics(pd00_res, out, smoothing_n, global_nc)
    plot_bic_scores(pd00_res, out, smoothing_n, global_nc)
    plot_distribution_pdf(pd00_res, out, smoothing_n, global_nc)

    thresholds = {
        k.split('_')[0]: {'threshold': v['threshold'], 'n_components': v['n_components']}
        for k,v in pd00_res.items()
    }

    summary_df = process_all_files(folder, smoothing_n, thresholds, tp)

    # summary file
    summary_df.to_csv(
        tp / f"gmm-{global_nc}_smooth-{smoothing_n}_time_spent_moving_summary.csv",
        index=False
    )

    # plots
    plot_time_spent_line(summary_df, tp, smoothing_n, global_nc)
    plot_normalized_time_spent_line(summary_df, tp, smoothing_n, global_nc)
    plot_velocity_while_moving_line(summary_df, tp, smoothing_n, global_nc)
    plot_velocity_while_moving_normalized_line(summary_df, tp, smoothing_n, global_nc)
    plot_velocity_nojumps_line(summary_df, tp, smoothing_n, global_nc)
    plot_velocity_nojumps_normalized_line(summary_df, tp, smoothing_n, global_nc)

if __name__ == "__main__":
    main(ANALYSIS_FOLDER, SMOOTHING_N)


In [None]:
# Create videos with moving/still overlays based on GMM velocity thresholds
"""
Overlays a color-coded "moving" or "still" status on each video frame using thresholds
derived from a GMM or similar baseline approach.

1. Extracts a velocity threshold (or thresholds) from a summary metrics CSV.
2. For each per-frame CSV in `csv_folder` (matching '*_processed.csv'):
    - Computes smoothed velocity (rolling average) from centroid displacements.
    - Finds the corresponding video (named '{animal_id}_{day_id}-trim_crop.mp4').
    - For each frame, compares velocity vs. threshold, overlaying "PROGRESS" (moving)
    or "LINGER" (still).
3. Saves a new annotated video to an output directory named according to:
    the threshold type ("baseline" or "mean") plus the GMM parameters.

Args:
    videos_folder (str or Path):
        Directory containing the raw/truncated videos (named '{animal_id}_{day_id}-trim_crop.mp4').
    csv_folder (str or Path):
        Directory containing '*_processed.csv' files with centroid/time columns.
    threshold_csv_path (str or Path):
        Path to the summary metrics CSV that includes a 'threshold' column.
        Expected filename format: 'method-method_n_smooth-smoothing_n_summary_metrics.csv'.
    frame_width (int):
        Desired output video width in pixels.
    frame_height (int):
        Desired output video height in pixels.
    threshold_type (str):
        Either 'baseline' or 'mean'.
        - 'baseline': Use each animal's PD00 threshold.
        - 'mean': Use the mean PD00 threshold across all animals.

Notes:
    - If video length > CSV rows, extra frames default to zero velocity.
    - If CSV rows > video length, the extra rows go unused.
    - Filenames for both videos and CSVs must follow a consistent pattern to match
    animals and days (e.g., '{animal_id}_{day_id}-trim_crop.mp4' and
    '{animal_id}_{day_id}_processed.csv').
    - In 'baseline' mode, animals without PD00 entries in `threshold_csv_path` are skipped.
"""

import cv2
import pandas as pd
import numpy as np
import re
from pathlib import Path

def extract_metadata_from_filename(path):
    """
    Extracts metadata from a threshold CSV filename.
    
    Expected filename format: method-method_n_smooth-smoothing_n_summary_metrics.csv
    Example: "gmm-2_smooth-1_summary_metrics.csv"
    
    Returns:
      - method: The base method name (e.g., "gmm")
      - method_n: The method version/number (e.g., "2")
      - smoothing_n: The smoothing factor as an integer (e.g., 1)
    """
    filename = Path(path).stem
    match = re.match(r"([^-]+)-(\d+)_smooth-(\d+)_summary_metrics", filename)
    if not match:
        raise ValueError("Filename must follow format: method-method_n_smooth-smoothing_n_summary_metrics")
    method, method_n, smoothing_n = match.groups()
    return method, method_n, int(smoothing_n)

def process_videos(videos_folder, csv_folder, threshold_csv_path, frame_width, frame_height, threshold_type):
    """
    Process videos to display moving/still status based on either animal-specific PD00 thresholds (baseline)
    or the mean PD00 threshold (mean).
    """
    method, method_n, smoothing_n = extract_metadata_from_filename(threshold_csv_path)
    print(f"Using threshold method: {method}-{method_n}")
    print(f"Smoothing factor: {smoothing_n}")

    df_thresh = pd.read_csv(threshold_csv_path)

    if threshold_type == "baseline":
        df_thresh = df_thresh[df_thresh["day_id"] == "PD00"]
        animal_thresholds = dict(zip(df_thresh["animal_id"], df_thresh["threshold"]))
        output_dir = Path(csv_folder) / f"moving_threshold_videos_{method}-{method_n}_smooth-{smoothing_n}_thresh-Baseline"
    elif threshold_type == "mean":
        avg_thresh = df_thresh[df_thresh["day_id"] == "PD00"]["threshold"].mean()
        animal_thresholds = None
        output_dir = Path(csv_folder) / f"moving_threshold_videos_{method}-{method_n}_smooth-{smoothing_n}_thresh-{avg_thresh:.2f}"
    else:
        raise ValueError("threshold_type must be either 'baseline' or 'mean'")

    output_dir.mkdir(parents=True, exist_ok=True)

    for csv_file in Path(csv_folder).glob("*_processed.csv"):
        match = re.match(r"([A-Za-z0-9]+)_([A-Za-z0-9]+)_processed\.csv", csv_file.name)
        if not match:
            continue
        animal_id, day_id = match.groups()

        if threshold_type == "baseline":
            if animal_id not in animal_thresholds:
                print(f"Missing PD00 threshold for {animal_id}, skipping.")
                continue
            thresh = animal_thresholds[animal_id]
        else:
            thresh = avg_thresh

        df = pd.read_csv(csv_file)

        if not {'centroid_x', 'centroid_y', 'real_time_s'}.issubset(df.columns):
            print(f"Missing centroid/time columns in {csv_file.name}, skipping.")
            continue

        dx = np.diff(df["centroid_x"])
        dy = np.diff(df["centroid_y"])
        dt = np.diff(df["real_time_s"])
        velocity = np.sqrt(dx**2 + dy**2) / dt
        df["velocity_px_s"] = np.nan
        df.loc[1:, "velocity_px_s"] = velocity
        df["velocity_px_s"] = df["velocity_px_s"].rolling(window=smoothing_n, min_periods=1).mean()

        video_name = f"{animal_id}_{day_id}-trim_crop.mp4"
        video_path = Path(videos_folder) / video_name
        if not video_path.exists():
            print(f"Video not found: {video_path}")
            continue

        cap = cv2.VideoCapture(str(video_path))
        fps = cap.get(cv2.CAP_PROP_FPS)

        if threshold_type == "baseline":
            out_name = f"{animal_id}_{day_id}-moving_{method}-{method_n}_thresh-Baseline_smooth-{smoothing_n}.mp4"
        else:
            out_name = f"{animal_id}_{day_id}-moving_{method}-{method_n}_thresh-{avg_thresh:.2f}_smooth-{smoothing_n}.mp4"

        out_path = output_dir / out_name
        fourcc = cv2.VideoWriter_fourcc(*"mp4v")
        out = cv2.VideoWriter(str(out_path), fourcc, fps, (frame_width, frame_height))

        frame_idx = 0
        font = cv2.FONT_HERSHEY_SIMPLEX
        font_scale = 1
        thickness = 2

        while True:
            ret, frame = cap.read()
            if not ret:
                break

            frame = cv2.resize(frame, (frame_width, frame_height))
            velocity_val = df.loc[frame_idx, "velocity_px_s"] if frame_idx < len(df) else 0

            if pd.isna(velocity_val) or velocity_val < thresh:
                status_text, color = "LINGER", (0, 0, 255)
            else:
                status_text, color = "PROGRESS", (0, 255, 0)

            text_size, _ = cv2.getTextSize(status_text, font, font_scale, thickness)
            text_width, text_height = text_size
            text_x = (frame_width - text_width) // 2
            text_y = (frame_height + text_height) // 2

            cv2.putText(frame, status_text, (text_x, text_y), font, font_scale, color, thickness)
            out.write(frame)
            frame_idx += 1

        cap.release()
        out.release()
        print(f"Processed: {out_path}")

VIDEOS_FOLDER = r"D:\CK3_open_field\videos"
CSV_FOLDER = r"D:\CK3_open_field\openfield_analysis_mask_2025-03-31"
THRESHOLD_CSV_PATH = r"D:\CK3_open_field\openfield_analysis_mask_2025-03-31\gmm_analysis_auto\gmm-2_smooth-1_summary_metrics.csv"
FRAME_WIDTH = 700
FRAME_HEIGHT = 700
THRESHOLD_TYPE = "baseline"  # or "mean"

if __name__ == "__main__":
    process_videos(
        videos_folder=VIDEOS_FOLDER,
        csv_folder=CSV_FOLDER,
        threshold_csv_path=THRESHOLD_CSV_PATH,
        frame_width=FRAME_WIDTH,
        frame_height=FRAME_HEIGHT,
        threshold_type=THRESHOLD_TYPE
    )
