# Extraction of LFP signal from extracellular potentials
NB: The run time of this notebook is quite long.

In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import elephant
import LFPy
from lfpykit import RecExtElectrode
from brainsignals import neural_simulations as ns
from brainsignals.plotting_convention import mark_subplots

np.random.seed(1234)

elec_params = {  
    'sigma': 0.3,  
    'x': np.zeros(8),
    'y': np.zeros(8),
    'z': np.linspace(-200, 1200, 8),
}
dz = np.abs(elec_params["z"][1] - elec_params["z"][0])

num_cells = 1000
spike_fraction = 0.05
num_syns = 25
tstop = 170
dt = 2**-4

weights = np.ones(num_cells) * 0.001
weights[np.random.random(size=num_cells) < spike_fraction] = 0.002

syn_params = {'e': 0.,
              'record_current': True,
              'syntype': 'Exp2Syn',
              'tau1': 1, 'tau2': 3,
              'idx': 0}

basal_wave = 20
apical_wave = 70
uniform_wave = 120
wave_std = 5


In [None]:

data_dicts = []


for cell_id in range(num_cells):
    cell = ns.return_hay_cell(tstop=tstop, dt=dt, make_passive=False)
    cell.set_rotation(z=2 * np.pi * np.random.uniform())

    syn_idxs_basal = cell.get_rand_idx_area_norm(nidx=num_syns, z_min=-500, z_max=300)
    syn_idxs_apical = cell.get_rand_idx_area_norm(nidx=num_syns, z_min=500, z_max=1500)
    syn_idxs_uniform = cell.get_rand_idx_area_norm(nidx=num_syns, z_min=-500, z_max=1500)
    
    cell.set_pos(x=np.random.normal(0, 100),
                 y=np.random.normal(0, 100),
                 z=np.random.normal(0, 20),)

    wave_basal = np.random.normal(basal_wave, wave_std, size=num_syns)
    wave_apical = np.random.normal(apical_wave, wave_std, size=num_syns)
    wave_uniform = np.random.normal(uniform_wave, wave_std, size=num_syns)

    syn_idxs = np.array([syn_idxs_basal, syn_idxs_apical, syn_idxs_uniform]).flatten()
    syn_times = np.array([wave_basal, wave_apical, wave_uniform]).flatten()

    for idx, sidx in enumerate(syn_idxs):
        syn_params["idx"] = sidx
        syn_params["weight"] = weights[cell_id]

        synapse = LFPy.Synapse(cell, **syn_params)
        synapse.set_spike_times(np.array([syn_times[idx]]))

    cell.simulate(rec_vmem=True, rec_imem=True)

    electrode = RecExtElectrode(cell, **elec_params)
    M = electrode.get_transformation_matrix()
    V_ex = M @ cell.imem * 1000 # uV

    data_dicts.append({"cell_x": cell.x.copy(),
                 "cell_z": cell.z.copy(),
                 "cell_y": cell.y.copy(),
                 "tvec": cell.tvec.copy(),
                 "V_ex": V_ex.copy(),
                 "syn_times": syn_times,
                 "syn_zs": cell.z[syn_idxs].mean(axis=1),
                 })

    del cell
    del synapse
    del electrode

np.save("neural_data_dicts.npy", data_dicts)


In [None]:
data_dicts = np.load("neural_data_dicts.npy", allow_pickle=True)

tvec = data_dicts[0]["tvec"]
num_tsteps = len(tvec)
V_e = np.zeros((8, num_tsteps))
syn_times = []
syn_zs = []
for dd in data_dicts:
    V_e += dd["V_ex"]
    syn_times.extend(dd["syn_times"])
    syn_zs.extend(dd["syn_zs"])

filt_dict_lf = {'highpass_freq': None,
                 'lowpass_freq': 300,
                 'order': 4,
                 'filter_function': 'filtfilt',
                 'fs': 1 / dt * 1000,
                 'axis': -1}

V_e_lf = elephant.signal_processing.butter(V_e, **filt_dict_lf)

In [None]:
num_plot_cells = np.min([300, num_cells])
fig = plt.figure(figsize=[6, 3.5])
fig.subplots_adjust(wspace=0, top=1, bottom=0, left=0, right=0.9)

