In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path

from Ballpushing_utils import *

In [None]:
# Victor function commented
def find_events_from(
    signal,
    limit_values,
    gap_between_events,
    event_min_length,
    omit_events=None,
    plot_signals=False,
    signal_name="",
):
    """
    This function finds events in a given signal based on certain criteria.

    Parameters:
    signal (list): The signal in which to find events.
    limit_values (list): The lower and upper limit values for the signal.
    gap_between_events (int): The minimum gap required between two events.
    event_min_length (int): The minimum length of an event.
    omit_events (list, optional): A range of events to omit. Defaults to None.
    plot_signals (bool, optional): Whether to plot the signals or not. Defaults to False.
    signal_name (str, optional): The name of the signal. Defaults to "".

    Returns:
    list: A list of events found in the signal.
    """

    # Initialize the list of events
    events = []

    # Find all frames where the signal is within the limit values
    all_frames_above_lim = np.where(
        (np.array(signal) > limit_values[0]) & (np.array(signal) < limit_values[1])
    )[0]

    # If no frames are found within the limit values, return an empty list
    if len(all_frames_above_lim) == 0:
        if plot_signals:
            print(f"Any point is between {limit_values[0]} and {limit_values[1]}")
            plt.plot(signal, label=f"{signal_name}-filtered")
            plt.legend()
            plt.show()
        return events

    # Find the distance between consecutive frames
    distance_betw_frames = np.diff(all_frames_above_lim)

    # Find the points where the distance between frames is greater than the gap between events
    split_points = np.where(distance_betw_frames > gap_between_events)[0]

    # Add the first and last points to the split points
    split_points = np.insert(split_points, 0, -1)
    split_points = np.append(split_points, len(all_frames_above_lim) - 1)

    # Plot the signal if required
    if plot_signals:
        limit_value = limit_values[0] if limit_values[1] == np.inf else limit_values[1]
        print(all_frames_above_lim[split_points])
        plt.plot(signal, label=f"{signal_name}-filtered")

    # Iterate over the split points to find events
    for f in range(0, len(split_points) - 1):
        # If the gap between two split points is less than 2, skip to the next iteration
        if split_points[f + 1] - split_points[f] < 2:
            continue

        # Define the start and end of the region of interest (ROI)
        start_roi = all_frames_above_lim[split_points[f] + 1]
        end_roi = all_frames_above_lim[split_points[f + 1]]

        # If there are events to omit and the start of the ROI is within these events, adjust the start of the ROI
        if omit_events:
            if (
                start_roi >= omit_events[0]
                and start_roi < omit_events[1]
                and end_roi < omit_events[1]
            ):
                continue
            elif (
                start_roi >= omit_events[0]
                and start_roi < omit_events[1]
                and end_roi > omit_events[1]
            ):
                start_roi = int(omit_events[1])

        # Calculate the duration of the event
        duration = end_roi - start_roi

        # Calculate the mean and median of the signal within the ROI
        mean_signal = np.mean(np.array(signal[start_roi:end_roi]))
        median_signal = np.median(np.array(signal[start_roi:end_roi]))

        # Calculate the proportion of the signal within the ROI that is within the limit values
        signal_within_limits = len(
            np.where(
                (np.array(signal[start_roi:end_roi]) > limit_values[0])
                & (np.array(signal[start_roi:end_roi]) < limit_values[1])
            )[0]
        ) / len(np.array(signal[start_roi:end_roi]))

        # If the duration of the event is greater than the minimum length and more than 75% of the signal is within the limit values, add the event to the list
        if duration > event_min_length and signal_within_limits > 0.75:
            events.append([start_roi, end_roi, duration])
            if plot_signals:
                print(
                    start_roi,
                    end_roi,
                    duration,
                    mean_signal,
                    median_signal,
                    signal_within_limits,
                )
                plt.plot(start_roi, limit_value, "go")
                plt.plot(end_roi, limit_value, "rx")

    # Plot the limit value if required
    if plot_signals:
        plt.plot([0, len(signal)], [limit_value, limit_value], "c-")
        plt.legend()
        plt.show()

    # Return the list of events
    return events

In [None]:
# My function


def my_find_interaction_events(df, Thresh=80, min_time=60, plot_signals=False):
    """
    This function finds interaction events in a given dataframe based on certain criteria.

    Parameters:
    df (DataFrame): The dataframe in which to find events.
    Thresh (int): The threshold for the distance between yfly_smooth and yball_smooth.
    min_time (int): The minimum time for an event.
    plot_signals (bool, optional): Whether to plot the signals or not. Defaults to False.

    Returns:
    list: A list of interaction events found in the dataframe.
    """

    df.loc[:, "dist"] = df.loc[:, "yfly_smooth"] - df.loc[:, "yball_smooth"]
    df.loc[:, "close"] = df.loc[:, "dist"] < Thresh
    df.loc[:, "block"] = (df.loc[:, "close"].shift(1) != df.loc[:, "close"]).cumsum()

    events = (
        df[df["close"]]
        .groupby("block")
        .agg(start=("Frame", "min"), end=("Frame", "max"))
    )

    interaction_events = [
        (start, end) for start, end in events[["start", "end"]].itertuples(index=False)
    ]

    # Filter the interaction events based on the duration
    interaction_events = [
        event for event in interaction_events if event[1] - event[0] >= min_time
    ]

    if plot_signals:
        for start, end in interaction_events:
            plt.plot([start, end], [Thresh, Thresh], "go-")
        plt.show()

    return interaction_events

In [None]:
# My function with VLR method added


