In [2]:
from typing import List
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from ssri_interactions.transforms import SpikesHandler
from ssri_interactions.spiketrains.spiketrain_stats import SpikeTrainDescriptor, SpikeTrainStats
from ssri_interactions.config import Config
from IPython.display import display

from ssri_interactions.spiketrains.spiketrain_stats import cv2
from binit.bin import which_bin
import warnings
from scipy.stats import variation


%load_ext autoreload
%autoreload 2

In [3]:
sns.set_theme(context="paper", style="ticks")

In [4]:
sh = SpikesHandler(block="pre", bin_width=1, t_start=0, t_stop=1800)
df_spikes = sh.spikes

In [5]:
descriptor = SpikeTrainDescriptor()

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    df_stats = descriptor.describe(
        df_spikes, 
        t_start=0, 
        t_stop=1800, 
        burst_thresh=0.02, 
        max_fr_binsize=30, 
        cv_fr_binsize=30, 
        mfr_bin_size=30, 
        mfr_exclude_below=0.5,
        )

display(df_stats.sample(5))
df_stats.to_csv(Config.derived_data_dir / "spiketrain_stats.csv", index=False)

Unnamed: 0,neuron_id,cv2_isi,is_buster,cv2_isi_burst,median_burst_interval,mean_firing_rate,mean_firing_rate_ifr,fraction_bursts,max_firing_rate,cv_firing_rate
558,2191,0.517842,False,0.514676,0.00545,4.715,4.819395,0.001178,193,0.084317
2,1071,0.743635,False,0.637153,0.010367,9.367778,9.56613,0.067904,423,0.078872
75,1170,1.178299,False,1.118318,0.015,0.136111,0.722582,0.036735,36,0.692167
1,1070,0.999238,True,0.822681,0.009733,6.366111,6.265885,0.149926,492,0.119258
511,2106,0.876126,False,0.876126,,0.2,0.688578,0.0,80,0.803801


In [14]:
bins = np.arange(0, 1801, 30)
df_spikes["segment"] = which_bin(df_spikes["spiketimes"].values, bins, time_after=30)
stats = SpikeTrainStats(thresh_burst=0.02)

with warnings.catch_warnings():  # ignore warnings about empty bins
    warnings.simplefilter("ignore")
    stats_by_segment_wide = df_spikes.groupby(["neuron_id", "segment"])["spiketimes"].apply(stats)
    stats_by_segment_long = (
        stats_by_segment_wide
        .reset_index()
        .rename(columns={"level_2": "metric"})
        .assign(spiketimes=lambda x: x["spiketimes"].astype(float))
    )


In [15]:
# median stats by segment
median_over_segments_long =(
    stats_by_segment_long
    .groupby(["neuron_id", "metric"])["spiketimes"]
    .apply(lambda x: np.nanmedian(x))
    .reset_index()
)

# cv of stats by segment (volatility)
cv_over_segments_long = (
    stats_by_segment_long
    .fillna(method="backfill")
    .fillna(method="bfill")
    .groupby(["neuron_id", "metric"])
    ["spiketimes"]
    .apply(lambda x: variation(x))
    .reset_index()
)

  r, k = function_base._ureduce(a, func=_nanmedian, axis=axis, out=out,


In [16]:
# convert to wide
median_over_segments_wide = median_over_segments_long.pivot(index="neuron_id", columns="metric", values="spiketimes")
cv_over_segments_wide = cv_over_segments_long.pivot(index="neuron_id", columns="metric", values="spiketimes")

In [17]:
# save
median_over_segments_wide.reset_index().to_csv(Config.derived_data_dir / "spiketrain_stats_segments.csv", index=False)
cv_over_segments_wide.reset_index().to_csv(Config.derived_data_dir / "spiketrain_stats_volitility.csv", index=False)