In [1]:
from pathlib import Path
import numpy as np
import pandas as pd
from scipy.stats import kruskal, gaussian_kde
from scikit_posthocs import posthoc_dunn

from mplex import Grid
import seaborn as sns
from bino_utils.plotting import add_ocular_condition_symbol
from bino_utils import load_config

In [2]:
def p_value_to_asterisks(p, prob=(5e-2, 1e-2, 1e-3)):
    return np.less_equal.outer(p, prob).sum(-1)

In [3]:
config = load_config()
palette = config["palettes"]["ocular_conditions"]
conds = list("lrb")
figh = 100
fmt = ".pdf"
out_dir = Path("outputs")
out_dir.mkdir(exist_ok=True)

In [None]:
data_dir = Path("../../data/behavior")
df_pc = {}
df_session = {}

for path in data_dir.glob("*.h5"):
    try:
        fish_id = path.stem
        df_pc[fish_id] = pd.read_hdf(path, "pc")
        df_session[fish_id] = pd.read_hdf(path, "session")
    except KeyError:
        pass

df_pc = pd.concat(df_pc, names=["fish_id", ""]).reset_index().drop("", axis=1)
df_pc = df_pc[~df_pc["laterality"].isna()]
df_pc["laterality"] = np.rad2deg(df_pc["laterality"])

df_session = pd.concat(df_session, names=["fish_id", ""]).reset_index().drop("", axis=1)
session_count = df_session.groupby("fish_id").count()["session"]
complete_fish = session_count[session_count.eq(6)].index.values

df_pc = df_pc[df_pc["fish_id"].isin(complete_fish)]
visible = ~(
    (df_pc["condition"].eq("l") & df_pc["azimuth_int"].ge(20))
    | df_pc["condition"].eq("r") & df_pc["azimuth_int"].le(-20)
)
df_stats = (
    df_pc.groupby(["fish_id", "condition", "azimuth_int"])
    .count()
    .unstack(fill_value=0)
    .unstack(fill_value=0)
    .stack()
    .stack()["session"]
    .to_frame(name="count")
    .reset_index()
)
df_stats["percentage"] = df_stats["count"] / 8 * 100
df_stats.loc[df_stats["azimuth_int"].eq(0), "percentage"] /= 2

In [None]:
space = 10
n_rows = len(conds)
g = Grid(
    (figh, (figh - space * (n_rows - 1)) / n_rows),
    (n_rows, 1),
    space=space,
    spines="lb",
)

for i, c in enumerate(conds):
    ax = g[i, 0]
    sns.barplot(
        df_stats[df_stats["condition"].eq(c)],
        x="azimuth_int",
        y="percentage",
        color=palette[c],
        width=0.9,
        capsize=0.3,
        ax=ax,
    )
    ax.set_xlabel(None)
    ax.set_ylabel(None)
    ax.set_xlim(-0.6, 10.6)
    ax.spines["bottom"].set_bounds(-0.45, 10.45)

g[1, 0].set_ylabel("Prey capture rate (%)")

ax.set_ylim([0, 60])
ax.set_yticks([0, 60])
ax.set_xticks([1, 3, 5, 7, 9])
ax.set_xticks(np.arange(0, 11), minor=True)
ax.set_xticklabels(["$-40$", "$-20$", "$0$", "$20$", "$40$"])
ax.set_xlabel("Stimulus azimuth (°)")

for i, c in enumerate("lrb"):
    add_ocular_condition_symbol(
        c, 0.5, 1, va="center", ha="center", scale=0.6, ax=g[i, 0]
    )

if fmt:
    g.savefig((out_dir / "prey_capture_rate_vs_stimulus_azimuth").with_suffix(fmt))

In [6]:
results = {}
for azimuth, df in df_pc[visible].groupby("azimuth_int"):
    pval = kruskal(*[i[1] for i in df.groupby("condition")["laterality"]]).pvalue
    if pval < 0.05:
        pvals = posthoc_dunn(df, "laterality", "condition")["b"][["l", "r"]]
        for condition, p in pvals[pvals < 0.05].items():
            results[condition, azimuth] = p
results = pd.Series(results)
results.index.names = ["condition", "azimuth"]

