# LFP and EEG signals from volleys of synaptic input to a neural population
NB! This notebook may take a while to execute

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import elephant
import LFPy
from lfpykit import CurrentDipoleMoment, RecExtElectrode
from lfpykit.eegmegcalc import NYHeadModel
from brainsignals import neural_simulations as ns
from brainsignals.plotting_convention import simplify_axes, mark_subplots, cmap_v_e

np.random.seed(1234)

elec_params = {  # parameters for RecExtElectrode class
    'sigma': 0.3,  # Extracellular potential
    'x': np.zeros(8),  # Coordinates of electrode contacts
    'y': np.zeros(8),
    'z': np.linspace(-200, 1200, 8),
}
dz = np.abs(elec_params["z"][1] - elec_params["z"][0])


In [None]:
num_cells = 1000
spike_fraction = 0.05
num_syns = 25
tstop = 170
dt = 2**-4
data_dicts = []

weights = np.ones(num_cells) * 0.001
weights[np.random.random(size=num_cells) < spike_fraction] = 0.002

syn_params = {'e': 0.,
              'record_current': True,
              'syntype': 'Exp2Syn',
              'tau1': 1, 'tau2': 3,
              'idx': 0}

basal_wave = 20
apical_wave = 70
uniform_wave = 120
wave_std = 5


In [None]:

for cell_id in range(num_cells):
    cell = ns.return_hay_cell(tstop=tstop, dt=dt, make_passive=False)
    cell.set_rotation(z=2 * np.pi * np.random.uniform())

    syn_idxs_basal = cell.get_rand_idx_area_norm(nidx=num_syns, z_min=-500, z_max=300)
    syn_idxs_apical = cell.get_rand_idx_area_norm(nidx=num_syns, z_min=500, z_max=1500)
    syn_idxs_uniform = cell.get_rand_idx_area_norm(nidx=num_syns, z_min=-500, z_max=1500)
    
    cell.set_pos(x=np.random.normal(0, 100),
                 y=np.random.normal(0, 100),
                 z=np.random.normal(0, 20),)

    wave_basal = np.random.normal(basal_wave, wave_std, size=num_syns)
    wave_apical = np.random.normal(apical_wave, wave_std, size=num_syns)
    wave_uniform = np.random.normal(uniform_wave, wave_std, size=num_syns)

    syn_idxs = np.array([syn_idxs_basal, syn_idxs_apical, syn_idxs_uniform]).flatten()
    syn_times = np.array([wave_basal, wave_apical, wave_uniform]).flatten()

    for idx, sidx in enumerate(syn_idxs):
        syn_params["idx"] = sidx
        syn_params["weight"] = weights[cell_id]

        synapse = LFPy.Synapse(cell, **syn_params)
        synapse.set_spike_times(np.array([syn_times[idx]]))

    cell.simulate(rec_vmem=True, rec_imem=True)

    cdm = CurrentDipoleMoment(cell)
    M = cdm.get_transformation_matrix()
    cdm = M @ cell.imem

    electrode = RecExtElectrode(cell, **elec_params)
    M = electrode.get_transformation_matrix()
    V_ex = M @ cell.imem * 1000 # uV

    v = cell.vmem - cell.vmem[:, 0, None]

    data_dicts.append({"cell_x": cell.x.copy(),
                 "cell_z": cell.z.copy(),
                 "cell_y": cell.y.copy(),
                 #"vmem": v,
                 "tvec": cell.tvec.copy(),
                 "cdm": cdm.copy(),
                 "V_ex": V_ex.copy(),
                 "syn_times": syn_times,
                 "syn_zs": cell.z[syn_idxs].mean(axis=1),
                 })

    del cell
    del synapse

np.save("neural_data_dicts.npy", data_dicts)


In [None]:
data_dicts = np.load("neural_data_dicts.npy", allow_pickle=True)

tvec = data_dicts[0]["tvec"]
num_tsteps = len(tvec)
p = np.zeros((3, num_tsteps))
V_e = np.zeros((8, num_tsteps))
syn_times = []
syn_zs = []
for dd in data_dicts:
    p += dd["cdm"]
    V_e += dd["V_ex"]
    syn_times.extend(dd["syn_times"])
    syn_zs.extend(dd["syn_zs"])

filt_dict_lf = {'highpass_freq': None,
                 'lowpass_freq': 300,
                 'order': 4,
                 'filter_function': 'filtfilt',
                 'fs': 1 / dt * 1000,
                 'axis': -1}

V_e_lf = elephant.signal_processing.butter(V_e, **filt_dict_lf)
    

In [None]:

dipole_loc = np.array([-9.17467219,  0.09118176, 75.61421356]) # x, y, z location in mm
nyhead = NYHeadModel()

