# Notebook to generate videos cropped to area around a bee for FU (BeesBook) dataset

This code could be extended to automatically merge consecutive video of the same bee if no frame-gap exists between the extracted videos. Right now manual merging of videos is necessary if videos longer than one minute are required (as raw beesbook input videos are one minute long)

In [None]:
import os
import sys

import bb_behavior.utils.images

import pandas as pd
import numpy as np
from pyarrow.parquet import ParquetFile
import imageio
import skimage.io
from typing import Tuple
import joblib
import math
import logging
from video_utils import CustomVideoManager

In [None]:
for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)

logging.basicConfig(
    filename='indiv_video_gen_fu_data.log',
    level=logging.DEBUG,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)

# write info to console
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(logging.Formatter(
    '%(levelname)s - %(message)s'
))
logging.getLogger().addHandler(console_handler)
logger = logging.getLogger(__name__)
logger.info("Starting video generation ...")

## Set global variables

| Variable | Explanation |
| --- | --- |
| OVERWRITE_ORIENT_ERRORS_IF_EXIST | Whether to overwrite orientation errors (by bee) file if it exists |
| APPLY_BODY_MASKS | Whether to apply a body mask in each extracted video, which masks the area around the bee (liberally) |

In [None]:
OVERWRITE_ORIENT_ERRORS_IF_EXIST = False
APPLY_BODY_MASKS = False


FRAME_RATE = 6 
N_JOBS = 8     # number of parallel jobs for frame extraction

VIDEO_ROOT = '/mnt/windows-ssd/BeesData/'
CACHE_PATH = '/mnt/windows-ssd/BeesData/tmp/'
VIDEO_OUTPUT_PATH = '/mnt/windows-ssd/BeesData/videos_single_bees/'
PARQUET_DIR = "/mnt/windows-ssd/BeesData/parquet/Hive_B/"
ORIENT_ERRORS_FILE = "/mnt/windows-ssd/trophallaxis_detection_code/data/orientation_errors_fu/orientation_errors.csv"

# Load trajectory parquet files into dataframe and prepare data

Currently only videos of cam-2 are extracted

In [None]:
parquet_files = []
if os.path.exists(PARQUET_DIR):
    parquet_files.extend([
        os.path.join(PARQUET_DIR, f) for f in os.listdir(PARQUET_DIR) if f.endswith('.parquet')
    ])
dfs = [pd.read_parquet(parquet_file) for parquet_file in parquet_files]
df = pd.concat(dfs)

df = df[df.cam_id == 2]

# x and y pixel values are swapped in the parquet files
df[['x_pixels','y_pixels']] = df[['y_pixels','x_pixels']]
df['orientation_hive'] = df['orientation_hive'].to_numpy() - np.pi / 2

# add 2 * pi to all negative values
df.loc[df["orientation_hive"] < 0, "orientation_hive"] += 2 * np.pi

df_all_data = df.copy()
df.head()



# Combine the dataframe with the video file data

In [None]:
video_manager = CustomVideoManager(VIDEO_ROOT, CACHE_PATH, VIDEO_OUTPUT_PATH, max_workers=16)
video_manager.clear_video_cache()


