# Notebook to generate video clips of potential trophallaxis from Konstanz dataset (scented bee petri dish videos)

In [None]:
import os
import sys

import bb_behavior.utils.images

import pandas as pd
import numpy as np
import skimage.io
import imageio
from datetime import datetime
from typing import Tuple
import joblib
import matplotlib.pyplot as plt
import math
import glob
import logging
import random
import json
from scipy.ndimage import binary_closing, binary_opening, binary_dilation
from scipy.signal import savgol_filter
from scipy.stats import iqr
from video_utils import CustomVideoManager

In [None]:
for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)

logging.basicConfig(
    filename='troph_video_gen_konst_data.log',
    level=logging.DEBUG,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)

# write info to console
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(logging.Formatter(
    '%(levelname)s - %(message)s'
))
logging.getLogger().addHandler(console_handler)
logger = logging.getLogger(__name__)
logger.info("Starting video generation ...")

## Set global variables

| Variable | Explanation |
| --- | --- |
| GENERATE_VIDEOS | whether to generate video clips of potential troph. events; if False then only the metadata for the videos is generated (much faster) |
| USE_ZOOM_LEVELS | whether to generate the clips in different zoom levels (for each video 1 of 3 zoom levels is chosen randomly) |
| USE_CROP_CENTER_SHIFT | whether to generate the clips with a random crop center shift for the clips (so that troph. is not always in the center) |
| USE_EGOCENTRIC_ALIGNMENT | whether to egocentrically align the videos to the focus bee (focus bee is always the bee with the lower bee ID and always mentioned first in the generated troph. clip name) |
| DEFAULT_CM_PER_PX | All values in pixels in this notebook get adjusted to the video dimensions because they are in relation to 1920x1080 videos (which all have a cm_per_px of around 0.0136). They get scaled by multiplying with this value and diving by the video-specific cm_per_pixel that was used in TRex. |

In [None]:
GENERATE_VIDEOS = True
USE_ZOOM_LEVELS = False
USE_CROP_CENTER_SHIFT = False
USE_EGOCENTRIC_ALIGNMENT = True

# only use crop center shift OR egocentric alignment, not both!
assert USE_CROP_CENTER_SHIFT == False or USE_EGOCENTRIC_ALIGNMENT == False

VIDEO_ROOT = '/mnt/windows-ssd/BeesData/scented_bees_video_files/'
CACHE_PATH = '/mnt/windows-ssd/BeesData/tmp/'
POSTURE_FILES_PATH = '/mnt/windows-ssd/BeesData/scented_bees_posture_files/'
VIDEO_OUTOUT_PATH = '/mnt/windows-ssd/BeesData/scented_bees_troph_clips/'
METADATA_OUTPUT_FILENAME = f"metadata_troph_konstanz_videos_{datetime.today().strftime('%Y-%m-%d')}.json"

FRAME_RATE = 50
FRAMES_PER_MINUTE = 60 * 50
# cut off the first five minutes of frames because bees are still waking up
MINUTES_CUTOFF_START = 5
FRAME_CUTOFF_START = MINUTES_CUTOFF_START * FRAMES_PER_MINUTE

DEFAULT_CM_PER_PX = 0.0136

BODY_TO_TROPHALLAXIS_CENTER_OFFSET_PX = 22
GROUPS = ["hex_*.npz", "OLE_*.npz", "OCI_*.npz"]
COLS_BEE_DATA = ['x_pixels', 'y_pixels', 'orientation']

#############################################################################################
# wcentroid is based on centroid weighted by pixel values, 
# pcentroid is based on posture centroid (the center of the midline)
# centroid is center of mass of all thresholded pixels
# nothing after the hashtag means based on head position
#############################################################################################
COLS_TO_EXCLUDE = [
    'tracklets','tracklet_vxys','video_size', 'id', 'frame_rate', 
    'ACCELERATION#pcentroid', 'ACCELERATION#wcentroid', 'ANGULAR_A#centroid', 
    'ACCELERATION#wcentroid', 'ANGULAR_V#centroid', 'BORDER_DISTANCE#pcentroid', 
    'AX', 'AY', 'MIDLINE_OFFSET',
    'poseX0', 'poseY0', 'poseX1', 'poseY1', 'poseX2', 'poseY2',
    'poseX3', 'poseY3', 'poseX4', 'poseY4', 'poseX5', 'poseY5',
    'SPEED#wcentroid', 'SPEED#pcentroid', 'SPEED', 'VX', 'VY', 'X#wcentroid', 
    'Y#wcentroid', 'midline_length', 'midline_segment_length', 'midline_x', 
    'midline_y', 'missing', 'normalized_midline', 'num_pixels', 'timestamp'
]