cortex = nyhead.cortex # Locations of every vertex in cortex
elecs = np.array(nyhead.elecs) # 3D locations of electrodes

head_tri = np.array(nyhead.head_data["head"]["tri"]).T - 1 # For 3D plotting
head_vc = np.array(nyhead.head_data["head"]["vc"])
cortex_tri = np.array(nyhead.head_data["cortex75K"]["tri"]).T - 1 # For 3D plotting
x_ctx, y_ctx, z_ctx = cortex
x_h, y_h, z_h = head_vc[0, :], head_vc[1, :], head_vc[2, :]
num_elecs = elecs.shape[1]

upper_idxs = np.where(elecs[2, :] > 0)[0]
elecs = elecs[:, upper_idxs]
nyhead.set_dipole_pos(dipole_loc)

# We rotate current dipole moment to be oriented along the normal vector of cortex
p = nyhead.rotate_dipole_to_surface_normal(p)

vertex_idx = np.argmin(np.sqrt(np.sum((dipole_loc[:, None] - cortex)**2, axis=0)))
p -= np.mean(p, axis=1)[:, None]

# Calculate EEG signal from lead field. 
eeg = np.zeros((num_elecs, num_tsteps))
eeg[:, :] = nyhead.get_transformation_matrix() @ p * 1e6  # nV

eeg = eeg[upper_idxs, :]

# Find closest electrode
top_elec_idx = np.argmin(np.sqrt(np.sum((cortex[:, vertex_idx, None] - 
                       elecs[:3, :])**2, axis=0)))
max_time_idx = np.argmax(np.abs(eeg[top_elec_idx]))

In [None]:
from mpl_toolkits.mplot3d import art3d
from matplotlib.patches import Circle
def rotation_matrix(d):
    """
    Calculates a rotation matrix given a vector d. The direction of d
    corresponds to the rotation axis. The length of d corresponds to
    the sin of the angle of rotation.

    Variant of: http://mail.scipy.org/pipermail/numpy-discussion/2009-March/040806.html
    """
    sin_angle = np.linalg.norm(d)

    if sin_angle == 0:
        return np.identity(3)

    d /= sin_angle

    eye = np.eye(3)
    ddt = np.outer(d, d)
    skew = np.array([[    0,  d[2],  -d[1]],
                  [-d[2],     0,  d[0]],
                  [d[1], -d[0],    0]], dtype=np.float64)

    M = ddt + np.sqrt(1 - sin_angle**2) * (eye - ddt) + sin_angle * skew
    return M

def pathpatch_2d_to_3d(pathpatch, z=0, normal='z'):
    """
    Transforms a 2D Patch to a 3D patch using the given normal vector.

    The patch is projected into they XY plane, rotated about the origin
    and finally translated by z.
    """
    if type(normal) is str: #Translate strings to normal vectors
        index = "xyz".index(normal)
        normal = np.roll((1.0,0,0), index)

    normal /= np.linalg.norm(normal) #Make sure the vector is normalised

    path = pathpatch.get_path() #Get the path and the associated transform
    trans = pathpatch.get_patch_transform()

    path = trans.transform_path(path) #Apply the transform

    pathpatch.__class__ = art3d.PathPatch3D #Change the class
    pathpatch._code3d = path.codes #Copy the codes
    pathpatch._facecolor3d = pathpatch.get_facecolor #Get the face color

    verts = path.vertices #Get the vertices in 2D

    d = np.cross(normal, (0, 0, 1)) #Obtain the rotation vector
    M = rotation_matrix(d) #Get the rotation matrix

    pathpatch._segment3d = np.array([np.dot(M, (x, y, 0)) + (0, 0, z) for x, y in verts])

def pathpatch_translate(pathpatch, delta):
    """
    Translates the 3D pathpatch by the amount delta.
    """
    pathpatch._segment3d += delta


In [None]:
# The rest is just plotting
plt.close("all")
fig = plt.figure(figsize=[6, 3.5])
fig.subplots_adjust(bottom=0.17, right=0.93, left=0.01, 
                    top=0.9, wspace=.1, hspace=0.6)

# Plot 3D head
ax_head = fig.add_axes([.01, 0.01, 0.2, 0.3], projection='3d', frame_on=False,
                          xticks=[], yticks=[], zticks=[],
                          xlim=[-70, 70], facecolor="none",
                          ylim=[-70, 70], zlim=[-70, 70], rasterized=True,
                          computed_zorder=False,
                          )

ax_geom = fig.add_axes([0.35, 0.05, 0.15, 0.27], aspect=1, 
                       frameon=False, 
                      xticks=[], yticks=[], rasterized=True)

