In [None]:
import numpy as np
import pandas as pd
import json
import matplotlib.pyplot as plt
import os

def load_data(file_path, px_to_cm):
    """ Read the trajectory data and convert it to centimeter units """
    data = np.load(file_path)
    columns = ['x', 'y']
    df = pd.DataFrame(data, columns=columns)
    df[['x', 'y']] *= px_to_cm 
    return df

def calculate_distance_speed(df, fps=30):
    """ Calculate distance and speed (unit: cm) """
    df['dx'] = df['x'].diff()
    df['dy'] = df['y'].diff()
    df['distance'] = np.sqrt(df['dx']**2 + df['dy']**2)
    df['speed'] = df['distance'] * fps  # Speed in cm/sec

    total_distance = df['distance'].sum()
    avg_speed = df['speed'].mean()
    max_speed = df['speed'].max()
    
    return total_distance, avg_speed, max_speed

def get_px_to_cm(trace_file_path, length_of_side=30):
    """ Get the length of one side of the field from `roi_data.json` and calculate the equivalent value per pixel (cm/px) """
    roi_file = r"roi_data.json"
    with open(roi_file, 'r') as f:
        roi_dict = json.load(f)

    file_name = os.path.basename(trace_file_path).split(".")[0]
    sample_name = "_".join(file_name.split("_")[0:3])
    if sample_name not in roi_dict:
        print(f"ROI coordinates not found: {sample_name}")
        return None
    
    field_corners = np.array(roi_dict[sample_name], dtype=np.int32)
    field_width_px = np.linalg.norm(field_corners[0] - field_corners[1])
    px_to_cm = length_of_side / field_width_px
    return px_to_cm

def analyze_mouse_trajectory(file_path, arena_center=(15, 15), fps=30):
    px_to_cm = get_px_to_cm(file_path)
    if px_to_cm is None:
        return None
    df = load_data(file_path, px_to_cm)
    total_distance, avg_speed, max_speed = calculate_distance_speed(df, fps)
    results = {
        "Total Distance (cm)": total_distance,
        "Average Speed (cm/s)": avg_speed,
        "Max Speed (cm/s)": max_speed,
    }
    return results

input_dir = r"trace_data"
trace_files = [f for f in os.listdir(input_dir) if f.lower().endswith(".npy")]

summary_df = pd.DataFrame()

for file in trace_files:
    file_name = file.split(".")[0]
    label_names = file_name.split("_")[0:3]
    label_name = "_".join(label_names)
    file_path = os.path.join(input_dir, file)
    result = analyze_mouse_trajectory(file_path, arena_center=(15, 15))
    if result is None:
        continue
    result = {
        'sample': "_".join(label_names[0:2]),
        'mouse': label_names[0],
        "date": label_names[2],
        **result
    }
    summary_df = pd.concat([summary_df, pd.DataFrame([result])], ignore_index=True)
summary_df.to_csv("injury_striatum_sammary.csv", index=False)