## Build the dataframe from tracking data

In [None]:
def build_dataframe() -> pd.DataFrame:
    frame_id_offset = 0
    files_to_grab = []
    if GROUPS is not None and len(GROUPS) > 0:
        for group in GROUPS:
            files_to_grab.extend(glob.glob(POSTURE_FILES_PATH + group))
        posture_file_names = sorted(files_to_grab)
    else:
        posture_file_names = sorted(glob.glob(POSTURE_FILES_PATH + "*.npz"))

    for file_idx, file in enumerate(posture_file_names):
        try:
            data = np.load(file)
        except Exception as e:
            logger.error(f"There was a problem loading the data file {file}. Exception: {e}")
            raise Exception(e)

        keys = data.files

        # Figure out how many rows (use 'time' as canonical-but any 1D key works)
        n = data['time'].shape[0]

        columns = {}
        for k in keys:
            if k in COLS_TO_EXCLUDE:
                continue
            v = data[k]
            
            # Scalar → broadcast to length n
            if np.ndim(v) == 0 or k == 'cm_per_pixel':
                columns[k] = np.repeat(v.item(), n)
            # 1-D array (including object-dtype) → straight in
            elif v.ndim == 1:
                columns[k] = v
            # Multi-D numeric array → flatten trailing dims into separate columns
            else:
                flat = v.reshape(n, -1)
                for i in range(flat.shape[1]):
                    columns[f"{k}_{i}"] = flat[:, i]

        # Build the DataFrame
        df_file = pd.DataFrame(columns)
        data_filename = (file.split("/")[-1]).split(".")[0]
        df_file['data_filename'] = data_filename

        # X and Y are in cm (convert to px)
        df_file = df_file.rename(columns={'cm_per_pixel': 'cm_per_px'})
        cm_per_px = df_file['cm_per_px'].iloc[0]
        df_file['x_pixels'] = np.round(df_file['X'] / cm_per_px)
        df_file['y_pixels'] = np.round(df_file['Y'] / cm_per_px)

        df_file['bee_id'] = data['id'][0]
        df_file['frame_index'] = df_file['frame'].astype(int)
        df_file = df_file.drop(columns=['frame', 'X', 'Y'])
        df_file = df_file.rename(columns={'ANGLE': 'orientation'})

        df_file = df_file[df_file.frame_index >= FRAME_CUTOFF_START]
        
        frames_cnt = len(df_file['frame_index'])
        df_file['frame_id'] = df_file['frame_index'] + frame_id_offset
        if (file_idx+1) % 4 == 0:
            frame_id_offset += frames_cnt

        df_files.append(df_file)

    df = pd.concat(df_files)
    df['video_filename'] = VIDEO_ROOT + df['data_filename'].str.split('_fish').str[0] + ".mp4"
    return df

df_files = []
df = build_dataframe()

print(df.keys())
print(len(df))

## Combine the tracking dataframe with the video file dataframe

The Custom video manager is used for caching and extracting frames from videos as well as saving frames to videos. It is partly adapted from a notebook by Jacob Davidson

In [None]:
video_manager = CustomVideoManager(VIDEO_ROOT, CACHE_PATH, VIDEO_OUTOUT_PATH, max_workers=16)
video_manager.clear_video_cache()

videos_df = video_manager.get_all_video_files()
df = pd.merge(df, videos_df, on='video_filename')

## Get the cropped images for the bees
This cell was adapted from https://github.com/nebw/unsupervised_behaviors/blob/master/unsupervised_behaviors/data.py