ax_neur = fig.add_axes([0.0, 0.35, 0.25, 0.62], xticks=[], yticks=[], frameon=False,
                             aspect=1,
                      ylim=[-260, 1250], rasterized=True)

ax_syn = fig.add_axes([0.25, 0.35, 0.12, 0.62], xticks=[], yticks=[], frameon=False,
                      ylim=[-260, 1250], rasterized=True)

ax_ecp = fig.add_axes([0.39, 0.35, 0.12, 0.62], xticks=[], yticks=[], frameon=False,
                      ylim=[-260, 1250])

ax_p = fig.add_axes([0.6, 0.6, 0.15, 0.35], xlabel="time (ms)", ylabel="$P_z$ (µAµm)")
ax_eeg1 = fig.add_axes([0.6, 0.1, 0.15, 0.35], xlabel="time (ms)", ylabel="EEG (nV)")

ax_eeg_b = fig.add_axes([0.75, 0.66, 0.17, 0.33], xlim=[-110, 110], 
                       ylim=[-120, 110], aspect=1, 
                       frameon=False, rasterized=True, 
                      xticks=[], yticks=[])
ax_eeg_a = fig.add_axes([0.75, 0.33, 0.17, 0.33], xlim=[-110, 110], 
                       ylim=[-120, 110], aspect=1,
                       frameon=False, rasterized=True, 
                      xticks=[], yticks=[])
ax_eeg_u = fig.add_axes([0.75, 0.0, 0.17, 0.33], xlim=[-110, 110], 
                       ylim=[-120, 110], aspect=1, 
                       frameon=False, rasterized=True, 
                      xticks=[], yticks=[])

ax_head.axis('off')
ax_head.plot_trisurf(x_ctx, y_ctx, z_ctx, triangles=cortex_tri,
                              color="pink", zorder=0, rasterized=True)

ax_head.plot_trisurf(x_h, y_h, z_h, triangles=head_tri, 
                     color="#c87137", zorder=0, alpha=0.2)
all_patches = []                     
for elec_idx in range(len(elecs[0, :])):
    elec_normal = elecs[3:, elec_idx]
    elec_xyz = elecs[:3, elec_idx]
    p_ = Circle((0, 0), 5, facecolor='gray', zorder=elec_xyz[2],
               ) #Add a circle in the xy plane
    all_patches.append(p_)
    ax_head.add_patch(p_)
    pathpatch_2d_to_3d(p_, z=0, normal=elec_normal)
    pathpatch_translate(p_, elec_xyz)

ax_head.view_init(elev=90., azim=-90)
fig.text(0.05, 0.3, "E",
                horizontalalignment='center',
                verticalalignment='center',
                fontweight='demibold',
                fontsize=10)
ax_p.plot(tvec, 1e-3 * p[2, :], c='k')
ax_eeg1.plot(tvec, eeg[top_elec_idx], c='k')

ax_syn.scatter(syn_times[::20], syn_zs[::20], marker=".", s=1, color='k')


max_V_e = 300
norm = 1 / max_V_e * dz

for elec in range(len(V_e)):
    v_ecp = (V_e[elec] - V_e[elec, 0]) * norm + elec_params["z"][elec]
    ax_ecp.plot(tvec, v_ecp, c='k', lw=1)

t_idx_b = np.argmin(np.abs(basal_wave - tvec))
t_idx_a = np.argmin(np.abs(apical_wave - tvec))
t_idx_u = np.argmin(np.abs(uniform_wave - tvec))

t_idx_bmax = np.argmax(eeg[top_elec_idx, :])
t_idx_amax = np.argmin(eeg[top_elec_idx, :])
t_idx_umax = np.argmax(eeg[top_elec_idx, t_idx_u:]) + t_idx_u

l_ecp = ax_ecp.axvline(basal_wave, ls="--", c='r', lw=0.5)
l_syn = ax_syn.axvline(basal_wave, ls="--", c='r', lw=0.5)

l_ecp = ax_ecp.axvline(apical_wave, ls="--", c='b', lw=0.5)
l_syn = ax_syn.axvline(apical_wave, ls="--", c='b', lw=0.5)

l_ecp = ax_ecp.axvline(uniform_wave, ls="--", c='orange', lw=0.5)
l_syn = ax_syn.axvline(uniform_wave, ls="--", c='orange', lw=0.5)

ax_p.axvline(basal_wave, ls="--", c='r', lw=0.5)
ax_eeg1.axvline(basal_wave, ls="--", c='r', lw=0.5)