ax_neur = fig.add_axes([0.0, 0.02, 0.33, 0.96], xticks=[], yticks=[], frameon=False,
                             aspect=1, ylim=[-260, 1250], rasterized=True)

ax_syn = fig.add_axes([0.34, 0.02, 0.19, 0.96], xticks=[], yticks=[], frameon=False,
                      ylim=[-260, 1250], rasterized=True)

ax_ecp = fig.add_axes([0.55, 0.02, 0.19, 0.96], xticks=[], yticks=[], frameon=False,
                      ylim=[-260, 1250])

ax_lfp = fig.add_axes([0.75, 0.02, 0.19, 0.96], xticks=[], yticks=[], frameon=False,
                      ylim=[-260, 1250])

ax_syn.scatter(syn_times[::20], syn_zs[::20], marker=".", s=1, color='k')


max_V_e = 300
norm = 1 / max_V_e * dz
for elec in range(len(V_e)):
    v_ecp = (V_e[elec] - V_e[elec, 0]) * norm + elec_params["z"][elec]
    v_lfp = (V_e_lf[elec] - V_e_lf[elec, 0]) * norm + elec_params["z"][elec]
    
    ax_ecp.plot(tvec, v_ecp, c='k', lw=1)
    ax_lfp.plot(tvec, v_lfp, c='k', lw=1)

    
ax_neur.plot([-350, -350], [400, 600], lw=1, c='k')
ax_neur.text(-355, 500, "200\nµm", va="center", ha="right")    

ax_lfp.plot([tvec[-1] + 10, tvec[-1] + 10], [1000, 1000 - dz], lw=1, 
            c='k')
ax_lfp.text(tvec[-1] - 3, 1000 - dz/2, "{:1.0f}\nµV".format(max_V_e),
            va="center", ha="right")

ax_lfp.plot([tvec[-1] - 50, tvec[-1]], 
            [1200 - dz + 50, 1200 - dz + 50], lw=1, c='k')
ax_lfp.text(tvec[-1] - 25, 1200 - dz + 60, "50\nms",
            va="bottom", ha="center")

ax_syn.plot([basal_wave, apical_wave], 
            [350, 350], lw=1, c='k')
ax_syn.text(apical_wave - basal_wave, 360, "50\nms",
            va="bottom", ha="center")

l_lfp = ax_lfp.axvline(basal_wave, ls="--", c='r', lw=0.5)
l_ecp = ax_ecp.axvline(basal_wave, ls="--", c='r', lw=0.5)
l_syn = ax_syn.axvline(basal_wave, ls="--", c='r', lw=0.5)

l_lfp = ax_lfp.axvline(apical_wave, ls="--", c='b', lw=0.5)
l_ecp = ax_ecp.axvline(apical_wave, ls="--", c='b', lw=0.5)
l_syn = ax_syn.axvline(apical_wave, ls="--", c='b', lw=0.5)

l_lfp = ax_lfp.axvline(uniform_wave, ls="--", c='orange', lw=0.5)
l_ecp = ax_ecp.axvline(uniform_wave, ls="--", c='orange', lw=0.5)
l_syn = ax_syn.axvline(uniform_wave, ls="--", c='orange', lw=0.5)

#clrfunc_cells = lambda c_idx: plt.cm.viridis(c_idx / num_plot_cells)
clrfunc_cells = lambda c_idx: plt.cm.Greys(0.4 + c_idx / num_plot_cells * 0.6)
for cell_idx in range(num_plot_cells):
    dd = data_dicts[cell_idx]
    
    ax_neur.plot(dd["cell_x"].T, dd["cell_z"].T,
                          c=clrfunc_cells(cell_idx), lw=1.,
                          zorder=dd["cell_y"].mean())

ax_neur.plot(elec_params["x"], elec_params["z"], 'o', c="cyan", ms=5, 
             zorder=1000)
mark_subplots(fig.axes, ypos=0.99, xpos=0.05)

plt.savefig("fig_LFP_waves_spikes_300Hz_cmap.pdf")