In [None]:
# get responsivity of all neurons with adjustible window
# assess for neuron-type level effects
# plot examples
# indivdual trial variability

In [None]:
# from drn_interactions.spikes import SpikesHandler
# from drn_interactions.transforms import align_to_data_by
# from drn_interactions.stats import mannwhitneyu_plusplus

from drn_interactions.load import (
    get_fig_dir, load_events, load_spikes, load_neurons_derived
)
from drn_interactions.fs_fast import ShortTsAnova, ShortTsAvg,ShockPlotter
from drn_interactions.shock_transforms import ShockUtils

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from spiketimes.df.surrogates import shuffled_isi_spiketrains_by
import pingouin as pg
from IPython.display import display


%load_ext autoreload
%autoreload 2

In [None]:
sns.set_theme(style="ticks", context="poster")

outdir = get_fig_dir() / "base_shock"
outdir.mkdir(exist_ok=True, parents=True)
neurons = load_neurons_derived()
neurons_sub = neurons[["neuron_id", "session_name"]]
df_events = load_events("base_shock")
sessions = neurons_sub.merge(df_events[["session_name"]]).session_name.unique()

df_spikes = load_spikes("base_shock").merge(neurons_sub)

clusters = neurons[["neuron_id", "wf_3"]]


In [None]:
# Anova level

transformer = ShockUtils()
df_aligned = transformer.aligned_binned_from_spikes(
    df_spikes, 
    df_events, 
    session=None,
    bin_width=0.01,
    )
anova, contrasts =  ShortTsAnova(window=(0.05, 0.2)).get_responders(
    df_aligned, z=True, clusters=neurons[["neuron_id", "wf_3"]],
    )
display(anova)
display(contrasts)

# Neuron Level
unit_mod = ShortTsAvg(window=(0.05, 0.2))
responders = unit_mod.get_responders(df_aligned, z=True)
unit_mod.plot_responders(
    responders, 
    clusters=neurons[["neuron_id", "wf_3"]], 
    bins=np.arange(-2.6, 2.6, 0.2),
    )

sns.despine()
print((responders["p"] < 0.05).mean())

figs = ShockPlotter().psth_heatmap_by_cluster(df_aligned, responders, clusters)

In [None]:
df_spikes_aligned = ShockUtils().align_spikes(df_spikes, df_events)
responders1 = responders.reset_index().merge(clusters)
display(responders1.sort_values(["Diff"]).head())
display(responders1.sort_values(["Diff"], ascending=False).head())
display(responders1.sort_values(["Diff"]).loc[lambda x: x.wf_3 == "ff"].head())

ax = ShockPlotter().unit_raster_across_trials(df_spikes_aligned, neuron=1843)
ax.set_title("SR")
ax = ShockPlotter().unit_raster_across_trials(df_spikes_aligned, neuron=1974)
ax.set_title("SIR")
ax = ShockPlotter().unit_raster_across_trials(df_spikes_aligned, neuron=1897)
ax.set_title("FF")

In [None]:
evoked_counts = (
    df_aligned
    .loc[lambda x: x["bin"].between(0.05, 0.2)]
    .drop("bin", axis=1)
    .melt(id_vars="event", var_name="neuron_id")
    .groupby(["neuron_id", "event"], as_index=False)
    .sum()
)  

In [None]:
from scipy.stats import variation
from drn_interactions.stats import cv2
dfp = (
    evoked_counts
    .groupby("neuron_id")["value"].apply(variation)
    .to_frame("cov").reset_index()
    .merge(responders1)
    .query("p < 0.05")
)

# g = sns.FacetGrid(dfp, row="wf_3", sharey=False, aspect=2).map_dataframe(
#                 sns.histplot,
#                 x="cov",
#                 color="black",
#                 multiple="stack",
#                 alpha=1,
#                 bins="auto",
#             )
ax = evoked_counts.query("neuron_id == 1671")["value"].hist(color="black")
ax.set_ylabel("Trial Counts")
ax.set_xlabel("Evoked Spikes")
sns.despine()
plt.show()
ax = ShockPlotter().unit_raster_across_trials(df_spikes_aligned, neuron=1671)


In [None]:

(
    evoked_counts
    .merge(neurons_sub)
    .groupby(["event", "session_name"])["value"]
    .sum()
    .to_frame("count")
    .reset_index()
    .query("count != 0")
    .loc[lambda x: x.session_name == sessions[4]]
    ["count"]
    .hist()
)

In [None]:
df_neuron.loc[lambda x: x["bin"].between(0.05, 0.2)].groupby("event").sum()[neuron].hist()

In [None]:
df_neuron = df_spikes_aligned.query("neuron_id == 1031")
trains = [g["aligned"].values for name, g in df_neuron.groupby("event")]

_, ax = plt.subplots(figsize=(5, 4), nrows=1, sharex=True)


ax.eventplot(trains, color="black", )
ax.axvline(0, color="red")
ax.set_yticks([])
ax.axis("off")
ax.set_xticks([-0.5, 0, 1.5])

In [None]:
responders.sort_values(["Diff"]).head()

In [None]:
# surrogate data

df_spikes_surr = shuffled_isi_spiketrains_by(
    df_spikes, spiketimes_col="spiketimes",
    by_col="neuron_id"
).merge(neurons_sub)

transformer = ShockUtils()
df_aligned = transformer.aligned_binned_from_spikes(
    df_spikes_surr, 
    df_events, 
    session=None,
    bin_width=0.01,
    )
anova, contrasts =  ShortTsAnova().get_responders(
    df_aligned, z=True, clusters=neurons[["neuron_id", "wf_3"]],
    )
display(anova)
display(contrasts)

# Neuron Level
unit_mod = ShortTsAvg()
responders = unit_mod.get_responders(df_aligned, z=True)
unit_mod.plot_responders(responders, clusters=neurons[["neuron_id", "wf_3"]],  bins=np.arange(-2.1, 2.1, 0.2))

sns.despine()
print((responders["p"] < 0.05).mean())

In [None]:
ax = ShockPlotter().psth_heatmap_all(df_aligned, responders, clusters=clusters)
ax.set_title("All")

figs = ShockPlotter().psth_heatmap_by_cluster(df_aligned, responders, clusters)