In [5]:
from pathlib import Path
import numpy as np
import pandas as pd
from mplex.axes_collection import AxArray
from bg_space_extra import AnatomicalPoints
from bino_utils.atlas import Atlas
from bino_utils.plotting import BrainGrid, add_ocular_condition_symbol

In [6]:
out_dir = Path("outputs")
out_dir.mkdir(exist_ok=True)
fmt = ".pdf"
data_dir = Path("../../data/lensectomy")

atlas = Atlas()

In [7]:
df = (
    pd.read_hdf(data_dir / "regression.h5")
    .swaplevel(axis=1)["r2"]
    .reset_index()
    .drop("local_id", axis=1)
)
points = AnatomicalPoints("sal", pd.read_hdf(data_dir / "points.h5"))
df_lens = pd.read_hdf(data_dir / "conditions.h5")
df["lens"] = df_lens.loc[df["fish_id"], "lens"].values
data_path = data_dir / "data.h5"
df = pd.read_hdf(data_path, "pval")
bounds = dict(a=220, p=560, i=234)
df_regions = pd.read_hdf("../../data/ocular_types/regions.h5")

In [None]:
conditions = ["le", "re", "sham", "null"]
vmin = 0.3
vmax = 0.9
alpha = 0.025
cmap = "magma_r"

bg = BrainGrid(
    atlas,
    (2, len(conditions) // 2),
    w=130,
    bounds=bounds,
    space_within=5,
    space_across=20,
)
bg.add_structures("root", n=1, sigma=5, alpha=0.05, fc="k")

prev_ax = None

for j, condition in enumerate(conditions):
    bp = bg.brain_plots.ravel()[j]

    sel = df["lens"].eq(condition) & df["pval"].le(alpha)

    points_ = points[sel]
    values = df.loc[sel, "b"].values

    bp.scatter_values(
        points_,
        values,
        mode="csort",
        cmap=cmap,
        vmin=vmin,
        vmax=vmax,
        rasterized=True,
    )
    ax = AxArray(bp.axes_dict.values()).make_ax()

    fish_ids = df.loc[sel, "fish_id"].unique()

    counts = [
        (atlas.is_points_in_structures(points_, i)).sum() for i in df_regions.index
    ]
    density = np.array(counts) / len(fish_ids)

    add_ocular_condition_symbol(condition, ax=ax, scale=0.8)

    ax = AxArray(bp.empty).add_axes((35, 35), loc0="lt", loc1="lt", pad=(5, 0))

    if prev_ax is not None:
        ax.sharey(prev_ax)
    prev_ax = ax
    ax.barh(range(len(df_regions)), density, color=df_regions["color"])

    ax.set_yticks(np.arange(len(df_regions)), labels=df_regions["acronym"])
    ax.tick_params("y", length=0, pad=1)
    ax.set_xmargin(0)

    for side in ("right", "top", "left"):
        ax.spines[side].set_visible(False)

    ax.set_xlabel("Neurons per fish", size=6, labelpad=0)
    ax.invert_yaxis()

cb = bg.grid.add_colorbar(
    vmin, vmax, cmap, loc0="rb", loc1="lb", length=20, thick=4, pad=2
)
cb.set_ticks([vmin, vmax], labels=[vmin, vmax])
cb.ax.set_title("  R$^2$", pad=1, size=7)

if fmt:
    bg.grid.savefig((out_dir / "maps").with_suffix(fmt))