In [None]:
import bycycle
from bycycle.cyclepoints import extrema_interpolated_phase

from neurodsp.filt import filter_signal
from neurodsp.plts import plot_time_series

from bycycle.features import compute_features
from bycycle.cyclepoints import find_extrema, find_zerox
from bycycle.cyclepoints.zerox import find_flank_zerox
from bycycle.plts import plot_burst_detect_summary, plot_cyclepoints_array

import numpy as np
import scipy.signal as ss
import math
from sklearn.utils import resample

import matplotlib
import matplotlib.pyplot as plt


spikepath = '/home/e/etayhay/frankm/Mazza2023/data/simulations_raw/'

SPIKES_1 = [np.load(spikepath+'Healthy/SPIKES_Seed'+str(seed)+'.npy',allow_pickle=True) for seed in range(1000,1060)]
SPIKES_2 = [np.load(spikepath+'SST/Circuit_output_4/SPIKES_Seed'+str(seed)+'.npy',allow_pickle=True) for seed in range(2000,2060)]


EEG_h = np.load('/home/e/etayhay/frankm/Mazza2023/data/simulations_processed/EEG_h.npy', allow_pickle=True)
EEG_m = np.load('/home/e/etayhay/frankm/Mazza2023/data/simulations_processed/EEG_m.npy', allow_pickle=True)

def cohen_d(x,y):
    nx = len(x)
    ny = len(y)
    dof = nx + ny - 2
    return (mean(x) - mean(y)) / sqrt(((nx-1)*std(x, ddof=1) ** 2 + (ny-1)*std(y, ddof=1) ** 2) / dof)

def myround(x, prec=4, base=.025):
    return round(base * round(float(x)/base),prec)

def moving_average(x, w):
    return np.convolve(x, np.ones(w), 'valid') / w


In [None]:
# 1) loop through all runs and get min number of spikes, only take min number of spikes for all runs
#    to make sure that depression doesn't get increased preference just due to increased spikecounts

nspikes = [[[[],[],[],[]] for _ in range(60)] for _ in range(2)]

SPIKES = [SPIKES_1,SPIKES_2]

for c, cond in enumerate([EEG_h.item(0),EEG_m.item(0)]):
    if c<2:
        for r,run in enumerate(cond):
            if r<60:
                for z,(pop,spikelist) in enumerate(zip(SPIKES[c][r].item(0)['gids'],SPIKES[c][r].item(0)['times'])):
                    popspikes = []
                    for spikes in spikelist:
                        spikes2 = [spike for spike in spikes if spike>2000]
                        if len(spikes2)/28 > .2: #check only post-transient
                            popspikes.extend(spikes2)

                    nspikes[c][r][z] = len(popspikes)
                    
mins = [0,0,0,0]

for p in range(4):
    mins[p] = np.array(nspikes)[:,:,p].ravel().min()
    
mins

In [None]:
# 2) get phase

spikephase = [[[[],[],[],[]] for _ in range(60)] for _ in range(2)]

plot = False

fs = 40000

