This notebook contains code for animating activity propagation throughout the hive.

### Imports

In [1]:
import os
import sys

parent_dir = os.path.abspath(os.path.join(os.getcwd(), ".."))
sys.path.append(parent_dir)

In [2]:
from pathlib import Path

import imageio.v3 as iio
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from bb_rhythm.interactions import filter_overlap
from bb_rhythm.rhythm import circadian_cosine
from PIL import Image

import path_settings
from analysis.speed_transfers import make_both_bees_focal
from analysis.velocity_change_per_location import (
    concat_grids_over_time,
    replace_time_with_hour,
    swap_focal_bee,
)

### Prepare Data

In [3]:
df = pd.read_csv(
    path_settings.INTERACTION_SIDE_0_DF_PATH_2019,
    usecols=[
        "x_pos_start_bee0",
        "x_pos_start_bee1",
        "y_pos_start_bee0",
        "y_pos_start_bee1",
        "vel_change_bee0",
        "vel_change_bee1",
        "phase_bee0",
        "phase_bee1",
        "interaction_start",
        "overlapping",
    ],
    parse_dates=["interaction_start"],
)

df = filter_overlap(df)
df.drop(columns=["overlapping"], inplace=True)
df = replace_time_with_hour(df)
df.head()

Unnamed: 0,x_pos_start_bee0,y_pos_start_bee0,x_pos_start_bee1,y_pos_start_bee1,vel_change_bee0,vel_change_bee1,phase_bee0,phase_bee1,hour
0,53.41,28.9561,54.4619,36.3934,0.106241,0.438786,-1.196545,-1.283345,3
2,59.6309,56.6406,52.3356,61.4632,1.837692,-0.537582,-0.564902,-1.279075,3
3,54.2329,72.6049,52.3356,61.4632,-1.80831,-0.501465,-0.525804,-1.279075,3
7,104.294,31.6248,98.3685,23.8711,0.426386,-0.116344,-0.211704,-1.347064,3
9,116.116,71.0969,118.672,77.516,0.397959,0.010605,-0.778305,-0.899531,3


In [4]:
# Set the bee with the higher increase in velocity to be the focal one
fast_focal_df = swap_focal_bee(df)
fast_focal_df.head()

Unnamed: 0,x_grid,y_grid,vel_change,hour
0,54,36,0.438786,3
1,60,57,1.837692,3
2,52,61,-0.501465,3
3,104,32,0.426386,3
4,116,71,0.397959,3


In [5]:
phase_df = make_both_bees_focal(
    df[
        [
            "x_pos_start_bee0",
            "y_pos_start_bee0",
            "phase_bee0",
            "x_pos_start_bee1",
            "y_pos_start_bee1",
            "phase_bee1",
        ]
    ],
    var_list=["x_pos_start", "y_pos_start", "phase"],
)
phase_df = phase_df.drop(
    columns=[col for col in phase_df.columns if "non_focal" in col]
)

# Round the positions to int values
phase_df["x_pos_start_focal"] = phase_df["x_pos_start_focal"].astype(int)
phase_df["y_pos_start_focal"] = phase_df["y_pos_start_focal"].astype(int)

# Rename for consistency
phase_df.rename(
    columns={
        "x_pos_start_focal": "x_grid",
        "y_pos_start_focal": "y_grid",
        "phase_focal": "phase",
    },
    inplace=True,
)
phase_df.head()

Unnamed: 0,x_grid,y_grid,phase
0,53,28,-1.196545
2,59,56,-0.564902
3,54,72,-0.525804
7,104,31,-0.211704
9,116,71,-0.778305


### Compute number of interactions and velocity change per location per hour

In [6]:
n_interactions_per_loc = concat_grids_over_time(
    fast_focal_df, var="vel_change", aggfunc="count", scale=True
) # Shape (n_hours, hive_hight, hive_width)
vel_change_per_loc = concat_grids_over_time(
    fast_focal_df, var="vel_change", aggfunc="median", scale=True
)  # Shape (n_hours, hive_hight, hive_width)

### Compute cosine curves with constant amplitude per location

