# Toy illustration of the effect of correlation

In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import mlab as ml
import scipy.fftpack as ff
from brainsignals.plotting_convention import mark_subplots, simplify_axes
from elephant.spike_train_generation import inhomogeneous_poisson_process
from quantities import Hz, ms
import neo

np.random.seed(12345)

def return_freq_and_psd_welch(sig, welch_dict):
    sig = np.array(sig)
    if len(sig.shape) == 1:
        sig = np.array([sig])
    elif len(sig.shape) == 2:
        pass
    else:
        raise RuntimeError("Not compatible with given array shape!")
    psd = []
    freqs = None
    for idx in range(sig.shape[0]):
        yvec_w, freqs = ml.psd(sig[idx, :], **welch_dict)
        psd.append(yvec_w)
    return freqs, np.array(psd)

def return_freq_and_psd(tvec, sig):
    """ Returns the power and freqency of the input signal"""
    sig = np.array(sig)
    if len(sig.shape) == 1:
        sig = np.array([sig])
    elif len(sig.shape) == 2:
        pass
    else:
        raise RuntimeError("Not compatible with given array shape!")
    timestep = (tvec[1] - tvec[0])/1000. if type(tvec) in [list, np.ndarray] else tvec
    sample_freq = ff.fftfreq(sig.shape[1], d=timestep)
    pidxs = np.where(sample_freq >= 0)
    freqs = sample_freq[pidxs]

    Y = ff.fft(sig, axis=1)[:, pidxs[0]]

    power = np.abs(Y)**2/Y.shape[1]
    return freqs, power

num_tsteps = 2**19
dt = 2**-5
syn_tau = 2
num_tsteps_syn = 2**10
num_synapses = 10
num_spikes = 50
jitter = 2.

num_pads = int((num_tsteps - num_tsteps_syn / 2))

divide_into_welch = 16
welch_dict = {'Fs': 1000 / dt,
              'NFFT': int(num_tsteps/divide_into_welch),
              'noverlap': int(num_tsteps/divide_into_welch/2.),
              'window': ml.window_hanning,
              'detrend': ml.detrend_mean,
              'scale_by_freq': True,
             }

tvec = np.arange(num_tsteps) * dt
syn = np.zeros(num_tsteps_syn)
syn[int(num_tsteps_syn/2):] = np.exp(-tvec[:int(num_tsteps_syn/2)] / syn_tau)
syn[int(num_tsteps_syn/2):] -= np.linspace(0, syn[-1], len(syn[int(num_tsteps_syn/2):]))
syn_pad = np.r_[np.zeros(num_pads), syn, np.zeros(num_pads)]

In [None]:
sig_corr = np.zeros(num_tsteps)
sig_corr_jit = np.zeros(num_tsteps)
sig_ucorr = np.zeros(num_tsteps)
sig_sinrate = np.zeros(num_tsteps)

freq = 30

avrg_rate = num_synapses * num_spikes * 1000 / tvec[-1]

rate_signal = np.sin(2 * np.pi * freq * tvec / 1000) + 1.01
syn_rate = neo.AnalogSignal(np.array(rate_signal)  * avrg_rate * Hz,
                      sampling_rate=(1/dt * 1000)*Hz, t_stop=tvec[-1] * ms)

spiketimes_rate = inhomogeneous_poisson_process(syn_rate * 1.2)

spiketimes_rate = np.random.choice(spiketimes_rate, size=num_synapses * num_spikes, replace=False) * 1000

spiketimes_ucorr = np.random.uniform(tvec[num_tsteps_syn], tvec[-num_tsteps_syn], 
                                     size=(num_synapses, num_spikes))
spiketimes_corr = np.random.uniform(tvec[num_tsteps_syn], tvec[-num_tsteps_syn], 
                                    size=(num_spikes))
jitters = np.random.normal(0, jitter, size=(num_synapses, num_spikes))

firingrate_corr = np.zeros(num_tsteps)
firingrate_sin = np.zeros(num_tsteps)

for spiketime in spiketimes_rate:
    spiketime_idx = np.argmin(np.abs(tvec - spiketime))
    firingrate_sin[spiketime_idx] += 1

sig_sinrate += np.convolve(firingrate_sin, syn, mode="same")