frame_id_cam_dfs = []
for cam_id in df['cam_id'].unique():
    df_cam = df[df['cam_id'] == cam_id].copy()
    df_cam.reset_index(inplace=True)

    # Get video metadata
    videos_df = video_manager.get_all_video_files_for_cam(cam_id, df_cam)

    if videos_df.empty:
        logger.info(f"No video files found for df_cam with cam_id {cam_id}")
        continue

    # Sort dataframes by time
    df_cam.sort_values('timestamp', inplace=True)
    videos_df.sort_values('start_time', inplace=True)
    
    if 'video_filename' in df_cam.columns:
        df_cam = df_cam.drop(columns='video_filename')

    # Merge using merge_asof
    df_cam = pd.merge_asof(df_cam, videos_df,
                           left_on='timestamp',
                           right_on='start_time',
                           direction='backward',
                           tolerance=pd.Timedelta(seconds=60))

    # if not all data in the parquet file(s) has matching videos then timestamp, start_time and end_time are NaN after merge, so remove these rows
    df_cam = df_cam[df_cam.start_time.notna()]
    
    # Calculate frame indices
    df_cam['time_diff'] = (df_cam['timestamp'] - df_cam['start_time']).dt.total_seconds()
    df_cam['frame_index'] = (df_cam['time_diff'] * FRAME_RATE).astype(int)

    # Map video identifiers to unique integers
    unique_videos = df_cam['video_filename'].unique()
    video_id_map = {video: idx for idx, video in enumerate(unique_videos)}

    # Assign video ID
    df_cam['video_id'] = df_cam['video_filename'].map(video_id_map)

    # Assign frame_id using the DataFrame index
    df_cam['frame_id'] = df_cam['index']

    # Calculate frame_id using video ID as an offset
    df_cam['frame_id'] = df_cam['video_id'] * 360 + df_cam['frame_index']

    # dataframe with all detections sorted by video, then by bee, then by timestamp
    df_cam.sort_values(['video_filename', 'bee_id', 'timestamp'], inplace=True)

    # Create a DataFrame mapping frame IDs to video names and frame indices
    frame_id_cam_df = df_cam[['frame_id', 'cam_id', 'video_filename', 'frame_index', 'x_pixels', 'y_pixels', 'bee_id', 'orientation_hive', 'timestamp']].drop_duplicates(['frame_id', 'cam_id', 'video_filename', 'frame_index', 'bee_id'])
    frame_id_cam_df['frame_id_unique'] = frame_id_cam_df['frame_id'] + (frame_id_cam_df['frame_id'].max()+1)*frame_id_cam_df['cam_id']
    frame_id_cam_df = frame_id_cam_df.set_index('frame_id_unique')

    frame_id_cam_dfs.append(frame_id_cam_df)

frame_id_df = pd.concat(frame_id_cam_dfs)
    

# Do orientation correction for all bees based on their movement
Some of the bees have misaligned bee tags that don't face to the front of the bee. This cell aims to correct that by factoring in the movement directions of the bees.

Assume that bees move forwards so the movement vector is the same as the orientation vector. For each bee look at the movement direction between each pair of consecutive timestamps $t_n$ and $t_{n+1}$ if the next timestamp $t_{n+1}$ is also from the next frame in the original data and if there is enough movement over the next 3 timestamps (between $t_n$ and $t_{n+3}$). Create a vector from the movement between $t_n$ and $t_{n+1}$ and calculate the angle difference to the orientation vector in the data at time $t_n$. These orientation errors are saved in a list. If this list is longer than 50 elements for a bee then it is assumed to be long enough for orientation correction. The list is sorted and the top 45% and bottom 45% of values are discarded and the mean of the remaining elements is calculated. This value is used as the final orientation error for the bee and gets applied to all the orientation values of this bee in the dataframe.

In the videos the bees don't move exactly forwards in most cases and also sometimes bend their heads in a different direction than their bodies but over a long enough period of time the average of the differences of the movement direction vectors to the orientation in the data gives a good orientation error estimate.

In [None]:
def angle_difference(
    a: float, 
    b: float
) -> float:
    """Calculates the angle difference between two angles

    Args:
        a (float): Angle in radians
        b (float): Angle in radians

    Returns:
        float: Angle difference
    """
    diff = (b - a + math.pi) % (2 * math.pi) - math.pi
    return diff