In [None]:
g = Grid((105, 105), spines="lb")
ax = g.item()
ax.set_xlabel("Stimulus azimuth (°)")
ax.set_ylabel("Bout angle (°)")

ax = sns.lineplot(
    data=df_pc[visible],
    x="azimuth_int",
    y="laterality",
    hue="condition",
    hue_order=conds,
    estimator="median",
    errorbar=("ci", 95),
    err_kws=dict(capsize=1, capthick=0.5),
    err_style="bars",
    marker="o",
    markersize=2,
    mec="none",
)
ax.get_legend().set_visible(False)
vmax = 25
ax.set_xmargin(4 / figh)
ax.set_xticks(np.arange(-40, 60, 20))
ax.set_xticks(np.arange(-50, 60, 10), minor=True)
ax.spines["bottom"].set_bounds(-50, 50)

ax.set_ylim(-vmax, vmax)
ax.set_yticks(np.arange(-20, 30, 10))
ax.set_yticks(np.arange(-25, 30, 5), minor=True)
ax.spines["left"].set_bounds(-vmax, vmax)

for (c, a), pval in results.items():
    print(c, a, pval)
    s = "*" * p_value_to_asterisks(pval)
    line = ax.containers[conds.index(c)].lines[1][int(c == "r")]
    y = line.get_ydata()[line.get_xdata() == a].item() - 0.012 * vmax * 2
    ax.add_text(a, y, s, ha="c", va="b" if c == "r" else "t", color=palette[c])

for ic, c in enumerate(conds):
    add_ocular_condition_symbol(c, 4 / figh, 1 - ic * 0.075, scale=0.6, va="t", ha="l")

if fmt:
    g.savefig((out_dir / "bout_angle_vs_stimulus_azimuth").with_suffix(fmt))

In [8]:
step = 1e-2
x = np.arange(-30, 30 + step, step)
lat = df_pc.loc[df_pc["azimuth_int"].eq(0)].set_index("condition")["laterality"]
density = {c: gaussian_kde(lat.loc[c], 0.2)(x) for c in conds}
density["averaging"] = gaussian_kde(
    np.add.outer(lat["l"].values, lat["r"].values).ravel() / 2, 0.2
)(x)
density["wta"] = (density["l"] + density["r"]) / 2
df_kde = {k: pd.DataFrame(dict(laterality=x, density=v)) for k, v in density.items()}
df_kde = pd.concat(df_kde, names=["condition", ""]).reset_index().drop("", axis=1)

In [None]:
space = 10
n_cols = 2
n_rows = 3
size = ((figh - (n_cols - 1)) / n_cols, (figh - (n_rows - 1)) / n_rows)
g = Grid(size, (n_rows, n_cols), space=space, spines="lb")
g[0, 1].set_visible(False)
axs = np.concatenate([g.axs[:, 0], g.axs[1:, 1]])
g.set_ylim(0, df_kde["density"].max())
g.set_xticks([-30, 0, 30])

for ax, c in zip(axs, [*conds, "wta", "averaging"]):
    df = df_kde[df_kde["condition"].eq(c)]
    xc = df["laterality"]
    yc = df["density"]
    color = dict(palette, wta="k", averaging="k")[c]
    ax.plot(xc, yc, color=color, lw=0.75)
    ax.fill_between(xc, yc * 0, yc, color=color, alpha=0.3)

    if c in conds:
        add_ocular_condition_symbol(c, 1, 1, scale=0.6, ha="r", va="t", ax=ax)

for ax, label in zip(axs[[-2, -1]], ("WTA", "Avg")):
    ax.plot(
        x,
        density["b"],
        color="C2",
        zorder=-1,
        alpha=0.7,
        ls="--",
        lw=0.75,
        clip_on=False,
    )
    ax.fill_between(
        x, x * 0, density["b"], color="C2", zorder=-1, alpha=0.3 * 0.7, lw=0
    )
    ax.text(1, 1, label, ha="right", va="top", fontsize=7, transform=ax.transAxes)

g[1, 0].set_ylabel("Probability density")
g.fig.supxlabel("Bout angle (°)", y=-0.2)

if fmt:
    g.savefig((out_dir / "bout_angle_distributions").with_suffix(fmt))