# Boosting of spike signal by synchronous spikes.

In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import LFPy
from brainsignals.plotting_convention import mark_subplots, simplify_axes
from brainsignals.neural_simulations import return_hay_cell
import brainsignals.neural_simulations as ns
import elephant

np.random.seed(12345)
ns.load_mechs_from_folder(ns.cell_models_folder)

In [None]:
def insert_current_stimuli(cell):
    stim_params = {'amp': -0.4,
                   'idx': 0,
                   'pptype': "ISyn",
                   'dur': 1e9,
                   'delay': 0}

    synapse = LFPy.StimIntElectrode(cell, **stim_params)
    return synapse, cell

tstop = 150
dt = 2**-6
cell = return_hay_cell(tstop=tstop, dt=dt, make_passive=False)
ns.point_axon_down(cell)  

syn, cell = insert_current_stimuli(cell)
cell.simulate(rec_imem=True, rec_vmem=True)

elec_params = dict(
            sigma = 0.3,      # extracellular conductivity
            x = np.array([50]),
            y = np.zeros(1),
            z = np.zeros(1),
            method = 'root_as_point',
            )

# Time window to extract spike from:
t0 = 114.25
t1 = 124.25
t0_idx = np.argmin(np.abs(cell.tvec - t0))
t1_idx = np.argmin(np.abs(cell.tvec - t1))

vmem = cell.vmem[:, t0_idx:t1_idx]
imem = cell.imem[:, t0_idx:t1_idx]
tvec = cell.tvec[t0_idx:t1_idx] - cell.tvec[t0_idx]

elec = LFPy.RecExtElectrode(cell, **elec_params)
M_elec = elec.get_transformation_matrix()
eap = M_elec @ imem

t_eap = tvec
eap = eap[0]


In [None]:

num_spikes = 500
end_t = 500
t = np.arange(0, end_t , dt)
num_tsteps = len(t)
spiketime_stds = [5, 3, 1., 0]

eap -= np.linspace(eap[0], eap[-1], len(eap))  # Ensure zero start and end
num_pads = int((num_tsteps - len(eap)) / 2)
eap_pad = np.r_[np.zeros(num_pads), eap, np.zeros(num_pads)]

freqs_eaps, eap_psd = elephant.spectral.welch_psd(eap_pad, fs=1/dt * 1000, num_seg=4, window="hann")

fig = plt.figure(figsize=[6, 5])
fig.subplots_adjust(top=0.95, left=0.11, right=0.5, hspace=0.9, wspace=0.6, bottom=0.1) 
ax_eap = fig.add_subplot(len(spiketime_stds) + 1, 2, 1, xlabel="time (ms)",
                      ylabel="mV", title="single spike")

ax_eap_psd = fig.add_axes([0.8, 0.88, 0.17, 0.08], ylim=[1e-12, 5e-9], xlim=[5, 1000],
                             xlabel="frequency (Hz)", ylabel="mV$^2$/Hz",)

tpad_mid = int(len(eap_pad) / 2)
tpad_mid0 = int(tpad_mid - len(eap) / 2)
tpad_mid1 = int(tpad_mid + len(eap) / 2)

l0, = ax_eap.plot(t_eap, eap_pad[tpad_mid0:tpad_mid1], 'k')
ax_eap_psd.loglog(freqs_eaps[1:], eap_psd[1:], 'k')

gauss_func = lambda mu, sigma: 1 / np.sqrt(2 * np.pi * sigma**2) * np.exp(
                                           -(t - mu)**2 / (2 * sigma**2))  

for idx, spiketime_std in enumerate(spiketime_stds):

    midpoint = end_t / 2
    if np.abs(spiketime_std) < 1e-9:
        gauss = np.zeros(num_tsteps)
        gauss[np.argmin(np.abs(t - midpoint))] = 1
    else:
        gauss = gauss_func(midpoint, spiketime_std) * dt
    
    gauss *= num_spikes  
    
    spiketimes = np.random.normal(midpoint, spiketime_std, num_spikes)
    firingrate = np.zeros(num_tsteps)
    
    for spiketime in spiketimes:
        spiketime_idx = np.argmin(np.abs(t - spiketime))
        firingrate[spiketime_idx] += 1
         
            
    sig = np.zeros(num_tsteps)
    sig = np.convolve(firingrate, eap, mode="same")

    ax_sig = fig.add_axes([0.11, 0.70 - 0.59 * idx / (len(spiketime_stds) - 1),
                           0.58, 0.08],
                             xlabel="time (ms)", ylabel="mV", xlim=[midpoint - 50, midpoint + 50],
                             title="spiketime SD: {:1.1f} ms".format(spiketime_std))
    ax_psd = fig.add_axes([0.8, 0.70 - 0.59 * idx / (len(spiketime_stds) - 1),
                           0.17, 0.08],
                             xlabel="frequency (Hz)", ylabel="mV$^2$/Hz", 
                          ylim=[1e-7, 5e-4], 
                          xlim=[5, 1000])

    freqs, sig_psd = elephant.spectral.welch_psd(sig, fs=1/dt * 1000, num_seg=4, window="hann")
    freqs, gauss_psd = elephant.spectral.welch_psd(gauss, fs=1/dt * 1000, num_seg=4, window="hann")

    ax_sig.plot(t, sig, 'k')

    ax_sig.plot(t, gauss / np.max(gauss) * np.max(np.abs(sig)), c='gray', ls=':', lw=1.5)
    
    idx_10Hz = np.argmin(np.abs(freqs - 10))
    th = gauss_psd[:] * eap_psd[:] #* num_tsteps / 2e6 # Prefactor is guessed!
    th *= sig_psd[idx_10Hz] / th[idx_10Hz]
    
    l1, = ax_psd.loglog(freqs[1:], sig_psd[1:], 'k')
    l2, = ax_psd.loglog(freqs[1:], th[1:], 'gray', ls='--')
    l3, = ax_psd.loglog(freqs[1:], gauss_psd[1:], 'gray', ls=':')
    
fig.legend([l1, l3, l2],
           [r"$V_{\rm e}$", "gaussian", "PSD(gaussian)*PSD(single spike)"],
           ncol=4, frameon=False, loc=(0, 0))
simplify_axes(fig.axes)
mark_subplots(fig.axes)
plt.savefig("LFP_spike_filter_effect_gauss.pdf")