def get_orientation_error_for_detection(
    i: int, 
    x_arr: np.ndarray, 
    y_arr: np.ndarray, 
    timestamps: np.ndarray, 
    orientations: np.ndarray, 
    next_detections_to_check_cnt: int
):
    """Calculate the orientation error for a bee at a certain timestamp based on its orientation 
       and movement over the next few frames. If not enough movement happens over the next few frames or if
       the very next detection is not also the next frame in the original video then don't use this timestamp
       for orientation correction and return None

    Args:
        i (int): Detection index of the frame for this bee
        x_arr (np.ndarray): x-pixel values of detections for the bee
        y_arr (np.ndarray): y-pixel values of detections for the bee
        timestamps (np.ndarray): Timestamps of detections for the bee
        orientations (np.ndarray): Orientation values of detections for the bee
        next_detections_to_check_cnt (int): How many next detections should be checked for minimum movement

    Returns:
        float: Orientation error
    """

    x = x_arr[i]
    y = y_arr[i]
    timestamp = timestamps[i]
    orientation_in_data = orientations[i]

    # L2 norm (distance of travel)
    total_travel_dist = np.hypot(
        x_arr[i + next_detections_to_check_cnt] - x, 
        y_arr[i + next_detections_to_check_cnt] - y
    )
    # if too little movement over the next detections, then don't use the heading for orientation correction
    if total_travel_dist < 5 * next_detections_to_check_cnt:
        return None

    # check the next timestep
    next_x = x_arr[i + 1]
    next_y = y_arr[i + 1]
    next_timestamp = timestamps[i + 1]
    next_orientation_in_data = orientations[i + 1]

    # if timestep too large (not consecutive frames) don't use the heading for orientation correction
    if np.timedelta64(next_timestamp - timestamp, 'ms') > 200:
        return None
    
    heading_angle_rad_current = math.atan2((next_y - y), (next_x - x))
    
    if heading_angle_rad_current < 0:
        heading_angle_rad_current += 2 * math.pi
        
    heading_diff = abs(angle_difference(heading_angle_rad_current, next_orientation_in_data))
    if heading_diff > math.pi:
        heading_angle_rad_current -= math.pi
        heading_angle_rad_current += 2 * math.pi if heading_angle_rad_current < 0 else 0

    orientation_error = angle_difference(orientation_in_data, heading_angle_rad_current); 
    return orientation_error


def get_all_orientation_errors(
    next_detections_to_check_cnt: int = 3, 
    min_required_errors: int = 50
) -> dict:
    """Calculates orientation errors for all bees, where enough frames can be found
       with movement happens over the next frames

    Args:
        next_detections_to_check_cnt (int, optional): How many next detections should be checked for minimum movement. 
            Defaults to 3.
        min_required_errors (int, optional): Minimum number of required errors for a bee to calculate the orientation error. 
            Defaults to 50.

    Returns:
        dict: Dictionary with bee-IDs as keys and orientation errors as values
    """
    logger.info("Calculating orientation errors...")
    relevant_bee_ids = frame_id_df.bee_id.unique()
    df_relevant_data = df_all_data[df_all_data['bee_id'].isin(relevant_bee_ids)]

    df_relevant_data.sort_values('timestamp', inplace=True)
    detections_by_bee = df_relevant_data.groupby("bee_id", sort=True)

    orientation_error_by_bee = {}
    for _, bee_detections in detections_by_bee:
        bee_id = bee_detections.iloc[0]["bee_id"]
        x_arr = bee_detections["x_pixels"].to_numpy()
        y_arr = bee_detections["y_pixels"].to_numpy()
        timestamps = bee_detections["timestamp"].to_numpy()
        orientations = bee_detections["orientation_hive"].to_numpy()

        orient_errors = []
        row_cnt = len(bee_detections)

        for i in range(row_cnt - next_detections_to_check_cnt - 1):
            orientation_error = get_orientation_error_for_detection(i, x_arr, y_arr, timestamps, orientations, next_detections_to_check_cnt)
            if orientation_error is not None:
                orient_errors.append(orientation_error)

        final_orient_error = 0
        if len(orient_errors) > min_required_errors:
            cutoff = math.floor(len(orient_errors) * 0.45)
            orient_errors.sort()
            orient_errors = orient_errors[cutoff:-cutoff]
            final_orient_error = sum(orient_errors) / len(orient_errors)
        orientation_error_by_bee[bee_id] = final_orient_error
    return orientation_error_by_bee


