In [1]:
from ephys_queries import select_spike_times
from ephys_queries import db_setup_core
from spiketimes.df import (
    spike_count_correlation_df_test,
    spike_count_correlation_between_groups_test,
)
import dotenv
from pathlib import Path
import pandas as pd
import numpy as np
from pyarrow.feather import write_feather, read_feather

In [2]:
# load neuron_type data
# data_dir = Path(__file__).absolute().parent.parent / "data"
data_dir = Path(".").absolute().parent / "data"
tmp_dir = data_dir / "tmp"
tmp_dir.mkdir(exist_ok=True)
df_labels = pd.read_csv(data_dir / "chronic_baseline.csv")
df_labels = df_labels[["neuron_id", "type", "session_name"]]


In [3]:
# select spiketime data from the database
dotenv.load_dotenv()
engine, metadata = db_setup_core()
group_names = [
    "citalopram_continuation",
    "chronic_saline",
    "citalopram_discontinuation",
    "chronic_citalopram",
    "chronic_saline_",
]
block_name = "pre"
df_spikes = select_spike_times(engine, metadata, block_name=block_name)
df_spikes["spiketimes"] = df_spikes["spike_time_samples"].divide(30000)

In [4]:
# merge datasets
df = pd.merge(df_spikes, df_labels, on="neuron_id")

In [5]:
# apply exclusion criteria

df = df[df["type"] != "no_baseline"]

In [6]:
recording_sessions = df_labels["session_name"].unique()
slow_reg_fnames = []
group_wise_frames = []

In [7]:
for session in recording_sessions:
    df_sub = df[df["session_name"] == session].copy()
    group_wise = spike_count_correlation_between_groups_test(
        df_sub,
        fs=1,
        n_boot=4000,
        spiketimes_col="spiketimes",
        neuron_col="neuron_id",
        group_col="type",
        verbose=True
    )

    slow_reg = spike_count_correlation_df_test(df_sub[df_sub["type"] == "slow_regular"].copy(),
        fs=1,
        n_boot=4000,
        spiketimes_col="spiketimes",
        neuron_col="neuron_id",
        verbose=True                           
        )
    
    slow_reg["session_name"] = session
    group_wise["session_name"] = session
    

    fname_sr = str(tmp_dir / f"{session}_slow_reg.feather")
    fname_gw = str(tmp_dir / f"{session}_group_wise.feather")
    
    write_feather(slow_reg, fname_sr)
    write_feather(group_wise, fname_gw)
    
    slow_reg_fnames.append(fname_sr)
    group_wise_frames.append(fname_gw)

sr_frames = [read_feather(fname) for fname in slow_reg_fnames]
gw_fnames = [read_feather(fname) for fname in group_wise_frames]

df_sr = pd.concat(sr_frames, axis=0)
df_gw = pd.concat(gw_fnames, axis=0)

df_sr.to_csv(data_dir / "spikecount_corr_1s_sr.csv", index=False)
df_gw.to_csv(data_dir / "spikecount_corr_1s_gw).csv", index=False)