In [None]:
import matplotlib.pyplot as plt
import pingouin as pg
import seaborn as sns
from drn_interactions.config import Config
from drn_interactions.io import load_derived_generic
import pandas as pd
import numpy as np
from scipy.stats import wilcoxon, mannwhitneyu
from drn_interactions.stats import mannwhitneyu_plusplus
from drn_interactions.plots.pallets_cmaps import PAL_GREY_BLACK
sns.set_theme(style="ticks", context="paper")

In [None]:
def load_responders():
    bs_response = load_derived_generic("brain_states_spikerate_responders.csv")[
        ["neuron_id", "Diff", "sig"]
    ].assign(
        response_bs=lambda x: np.where(
            x["sig"] == False,
            "Not State-Responsive",
            np.where(x["Diff"] < 0, "Inactivated-Preferring", "Activated-Preferring"),
        )
    )[["neuron_id", "response_bs"]]
    phase_lock_response  = (
        load_derived_generic("brain_states_phase_responders.csv")
        .pivot(index="neuron_id", columns="oscillation", values="p")
        [["delta", "theta"]]
        .transform(lambda x: x < 0.5)
        .rename(columns={"delta": "phase_lock_delta", "theta": "phase_lock_theta"})
        .reset_index()
        [["neuron_id", "phase_lock_delta", "phase_lock_theta"]]
    )
    df_responders = pd.merge(bs_response, phase_lock_response, on="neuron_id", how="outer")
    return df_responders

In [None]:
# encoding

comb = (
    load_derived_generic("encoding/brain state - comb.csv")
    .rename(columns=dict(comb="State + Interactions"))
    .merge(
        load_derived_generic("encoding/brain state - comb shuffle.csv").rename(columns=dict(comb="State + Interactions\nShuffle")),
         on=["neuron_id", "session_name"])
)

state = (
    load_derived_generic("encoding/brain state - state.csv").rename(columns=dict(state="State"))
    .merge(
        load_derived_generic("encoding/brain state - state shuffle.csv").rename(columns=dict(state="State\nShuffle")),
         on=["neuron_id", "session_name"])
)

state_pop = (
    pd.concat(
    [
        state.melt(id_vars=["neuron_id", "session_name"]),
        comb.melt(id_vars=["neuron_id", "session_name"]),
        ]
    )
    .reset_index()
    .merge(
        load_derived_generic("neuron_types.csv")[["neuron_id", "neuron_type"]]
    )
    .merge(load_responders(), on="neuron_id", how="left")
)

limit = load_derived_generic("encoding/brain state - limit.csv")
dropout = load_derived_generic("encoding/brain state - dropout.csv")

fig_dir = Config.fig_dir

In [None]:
f = plt.figure(figsize=(3, 1.5))

ax_box = f.subplots()

ax_box.set_ylim(-0.4, 1)
ax_box.set_yticks([-0.25, 0, 0.25, 0.5, 0.75, 1])
pg.plot_paired(
    data=state_pop, 
    dv="value", 
    subject="neuron_id", 
    within="variable",
    boxplot_in_front=True,
    pointplot_kwargs=dict(alpha=0.05),
    boxplot_kwargs=dict(width=0.3),
    ax=ax_box,
    order=["State\nShuffle", "State", "State + Interactions", "State + Interactions\nShuffle"]
    )

ax_box.set_xticklabels(ax_box.get_xticklabels(), rotation=45)
ax_box.set_ylabel("Encoder\nPerformance\n($R^{2}$)")
ax_box.set_xlabel("")
ax_box.axhline(0, color="grey", linewidth=0.5, linestyle="--")
ax_box.axhline(1, color="grey", linewidth=0.5, linestyle="--")


f.savefig(fig_dir / "bs - encoding_performance_box.png", bbox_inches="tight", dpi=300)