def apply_orientation_errors_to_dataframe(
    orientation_error_by_bee: dict
):
    """Applies the orientation errors by bee to all detections for the corresponing bees

    Args:
        orientation_error_by_bee (dict): Dictionary with bee-IDs as keys and orientation errors as values
    """
    for bee_id, orientation_error in orientation_error_by_bee.items():
        frame_id_df['orientation_hive'] = frame_id_df['orientation_hive'].mask(
                                              frame_id_df['bee_id'] == bee_id, 
                                              frame_id_df['orientation_hive'] + orientation_error
                                          )
    frame_id_df.loc[frame_id_df["orientation_hive"] < 0, "orientation_hive"] += 2 * np.pi
    frame_id_df.loc[frame_id_df["orientation_hive"] > 2 * np.pi, "orientation_hive"] -= 2 * np.pi


def save_orientation_errors_to_csv(
    orientation_error_by_bee: dict, 
    orient_errors_path: str
):
    """Saves orientation errors as CSV-file

    Args:
        orientation_error_by_bee (dict): Dictionary with bee-IDs as keys and orientation errors as values
        orient_errors_path (str): Path to CSV-file
    """
    df_orientation_errors = pd.DataFrame(orientation_error_by_bee.items(), columns=['bee_id', 'orientation_error'])
    df_orientation_errors.to_csv(orient_errors_path, sep='\t', index=False, header=True)
    logger.info("Saved orientation errors")



orientation_error_by_bee = {}

# load/calculated orientation errors
if os.path.exists(ORIENT_ERRORS_FILE) and not OVERWRITE_ORIENT_ERRORS_IF_EXIST:
    df_orientation_errors = pd.read_csv(ORIENT_ERRORS_FILE, sep='\t', index_col=0) 
    orientation_error_by_bee = df_orientation_errors['orientation_error'].to_dict()
    logger.info("Imported orientation errors.")
else:
    if not os.path.exists(ORIENT_ERRORS_FILE):
        logger.warning("No orientation error file found.")
    orientation_error_by_bee = get_all_orientation_errors()
    save_orientation_errors_to_csv(orientation_error_by_bee, ORIENT_ERRORS_FILE)

apply_orientation_errors_to_dataframe(orientation_error_by_bee)

# Get the cropped images for the bees with the corresponding tag masks and body masks (for other bees)
This cell was adapted from https://github.com/nebw/unsupervised_behaviors/blob/master/unsupervised_behaviors/data.py