In [None]:
def process_frames(
    detections: pd.DataFrame,
    video_manager: CustomVideoManager,
    px_adj_ratio: float,
    uncropped_image_size_px: int = 120,
    image_crop_px: int = 20,
    egocentric: bool = True,
    generate_videos: bool = True,
    use_zoom_levels: bool = False,
    use_crop_center_shift: bool = False,
    use_clahe: bool = True,
    clahe_kernel_size_px: int = 25,
    n_jobs: int = -1,
) -> Tuple[np.ndarray, dict]:
    """Create a metadata dictionary for this potential trophallaxis run. Get cached images, crop them to the area 
       of interest and optionally rotate image region in each image so that the focus bee is egocentrically aligned.

    Args:
        detections (pd.DataFrame): Dataframe with detections.
        video_manager (CustomVideoManager): Manages cache.
        px_adj_ratio (float): Ratio for pixel adjustment in this video compared to 1920x1080 videos
        uncropped_image_size_px (int, optional): Image size before cropping. Defaults to 120.
        image_crop_px (int, optional): Crop amount after rotation. Defaults to 20.
        egocentric (bool, optional): Whether to rotate the frames so that the focus bee is egocentrically aligned. 
            Defaults to True.
        generate_videos (bool, optional): Whether to generate video clips. Defaults to True.
        use_zoom_levels (bool, optional): Whether to generate the clips with zoom levels. Defaults to False.
        use_crop_center_shift (bool, optional): Whether to generate the clips with a random crop center shift. 
            Defaults to False.
        use_clahe (bool, optional): Process entire frame using CLAHE. Defaults to True.
        clahe_kernel_size_px (int, optional): Kernel size for CLAHE. Defaults to 25.
        n_jobs (int, optional): Number of parallel jobs for processing. Defaults to -1.

    Returns:
        Tuple[np.ndarray, dict]: Extracted image regions and a dictionary with metadata about the video clip
    """

    def rotate_crop_img(image: np.ndarray, rotation_deg: float, image_crop_px: int) -> np.ndarray:
        image = skimage.transform.rotate(image, rotation_deg)
        image = image[image_crop_px:-image_crop_px, image_crop_px:-image_crop_px]
        return image

    
    def extract_crop_from_frame(
        detection: pd.DataFrame, 
        frame_path: str,
        uncropped_image_size_px: int,
        image_crop_px: int,
        x_shift_px: int,
        y_shift_px: int,
        fetch_image: bool = True
    ) -> np.ndarray:
        row = detection.iloc[0]
        if fetch_image:
            frame = imageio.v3.imread(frame_path, plugin="opencv", colorspace="GRAY")
                
            if use_clahe:
                frame = skimage.exposure.equalize_adapthist(frame, kernel_size=(clahe_kernel_size_px, clahe_kernel_size_px))

            if egocentric:
                # bee will be facing to the right (row.orientation_0 + np.pi / 2 would be facing upwards)
                rotation_deg = (1 / (2 * np.pi)) * 360 * row.orientation_0
            else:
                rotation_deg = 0
            
            # ensure that center is in frame
            min_px_to_edge = uncropped_image_size_px // 2
            center_x_shifted = max(min(row.crop_area_center_x + x_shift_px, frame.shape[1] - min_px_to_edge), min_px_to_edge)
            center_y_shifted = max(min(row.crop_area_center_y + y_shift_px, frame.shape[0] - min_px_to_edge), min_px_to_edge)

            center = np.array((center_x_shifted, center_y_shifted))

            image = bb_behavior.utils.images.get_crop_from_image(
                center, frame, width=uncropped_image_size_px, clahe=False
            )
            image = (rotate_crop_img(image, rotation_deg, image_crop_px) * 255).astype(np.uint8)
        return image
    


    logger.debug(f'Detection count in video: {len(detections.index)}')
    if len(detections.index) == 0:
        return

    # adjust pixel values to video dimensions
    uncropped_image_size_px = math.ceil(uncropped_image_size_px * px_adj_ratio)
    image_crop_px = math.floor(image_crop_px * px_adj_ratio)
    cropped_image_size_px = uncropped_image_size_px - 2 * image_crop_px
    if use_zoom_levels:
        cropped_image_size_px = random.choice([
            cropped_image_size_px,
            math.floor(cropped_image_size_px * 1.5),
            math.floor(cropped_image_size_px * 2)
        ])
        uncropped_image_size_px = math.ceil(cropped_image_size_px * math.sqrt(2))
        uncropped_image_size_px = uncropped_image_size_px if uncropped_image_size_px % 2 == 0 else uncropped_image_size_px + 1
        image_crop_px = int((uncropped_image_size_px - cropped_image_size_px) / 2)
    

    x_shift_px = 0
    y_shift_px = 0
    if use_crop_center_shift:
        if egocentric:
            raise ValueError("Only use egocentric if crop center shift is turned off!")
        # trophallaxis has to remain in the crop area
        x_shift_px = random.randint(-(cropped_image_size_px // 3), 
                                    cropped_image_size_px // 3)
        y_shift_px = random.randint(-(cropped_image_size_px // 3), 
                                    cropped_image_size_px // 3)
        

    detections_by_frame = detections.groupby("frame_id")
    # preload file paths because video_manager can't be used in parallel.
    frame_paths = [video_manager.get_frame_id_path(frame_id) for frame_id in detections["frame_id"].to_numpy()]
    logger.info(f"Processing {len(frame_paths)} cached images ...")
    parallel = joblib.Parallel(prefer="processes", n_jobs=n_jobs)(
        joblib.delayed(extract_crop_from_frame)(
            frame_detection, 
            frame_path,
            uncropped_image_size_px,
            image_crop_px,
            x_shift_px,
            y_shift_px,
            generate_videos
        )
        for (_, frame_detection), frame_path in zip(detections_by_frame, frame_paths)
    )
    logger.info("Processing of cached images complete.")

    images = []
    if generate_videos:
        for result in parallel:
            images.append(result)
    images = np.stack(images) if len(images) > 0 else np.array([])
    metadata_value_dict = {
        "Trophallaxis video dimensions in px": cropped_image_size_px,
        "x_shift_px": x_shift_px,
        "y_shift_px": y_shift_px,
        "start_frame_index": detections.frame_index.iat[0].item(),
        "end_frame_index": detections.frame_index.iat[-1].item(),
    }
    return images, metadata_value_dict

## Utility functions

In [None]:
def interpolate_missing_vals(
    df_video: pd.DataFrame, 
    cols_to_interpol: list
) -> pd.DataFrame:
    """Interpolates inf values in the trajectory data

    Args:
        df_video (pd.DataFrame): dataframe, in which to interpolate
        cols_to_interpol (list): columns to interpolate

    Returns:
        pd.DataFrame: Interpolated dataframe
    """
    df_video[cols_to_interpol] = (
        df_video
        .groupby('bee_id', sort=False)[cols_to_interpol]
        # transform applies methods in-place
        .transform(lambda x: x.replace(np.inf, np.nan)
                              .interpolate(method='linear')
                              .bfill()
                              .ffill()))
    return df_video


def check_troph_conditions(
    df_troph: pd.DataFrame, 
    px_adj_ratio: float, 
    shift_dist: float = 15.0, 
    max_dist: float = 40.0,
    max_inv_orientation_diff_deg: float = 80.0
) -> list[bool]:
    """Checks for each row (frame) of the dataframe if trophallaxis conditions for distance and relative orientation are fulfilled. 

    Args:
        df_troph (pd.DataFrame): Dataframe
        px_adj_ratio (float): Ratio for pixel adjustment in this video compared to 1920x1080 videos
        shift_dist (float, optional): Shift in pixels in the orientation direction of the bee before distance calculation. 
            Defaults to 15.0.
        max_dist (float, optional): Maximum allowed distance for trophallaxis condition to be fulfilled. Defaults to 40.0.
        max_inv_orientation_diff_deg (float, optional): Maximum orientation difference in degrees between the bees if one 
            bee were rotated by 180 degrees. Defaults to 80.0.

    Returns:
        list[bool]: List of booleans with one field per frame whether the conditions are fulfilled
    """
    shift_dist *= px_adj_ratio
    max_dist *= px_adj_ratio

    orient_0 = df_troph.orientation_0.to_numpy()
    orient_1 = df_troph.orientation_1.to_numpy()
    
    orient_diff = np.abs(orient_0 - orient_1)
    orient_diff = np.minimum(orient_diff, np.abs(2*np.pi - orient_diff))

    x0_shifted = df_troph.x_pixels_0.to_numpy() + np.cos(orient_0) * shift_dist
    x1_shifted = df_troph.x_pixels_1.to_numpy() + np.cos(orient_1) * shift_dist
    y0_shifted = df_troph.y_pixels_0.to_numpy() + np.sin(orient_0) * shift_dist
    y1_shifted = df_troph.y_pixels_1.to_numpy() + np.sin(orient_1) * shift_dist

    dist = np.hypot(x0_shifted - x1_shifted, 
                    y0_shifted - y1_shifted) 

    # if the two bees have not more than 80 deg difference in orientation when one of the bees is rotated by pi (180 deg)
    return (np.pi - orient_diff < np.deg2rad(max_inv_orientation_diff_deg)) & (dist < max_dist)


def remove_bool_islands(
    arr: np.ndarray, 
    structure_length: int = 25
) -> np.ndarray:
    """Removes boolean islands in the array with morphological functions

    Args:
        arr (np.ndarray): Boolean array
        structure_length (int, optional): Length of structure of 1's used by scipy morph. functions. Defaults to 25.

    Returns:
        np.ndarray: Processed boolean array
    """
    structure = np.ones(structure_length, dtype=bool)
    # closing first because removing false islands in true regions takes priority
    return binary_opening(binary_closing(arr, structure=structure), structure=structure)


def find_long_true_runs(
    arr: np.ndarray, 
    min_run: int = 250, 
    n_frame_padding: int = 10
) -> Tuple[np.ndarray, list[int]]:
    """Finds true regions with a certain minimum length in a boolean array

    Args:
        arr (np.ndarray): Boolean array
        min_run (int, optional): Minimum run length. Defaults to 250.
        n_frame_padding (int, optional): Number of frames to pad each true run before start and after end. Defaults to 10.

    Returns:
        Tuple[np.ndarray, list[int]]: Same as arr but with only true runs of min. length min_run left 
            and a list of Tuples with start and end indices for each run
    """
    meets_troph_conds = np.full(arr.shape, False)
    edges = np.flatnonzero(np.diff(arr))

    starts = np.r_[0, edges + 1]
    ends = np.r_[edges, len(arr) - 1]

    # contains all starts (of true runs and of false runs)
    true_mask = arr[starts]
    # contains all values in starts where true_mask is also true
    true_starts = starts[true_mask]
    true_ends = ends[true_mask]

    lengths = true_ends - true_starts + 1

    true_starts = true_starts[lengths >= min_run]
    lengths = lengths[lengths >= min_run]

    runs_frame_indices = []
    for start, length in zip(true_starts, lengths):
        # pad run
        end = min(len(arr), start+length+n_frame_padding)
        start = max(0, start-n_frame_padding)
        meets_troph_conds[start:end] = True
        # start and end are relative to arr
        runs_frame_indices.append(range(start+FRAME_CUTOFF_START, end+FRAME_CUTOFF_START))
    return meets_troph_conds, runs_frame_indices


def smooth_column_data(
    df: pd.DataFrame, 
    cols_to_smooth: list, 
    window_length: int = 5, 
    savgol_order: int = 3
):
    """Smooths specific columns of the dataframe with a Savitzky-Golay filter

    Args:
        df (pd.DataFrame): Dateframe
        cols_to_smooth (list): Columns to smooth
        window_length (int, optional): Window length of filter. Defaults to 5.
        savgol_order (int, optional): Order of filter. Defaults to 3.
    """
    df[cols_to_smooth] = savgol_filter(
        df[cols_to_smooth].to_numpy(),
        window_length=window_length,
        polyorder=savgol_order,
        axis=0
    )
    


def clean_orientation_outliers(
    orientations: np.ndarray,
    window_length: int = 11,
    iqr_factor: float = 6.0,
    dilation_iters: int = 3
) -> np.ndarray:
    """Clean orientation values by calculating a rolling median with a certain window length, taking the median of those values,
       calculating the difference of each orientation value to this median, setting values to NaN where the difference is above 
       an IQR-based threshold, dilating those NaN regions and interpolating the NaN values.

    Args:
        orientations (np.ndarray): Orientation values in radians
        window_length (int, optional): Length of rolling median window. Defaults to 11.
        iqr_factor (float, optional): Used to calculate the threshold for outliers. Defaults to 6.0.
        dilation_iters (int, optional): Binary dilation iterations. Defaults to 3.

    Returns:
        np.ndarray: Cleaned orientation array
    """

    orientations = pd.Series(orientations)
    median = orientations.rolling(window_length, center=True, min_periods=1).median()
    deviation = np.abs(orientations - median)
    deviation = np.minimum(deviation, np.abs(2*np.pi - deviation))

    dev_iqr = iqr(deviation.dropna())

    threshold = iqr_factor * dev_iqr
    outlier_mask = deviation > threshold
    outlier_mask = binary_dilation(outlier_mask.to_numpy(), iterations=dilation_iters)
    orientations_masked = orientations.copy()
    orientations_masked[outlier_mask] = np.nan

    orientations_clean = (
        pd.Series(orientations_masked)
            .interpolate(method="linear")
            .bfill()
            .ffill()
            .to_numpy()
    )
    return orientations_clean


def _clean_trajectory_outliers(
    x: np.ndarray, 
    y: np.ndarray, 
    iqr_factor: float = 3.0, 
    dilation_iters: int = 3
) -> Tuple[np.ndarray, np.ndarray]:
    """Clean trajectory outliers by calculating the z-score of x- and y-values, combining those,
       calculating the IQR, setting values above iqr_factor * IQR to NaN, dilating those NaN regions
       and interpolating them.

    Args:
        x (np.ndarray): X-values in pixels
        y (np.ndarray): Y-values in pixels
        iqr_factor (float, optional): Used to calculate the threshold for outliers. Defaults to 3.0.
        dilation_iters (int, optional): Binary dilation iterations. Defaults to 3.

    Returns:
        Tuple[np.ndarray, np.ndarray]: The cleaned x- and y-values
    """

    mean_x = np.nanmean(x)
    mean_y = np.nanmean(y)
    # Avoid division by zero: if std is 0, set to tiny value
    eps = 1e-8
    std_x = min(np.nanstd(x), eps)
    std_y = min(np.nanstd(y), eps)

    zx = (x - mean_x) / std_x
    zy = (y - mean_y) / std_y

    combined_z = np.hypot(zx, zy)

    iqr_val = iqr(combined_z)

    outlier_mask = combined_z > iqr_factor * iqr_val
    outlier_mask = binary_dilation(outlier_mask, iterations=dilation_iters)

    # mark both axes as NaN for outlier frames
    x[outlier_mask] = np.nan
    y[outlier_mask] = np.nan

    x_clean = pd.Series(x).interpolate(method="linear").bfill().ffill().to_numpy()
    y_clean = pd.Series(y).interpolate(method="linear").bfill().ffill().to_numpy()

    return x_clean, y_clean


def clean_trajectory_outliers(
    df_troph_video: pd.DataFrame, 
    video_width: int, 
    video_height: int, 
    px_adj_ratio: float
):
    """Clean the outliers in trajectory, which is used for crop center, by calculating and adjusting the crop center
       and calling _clean_trajectory_outliers.

    Args:
        df_troph_video (pd.DataFrame): Dateframe for one potential trophallaxis event
        video_width (int): Original video width in pixels
        video_height (int): Original video height in pixels
        px_adj_ratio (float): Ratio for pixel adjustment in this video compared to 1920x1080 videos
    """
    orientation_focus_bee = df_troph_video["orientation_0"].to_numpy()

    # use orientation of bee 0 for center calculation
    body_center_adj_x = np.cos(orientation_focus_bee) * BODY_TO_TROPHALLAXIS_CENTER_OFFSET_PX * px_adj_ratio
    body_center_adj_y = np.sin(orientation_focus_bee) * BODY_TO_TROPHALLAXIS_CENTER_OFFSET_PX * px_adj_ratio

    # add adjustment as we want to adjust to the area in front of bee 0
    area_center_x = df_troph_video["x_pixels_0"].to_numpy() + body_center_adj_x
    area_center_y = df_troph_video["y_pixels_0"].to_numpy() + body_center_adj_y

    # assert center point always within frame
    area_center_x = np.maximum(0, area_center_x)
    area_center_x = np.minimum(video_width, area_center_x)
    area_center_y = np.maximum(0, area_center_y)
    area_center_y = np.minimum(video_height, area_center_y)

    x_clean, y_clean = _clean_trajectory_outliers(area_center_x, area_center_y, dilation_iters=3)
    df_troph_video["crop_area_center_x"] = x_clean
    df_troph_video["crop_area_center_y"] = y_clean

## Process the videos, cache frame and generate potential trophallaxis video clips

This could still be sped up a lot by processing one original video, then caching all relevant frames for the video at once and only afterwards actually generating videos from those cached frames. I did not do this because I did not find a good way to pass a bunch of specific frame indices to ffmpeg (or a collection of ranges with one range for each clip to extract). You could simply cache all frames between first frame index to cache and last frame index to cache, which would also work and probably be faster too, but would also require more disk space because most of those frames are likely not needed. 

In [None]:
def generate_vid(
    df_troph_video: pd.DataFrame, 
    images_masked: np.ndarray, 
    video_name_short: str, 
    bee_0: int, 
    bee_1: int
) -> str:
    """Generate a video clip of one potential troph. event from the images passed

    Args:
        df_troph_video (pd.DataFrame): Dataframe for the event
        images_masked (np.ndarray): Images for the event
        video_name_short (str): Short string from the original video (video-group + index)
        bee_0 (int): ID of focus bee
        bee_1 (int): ID of other bee

    Returns:
        str: Filename of the generated video clip
    """
    df_troph_video_sorted = df_troph_video.sort_values(['time'])
    sorting_indices = df_troph_video_sorted.index

    start_time = df_troph_video_sorted.time.iat[0]
    end_time = df_troph_video_sorted.time.iat[-1]
    filename = "_".join((
        video_name_short, 
        "bees", 
        str(bee_0), 
        "and", 
        str(bee_1), 
        (str(math.floor(start_time)) + "--" + str(math.ceil(end_time))) + ".mp4"
    ))
    if GENERATE_VIDEOS:
        images_sorted = [images_masked[i] for i in sorting_indices]
        video_manager.write_to_video(images=images_sorted, filename=filename, frame_rate=FRAME_RATE)
    return filename 


def process_video(
    video_file: str, 
    df_video: pd.DataFrame, 
    generate_videos: bool
):
    """Process one of the original petri dish videos

    Args:
        video_file (str): File name (incl. path) of the video
        df_video (pd.DataFrame): Dataframe for the video
        generate_videos (bool): Whether to generate video clips
    """

    def process_troph_run(
        df_bee_pair: pd.DataFrame, 
        frame_indices_run: list[int], 
        frame_id_offset: int, 
        video_width: int, 
        video_height: int, 
        px_adj_ratio_video: float, 
        video_name_short: str
    ):
        """Process one of the troph. runs

        Args:
            df_bee_pair (pd.DataFrame): Dateframe for the troph. run of the bee pair
            frame_indices_run (list[int]): Frame indices of the run (to cache)
            frame_id_offset (int): Offset between any frame_id and frame_index in this run
            video_width (int): Original video width in pixels
            video_height (int): Original video height in pixels
            px_adj_ratio_video (float): Ratio for pixel adjustment in this video compared to 1920x1080 videos
            video_name_short (str): Short string from the original video (video-group + index)
        """

        frame_indices_run = np.array(frame_indices_run)
        frame_ids_run = frame_indices_run + frame_id_offset

        if generate_videos:
            video_manager.cache_frames(
                frame_ids=frame_ids_run, 
                video_name=video_file, 
                frame_indices=frame_indices_run
            )
        
        # find indices, where elements should be inserted to maintain order (possible because df_troph is sorted by frame_id)
        start = np.searchsorted(df_bee_pair.frame_id.values, frame_ids_run[0], side="left")
        end = np.searchsorted(df_bee_pair.frame_id.values, frame_ids_run[-1], side="right")
        orientation_0_cleaned = clean_orientation_outliers(df_bee_pair.orientation_0.iloc[start:end].to_numpy())
        df_bee_pair.iloc[start:end, df_bee_pair.columns.get_loc("orientation_0")] = orientation_0_cleaned

        df_troph_run = df_bee_pair.iloc[start:end].reset_index(drop=True).copy()
        clean_trajectory_outliers(df_troph_run, video_width=video_width, video_height=video_height, px_adj_ratio=px_adj_ratio_video)
        # smooth camera center for a more stable video
        smooth_column_data(df_troph_run, cols_to_smooth=["crop_area_center_x", "crop_area_center_y"], savgol_order=2)

        images, metadata_value_dict = process_frames(
            detections=df_troph_run, 
            video_manager=video_manager, 
            px_adj_ratio=px_adj_ratio_video,
            generate_videos=generate_videos,
            use_zoom_levels=USE_ZOOM_LEVELS,
            use_crop_center_shift=USE_CROP_CENTER_SHIFT,
            egocentric=USE_EGOCENTRIC_ALIGNMENT
        )

        # apply all needed masks by multiplying them to the image
        images_masked = images if len(images) > 0 else np.array([])

        filename = generate_vid(
            df_troph_video=df_troph_run, 
            images_masked=images_masked, 
            video_name_short=video_name_short,
            bee_0=troph_bee_0,
            bee_1=troph_bee_1
        )
        
        # overwrite json file after every video in case something goes wrong
        with open(METADATA_OUTPUT_FILENAME, "w") as jsonfile:
            metadata_dict[filename] = metadata_value_dict
            json.dump(metadata_dict, jsonfile, ensure_ascii=False, indent=4)



    logger.info(f"Processing video {video_file} ...")
    video_manager.clear_video_cache()
    video_name_short = (str(video_file)).rsplit('/', 1)[1].split('.')[0]
    if len(df_video) == 0:
        logger.warning(f"Skipped video {video_file} because dataframe was empty")
        return
    
    video_width, video_height = video_manager.get_video_dimensions(video_file)

    # dataframe should already be sorted correctly but make sure
    df_video = df_video.sort_values(['bee_id', 'frame_id'])
    cm_per_px_video = df_video.cm_per_px.iat[0]
    px_adj_ratio_video = DEFAULT_CM_PER_PX / cm_per_px_video
    
    df_video_grouped = df_video.groupby('bee_id', sort=False)
    sizes = df_video_grouped.size()

    # some video-datasets have a different number of rows for each bee, use last common frame-id
    if sizes.nunique() != 1:
        max_common_frame_id = df_video_grouped.frame_id.last().min()
        df_video = df_video[df_video.frame_id <= max_common_frame_id].copy()

    n_rows = len(df_video)
    rows_per_bee = int(n_rows / 4)
    logger.debug(f"Total rows for video: {len(df_video)}")
    frame_id_offset = df_video.frame_id.iat[0] - df_video.frame_index.iat[0]

    df_video = interpolate_missing_vals(df_video, cols_to_interpol=COLS_BEE_DATA)
 

    troph_bee_pairs = [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]
    for pair in troph_bee_pairs:
        logger.info(f"Processing bee pair ({str(pair[0])}, {str(pair[1])})")
        # bees 0 and 1 are always the two bees considered for trophallaxis
        troph_bee_0, troph_bee_1 = pair
        bee_2, bee_3 = np.setdiff1d([0, 1, 2, 3], pair)

        starts = np.arange(4) * rows_per_bee
        ends = starts + rows_per_bee

        cols_to_copy = COLS_BEE_DATA + ['frame_index', 'frame_id', 'video_filename', 'time']
        # dataframe, which contains all potential trophallaxis occurances between one bee pair
        df_bee_pair = (df_video
                    .iloc[starts[troph_bee_0]:ends[troph_bee_0]][cols_to_copy]
                    .rename(columns={c: f"{c}_0" for c in COLS_BEE_DATA})
                    .copy())
        for i, bee in [(1, troph_bee_1), (2, bee_2), (3, bee_3)]:
            df_bee_pair[[f"{c}_{i}" for c in COLS_BEE_DATA]] = (
                df_video
                .iloc[starts[bee]:ends[bee]][COLS_BEE_DATA]
                .reset_index(drop=True)
                .to_numpy()
            )

        df_bee_pair["meets_troph_conds"] = check_troph_conditions(df_bee_pair, px_adj_ratio_video)
        df_bee_pair["meets_troph_conds"] = remove_bool_islands(df_bee_pair.meets_troph_conds.to_numpy())
        df_bee_pair["meets_troph_conds"], runs_frame_indices = find_long_true_runs(df_bee_pair.meets_troph_conds.to_numpy())
        if len(runs_frame_indices) == 0:
            continue
        # smooth the data for more stable videos before filtering out non-trophallaxis frames
        cols_to_smooth = [col for i in range(4) for col in [f'x_pixels_{i}', f'y_pixels_{i}', f'orientation_{i}']]
        smooth_column_data(df_bee_pair, cols_to_smooth=cols_to_smooth)
        df_bee_pair = df_bee_pair.loc[df_bee_pair["meets_troph_conds"] == True].copy()
        logger.info("Preprocessing complete")

        for frame_indices_run in runs_frame_indices:
            process_troph_run(
                df_bee_pair, 
                frame_indices_run, 
                frame_id_offset, 
                video_width, 
                video_height, 
                px_adj_ratio_video, 
                video_name_short
            )
        

metadata_dict = {}
df_grouped = df.groupby('video_filename', sort=False)
logger.info(f"Videos: {df_grouped.groups.keys()}")
logger.info(f"Processing {df_grouped.ngroups} videos")
for video_file, df_video in df_grouped:
    process_video(video_file, df_video, GENERATE_VIDEOS)
logger.info("Video generation complete.")