for c, cond in enumerate([EEG_h.item(0),EEG_m.item(0)]):
    if c<2:
        for r,run in enumerate(cond):
            if r<60: #last item in EEG is mean
                print(c,r)
                runnum = r+1000 if c==0 else r+2000
                
                sig = cond[str(runnum)]['ts_raw'][80000:]
                
                f_theta = (4,15) #paper suggests 4x high point of band
                f_lowpass = 15*4
                n_seconds_filter = .1
                remove_edges = True


                # 1) Preprocess with lowpass filter
                sig_low = filter_signal(sig = sig, # signal to be filtered
                                        fs = fs,  # sampling rate (Hz)
                                        pass_type = 'lowpass', # filter type
                                        f_range = f_lowpass, # cutoff frequency
                                        n_seconds=n_seconds_filter, # length of filter in s
                                        #filter_type = 'fir', #optional
                                        #butterworth_order = 2, #optional if using iir
                                        remove_edges=remove_edges)
                
                if plot:
                    # Plot signal
                    times = np.arange(0, len(sig)/fs, 1/fs)
                    xlim = (0,1)
                    fig,ax = plt.subplots(figsize=(15,5))


                    ax.plot(times, sig, color = 'k', alpha=.5, lw=.5)
                    ax.plot(times, sig_low, color = 'k', alpha=1)

                    ax.set_xlim(xlim)
                    ax.set_ylabel('Voltage')
                    ax.set_xlabel('Time (s)')
                    
                # 2) Localize peaks and troughs
                n_seconds_theta = .75

                # Narrowband filter
                sig_narrow = filter_signal(sig = sig, # signal to be filtered
                                        fs = fs,  # sampling rate (Hz)
                                        pass_type = 'bandpass', # filter type
                                        f_range = f_theta, # cutoff frequency
                                        n_seconds=n_seconds_theta, # length of filter in s
                                        #filter_type = 'fir', #optional
                                        #butterworth_order = 2, #optional if using iir
                                        remove_edges=remove_edges)



                # Find rising and falling zerocrossings (narrowband)
                rise_xs = find_flank_zerox(sig = sig_narrow, flank = 'rise')
                decay_xs = find_flank_zerox(sig = sig_narrow, flank = 'decay')

                # Find peaks and troughs (this function also does the above)
                peaks, troughs = find_extrema(sig = sig_low,
                                              fs = fs,
                                              f_range = f_theta,
                                              boundary = 0, #number of samples from edge to ignore, default 0
                                              first_extrema = None, #forces first output to be peak or trough, default None
                                              filter_kwargs={'n_seconds':n_seconds_theta}, # filter settings
                                              pass_type = 'bandpass', #which type of filter
                                              pad = True) #pad with zeros to prevent missed cyclepoints at edges)

                if plot:
                    plot_cyclepoints_array(sig_low, fs, peaks=peaks, troughs=troughs, xlim=xlim)
                    
                # 3) Localize rise and decay midpoints

                rises, decays = find_zerox(sig_low, peaks, troughs)
                
                if plot:
                    plot_cyclepoints_array(sig_low, fs, xlim=xlim, peaks=peaks, troughs=troughs,
                                       rises=rises, decays=decays)
                
                # 4) Get features (no need here, just need phase)
                
                # 5) Get phase
                rad_phase = extrema_interpolated_phase(sig = sig,
                                                     peaks = peaks,
                                                     troughs = troughs,
                                                     rises = rises, 
                                                     decays = decays)

                phase = np.array([math.degrees(i) for i in rad_phase])
                    
            
                for z,(pop,spikelist) in enumerate(zip(SPIKES[c][r].item(0)['gids'],SPIKES[c][r].item(0)['times'])):
                    popspikes = []
                    for spikes in spikelist:
                        spikes2 = [spike for spike in spikes if spike>2000]
                        if len(spikes2)/28 > .2: #check only post-transient
                            popspikes.extend(spikes2)
                            
                    resampled_pop = resample(popspikes, replace=True,n_samples=mins[z],random_state=100)
                    phase_w_nans = np.array([phase[int(spike*40)-80000] for spike in resampled_pop])
                    
                    spikephase[c][r][z] = phase_w_nans[np.isnan(phase_w_nans)==False] # some nans at beginning and end, remove
                    

In [None]:
# np.save('/home/e/etayhay/frankm/Mazza2022_scratch/data/Figure5/newanal/spikephase_low_4_16.npy',spikephase)

In [None]:
spikehists = [[[],[],[],[]],[[],[],[],[]]]

for c in range(2):
    for r in range(60):
        for p in range(4):
            hist,bins = np.histogram(spikephase[c][r][p],bins=np.arange(-180,180,5))
            spikehists[c][p].append(hist)
            
spikehists = np.array(spikehists)

# np.save('/home/e/etayhay/frankm/Mazza2022_scratch/data/Figure5/newanal/bins_low_4_16.npy',bins)

# Stats

In [None]:
import math
from numpy import mean,std,sqrt
from astropy.stats import rayleightest,circmean,vonmisesmle
import scipy.stats as st

In [None]:
spikephase = np.load('/home/e/etayhay/frankm/Mazza2023/data/figures/Figure5/newanal/spikephase_low_4_16.npy',allow_pickle=True)

In [None]:
rps = np.zeros(shape = (2,60,4))

for c,cond in enumerate(spikephase):
    for r,run in enumerate(cond):
        for p,popspikes in enumerate(run):
            poprads = np.array([math.radians(spike)+math.pi for spike in popspikes if spike!=np.nan])
            rps[c][r][p] = rayleightest(poprads)
        
np.mean(rps, axis=1) # all nonuniform

## Mean angle

In [None]:
# can also use vonmisesmle output [0]
m_angle = np.zeros(shape = (2,60,4))

 
for c,cond in enumerate(spikephase):
    for r,run in enumerate(cond):
        for p,popspikes in enumerate(run):
            poprads = np.array([math.radians(spike) for spike in popspikes])
            m_angle[c][r][p] = circmean(poprads)
            
# use regular std because std doesn't cross 180 degrees between circuits

print('healthy',[round(math.degrees(i),0) for i in m_angle.mean(axis=1)[0]])
print('healthy',[round(math.degrees(i),0) for i in m_angle.std(axis=1)[0]])
print('mdd',[round(math.degrees(i),0) for i in m_angle.mean(axis=1)[1]])
print('healthy',[round(math.degrees(i),0) for i in m_angle.std(axis=1)[1]])

## Kurtosis

In [None]:
kappas = np.zeros(shape = (2,60,4))
for c,cond in enumerate(spikephase):
    for r,run in enumerate(cond):
        for p,popspikes in enumerate(run):
            poprads = np.array([math.radians(spike) for spike in popspikes])
            _,kappas[c][r][p] = vonmisesmle(poprads)

print('Healthy',kappas.mean(axis=1)[0])
print('MDD    ',kappas.mean(axis=1)[1])

# % decrease in kurtosis, effect size

