In [None]:
from utils_behavior import Sleap_utils
import importlib

from pathlib import Path

from matplotlib import pyplot as plt

import pandas as pd

import numpy as np
import h5py


# Working on one dataset to setup the metrics

In [None]:
importlib.reload(Sleap_utils)

In [None]:
data_path = Path("/mnt/upramdya_data/MD/F1_Tracks/Videos/240924_F1_3mm_ends_Videos_Checked/arena4/Right")
video = list(data_path.glob("*.mp4"))[0]
ball = list(data_path.glob("*ball*.h5"))[0]
fly = list(data_path.glob("*fly*.h5"))[0]

balltracks = Sleap_utils.Sleap_Tracks(ball, object_type="ball")
flytracks = Sleap_utils.Sleap_Tracks(fly, object_type="fly")

Experiment = Sleap_utils.CombinedSleapTracks(video, [flytracks, balltracks])

Experiment.dataset

## Find when the fly exits the first corridor

This corresponds to having its x position atleast 100 px away from the start

In [None]:
# Get the data for object fly_1
fly1 = Experiment.dataset[Experiment.dataset["object"] == "fly_1"]

# Find the first frame where fly_1 x_thorax is > its initial value + 100 px

fly1[fly1["x_thorax"] > fly1["x_thorax"].iloc[0] + 100].iloc[0]

In [None]:
# Display the frame corresponding to the above row

Exit_frame = Experiment.sleap_tracks_list[0].generate_annotated_frame(77816)

plt.imshow(Exit_frame)

## Finding when the second ball is moved

For this one I'm gonna find out when the ball has been displaced 25 px, 50 px, 75 px and 100 px

In [None]:
ball2 = Experiment.dataset[Experiment.dataset["object"] == "ball_2"]

# Initial y_centre value
initial_y_centre = ball2["y_centre"].iloc[0]

# Define the distances to check
distances = [25, 50, 75, 100]

# Find the first row where ball_2 y_centre is greater than 25, 50, 75, and 100 px away from its initial value
for distance in distances:
    filtered_ball2 = ball2[abs(ball2["y_centre"] - initial_y_centre) > distance]
    if not filtered_ball2.empty:
        print(f"Ball reached {distance} px distance:")
        print(filtered_ball2.iloc[0]["time"])
    else:
        print(f"No row found where y_centre is greater than {distance} px away from initial value.")


In [None]:
# This fly apparently never moved the ball. Let's check the last frame to see if the ball is still there
last_frame = Experiment.sleap_tracks_list[1].generate_annotated_frame(Experiment.sleap_tracks_list[1].total_frames - 1)

plt.imshow(last_frame)

In [None]:
# CHeck the range of values of the ball_2 y_centre
ball2["y_centre"].min(), ball2["y_centre"].max()

# Applying to whole dataset

In [None]:
# Get all the directories that contain the videos
data_path = Path("/mnt/upramdya_data/MD/F1_Tracks/Videos/")

# Get all the directories that contain mp4 files
video_paths = list(data_path.glob("**/*.mp4"))

# Initialize a list to store results
results = []

# Define the distances to check for ball movement
distances = [25, 50, 75, 100]