In [None]:
def get_image_and_mask_for_detections(
    detections: pd.DataFrame,
    video_manager: CustomVideoManager,
    image_size_px: int = 320,
    image_crop_px: int = 48,
    tag_mask_size_px: int = 20,
    body_center_offset_px: int = 40,
    body_mask_length_px: int = 120,
    body_mask_width_px: int = 80,
    egocentric: bool = True,
    use_clahe: bool = True,
    clahe_kernel_size_px: int = 25,
    n_jobs: int = -1,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, pd.DataFrame]:
    """Fetch image regions from cached images of raw BeesBook videos centered on detections of tagged bees. 
       Automatically generate loss masks for ellipsoid region around bee based on body orientation and tags 
       of all individuals visible in the image region. Optionally automatically rotate image region according 
       to body orientation.

    Args:
        detections (pd.DataFrame): Dataframe with detections.
        video_manager (CustomVideoManager): Manages cache.
        image_size_px (int, optional): Image size before cropping. Defaults to 320.
        image_crop_px (int, optional): Crop amount after rotation. Defaults to 48.
        tag_mask_size_px (int, optional): Size of tag mask in pixels. Defaults to 20.
        body_center_offset_px (int, optional): Offset from tag (bee coordinates) to body center. Defaults to 40.
        body_mask_length_px (int, optional): Length of body mask. Defaults to 120.
        body_mask_width_px (int, optional): Width of body mask. Defaults to 80.
        egocentric (bool, optional): Whether to rotate the frames so that the focus bee is egocentrically aligned. 
            Defaults to True.
        use_clahe (bool, optional): Process entire frame using CLAHE. Defaults to True.
        clahe_kernel_size_px (int, optional): Kernel size for CLAHE. Defaults to 25.
        n_jobs (int, optional): Number of parallel jobs for processing. Defaults to -1.

    Returns:
        Tuple[np.ndarray, np.ndarray, np.ndarray, pd.DataFrame]: Images, tag masks, body masks and the dataframe adjusted 
            to the new crop region
    """

    def rotate_crop(image: np.array, rotation_deg: float) -> np.array:
        image = skimage.transform.rotate(image, rotation_deg)
        image = image[image_crop_px:-image_crop_px, image_crop_px:-image_crop_px]
        return image


    def get_tag_mask(frame, all_detections_df):
        tag_mask = np.ones_like(frame)
        for _, row in all_detections_df.iterrows():
            tag_mask[
                skimage.draw.disk((row.y_pixels, row.x_pixels), tag_mask_size_px, shape=frame.shape)
            ] = 0
        return tag_mask


    def extract_images_from_frame(
        frame_detections: pd.DataFrame, 
        frame_path: str,
        egocentric: bool = True,
    ) -> Tuple[np.array, np.array, np.array, np.array]:
        images = []
        tag_masks = []
        body_masks = []
        rows = []

        assert frame_detections.frame_id.nunique() == 1

        frame = imageio.v3.imread(frame_path, plugin="opencv", colorspace="GRAY")
        if use_clahe:
            frame = skimage.exposure.equalize_adapthist(frame, kernel_size=(clahe_kernel_size_px, clahe_kernel_size_px))

        tag_mask = get_tag_mask(frame, frame_detections)

        for _, row in frame_detections.iterrows():
            center_x = row.x_pixels - np.cos(row.orientation_hive) * body_center_offset_px
            center_y = row.y_pixels - np.sin(row.orientation_hive) * body_center_offset_px

            # assert center point always within frame (even after trajectory extrapolation)
            center_x = max(0, center_x)
            center_x = min(frame.shape[1] - 1, center_x)
            center_y = max(0, center_y)
            center_y = min(frame.shape[0] - 1, center_y)

            center = np.array((center_x, center_y))

            if egocentric:
                # so that bee is facing to the right (row.orientation_hive + np.pi / 2 for facing upwards)
                rotation_deg = (1 / (2 * np.pi)) * 360 * row.orientation_hive
            else:
                rotation_deg = 0

            image = bb_behavior.utils.images.get_crop_from_image(
                center, frame, width=image_size_px, clahe=False
            )
            image = (rotate_crop(image, rotation_deg) * 255).astype(np.uint8)

            body_mask = np.zeros_like(frame)
            body_coords = skimage.draw.ellipse(
                center[1],
                center[0],
                body_mask_length_px,
                body_mask_width_px,
                rotation=-(row.orientation_hive - np.pi / 2),
                shape=frame.shape,
            )
            body_mask[body_coords] = 1
            body_mask = (
                bb_behavior.utils.images.get_crop_from_image(
                    center, body_mask, width=image_size_px, clahe=False
                )
                == 255
            )
            body_mask = rotate_crop(body_mask, rotation_deg) > 0.5
            
            task_mask = (
                bb_behavior.utils.images.get_crop_from_image(
                    center, tag_mask, width=image_size_px, clahe=False
                )
                == 255
            )
            task_mask = rotate_crop(task_mask, rotation_deg) > 0.5

            
            images.append(image)
            body_masks.append(body_mask)
            tag_masks.append(task_mask)
            rows.append(row.values)

        return images, tag_masks, body_masks, rows

    if len(detections.index) == 0:
        return

    images = []
    tag_masks = []
    body_masks = []
    rows = []

    logger.debug(f'Detection count in video: {len(detections.index)}')
    
    detections_by_frame = detections.groupby("frame_id")

    # preload file paths because video_manager can't be used in parallel.
    frame_paths = []
    for _, frame_detections in detections_by_frame:
        frame_paths.append(video_manager.get_frame_id_path(frame_detections.frame_id.iloc[0]))
    
    logger.debug(f"Processing {len(frame_paths)} cached images ...")
    parallel = joblib.Parallel(prefer="processes", n_jobs=n_jobs)(
        joblib.delayed(extract_images_from_frame)(
            frame_detections, 
            frame_path, 
            egocentric=egocentric
        )
        for (_, frame_detections), frame_path in zip(detections_by_frame, frame_paths)
    )

    logger.debug("Processing of cached images complete.")

    for results in parallel:
        images += results[0]
        tag_masks += results[1]
        body_masks += results[2]
        rows += results[3]

    images = np.stack(images)
    tag_masks = np.stack(tag_masks)
    body_masks = np.stack(body_masks)
    detections = pd.DataFrame(np.stack(rows), columns=detections.columns)

    return images, tag_masks, body_masks, detections