def find_interaction_events(
    df,
    thresh=[0, 80],
    min_length=60,
    gap_between_events=1,
    plot_signals=False,
    omit_events=None,
    signal_name="",
):
    """
    This function finds interaction events in a given dataframe based on certain criteria.

    Parameters:
    df (DataFrame): The dataframe in which to find events.
    Thresh (int): The threshold for the distance between yfly_smooth and yball_smooth.
    min_time (int): The minimum time for an event.
    gap_between_events (int): The minimum gap required between two events.
    plot_signals (bool, optional): Whether to plot the signals or not. Defaults to False.

    Returns:
    list: A list of interaction events found in the dataframe.
    """

    df.loc[:, "dist"] = df.loc[:, "yfly_smooth"] - df.loc[:, "yball_smooth"]

    # Initialize the list of events
    interaction_events = []

    # Find all frames where the signal is within the limit values
    all_frames_above_lim = np.where(
        (np.array(df["dist"]) > thresh[0]) & (np.array(df["dist"]) < thresh[1])
    )[0]

    # If no frames are found within the limit values, return an empty list
    if len(all_frames_above_lim) == 0:
        if plot_signals:
            print(f"Any point is between {thresh[0]} and {thresh[1]}")
            plt.plot(df["dist"], label=f"dist-filtered")
            plt.legend()
            plt.show()
        return interaction_events

    # Find the distance between consecutive frames
    distance_betw_frames = np.diff(all_frames_above_lim)

    # Find the points where the distance between frames is geeater than the gap between events
    split_points = np.where(distance_betw_frames > gap_between_events)[0]

    # Add the first and last points to the split points
    split_points = np.insert(split_points, 0, -1)
    split_points = np.append(split_points, len(all_frames_above_lim) - 1)

    # Plot the signal if required
    if plot_signals:
        limit_value = thresh[0] if thresh[1] == np.inf else thresh[1]
        print(all_frames_above_lim[split_points])
        plt.plot(signal, label=f"{signal_name}-filtered")

    # Iterate over the split points to find events
    for f in range(0, len(split_points) - 1):
        # If the gap between two split points is less than 2, skip to the next iteration
        if split_points[f + 1] - split_points[f] < 2:
            continue

        # Define the start and end of the region of interest (ROI)
        start_roi = all_frames_above_lim[split_points[f] + 1]
        end_roi = all_frames_above_lim[split_points[f + 1]]

        # If there are events to omit and the start of the ROI is within these events, adjust the start of the ROI
        if omit_events:
            if (
                start_roi >= omit_events[0]
                and start_roi < omit_events[1]
                and end_roi < omit_events[1]
            ):
                continue
            elif (
                start_roi >= omit_events[0]
                and start_roi < omit_events[1]
                and end_roi > omit_events[1]
            ):
                start_roi = int(omit_events[1])

        # Calculate the duration of the event
        duration = end_roi - start_roi

        # Calculate the mean and median of the signal within the ROI
        mean_signal = np.mean(np.array(signal[start_roi:end_roi]))
        median_signal = np.median(np.array(signal[start_roi:end_roi]))

        # Calculate the proportion of the signal within the ROI that is within the limit values
        signal_within_limits = len(
            np.where(
                (np.array(signal[start_roi:end_roi]) > thresh[0])
                & (np.array(signal[start_roi:end_roi]) < thresh[1])
            )[0]
        ) / len(np.array(signal[start_roi:end_roi]))

        # If the duration of the event is greater than the minimum length and more than 75% of the signal is within the limit values, add the event to the list
        if duration > min_length and signal_within_limits > 0.75:
            interaction_events.append([start_roi, end_roi, duration])
            if plot_signals:
                print(
                    start_roi,
                    end_roi,
                    duration,
                    mean_signal,
                    median_signal,
                    signal_within_limits,
                )
                plt.plot(start_roi, limit_value, "go")
                plt.plot(end_roi, limit_value, "rx")

    # Plot the limit value if required
    if plot_signals:
        plt.plot([0, len(signal)], [limit_value, limit_value], "c-")
        plt.legend()
        plt.show()
        
        
    # Return the list of events 
    return interaction_events

Load a video and compare the detected events

In [None]:
vidpath = Path(
    "/mnt/labserver/DURRIEU_Matthias/Experimental_data/MultiMazeRecorder/Videos/230721_Feedingstate_4_PM_Videos_Tracked/arena5/corridor3/corridor3.mp4"
)

ballpath = Path(
    "/mnt/labserver/DURRIEU_Matthias/Experimental_data/MultiMazeRecorder/Videos/230721_Feedingstate_4_PM_Videos_Tracked/arena5/corridor3/corridor3_tracked_ball.000_corridor3.analysis.h5"
)

flypath = Path(
    "/mnt/labserver/DURRIEU_Matthias/Experimental_data/MultiMazeRecorder/Videos/230721_Feedingstate_4_PM_Videos_Tracked/arena5/corridor3/tracked_fly.000_corridor3.analysis.h5"
)

In [None]:
Dataset = get_coordinates(ballpath, flypath)

In [None]:
Dataset.head()

In [None]:
Me = my_find_interaction_events(Dataset, Thresh=80, min_time=60, plot_signals=True)

In [None]:
find_interaction_events(Dataset, Thresh=80, min_time=60, plot_signals=True)

In [None]:
Dist = Dataset.loc[:, "yfly_smooth"] - Dataset.loc[:, "yball_smooth"]

In [None]:
Dist

In [None]:
VLR = find_events_from(Dist, [0, 80], 1, 60, plot_signals=True, signal_name="Dist")

In [None]:
VLR[0][0]

In [None]:
VLR

In [None]:
len(VLR)

In [None]:
Me

In [None]:
len(Me)

In [None]:
New = find_interaction_events(Dataset, thresh=[0, 80], min_length=60, gap_between_events=1, plot_signals=True, omit_events=None, signal_name="Dist")