# Process each video folder
for video_path in video_paths:
    video_name = video_path.stem
    fly_files = list(video_path.parent.glob("*fly*.h5"))
    print(f"fly_files: {fly_files}")
    ball_files = list(video_path.parent.glob("*ball*.h5"))
    
    print(f"ball_files: {ball_files}")
    
    if not fly_files or not ball_files:
        print(f"Missing fly or ball file for {video_name}")
        continue
    
    fly = fly_files[0]
    ball = ball_files[0]
    
    # Check if the files are valid HDF5 files
    try:
        with h5py.File(fly, "r") as f:
            pass
        with h5py.File(ball, "r") as b:
            pass
    except OSError as e:
        print(f"Error opening file for {video_name}: {e}")
        # Skip this video
        continue
    
    flytracks = Sleap_utils.Sleap_Tracks(fly, object_type="fly")
    balltracks = Sleap_utils.Sleap_Tracks(ball, object_type="ball")
    
    Experiment = Sleap_utils.CombinedSleapTracks(video_path, [flytracks, balltracks])
    
    # Process fly data
    fly1 = Experiment.dataset[Experiment.dataset["object"] == "fly_1"]
    fly_exit_condition = fly1["x_thorax"] > fly1["x_thorax"].iloc[0] + 100
    if not fly_exit_condition.any():
        print(f"No exit frame found for fly_1 in {video_name}")
        fly_exit_frame = np.nan
        fly_exit_time = np.nan
    else:
        fly_exit_frame = fly1[fly_exit_condition].iloc[0]["frame"]
        fly_exit_time = fly1[fly_exit_condition].iloc[0]["time"]
    
    # Determine which ball object to use and the condition
    if "ball_2" in Experiment.dataset["object"].values:
        ball_object = "ball_2"
        condition = "Pretrained"
    else:
        ball_object = "ball_1"
        condition = "Control"
    
    # Process ball data
    ball_data = Experiment.dataset[Experiment.dataset["object"] == ball_object]
    initial_y_centre = ball_data["y_centre"].iloc[0]
    
    ball_movement_times = {}
    adjusted_ball_movement_times = {}
    for distance in distances:
        filtered_ball_data = ball_data[abs(ball_data["y_centre"] - initial_y_centre) > distance]
        if not filtered_ball_data.empty:
            ball_movement_time = filtered_ball_data.iloc[0]["time"]
            ball_movement_times[distance] = ball_movement_time
            if pd.notna(fly_exit_time):
                adjusted_ball_movement_times[distance] = ball_movement_time - fly_exit_time
            else:
                adjusted_ball_movement_times[distance] = np.nan
        else:
            ball_movement_times[distance] = np.nan
            adjusted_ball_movement_times[distance] = np.nan
    
    # Append the results
    result = {
        "video_name": str(video_path.relative_to(data_path)),
        "condition": condition,
        "fly_exit_time": fly_exit_time,
        **{f"ball_movement_{distance}px": ball_movement_times[distance] for distance in distances},
        **{f"adjusted_ball_movement_{distance}px": adjusted_ball_movement_times[distance] for distance in distances}
    }
    results.append(result)
    
    print(f"Processed {video_name}")

# Convert results to DataFrame
results_df = pd.DataFrame(results)

# Save results to a CSV file
#results_df.to_csv("experiment_results.csv", index=False)

# Display the results
print("Experiment Results:")
results_df.head()

In [None]:
import pandas as pd
import holoviews as hv
from holoviews import opts
hv.extension('bokeh')

# Assuming results_df is your DataFrame with the necessary columns
# Convert the DataFrame to long format for easier plotting with Holoviews
long_df = pd.melt(results_df, id_vars=['video_name', 'condition', 'fly_exit_time'],
                  value_vars=[f'adjusted_ball_movement_{distance}px' for distance in [25, 50, 75, 100]],
                  var_name='distance', value_name='adjusted_time')

# Extract the distance value from the variable name
long_df['distance'] = long_df['distance'].str.extract('(\d+)').astype(int)

# Create the boxplots
boxplots = hv.BoxWhisker(long_df, ['distance', 'condition'], 'adjusted_time')

# Customize the plot
boxplots.opts(
    opts.BoxWhisker(width=800, height=400, box_fill_color=hv.Cycle('Category20'), 
                    whisker_color='black', show_legend=True, legend_position='top_right',
                    xlabel='Distance (px)', ylabel='Adjusted Time', title='Adjusted Time to Move Ball by Distance')
)

# Display the plot
#hv.save(boxplots, 'boxplots.html')
boxplots

In [None]:
# Count how many videos have control
results_df["condition"].value_counts()