## Reproduces Extended Data Figure 2 in Lauer et al., Nature Methods 2022

- note, to preserve benchmark integrity, we load only the 70% training set data for any ground truth (GT) aspects of the plots, whereas the Extended Data in the paper shows statistics for the full datasets.

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
sns.set_theme(style='ticks')

from scipy.spatial import cKDTree
from scipy.spatial.distance import cdist


def calc_proximity_and_visibility_indices(hdf_file):
    df = pd.read_hdf(hdf_file)
    df = df.droplevel('scorer', axis=1).dropna(axis=1, how='all')
    if 'single' in df:
        df.drop('single', axis=1, level=0, inplace=True)
    n_animals = len(df.columns.get_level_values('individuals').unique())
    temp = df.groupby('individuals', axis=1).count()
    mask = temp >= 2 * 2
    counts = mask.sum(axis=1)
    viz = counts / n_animals
    coords = df.to_numpy().reshape((df.shape[0], n_animals, -1, 2))
    centroids = np.expand_dims(np.nanmean(coords, axis=2), 2)
    index = np.zeros(coords.shape[:2])
    for i in range(coords.shape[0]):
        c = centroids[i]
        n_detected_animals = np.isfinite(c).all(axis=2).sum()
        if n_detected_animals < 2:
            continue
        xy = coords[i]
        radii = np.zeros(coords.shape[1])
        for j in range(coords.shape[1]):
            radii[j] = np.nanmax(cdist(c[j], xy[j]))
        tree = cKDTree(xy.reshape((-1, 2)))
        n_all = tree.query_ball_point(
            c.squeeze(), np.ceil(radii), return_length=True,
        )
        n = np.isfinite(xy).any(axis=2).sum(axis=1)
        index[i] = (n_all - n) / n
    prox = np.nanmean(index, axis=1)
    prox = prox[~np.isnan(prox)]
    return prox, viz

Fig S2

In [None]:
p1, _ = calc_proximity_and_visibility_indices('gt_trimice.h5')
p2, _ = calc_proximity_and_visibility_indices('gt_pups.h5')
p3, _ = calc_proximity_and_visibility_indices('gt_marmosets.h5')
p4, _ = calc_proximity_and_visibility_indices('gt_fish.h5')

fig, axes = plt.subplots(1, 4, tight_layout=True, figsize=(9, 2))
sns.histplot(
    x=p1, bins=51, ax=axes[0], stat='probability',
    color=(206/255, 101/255, 41/255),
)
sns.histplot(
    x=p2, bins=51, ax=axes[1], stat='probability',
    color=(68/255, 145/255, 90/255),
)
sns.histplot(
    x=p3, bins=51, ax=axes[2], stat='probability',
    color=(199/255, 57/255, 122/255),
)
sns.histplot(
    x=p4, bins=51, ax=axes[3], stat='probability',
    color=(72/255, 132/255, 175/255),
)
axes[1].set_xlim(*axes[2].get_xlim())
axes[1].set_xlim(-0.0825, 1.0825)
for ax in axes[1:]:
    ax.set_ylabel('')