In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd
from pathlib import Path
from Sociability_Learning.utils_files import get_10min_control_dataset

pixel_to_mm = 32 / 832
fps = 80
w = 40
assert w % 2 == 0
arange = np.arange(-w // 2, w // 2)

base_data_dir = Path(
    "/mnt/upramdya_files/LOBATO_RIOS_Victor/Experimental_data/Optogenetics/Optobot/"
)
outputs_dir = Path("../outputs/")
figures_dir = outputs_dir / "figures"
figures_dir.mkdir(exist_ok=True, parents=True)

In [None]:
df = get_10min_control_dataset(base_data_dir, "../data/10min_control_old.h5")

To generate a behavioral space for the proximity events, we first computed time series describing how flies were positioned and moved relative to each other:
1. $d\in\mathbb{R}_{\geq 0}$: the distance between two flies
2. $v^\text{to.}_1\in\mathbb{R}$: the velocity component of fly 1 toward fly 2
3. $v^\perp_1\in\mathbb{R}_{\geq 0}$: the absolute value of the velocity component of fly 1 perpendicular to the line connecting the two flies
4. $\theta_1\in[0,\pi)$: the magnitude of the smaller angle between the heading vector of fly 1 and the line connecting the two flies
5. $v^\text{to.}_2\in\mathbb{R}$: the velocity component of fly 2 toward fly 1
6. $v^\perp_2\in\mathbb{R}_{\geq 0}$: the absolute value of the velocity component of fly 2 perpendicular to the line connecting the two flies
7. $\theta_2\in[0,\pi)$: the magnitude of the smaller angle between the heading vector of fly 2 and the line connecting the two flies

For each proximity event $i$, we extracted from each time series a $w$-frame time window centered at the frame index $t_i$ when the two flies were closest to each other, and concatenated the time windows together. This resulted in an expansion of the number of dimensions from $7$ to $7w$:
\begin{equation*}
    \mathbf{x}_i=\text{vec}\left(\begin{bmatrix}
        d(t_i-w/2) & \ldots & d(t_i+w/2-1) \\
        \vdots & \ddots & \vdots \\
        \theta_2(t_i-w/2) & \ldots & \theta_2(t_i+w/2-1)
    \end{bmatrix}\right)\in \mathbb{R}^{7w}
\end{equation*}

In [None]:
def get_features(fly1: pd.DataFrame, fly2: pd.DataFrame, frame_indices: np.ndarray):
    """Compute features for a pair of flies.

    Parameters
    ----------
    fly1 : pd.DataFrame
        DataFrame with the position of the body parts of fly 1.
        Coordinates are encoded as complex numbers
        (real and imaginary parts are x- and y-coordinates, respectively).
    fly2 : pd.DataFrame
        DataFrame with the position of the body parts of fly 2.
        Coordinates are encoded as complex numbers
        (real and imaginary parts are x- and y-coordinates, respectively).
    frame_indices : np.ndarray
        Frame indices around which time windows are extracted.

    Returns
    -------
    pd.DataFrame
        DataFrame with features for the pair of flies.
    """
    from scipy.ndimage import gaussian_filter1d
    from itertools import product

    # position of the flies
    p1 = np.nanmean(fly1[["head", "thorax", "abdomen"]].values, axis=1) * pixel_to_mm
    p2 = np.nanmean(fly2[["head", "thorax", "abdomen"]].values, axis=1) * pixel_to_mm

    p1p2 = p2 - p1  # vector from fly1 to fly2
    dist = np.abs(p1p2)  # distance between flies
    p1p2 /= dist  # normalize

    # velocity of the flies in image coordinates
    v1 = gaussian_filter1d(p1, 2, order=1, mode="nearest") * fps
    v2 = gaussian_filter1d(p2, 2, order=1, mode="nearest") * fps

    # rotate velocities so that 1 + 0j points towards the other fly
    v1 /= p1p2
    v2 /= -p1p2

    # heading of the flies in image coordinates
    heading1 = fly1["head"].values * pixel_to_mm - p1
    heading2 = fly2["head"].values * pixel_to_mm - p2

    # rotate headings so that 1 + 0j points towards the other fly
    heading1 /= np.abs(heading1)
    heading2 /= np.abs(heading2)

    # convert complex numbers to angles and take absolute value
    theta1 = np.abs(np.angle(heading1 / p1p2))
    theta2 = np.abs(np.angle(heading2 / -p1p2))

    # column-stack all features
    X = np.column_stack(
        [dist, v1.real, np.abs(v1.imag), theta1, v2.real, np.abs(v2.imag), theta2]
    )

    # get time windows of length win_len starting from movement onsets
    X = X[frame_indices[:, None] + arange].transpose((0, 2, 1))

    columns = pd.MultiIndex.from_tuples(
        product(["d", "vt1", "abs_vp1", "theta1", "vt2", "abs_vp2", "theta2"], arange)
    )
    return pd.DataFrame(X.reshape((len(X), -1)), frame_indices, columns)


df["features"] = {}

for key, df_flies in df["data"].groupby(["datetime", "arena"]):
    frame_indices = df["clips"].loc[key, "ind_min_dist"].values
    df["features"][key] = get_features(df_flies["l"], df_flies["r"], frame_indices)

df["features"] = pd.concat(df["features"], names=["datetime", "arena", "frame"])

In [None]:
def z_normalize(X: np.ndarray):
    """Apply z-normalization to X in place.

    Parameters
    ----------
    X : np.ndarray
        Array to be z-normalized.
    """
    X -= X.mean()
    X /= X.std()


X = df["features"].values.copy()

for i in range(0, X.shape[1], w):
    z_normalize(X[:, i : i + w])

z_normalize(X[:, :40])
z_normalize(X[:, [*range(40, 80), *range(160, 200)]])
z_normalize(X[:, [*range(80, 120), *range(200, 240)]])
z_normalize(X[:, [*range(120, 160), *range(240, 280)]])

In [None]:
from umap import UMAP
import numba
from pynndescent.distances import euclidean


@numba.njit
def flip(x):
    x = np.asarray(x)
    return np.concatenate((x[..., :40], x[..., 160:], x[..., 40:160]))


@numba.njit
def my_dist(a, b):
    return min(euclidean(a, b), euclidean(a, flip(b)))


umap = UMAP(n_components=2, n_neighbors=100, metric=my_dist, random_state=0)
Z = umap.fit_transform(X)

In [None]:
def rotate_embedding(Z):
    from sklearn.decomposition import PCA

    Z = Z - Z.mean(0)
    return Z @ PCA(n_components=2).fit(Z).components_.T


Z = rotate_embedding(Z) * (-1, 1)

In [None]:
from sklearn.cluster import KMeans

n_clusters = 20
kmeans = KMeans(n_clusters, random_state=0)
labels = kmeans.fit_predict(Z)

In [None]:
coms = np.array([Z[labels == k].mean(0) for k in range(n_clusters)])

from mplex import Grid

ax = Grid(100).item()
ax.axis("off")
ax.set_xlim(coms[:, 0].min(), coms[:, 0].max())
ax.set_ylim(coms[:, 1].min(), coms[:, 1].max())
ax.set_aspect("equal")

for i, (x, y) in enumerate(coms):
    ax.text(x, y, i, ha="center", va="center", fontsize=5)

In [None]:
# manually reorder clusters
order = np.argsort(
    [14, 9, 15, 4, 17, 6, 12, 18, 1, 13, 16, 3, 10, 2, 11, 0, 7, 5, 8, 19],
)
assert (np.unique(order) == np.arange(n_clusters)).all()

new_labels = np.zeros_like(labels)

for k in range(n_clusters):
    new_labels[labels == k] = order[k]

labels = new_labels

In [None]:
coms = np.array([Z[labels == k].mean(0) for k in range(n_clusters)])

ax = Grid(100).item()
ax.axis("off")
ax.set_xlim(coms[:, 0].min(), coms[:, 0].max())
ax.set_ylim(coms[:, 1].min(), coms[:, 1].max())
ax.set_aspect("equal")

for i, (x, y) in enumerate(coms):
    ax.text(x, y, i, ha="center", va="center", fontsize=5)

In [None]:
need_flip = np.zeros(len(X), dtype=bool)

for k in range(n_clusters):
    idx = np.where(labels == k)[0]
    Xk = X[idx].copy()
    Zk = Z[idx]
    centroid_id = np.argmin(np.linalg.norm(Zk - kmeans.cluster_centers_[k], axis=1))
    centroid = Xk[centroid_id]

    for i, x in enumerate(Xk):
        if euclidean(centroid, x) < euclidean(centroid, flip(x)):
            need_flip[idx[i]] = False
            Xk[i] = x
        else:
            need_flip[idx[i]] = True
            Xk[i] = flip(x)

    mean = Xk.mean(0)

    if mean[40:80].mean() < mean[160:200].mean():
        need_flip[idx] = ~need_flip[idx]

In [None]:
f = df["features"].columns.get_level_values(0).unique()
F = df["features"].copy()
F = np.array([flip(f) if nf else f for f, nf in zip(F.values, need_flip)])
F = F.reshape((F.shape[0], len(f), -1))

In [None]:
import colorcet as cc
from matplotlib.colors import ListedColormap

feature_order = [0, 1, 4, 2, 5, 3, 6]

palette = ListedColormap(cc.rainbow)(np.linspace(0, 1, 20))
g = Grid((16, 20), (len(f), n_clusters), sharey="row", space=(2, 6), dpi=144)
g[:, :].set_visible_sides("")
g[:, 0].set_visible_sides("l")
ylabels = [
    "$d$\n(mm)",
    "$v^\\text{ to.}_1$\n(mm/s)",
    "$v^\\perp_1$\n(mm/s)",
    "$\\theta_1$\n(rad)",
    "$v^\\text{ to.}_2$\n(mm/s)",
    "$v^\\perp_2$\n(mm/s)",
    "$\\theta_2$\n(rad)",
]  # noqa
t = arange / fps

for k in range(n_clusters):
    Fk = F[labels == k]

    c = palette[k]
    for j in range(len(f)):
        i = feature_order[j]
        feature = Fk[:, i].mean(0)

        if k == 0:
            ax = g[j, :].make_ax(sharey=True)
            ax.axhline(0, color="k", ls="--", clip_on=False)
            ax.axis("on")
            ax.set_visible_sides("")
            ax.add_text(0, 0.5, ylabels[i], ha="c", va="c", transform="a", pad=(-23, 0))

        ax = g.axs[j, k]
        ax.set_xlim(t[0], t[-1])
        ax.plot(t, Fk[:, i].mean(0), c=c, alpha=1, lw=1, clip_on=False)
        mean = Fk[:, i].mean(0)
        std = Fk[:, i].std(0)

        ax.fill_between(
            t, mean - std, mean + std, color=c, alpha=0.5, lw=0, clip_on=False
        )
        if i == 0:
            ax.set_title(f"{k+1}", color="k")

g[0, 0].set_ylim(0, 10)
g[1, 0].set_ylim(-40, 40)
g[2, 0].set_ylim(-40, 40)
g[3, 0].set_ylim(0, 40)
g[4, 0].set_ylim(0, 40)
g[5, 0].set_ylim(0, np.pi)
g[6, 0].set_ylim(0, np.pi)

g[0, 0].set_yticks([0, 10])
g[1, 0].set_yticks([-40, 0, 40])
g[2, 0].set_yticks([-40, 0, 40])
g[3, 0].set_yticks([0, 40])
g[4, 0].set_yticks([0, 40])
g[5, 0].set_yticks([0, np.pi], labels=["0", r"$\pi$"])
g[6, 0].set_yticks([0, np.pi], labels=["0", r"$\pi$"])
g.savefig(figures_dir / "edfig2d.pdf", transparent=True)

In [None]:
bound = 4.5
n_bins = 512


def get_areas(bound, n_bins):
    from Sociability_Learning.utils_embedding import get_kde
    from scipy.ndimage import binary_fill_holes

    Y, X = np.mgrid[-bound : bound : n_bins * 1j, -bound : bound : n_bins * 1j]
    mg = np.stack((X, Y), axis=-1)[::-1, :, None]
    centers = kmeans.cluster_centers_[order.argsort()]
    im_argmin = np.linalg.norm(mg - centers, axis=-1).argmin(-1).astype(np.float32)
    im_argmin[~binary_fill_holes(get_kde(Z, n_bins, bound)[0] > 2e-3)] = -1
    return im_argmin


im_regions = get_areas(bound, n_bins)

In [None]:
from scipy.ndimage import binary_fill_holes

In [None]:
from Sociability_Learning.utils_embedding import get_bbox

bbox = get_bbox(im_regions != -1) + (-16, 16)
ylim, xlim = (bbox / 512 - 0.5) * bound * 2

In [None]:
import seaborn as sns
from matplotlib.colors import to_hex

conds = df["clips"]["condition"].values

axw = 90
axh = axw * (ylim[1] - ylim[0]) / (xlim[1] - xlim[0])

g = Grid((axw, axh), (2, 1), space=(6.716 + 5.487) / 10 / 2.54 * 72, facecolor="w")
g.set_visible_sides("")

for i, c in enumerate("ig"):
    x, y = Z[conds == c].T
    cmap = {"i": "Blues", "g": "Oranges"}[c]
    ax = g[i, 0]
    sns.kdeplot(x=x, y=y, bw_method=0.1, cmap=cmap, fill=True, levels=100, ax=ax)
    ax.scatter(x, y, s=3, c=f"C{i}", lw=0)
    ax.contour(
        im_regions + 1,
        levels=np.arange(n_clusters + 1),
        colors=to_hex((0.3,) * 3),
        extent=(-bound, bound, bound, -bound),
        antialiased=True,
        linewidths=0.5,
    )
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    ax.set_aspect("equal")
    ax.add_text(
        0.5,
        1,
        ["Single\nhoused", "Group\nhoused"][i],
        ha="c",
        va="c",
        transform="a",
        c=f"C{i}",
        size=7,
        pad=(0, 5),
    )

ax = g[0, 0]
ax.add_scale_bars(
    xlim[1] - 0.7,
    ylim[0],
    -2.1,
    2.1,
    xlabel="UMAP 1",
    ylabel="UMAP 2",
    fmt="",
    pad=(2, -6),
)

Path("outputs").mkdir(exist_ok=True)
g.savefig(figures_dir / "fig1ef.pdf", transparent=True)

In [None]:
from scipy.ndimage import center_of_mass

coms = (
    np.array([center_of_mass(im_regions == k) for k in range(n_clusters)])
    / im_regions.shape
    * bound
    * 2
    - bound
)
g = Grid(40 / 10 / 2.54 * 72)
ax = g.item()
ax.axis("off")
ax.set_aspect("equal")

ax.contour(
    im_regions + 1,
    levels=np.arange(n_clusters + 1),
    colors="k",
    extent=(-bound, bound, bound, -bound),
    antialiased=True,
    linewidths=0.5,
)

ax.contourf(
    im_regions + 1,
    levels=np.arange(n_clusters + 2) - 1,
    colors=["none", *palette, "k"],
    extent=(-bound, bound, bound, -bound),
    antialiased=True,
)

ax.set_xlim(xlim)
ax.set_ylim(ylim)

for i, (x, y) in enumerate(coms[:, ::-1] * (1, -1)):
    ci = palette[i]
    txt = ax.text(
        x, y, i + 1, ha="center", va="center", fontsize=7, c="k", weight="bold"
    )
    from mplex.text import set_text_outline

    set_text_outline(txt, 1, "w")

g.savefig(figures_dir / "edfig2c.pdf", transparent=True)

In [None]:
from mplex.colors import change_hsv

arrow_kw = dict(
    width=0.06,
    overhang=0.25,
    fc="k",
    ec="none",
    head_width=0.4,
    head_length=0.4,
    length_includes_head=True,
    zorder=100,
)


def c1(c):
    return change_hsv(c, v=0.95, s=0.6)


def c2(c):
    return change_hsv(c, v=0.7, s=1)


cdict = dict(
    d=change_hsv("goldenrod", v=0.6),
    vt1=c1("r"),
    vp1=c1("b"),
    t1=c1((0, 1, 0)),
    vt2=c2("r"),
    vp2=c2("b"),
    t2=c2((0, 1, 0)),
    v1=c1("m"),
    v2=c2("m"),
    p="gray",
)

In [None]:
def draw_flies(ax):
    import matplotlib.pyplot as plt
    from matplotlib.patches import FancyArrowPatch, Wedge
    from flyplotlib import add_fly

    def c2r(c):
        return (c.real, c.imag)

    def proj(u, v):
        return u * (v * u.conjugate()).real

    def arrow(p1, p2, *args, **kwargs):
        ax = plt.gca()
        x, y = p1.real, p1.imag
        dx, dy = (p2 - p1).real, (p2 - p1).imag
        return ax.arrow(x, y, dx, dy, *args, **kwargs)

    def line(p1, p2, *args, **kwargs):
        ax = plt.gca()
        return ax.plot([p1.real, p2.real], [p1.imag, p2.imag], *args, **kwargs)

    rh = 1
    ra = 0.7

    # thorax positions
    pt1 = 0
    pt2 = 4 - 1j

    # direction toward other fly
    u = pt2 - pt1
    u /= np.abs(u)

    # angles
    theta1 = np.pi / 6
    theta2 = np.deg2rad(-120)
    phi1 = theta1 + np.angle(pt2 - pt1)
    phi2 = theta2 + np.angle(pt2 - pt1) + np.pi

    # velocity vectors
    v1 = (0.9 + 0.6j) * 2
    v2 = (-0.1 + 1j) * 2

    ph1 = pt1 + rh * np.exp(1j * phi1)
    ph2 = pt2 + rh * np.exp(1j * phi2)

    ax.set_aspect(1)
    ax.axis("off")

    # plot velocity vectors
    arrow(pt1, pt1 + v1, **dict(arrow_kw, fc=cdict["v1"]))
    arrow(pt2, pt2 + v2, **dict(arrow_kw, fc=cdict["v2"]))
    ax.add_text(
        *c2r(pt1 + v1),
        "$(v^\\text{to.}_1,v^\\perp_1)$",
        color="k",
        ha="c",
        va="c",
        pad=(-2, 7),
        size=7,
    )
    ax.add_text(
        *c2r(pt2 + v2),
        "$(v^\\text{to.}_2,v^\\perp_2)$",
        color="k",
        ha="c",
        va="c",
        pad=(4, 6),
        size=7,
    )

    # plot component toward the other fly
    line(pt1 + v1, pt1 + proj(u, v1), cdict["p"], ls="--")
    line(pt2 + v2, pt2 + proj(u, v2), cdict["p"], ls="--")
    arrow(pt1, pt1 + proj(u, v1), **dict(arrow_kw, fc=cdict["vt1"]))
    arrow(pt2, pt2 + proj(u, v2), **dict(arrow_kw, fc=cdict["vt2"]))

    # plot component perpendicular to the direction toward other fly
    line(pt1 + v1, pt1 + proj(u * 1j, v1), cdict["p"], ls="--")
    line(pt2 + v2, pt2 + proj(u * 1j, v2), cdict["p"], ls="--")
    arrow(pt1, pt1 + proj(u * 1j, v1), **dict(arrow_kw, fc=cdict["vp1"]))
    arrow(pt2, pt2 + proj(u * 1j, v2), **dict(arrow_kw, fc=cdict["vp2"]))

    # plot angles
    ax.add_patch(
        Wedge(
            c2r(pt1),
            ra,
            np.rad2deg(np.angle(pt2 - pt1)),
            np.rad2deg(np.angle(ph1 - pt1)),
            color=cdict["t1"],
            alpha=0.7,
            lw=0,
        )
    )
    ax.add_patch(
        Wedge(
            c2r(pt2),
            ra,
            np.rad2deg(np.angle(ph2 - pt2)),
            np.rad2deg(np.angle(pt1 - pt2)),
            color=cdict["t2"],
            alpha=0.7,
            lw=0,
        )
    )
    ax.add_text(
        *c2r(pt1),
        "$\\theta_1$",
        color=cdict["t1"],
        ha="c",
        va="c",
        pad=(14.5, 2.5),
        size=7,
    )
    ax.add_text(
        *c2r(pt2),
        "$\\theta_2$",
        color=cdict["t2"],
        ha="c",
        va="c",
        pad=(-12, 7),
        size=7,
    )

    # plot heading directions
    line(pt1, ph1, "k", zorder=100, solid_capstyle="round")
    line(pt2, ph2, "k", zorder=100, solid_capstyle="round")

    # plot distance
    d = 0.25
    ax.add_patch(
        FancyArrowPatch(
            c2r(pt1 - u * 1j * d),
            c2r(pt2 - u * 1j * d),
            arrowstyle="|-|",
            mutation_scale=1,
            color=cdict["d"],
            shrinkA=0,
            shrinkB=0,
            zorder=1000,
        )
    )
    line(pt1, pt2, "k", ls="-", alpha=0.2)
    ax.add_text(
        *c2r((pt1 + pt2) / 2),
        "$d$",
        color=cdict["d"],
        ha="c",
        va="c",
        pad=(0, -8),
        size=7,
    )

    # plot movements
    rv = 0.08
    n_frames = 3

    for j, i in enumerate(np.linspace(rv, 0, n_frames)):
        add_fly(
            c2r(pt1 - v1 * i),
            np.rad2deg(phi1),
            alpha=(j + 1) / n_frames,
            grayscale=True,
            zorder=-200,
            ax=ax,
        )
        add_fly(
            c2r(pt2 - v2 * i),
            np.rad2deg(phi2),
            alpha=(j + 1) / n_frames,
            grayscale=True,
            zorder=-200,
            ax=ax,
        )

In [None]:
window_kw = dict(fc="none", ec="k", ls="--", alpha=1, clip_on=False)
clip = df["clips"].iloc[126]
df_flies = df["data"].loc[clip.name]
f = get_features(
    df_flies["l"], df_flies["r"], clip["ind_min_dist"] + np.arange(-1, 2) * w
)
f = np.concatenate(f.values.reshape((3, -1, w)), axis=-1)


def plot_traces(g_):
    w = 40
    fps = 80
    # f = np.load(figures_dir / "data/example_clip_features.npy")
    t_ = np.arange(-w // 2 - w, w // 2 + w) / fps

    ylabels = [
        "$d$",
        "$v^\\text{to.}_1$",
        "$v^\\perp_1$",
        "$\\theta_1$",
        "$v^\\text{to.}_2$",
        "$v^\\perp_2$",
        "$\\theta_2$",
    ]
    axs = g_.axs.ravel()
    for j, ax in enumerate(axs):
        i = feature_order[j]
        ci = list(cdict.values())[i]
        ax.plot(t_, f[i], c=ci)
        ax.add_text(
            0, 0.5, ylabels[i], transform="a", ha="c", va="c", pad=(-7, 0), size=6, c=ci
        )

    ax = g[:, 1].make_ax(sharex=True)
    ax.axvline(0, color="k", linestyle="-", alpha=0.3)
    ax.set_xlim(-0.35, 0.35)

    ax.axvspan(-w // 2 / fps, w // 2 / fps, **window_kw)
    ax.add_text(
        0.5,
        0,
        "40 frames (0.5 s)",
        transform="a",
        ha="c",
        va="t",
        pad=(0, -2),
        size=6,
        c="k",
    )
    return ax

In [None]:
def draw_data_matrix(ax):
    ax.axis("on")
    ax.set_xticks([])
    ax.set_yticks([])
    row_height = 0.02

    for i in range(7):
        ci = list(cdict.values())[feature_order[i]]
        ax.axvspan(i, i + 1, 0.5 - row_height, 0.5 + row_height, color=ci, lw=0)

    ax.axvspan(0, 7, 0.5 - row_height, 0.5 + row_height, **window_kw)
    ax.set_xlim(0, 7)
    for spine in ax.spines.values():
        spine.set_alpha(0.3)

    ax.add_text(
        0.5, 0, "280 dims", transform="a", ha="c", va="t", pad=(0, -2), size=6, c="k"
    )

In [None]:
def plot_umap(ax):
    from matplotlib.colors import to_hex

    bound = 4.5
    bbox = get_bbox(im_regions != -1) + (-16, 16)
    ylim, xlim = (bbox / 512 - 0.5) * bound * 2

    ax.scatter(*Z.T, s=5, c="k", alpha=0.2, marker=".", lw=0)
    ax.scatter(*Z[126], s=15, c="r", marker=".", lw=0)
    ax.set_aspect(1)
    ax.contour(
        im_regions != -1,
        colors=to_hex((0.3,) * 3),
        extent=(-bound, bound, bound, -bound),
        antialiased=True,
        linewidths=0.25,
    )

    ax.add_scale_bars(
        xlim[1] - 0.7,
        ylim[0],
        -2.1,
        2.1,
        xlabel="UMAP 1",
        ylabel="UMAP 2",
        fmt="",
        pad=(2, -6),
        size=5,
    )

    ax.set_xlim(*xlim)
    ax.set_ylim(*ylim)

In [None]:
g = Grid(
    (np.array([120, 40, 40, 80]) * 1.25, 60 / 7),
    (7, None),
    space=(16, 0),
    sharex=True,
    sharey=False,
)
g.set_visible_sides(False)
draw_flies(g[:, 0].make_ax())
plot_traces(g[:, 1])
draw_data_matrix(g[:, 2].make_ax())
plot_umap(g[:, 3].make_ax())
g.savefig(figures_dir / "edfig2ab.pdf", transparent=True)

In [None]:
axw = 110
axh = axw * (ylim[1] - ylim[0]) / (xlim[1] - xlim[0])
g = Grid((axw, axh), (2, 1), space=(62.486 - 57.409) / 25.4 * 72, facecolor="w")
axs = g.axs.ravel()
g.set_visible_sides(False)
for i, event_type in enumerate(["distancing", "standstill"]):
    bidx = df["clips"]["type"].eq(event_type)
    scatter_kw = dict(lw=0, marker=".", s=5)

    from mplex.text import add_text

    axs[i].scatter(
        *Z[bidx].T,
        c=["C1" if i else "C0" for i in df["clips"].loc[bidx, "below_threshold"]],
        **scatter_kw,
    )

for ax in axs:
    ax.contour(
        im_regions + 1,
        levels=np.arange(n_clusters + 1),
        colors=to_hex((0.3,) * 3),
        extent=(-bound, bound, bound, -bound),
        antialiased=True,
        linewidths=0.2,
    )
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    ax.set_aspect(1)

g[0, 0].set_title("Moving events")
ax = g[1, 0]
ax.set_title("Standstill events")

for s, c, xpad in zip(
    ["Above", "/", "below", "threshold"],
    ["C0", "k", "C1", "k"],
    [-30.2, -13.3, -11.5, 5.65],
):
    ax.add_text(
        0.5,
        0,
        s=s,
        va="t",
        transform="a",
        size=7,
        c=c,
        pad=(xpad / 6 * 7, 0),
    )

g.savefig(figures_dir / "edfig2e.pdf", transparent=True)