# Cache frames, extract and process them and write to videos
Main tasks:
1. Cache all frames with bees in them from the videos. 
2. From the cached frames extract all cropped and rotated frames showing the individual bees in the frames.
3. Apply the masks and write the extracted frames to videos for each individual bee.

In [None]:
def find_continuous_ranges(arr: np.ndarray) -> list:
    """Finds starts and ends of continuous ranges (no value gaps) in array

    Args:
        arr (np.ndarray): Input array

    Returns:
        list: List of Tuples containing the start and end index of each continuous range in arr
    """
    diffs = np.diff(arr)

    # Indices where a gap occurs
    gap_indices = np.where(diffs > 1)[0]
    
    # Start and end of continuous segments
    starts = np.concatenate(([0], gap_indices + 1))
    ends = np.concatenate((gap_indices + 1, [len(arr)]))

    return list(zip(starts, ends))


def extract_all_videos(
    frame_id_df: pd.DataFrame, 
    min_frames: int = 30
):
    """Extracts cropped videos of bees with enough consecutive frame detections

    Args:
        frame_id_df (pd.DataFrame): Dataframe with detections
        min_frames (int, optional): Minimum number of consecutive frames required to generate a video. 
            Defaults to 30.
    """

    df_grouped = frame_id_df.groupby('video_filename', sort=False)
    for video_filename, df_video in df_grouped:
        logger.info(f"Processing video {video_filename} ...")

        video_manager.cache_frames(
            frame_ids=np.sort(df_video['frame_id'].unique()), 
            video_name=video_filename, 
            frame_indices=np.sort(df_video['frame_index'].unique())
        )

        # get images and masks for all detections in one video
        images, tag_masks, body_masks, df_video_after = get_image_and_mask_for_detections(
            detections=df_video, 
            video_manager=video_manager, 
            n_jobs=N_JOBS
        )

        # apply the masks
        images_masked = images * tag_masks
        
        if APPLY_BODY_MASKS:
            images_masked *= body_masks

        # sort the df by bee_id and timestamp and apply the same sorting to the images
        df_video_sorted = df_video_after.sort_values(['bee_id', 'timestamp'])
        sorting_indices = df_video_sorted.index
        images_sorted = [images_masked[i] for i in sorting_indices]

        df_video_grouped = df_video_sorted.groupby('bee_id', sort=False)
        # dict{bee_id -> indices}
        df_video_grouped_indices = df_video_grouped.indices
        df_video_grouped_indices_iter = iter(df_video_grouped_indices)
        
        for bee_id, df_bee_video in df_video_grouped:
            # there can be frame gaps in the data, where the bee was not detected
            frame_indices = df_bee_video["frame_index"].to_numpy()
            ranges = find_continuous_ranges(frame_indices)

            cam_str = "cam-" + str(df_bee_video.cam_id.iloc[0])
            indices_by_bee_id = df_video_grouped_indices.get(next(df_video_grouped_indices_iter))

            # for every continous series of frames for this bee
            for (start, stop) in ranges:
                frame_cnt = stop - start
                start_time = df_bee_video.timestamp.iloc[start].strftime('%Y-%m-%dT%H.%M.%S')
                end_time = df_bee_video.timestamp.iloc[stop-1].strftime('%Y-%m-%dT%H.%M.%S')
                filename = "_".join((cam_str, str(bee_id), (start_time + "--" + end_time), str(frame_cnt) + "frames" + ".mp4"))
                images_bee = [images_sorted[i] for i in indices_by_bee_id[start:stop]]
                if len(images_bee) < min_frames:
                    continue
                video_manager.write_to_video(images=images_bee, filename=filename, frame_rate=FRAME_RATE)
        video_manager.clear_video_cache()
    
    
extract_all_videos(frame_id_df)