for i in range(4):
    d = round(cohen_d(kappas[0,:,i],kappas[1,:,i]),1)
    _,p =st.ttest_ind(kappas[0,:,i],kappas[1,:,i])
    
    print(round((kappas.mean(axis=1)[1][i] - kappas.mean(axis=1)[0][i]) / kappas.mean(axis=1)[0][i]*100,0),d,p)

# Figure

In [None]:
low = spikephase
spikes = SPIKES_1[0].item(0) # for 1 A



def myround(x, prec=4, base=.025):
    return round(base * round(float(x)/base),prec)

def moving_average(x, w):
    return np.convolve(x, np.ones(w), 'valid') / w


for z,(pop,pop_spikelist) in enumerate(zip(spikes['gids'],spikes['times'])):
    if z==0:
        popspikes = np.concatenate(pop_spikelist).ravel()
        hist,bins = np.histogram(popspikes,bins=np.arange(2000,30000+10,10))
        
# 1 B
spikecount = moving_average(hist,3) # for 1 B

print(len(hist))

# add 2 0's to front and end to account for shift when doing a moving average, based on window
spikecount = [i for i in spikecount]
spikecount.insert(0,0)
spikecount.append(0)

spikecount = np.array(spikecount)

print(len(spikecount))


# 1 C, D

eeg = EEG_h.item(0)['1000']['ts_raw']

fs = 40000

#Filter settings

f_theta = (4,12) #paper suggests 4x high point of band
f_lowpass = 8*4
n_seconds_filter = .1
remove_edges = True

# Lowpass filter
eeg_low = filter_signal(sig = eeg, # signal to be filtered
                        fs = fs,  # sampling rate (Hz)
                        pass_type = 'lowpass', # filter type
                        f_range = f_lowpass, # cutoff frequency
                        n_seconds=n_seconds_filter, # length of filter in s
                        #filter_type = 'fir', #optional
                        #butterworth_order = 2, #optional if using iir
                        remove_edges=remove_edges)

# Narrowband filter
eeg_narrow = filter_signal(sig = eeg, # signal to be filtered
                        fs = fs,  # sampling rate (Hz)
                        pass_type = 'bandpass', # filter type
                        f_range = f_theta, # cutoff frequency
                        n_seconds=n_seconds_theta, # length of filter in s
                        #filter_type = 'fir', #optional
                        #butterworth_order = 2, #optional if using iir
                        remove_edges=remove_edges)



# Find rising and falling zerocrossings (narrowband)
rise_xs = find_flank_zerox(sig = eeg_narrow, flank = 'rise')
decay_xs = find_flank_zerox(sig = eeg_narrow, flank = 'decay')

# Find peaks and troughs (this function also does the above)
peaks, troughs = find_extrema(sig = eeg_low,
                              fs = fs,
                              f_range = f_theta,
                              boundary = 0, #number of samples from edge to ignore, default 0
                              first_extrema = None, #forces first output to be peak or trough, default None
                              filter_kwargs={'n_seconds':n_seconds_theta}, # filter settings
                              pass_type = 'bandpass', #which type of filter
                              pad = True) #pad with zeros to prevent missed cyclepoints at edges)

rises, decays = find_zerox(eeg_low, peaks, troughs)

df_features = compute_features(sig = eeg,
                               fs = fs,
                               f_range = f_theta,
                               center_extrema = 'peak',
                               burst_method = 'cycles', # burst based on period of consistent cycles
                               burst_kwargs = None, # for burst analysis, not using
                               threshold_kwargs = None, 
                               find_extrema_kwargs = None,
                               return_samples = True
                              )
rad_phase = extrema_interpolated_phase(sig = eeg,
                                     peaks = peaks,
                                     troughs = troughs,
                                     rises = rises, 
                                     decays = decays)

phase = np.array([math.degrees(i) for i in rad_phase])



In [None]:
#PHASE PLOTS
# get bins
spikehists = [[[],[],[],[]],[[],[],[],[]]]

for c in range(2):
    for r in range(60):
        for p in range(4):
            hist,bins = np.histogram(spikephase[c][r][p],bins=np.arange(-180,180,10))
            spikehists[c][p].append(hist)
            
spikehists = np.array(spikehists) # low
# spikehists == low

bins = np.append(bins,180)
radbins = [math.radians(i) for i in bins]

In [None]:
np.save('/home/e/etayhay/frankm/Mazza2023/data/figures/Figure5/spikecount.npy',spikecount)
np.save('/home/e/etayhay/frankm/Mazza2023/data/figures/Figure5/phase.npy',phase)
np.save('/home/e/etayhay/frankm/Mazza2023/data/figures/Figure5/eeg_low.npy',eeg_low)
np.save('/home/e/etayhay/frankm/Mazza2023/data/figures/Figure5/eeg.npy',eeg)
np.save('/home/e/etayhay/frankm/Mazza2023/data/figures/Figure5/spikehists.npy',spikehists)
np.save('/home/e/etayhay/frankm/Mazza2023/data/figures/Figure5/radbins.npy',radbins)