In [6]:
from ssri_interactions.io import load_derived_generic
from ssri_interactions.transforms import SpikesHandler
from ssri_interactions.surrogates import shuffle_spikes
from ssri_interactions.transforms.brain_state import StateHandler, RawEEGHandler
from ssri_interactions.transforms.brain_state_spikes import (
    align_spikes_to_states_long, align_spikes_to_phase_long, align_bins_to_states_long,
    )
from ssri_interactions.transforms.nbox_transforms import segment_spikes
from ssri_interactions.spiketrains.spiketrain_stats import cv_isi_burst
from ssri_interactions.config import ExperimentInfo, Config
from ssri_interactions.responders.brain_state import SpikeRateResonders, PhaseLockResponders
from ssri_interactions.spiketrains.neurontype_props import ChiSquarePostHoc
from ssri_interactions.plots.circular import circular_hist

from scipy.stats import zscore
import matplotlib.pyplot as plt
import seaborn as sns
import pingouin as pg
import numpy as np
import pandas as pd
import warnings

sns.set_theme(context="poster", style="ticks")

from IPython.display import display

# Spike Rate Change During EEG States


### Load Data

In [7]:
states_path = Config.derived_data_dir / "lfp_states.csv"
session_names = pd.read_csv(states_path).query("quality in ('good', 'med')").session_name.unique().tolist()

neuron_types = load_derived_generic("neuron_types.csv")
states_handler = StateHandler(
    states_path=states_path,
    quality_to_include=("good", "med"),
    t_start=0,
    t_stop=1800,
    session_names=session_names,
)
spikes_handler = SpikesHandler(
    block="pre",
    t_start=0,
    bin_width=1,
    t_stop=1800,
    session_names=session_names,
)

df_aligned = align_bins_to_states_long(
    spikes_handler=spikes_handler,
    states_handler=states_handler,
    neuron_types=neuron_types
)
df_aligned["zcounts"] = (
    df_aligned
    .groupby("neuron_id")["counts"]
    .transform(zscore)
)

### Calculate Responders

- Mixed ANOVA for interactions within neurons (brain states) and among neurons (neuron types)
- Post hoc responder status for each neuron using Mann-Whitney U test 

In [8]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    mod = SpikeRateResonders(df_value_col="zcounts", round_output=2)
    anova, contrasts = mod.get_anova(df_aligned, fit_neuron_types=True)

    display(anova)
    # display(contrasts)

    responders = mod.get_responders(df_aligned, abs_diff_thresh=0.1)
    display(responders.sample(3))


    responders_raw = (
        SpikeRateResonders(df_value_col="counts", round_output=2)
        .get_responders(df_aligned, abs_diff_thresh=0.1)
        .rename(columns={"Diff": "Diff_raw"})
    )

    (
        responders
        .merge(responders_raw[["neuron_id", "Diff_raw"]], how="left", on="neuron_id")
        .to_csv(
            Config.derived_data_dir / "brain_states_spikerate_responders.csv",
            index=False,
        )
    )

df_aligned.to_parquet(
    Config.derived_data_dir / "brain_states_counts.parquet",
)

Unnamed: 0,Source,SS,DF1,DF2,MS,F,p-unc,np2,eps
0,neuron_type,0.09,2,653,0.04,4.0,0.02,0.01,
1,state,9.24,1,653,9.24,85.85,0.0,0.12,1.0
2,Interaction,0.96,2,653,0.48,4.46,0.01,0.01,


Unnamed: 0,neuron_id,n_sw,n_act,Mean_sw,Mean_act,Diff,U,p,sig
188,1304,1233.0,566.0,-0.22,0.49,0.71,217188.0,0.0,True
72,1166,1019.0,780.0,-0.26,0.34,0.61,250020.0,0.0,True
590,2252,1187.0,612.0,0.07,-0.13,-0.19,411316.5,0.0,True


# Phase Locking Analysis

### Load Data and Align to EEG Oscillation Phase
- Raw EEG signal downsampled to 250 Hz
- In activated brain states, it is filtered between 4 - 8 Hz
- In slow wave states, it is filtered between 0.5 - 4 Hz
- Spike times are aligned to EEG phase separately for each brain state
- In each state, the distrobution of phases is tested for uniformity using Rayleigh tests
- The prefered phase of each neuron and whether it is significantly different phase locked is saved in a file
- This file is loaded into R and analysed using an GLM on angular embeddings (see below)