for spiketime in spiketimes_corr:
    spiketime_idx = np.argmin(np.abs(tvec - spiketime))
    firingrate_corr[spiketime_idx] += 1

for syn_idx in range(num_synapses):
    firingrate_ucorr = np.zeros(num_tsteps)
    for spiketime in spiketimes_ucorr[syn_idx]:
        spiketime_idx = np.argmin(np.abs(tvec - spiketime))
        firingrate_ucorr[spiketime_idx] += 1
    sig_ucorr += np.convolve(firingrate_ucorr, syn, mode="same")
        
    firingrate_corr_jit = np.zeros(num_tsteps) 
    for i, spiketime in enumerate(spiketimes_corr):
        spiketime_idx = np.argmin(np.abs(tvec - spiketime + jitters[syn_idx, i]))
        firingrate_corr_jit[spiketime_idx] += 1
    
    sig_corr += np.convolve(firingrate_corr, syn, mode="same")
    sig_corr_jit += np.convolve(firingrate_corr_jit, syn, mode="same")

syn_pad -= np.mean(syn_pad)
sig_corr -= np.mean(sig_corr)
sig_corr_jit -= np.mean(sig_corr_jit)
sig_ucorr -= np.mean(sig_ucorr)
sig_sinrate -= np.mean(sig_sinrate)
    
freq_syn, syn_psd = return_freq_and_psd_welch(syn_pad, welch_dict)
freq, corr_psd = return_freq_and_psd_welch(sig_corr, welch_dict)
freq, corr_jit_psd = return_freq_and_psd_welch(sig_corr_jit, welch_dict)
freq, ucorr_psd = return_freq_and_psd_welch(sig_ucorr, welch_dict)
freq, sinrate_psd = return_freq_and_psd_welch(sig_sinrate, welch_dict)


In [None]:
fig = plt.figure(figsize=[6, 2.6])
fig.subplots_adjust(hspace=0.5, wspace=0.4, top=0.9, right=0.98, bottom=0.38, left=0.08)

ax1 = fig.add_subplot(131, 
                      xlabel="time (ms)", ylabel="µV")

ax2 = fig.add_subplot(132, xlabel="time (s)", 
                      ylabel="µV")

ax4 = fig.add_subplot(133,
                      ylabel="µV²/Hz", xlabel="frequency (Hz)",
                      ylim=[1e-8, 1e-2],
                      xlim=[2, 1000],)

l0, = ax1.plot(tvec[:num_tsteps_syn], syn, c='k')


l1, = ax2.plot(tvec / 1000, sig_corr, c='r')
l3, = ax2.plot(tvec / 1000, sig_corr_jit, c='orange')
l4, = ax2.plot(tvec / 1000, sig_sinrate, c='g')
l2, = ax2.plot(tvec / 1000, sig_ucorr, c="b")

max_idx = np.argmax(sig_corr_jit)

ax4.loglog(freq[1:], corr_psd[0][1:], c='r')
ax4.loglog(freq[1:], corr_jit_psd[0][1:], c='orange')
ax4.loglog(freq[1:], sinrate_psd[0][1:], c='g')
ax4.loglog(freq[1:], ucorr_psd[0][1:], c="b")

ax4.loglog(freq_syn[1:], syn_psd[0][1:], c='k')
l5, = ax4.loglog(freq_syn[1:], syn_psd[0][1:] * num_spikes * num_synapses, c='k', ls='--')
l6, = ax4.loglog(freq_syn[1:], syn_psd[0][1:] * num_spikes * num_synapses **2, c='k', ls=':')

simplify_axes([ax1, ax2, ax4])
mark_subplots([ax1, ax2, ax4], ypos=1.08)

ax4.set_xticks([10, 100, 1000])
fig.legend([l0, l1, l2, l3, l4, l5, l6],
           [r"LFP$_{\rm single}$",
            "correlated",
            "uncorrelated",
            "correlated + jitter",
            "sinus-modulated",
            r"$N_{\rm{p}}\times N_{\rm{s}} \times $PSD(LFP$_{\rm single}$)",
            r"$N_{\rm{p}}\times N_{\rm{s}}^2 \times$PSD(LFP$_{\rm single}$)",
            ],
           loc="lower center",
           frameon=False, ncol=3)

plt.savefig("correlation_jitter_sinmod.pdf")