In [None]:
def get_curves_from_median_phase(df: pd.DataFrame) -> np.array:
    """Computes the median phase per location and generates cosine curves.

    Args:
        df (pd.DataFrame): DataFrame containing x_grid, y_grid, and phase columns.

    Returns:
        np.array: 3D array of cosine curves with shape (24, height, width).
    """

    # Compute median phase at each location
    phase_grid = df.pivot_table(
        index="y_grid", columns="x_grid", values="phase", aggfunc="median"
    ).to_numpy()

    # Define time index for cosine curves
    t_index = np.arange(-12, 12)

    # Keep amplitude and offset constant to ensure range of [0,1]
    amplitude = 0.5
    offset = 0.5

    # Generate cosine curves where phase is available
    result = np.array(
        [
            [
                circadian_cosine(t_index, amplitude, phase, offset, period=24)
                if not np.isnan(phase)
                else np.full(t_index.shape, np.nan)
                for phase in row
            ]
            for row in phase_grid
        ]
    ).transpose(2, 0, 1)  # Reshape to (24, height, width)

    return result

In [8]:
curves_per_loc = get_curves_from_median_phase(phase_df)

### Create Animations

In [None]:
def create_animation(
    data: np.array, kind: str, output_format: str = "gif", fps: int = 6
) -> None:
    """
    Creates an animation (GIF or MP4) from heatmap data over time without saving frames to disk.

    Args:
        data (np.array): 3D array (n_hours x height x width) representing time-series data.
        kind (str): Key to determine the colorbar label.
        output_format (str): "gif" for GIF, "mp4" for high-quality video.
        fps (int): Frames per second for the animation.
    """
    # Ensure directory exists
    animation_dir = Path("animations")
    animation_dir.mkdir(exist_ok=True)

    labels = {
        "fit": "Value of fitted cosine curve (scaled)",
        "vel_change": "Post-interaction velocity change (scaled)",
        "n_interactions": "Number of interactions (scaled)",
    }

    frames = []
    n_hours = data.shape[0]

    for h in range(n_hours):
        fig, ax = plt.subplots(figsize=(8, 5), dpi=96)

        # Get hour in AM/PM format
        timelabel = f"{h} am" if h < 12 else f"{h - 12} pm" if h > 12 else "12 pm"

        sns.heatmap(
            data[h, :, :],
            xticklabels=50,
            yticklabels=50,
            cmap="rocket",
            cbar=True,
            cbar_kws={"label": labels[kind]},
            alpha=0.95,
            square=True,
            ax=ax,
            vmin=0,
            vmax=1,
            rasterized=True,
            robust=True,
        )
        ax.set_xlabel("x position [mm]")
        ax.set_ylabel("y position [mm]")
        ax.set_title(timelabel)

        # Convert figure to an image in memory
        fig.canvas.draw()
        image = Image.frombytes(
            "RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb()
        )
        frames.append(image)

        plt.close()

    output_path = animation_dir / f"flow_{kind}.{output_format}"

    if output_format == "gif":
        frames[0].save(
            output_path,
            save_all=True,
            append_images=frames[1:],
            duration=int(1000 / fps),  # Convert FPS to frame duration in milliseconds
            loop=0,
        )

    elif output_format == "mp4":
        n_loops = 5
        # Convert frames to numpy arrays for MP4 creation
        frames_np = [np.array(img) for img in frames] * n_loops
        iio.imwrite(
            output_path,
            frames_np,
            fps=fps,
            plugin="FFMPEG",
            quality=9,
        )

    print(f"Animation saved as {output_path}")

In [20]:
create_animation(data=n_interactions_per_loc, kind="n_interactions", output_format="gif")

Animation saved as flow_n_interactions.gif


<img src="animations/flow_n_interactions.gif" width="650" align="center">

In [26]:
create_animation(data=vel_change_per_loc, kind="vel_change", output_format="gif")

Animation saved as flow_vel_change.gif


<img src="animations/flow_vel_change.gif" width="650" align="center">

In [24]:
create_animation(data=curves_per_loc, kind="fit", output_format="gif")

Animation saved as flow_fit.gif


<img src="animations/flow_fit.gif" width="650" align="center">