In [9]:
from ssri_interactions.io import load_lfp_raw

spikes_handler = SpikesHandler(
    block="pre",
    t_start=0,
    bin_width=1,
    t_stop=1800,
    session_names=session_names,
)

eeg_handler = RawEEGHandler(
    block="pre",
    t_start=0,
    t_stop=1800,
    session_names=session_names,
    loader=load_lfp_raw
)
df_aligned_phase = align_spikes_to_phase_long(
    spikes_handler=spikes_handler,
    states_handler=states_handler,
    raw_eeg_handler=eeg_handler,
    neuron_types=None,
).dropna()


df_sw = df_aligned_phase.query("state == 'sw'")
df_act = df_aligned_phase.query("state == 'act'")

mod = PhaseLockResponders(round_output=2, fs=(250 * 6) / (2 * np.pi))
df_res_act = mod.prefered_angles(df_act, phase_col="theta_phase")
df_res_sw = mod.prefered_angles(df_sw, phase_col="delta_phase",)
df_prefered_angles = pd.concat([(
        df_res_sw
        .assign(oscillation="delta")
        [["neuron_id", "oscillation", "mean_angle", "var", "p"]]
        ),
        (
            df_res_act
            .assign(oscillation="theta")
            [["neuron_id", "oscillation", "mean_angle", "var", "p"]]
        )
]
)
df_prefered_angles = df_prefered_angles.merge(neuron_types)
display(df_prefered_angles.sample(3))

df_prefered_angles.to_csv(Config.derived_data_dir / "brain_states_phase_responders.csv", index=False)
df_aligned_phase.to_parquet(
    Config.derived_data_dir / "brain_states_phase_aligned.parquet",
)

Unnamed: 0,neuron_id,oscillation,mean_angle,var,p,session_name,group_name,experiment_name,group,neuron_type,width_basepost,mean_firing_rate,cv_isi_burst
181,1190,theta,1.83,0.91,0.01,hamilton_17,citalopram_discontinuation,HAMILTON,DIS,SIR,50.227541,1.080933,0.86661
341,1286,theta,-0.81,0.35,0.13,hamilton_23,citalopram_continuation,HAMILTON,CIT,SR,,,
892,2006,theta,-1.55,0.99,0.94,hamilton_25,citalopram_discontinuation,HAMILTON,DIS,SIR,,,1.003977


In [10]:
df_prefered_angles = df_prefered_angles.assign(sig=lambda x: x.p < 0.05)
mod = ChiSquarePostHoc(value_col="sig", round=2)

display(mod(df_prefered_angles.query("oscillation == 'delta' and group == 'CIT'")))
display(mod(df_prefered_angles.query("oscillation == 'delta' and group == 'SAL'")))

display(mod(df_prefered_angles.query("oscillation == 'theta' and group == 'CIT'")))
display(mod(df_prefered_angles.query("oscillation == 'theta' and group == 'SAL'")))

anova                          Chi2(2)=7.5 (p=0.02*)
SIR - SR     66.67%; 60.53% | Chi(1.0)=0.76 (p=0.38)
SIR - FF     66.67%; 94.12% | Chi(1.0)=4.17 (p=0.06)
SR - FF     60.53%; 94.12% | Chi(1.0)=5.93 (p=0.04*)
dtype: object

anova                          Chi2(2)=5.0 (p=0.08)
SIR - FF    70.21%; 84.21% | Chi(1.0)=0.75 (p=0.39)
SIR - SR     70.21%; 58.67% | Chi(1.0)=1.2 (p=0.39)
FF - SR     84.21%; 58.67% | Chi(1.0)=3.25 (p=0.21)
dtype: object

anova                          Chi2(2)=1.1 (p=0.58)
SIR - SR      22.22%; 22.88% | Chi(1.0)=0.0 (p=1.0)
SIR - FF    22.22%; 11.76% | Chi(1.0)=0.46 (p=0.74)
SR - FF     22.88%; 11.76% | Chi(1.0)=0.53 (p=0.74)
dtype: object

anova                         Chi2(2)=8.0 (p=0.02*)
SIR - FF    44.68%; 73.68% | Chi(1.0)=3.48 (p=0.09)
SIR - SR    44.68%; 68.06% | Chi(1.0)=5.49 (p=0.06)
FF - SR     73.68%; 68.06% | Chi(1.0)=0.04 (p=0.85)
dtype: object