In [None]:
def plot_paired_neuron_type(df, ax):
    ax.set_ylim(-0.35, 1)
    ax.set_yticks([-0.25, 0, 0.25, 0.5, 0.75, 1])
    pg.plot_paired(
        data=df.query('variable in ("State", "State + Interactions")'), 
        dv="value", 
        subject="neuron_id", 
        within="variable",
        boxplot_in_front=True,
        pointplot_kwargs=dict(alpha=0.2),
        boxplot_kwargs=dict(width=0.3),
        ax=ax,
        order=["State", "State + Interactions"]
    )
    # ax.set_xticklabels()
    ax.set_xticklabels(["State", "State +\nInteractions"], rotation=45)
    ax.set_ylabel("Encoder\nPerformance\n($R^{2}$)")
    ax.set_xlabel("")
    ax.axhline(0, color="grey", linewidth=0.5, linestyle="--")
    ax.axhline(1, color="grey", linewidth=0.5, linestyle="--")



f = plt.figure(figsize=(4.4, 2.2))
axes = f.subplots(1, 3, sharey=True)
for neuron_type, ax in zip(["SR", "SIR", "FF"], axes):
    plot_paired_neuron_type(
        state_pop.query("neuron_type == @neuron_type"), 
        ax=ax
    )
    ax.set_title(neuron_type, pad=22)
    ax.set_xlabel("")

f.subplots_adjust(wspace=1.4)
f.tight_layout()
f.savefig(fig_dir / "bs - encoding_performance_neuron_type.png", bbox_inches="tight", dpi=300)

In [None]:
dfp = (
    pd.merge(
        (
            limit
            .rename(columns=dict(score="State + Interactions"))
            [["neuron_id", "n_best", "State + Interactions"]]
        ),
        (
            state
            [["neuron_id",  "State",]]
        )
    )
    .melt(
        id_vars=["neuron_id", "n_best"], 
        value_vars=["State", "State + Interactions"], 
        var_name="Metric", 
        value_name="Score",
        )

    .loc[lambda x: x.n_best <= 15]
    .pipe(lambda x: x.append(x.query("n_best == 1 and Metric == 'State'").assign(n_best=0)))
    .merge(
        load_derived_generic("neuron_types.csv")[["neuron_id", "neuron_type"]]
    )
    .merge(load_responders(), on="neuron_id", how="left")  
)

In [None]:
f = plt.figure(figsize=(5, 1.8))
ax_line = f.subplots()
ax_line.set_ylim(0, 0.7)
ax_line.set_xlim([0, 15])
sns.pointplot(
    data=dfp, 
    x="n_best", 
    y="Score", 
    ax=ax_line, 
    hue="Metric",
    palette=PAL_GREY_BLACK[::-1],
)

ax_line.set_xticks([0, 5, 10, 15])
ax_line.set_xticklabels([0, 5, 10, 15])
ax_line.set_ylabel("Encoder Performance\n($R^{2}$)")
ax_line.set_xlabel("Number of Neighboring Neurons")

sns.despine(ax=ax_line, offset=5)

sns.move_legend(
    ax_line, "lower center",
    bbox_to_anchor=(.5, 1), ncol=3, title="Encoding Model", frameon=False,
)
f.savefig(fig_dir / "bs - encoding_by_model_line.png", bbox_inches="tight", dpi=300)

In [None]:
f = plt.figure(figsize=(5, 1.8))
ax_line = f.subplots()
ax_line.set_ylim(-0.3, 0.8)

sns.pointplot(
    data=dfp.query("Metric == 'State + Interactions'"), 
    x="n_best", 
    y="Score", 
    ax=ax_line, 
    hue="response_bs",
    palette="Set2",
    dodge=0.35,
    hue_order=['Not State-Responsive', 'Inactivated-Preferring', 'Activated-Preferring'],
)

ax_line.set_xticks([0, 4, 9, 14])
ax_line.set_ylabel("Encoder\nPerformance\n($R^{2}$)")
ax_line.set_xlabel("Number of Neighboring Neurons")
ax_line.axhline(0, color="grey", linewidth=0.5, linestyle="--")

sns.despine(ax=ax_line, offset=5)
sns.move_legend(
    ax_line, "lower center",
    bbox_to_anchor=(.5, 1), ncol=3, title="Brain State Response Profile", frameon=False,
)
f.savefig(fig_dir / "bs - encoding_by_response_type_line.png", bbox_inches="tight", dpi=300)