ax_p.axvline(tvec[t_idx_bmax], ls="--", c='gray', lw=0.5)
ax_eeg1.axvline(tvec[t_idx_bmax], ls="--", c='gray', lw=0.5)
ax_p.axvline(tvec[t_idx_amax], ls="--", c='gray', lw=0.5)
ax_eeg1.axvline(tvec[t_idx_amax], ls="--", c='gray', lw=0.5)
ax_p.axvline(tvec[t_idx_umax], ls="--", c='gray', lw=0.5)
ax_eeg1.axvline(tvec[t_idx_umax], ls="--", c='gray', lw=0.5)

ax_p.axvline(apical_wave, ls="--", c='b', lw=0.5)
ax_eeg1.axvline(apical_wave, ls="--", c='b', lw=0.5)

l_p = ax_p.axvline(uniform_wave, ls="--", c='orange', lw=0.5)
l_eeg = ax_eeg1.axvline(uniform_wave, ls="--", c='orange', lw=0.5)


cax = fig.add_axes([0.92, 0.2, 0.01, 0.6]) # This axis is just the colorbar

vmax = 10#np.floor(np.max(np.abs(eeg[:, max_time_idx])))
print(vmax, np.max(np.abs(eeg[:, max_time_idx])))


vmap = lambda v: cmap_v_e((v + vmax) / (2*vmax))
levels = np.linspace(-vmax, vmax, 60)

contourf_kwargs = dict(levels=levels, 
                       cmap=cmap_v_e, 
                       vmax=vmax, 
                       vmin=-vmax,
                      extend="both")

# Plot 3D location EEG electrodes
img = ax_eeg_b.tricontourf(elecs[0], elecs[1], eeg[:, t_idx_bmax], **contourf_kwargs)
ax_eeg_b.tricontour(elecs[0], elecs[1], eeg[:, t_idx_bmax], **contourf_kwargs)

img = ax_eeg_a.tricontourf(elecs[0], elecs[1], eeg[:, t_idx_amax], **contourf_kwargs)
ax_eeg_a.tricontour(elecs[0], elecs[1], eeg[:, t_idx_amax], **contourf_kwargs)

img = ax_eeg_u.tricontourf(elecs[0], elecs[1], eeg[:, t_idx_umax], **contourf_kwargs)
ax_eeg_u.tricontour(elecs[0], elecs[1], eeg[:, t_idx_umax], **contourf_kwargs)

cbar = plt.colorbar(img, cax=cax)
cbar.set_label("nV", labelpad=-5)
cbar.set_ticks([-vmax, -vmax/2, 0, vmax/2, vmax])


# Plotting crossection of cortex around active region center
threshold = 2  # threshold in mm for including points in plot
xz_plane_idxs = np.where(np.abs(cortex[1, :] - 
                                dipole_loc[1]) < threshold)[0]

ax_geom.scatter(cortex[0, xz_plane_idxs], 
            cortex[2, xz_plane_idxs], s=1, c='0.9')

ax_geom.plot([-30, -40], [-60, -60], c='k', lw=2)
ax_geom.text(-35, -65, "20 mm", ha='center', va="top")
ax_geom.arrow(cortex[0, vertex_idx], cortex[2, vertex_idx] - 4, 0, 4, 
          color='k', head_width=2)
        
ax_eeg_b.plot(cortex[0, vertex_idx], cortex[1, vertex_idx], 'o', c='k', ms=2)
ax_eeg_a.plot(cortex[0, vertex_idx], cortex[1, vertex_idx], 'o', c='k', ms=2)
ax_eeg_u.plot(cortex[0, vertex_idx], cortex[1, vertex_idx], 'o', c='k', ms=2)
ax_head.plot(cortex[0, vertex_idx], cortex[1, vertex_idx], cortex[2, vertex_idx], 'o', c='k', ms=2)


num_plot_cells = 300
clrfunc_cells = lambda c_idx: plt.cm.Greys(0.4 + c_idx / num_plot_cells * 0.6)
for cell_idx in range(num_plot_cells):
    dd = data_dicts[cell_idx]
    
    ax_neur.plot(dd["cell_x"].T, dd["cell_z"].T,
                          c=clrfunc_cells(cell_idx), lw=1.,
                          zorder=dd["cell_y"].mean())

ax_neur.plot(elec_params["x"], elec_params["z"], 'o', c="cyan", ms=4, 
             zorder=1000)

mark_subplots([ax_neur, ax_syn, ax_ecp], "ABC", ypos=1.01, xpos=0.05)

simplify_axes([ax_p, ax_eeg1])
mark_subplots([ax_geom], "F", ypos=1.01, xpos=0.0)
mark_subplots([ax_eeg_b], "H", ypos=0.97, xpos=0.1)
mark_subplots([ax_p, ax_eeg1], "DG", ypos=1.05, xpos=-0.1)
plt.savefig("EEG_waves_illustration.pdf")