In [1]:
from pathlib import Path
import colorcet as cc
import numpy as np
import pandas as pd
from matplotlib.colors import to_hex
from bg_space_extra import AnatomicalPoints
from mplex.axes_collection import AxArray
from bino_utils import Atlas
from bino_utils.plotting import BrainGrid, add_ocular_condition_symbol

In [2]:
def get_trans_cmap(c, n=256):
    from matplotlib.colors import ListedColormap, to_rgb

    return ListedColormap(
        np.column_stack((np.tile(to_rgb(c), (n, 1)), np.linspace(0, 1, n)))
    )

In [3]:
atlas = Atlas()
fmt = ".pdf"
out_dir = Path("outputs")
out_dir.mkdir(exist_ok=True)
data_dir = Path("../../data/ocular_types")

In [4]:
regions_path = data_dir / "regions.h5"

if regions_path.exists():
    df_regions = pd.read_hdf(regions_path)
else:
    regions = np.array(
        [
            "tectum",
            "pretectum (alar prosomere 1)",
            "thalamus proper",
            "nucleus isthmi",
            "tegmentum (midbrain tegmentum)",
        ]
    )

    projs = np.array([atlas.get_masks(i, views="left", method="mean") for i in regions])
    max_projs = np.array(
        [atlas.get_masks(i, views="left", method="mean") for i in regions]
    )

    df_regions = pd.DataFrame(
        dict(
            z=projs.sum(1) / projs.sum((1, 2))[..., None] @ np.arange(projs.shape[2]),
            x=projs.sum(2) / projs.sum((1, 2))[..., None] @ np.arange(projs.shape[1]),
            zm=max_projs.sum(1)
            / max_projs.sum((1, 2))[..., None]
            @ np.arange(max_projs.shape[2]),
            xm=max_projs.sum(2)
            / max_projs.sum((1, 2))[..., None]
            @ np.arange(max_projs.shape[1]),
            acronym=["OT", "Pt", "Thal", "NI", "Teg"],
            color=[to_hex(i) for i in np.array(cc.glasbey_hv)[[0, 1, 2, 3, 6]]],
            n_voxels=[np.array(atlas.get_structures_stack(i)).sum() for i in regions],
        ),
        index=regions,
    )

    df_regions.to_hdf(regions_path, "data")

In [5]:
df_pval = pd.read_hdf(data_dir / "pval.h5")
df_r2 = pd.read_hdf(data_dir / "r2.h5")
points = AnatomicalPoints("sal", pd.read_hdf(data_dir / "points.h5"))

In [6]:
df_pval.loc[(points["a"] < 560) & (np.abs(points["l"] - atlas.midline) < 10), "I"] = (
    np.nan
)

In [7]:
bounds = dict(a=220, p=560, i=234)
alpha = 0.025

In [8]:
n_fish = len(df_r2["fish_id"].unique())

regions = df_regions.index.values
df_counts = pd.DataFrame(
    [
        [
            (atlas.is_points_in_structures(points, i) & df_pval[c].lt(alpha)).sum()
            / n_fish
            for i in regions
        ]
        for c in "lrb0BIC"
    ],
    columns=regions,
    index=list("lrb0BIC"),
)

df_density = df_counts / (df_regions["n_voxels"] / 1000000)

In [9]:
cmap = "magma_r"
vmin = 0.3
vmax = 0.9

In [None]:
for conditions in ("lrb", "CB", "I"):
    bg = BrainGrid(
        atlas,
        (1, len(conditions)),
        w=130,
        bounds=bounds,
        space_within=5,
        space_across=20,
    )
    bg.add_structures("root", n=1, sigma=5, alpha=0.05, fc="k")

    for j, condition in enumerate(conditions):
        bp = bg.brain_plots.ravel()[j]
        sel = df_pval[condition].lt(alpha)
        bp.scatter_values(
            points[sel],
            df_r2.loc[sel, condition].values,
            mode="csort",
            cmap=cmap,
            vmin=vmin,
            vmax=vmax,
            rasterized=True,
        )
        ax = AxArray(bp.axes_dict.values()).make_ax()
        add_ocular_condition_symbol(condition, ax=ax, scale=0.8)

        ax = AxArray(bp.empty).add_axes((35, 35), loc0="lt", loc1="lt", pad=(5, 0))
        ax.barh(
            range(len(df_regions)), df_counts.loc[condition], color=df_regions["color"]
        )

        ax.set_yticks(np.arange(len(df_regions)), labels=df_regions["acronym"])
        ax.tick_params("y", length=0, pad=1)
        ax.set_xmargin(0)

        for side in ("right", "top", "left"):
            ax.spines[side].set_visible(False)

        ax.set_xlabel("Neurons per fish", size=6, labelpad=0)
        ax.invert_yaxis()

    cb = bg.grid.add_colorbar(
        vmin, vmax, cmap, loc0="rb", loc1="lb", length=20, thick=4, pad=2
    )
    cb.set_ticks([vmin, vmax], labels=[vmin, vmax])
    cb.ax.set_title("  R$^2$", pad=1, size=7)

    if fmt:
        bg.grid.savefig((out_dir / f"{conditions.lower()}").with_suffix(fmt))

In [None]:
is_responsive = df_pval.lt(alpha)
is_responsive["fish_id"] = df_r2["fish_id"]
is_responsive["prn"] = is_responsive[["l", "r", "b"]].max(1)
is_responsive = is_responsive[["fish_id", "I", "C", "B", "prn"]]
is_responsive.columns = ["fish_id", "ipsi", "contra", "bino", "prn"]
is_responsive.groupby("fish_id").sum()

In [None]:
is_responsive.groupby("fish_id").sum().std(0)