# Illustration of the kernel method used to calculate EEG signals

In [None]:
%matplotlib inline
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import h5py
from copy import deepcopy
import json
import hashlib
from parameters import ParameterSpace, ParameterSet
import elephant
from lfpykernels import KernelApprox, GaussCylinderPotential
from lfpykit import CurrentDipoleMoment

from brainsignals.plotting_convention import mark_subplots, cmap_v_e
import brainsignals.neural_simulations as ns

# Import code used in Chapter 6
codebase_dir = os.path.join("..", "Ch-6")
sys.path.append(codebase_dir)
from example_network_parameters import (networkParameters, population_names,
                                        population_sizes)
import example_network_methods as methods
import example_network_parameters as params

ns.load_mechs_from_folder(os.path.join(codebase_dir, "mod"))

In [None]:
PS0 = ParameterSpace(os.path.join(codebase_dir, 'PS0.txt'))
PS1 = ParameterSpace(os.path.join(codebase_dir, 'PS1.txt'))
PS2 = ParameterSpace(os.path.join(codebase_dir, 'PS2.txt'))

TRANSIENT = 2000
dt = networkParameters['dt']
tau = 100  # max time lag relative to spike for kernel predictions
tau_trunc = 25 # max time lag for shown in plot

# flag; if True, use the mean membrane potential per compartment for kernel predictions 
perseg_Vrest = False

In [None]:
# figure out which real LFP to compare with
for pset in PS1.iter_inner():
    weight_EE = pset['weight_EE']
    weight_IE = pset['weight_IE']
    weight_EI = pset['weight_EI']
    weight_II = pset['weight_II']
    weight_scaling = pset['weight_scaling']
    pset_0 = ParameterSet(dict(weight_EE=weight_EE,
                               weight_IE=weight_IE,
                               weight_EI=weight_EI,
                               weight_II=weight_II,
                               weight_scaling=weight_scaling,
                               n_ext=PS0['n_ext'].value))
    js_0 = json.dumps(pset_0, sort_keys=True).encode()
    md5_0 = hashlib.md5(js_0).hexdigest()
    OUTPUTPATH_REAL = os.path.join(codebase_dir, 'output', md5_0)
    break
print(f'comparing with ground truth dataset: {OUTPUTPATH_REAL}')

In [None]:
# compute firing rate time series of "real" network (as spikes per time bin of width dt)
nu_X = dict()
tstop = networkParameters['tstop']
bins = (np.arange(0, tstop / dt + 2)
        * dt - dt / 2)
with h5py.File(os.path.join(OUTPUTPATH_REAL, 'spikes.h5'), 'r') as f:
    for i, X in enumerate(params.population_names):
        hist = np.histogram(np.concatenate(f[X]['times']), bins=bins)[0]
        nu_X[X] = hist.astype(float)

In [None]:
# Compute spike-LFP and spike-dipole moment kernel approximations using the KernelApprox class

# kernel container
H_YX_pred = dict()
for k, pset in enumerate(PS2.iter_inner()):
    # sorted json dictionary
    js = json.dumps(pset, sort_keys=True).encode()
    md5 = hashlib.md5(js).hexdigest()
    
    # parameters
    weight_EE = pset['weight_EE']
    weight_IE = pset['weight_IE']
    weight_EI = pset['weight_EI']
    weight_II = pset['weight_II']
    weight_scaling = pset['weight_scaling']
    biophys = pset['biophys']
    n_ext = pset['n_ext']
    g_eff = pset['g_eff']

    t_X = TRANSIENT  # presynaptic activation time

    # define biophysical membrane properties
    if biophys == 'pas':
        custom_fun = [methods.set_pas_hay2011, methods.make_cell_uniform]
    elif biophys == 'frozen':
        custom_fun = [methods.set_frozen_hay2011, methods.make_cell_uniform]
    elif biophys == 'frozen_no_Ih':
        custom_fun = [methods.set_frozen_hay2011_no_Ih, methods.make_cell_uniform]
    elif biophys == 'lin':
        custom_fun = [methods.set_Ih_linearized_hay2011, methods.make_cell_uniform]
    else:
        raise NotImplementedError

    # synapse max. conductance (function, mean, st.dev., min.):
    weights = np.array([[weight_EE, weight_IE],
                        [weight_EI, weight_II]]) * weight_scaling

    # class RecExtElectrode/PointSourcePotential parameters:
    electrodeParameters = params.electrodeParameters.copy()
    for key in ['r', 'n', 'N', 'method']:
        del electrodeParameters[key]

    # Not using RecExtElectrode class as we anyway average potential in
    # space for each source element. 

    # Predictor assuming planar disk source elements convolved with Gaussian
    # along z-axis
    gauss_cyl_potential = GaussCylinderPotential(
        cell=None,
        z=electrodeParameters['z'],
        sigma=electrodeParameters['sigma'],
        R=params.populationParameters['pop_args']['radius'],
        sigma_z=params.populationParameters['pop_args']['scale'],
        )

    # set up recording of current dipole moments.
    current_dipole_moment = CurrentDipoleMoment(cell=None)

    # Compute average firing rate of presynaptic populations X
    mean_nu_X = methods.compute_mean_nu_X(params, OUTPUTPATH_REAL,
                                     TRANSIENT=TRANSIENT)

    # kernel container
    H_YX_pred[md5] = dict()

    for i, (X, N_X) in enumerate(zip(params.population_names,
                                     params.population_sizes)):
        for j, (Y, N_Y, morphology) in enumerate(zip(params.population_names,
                                                     params.population_sizes,
                                                     params.morphologies)):
            
            # Extract median soma voltages from actual network simulation and
            # assume this value corresponds to Vrest.
            if not perseg_Vrest:
                with h5py.File(os.path.join(OUTPUTPATH_REAL, 'somav.h5'
                                            ), 'r') as f:
                    Vrest = np.median(f[Y][()][:, TRANSIENT:])
            else:  # perseg_Vrest == True
                with h5py.File(os.path.join(OUTPUTPATH_REAL, 'vmem.h5'
                                            ), 'r') as f:
                    Vrest = np.median(f[Y][()][:, TRANSIENT:], axis=-1)

            cellParameters = deepcopy(params.cellParameters)
            cellParameters.update(dict(
                morphology=os.path.join(codebase_dir, morphology),
                custom_fun=custom_fun,
                custom_fun_args=[dict(Vrest=Vrest), dict(Vrest=Vrest)],
                templatefile=os.path.join(codebase_dir, params.cellParameters["templatefile"])
            ))
            
            # some inputs must be lists
            synapseParameters = [
                dict(weight=weights[ii][j],
                     syntype='Exp2Syn',
                     **params.synapseParameters[ii][j])
                for ii in range(len(params.population_names))]
            synapsePositionArguments = [
                params.synapsePositionArguments[ii][j]
                for ii in range(len(params.population_names))]

            # Create kernel approximator object
            kernel = KernelApprox(
                X=params.population_names,
                Y=Y,
                N_X=np.array(params.population_sizes),
                N_Y=N_Y,
                C_YX=np.array(params.connectionProbability[i]),
                cellParameters=cellParameters,
                populationParameters=params.populationParameters['pop_args'],
                multapseFunction=params.multapseFunction,
                multapseParameters=[params.multapseArguments[ii][j] for ii in range(len(params.population_names))],
                delayFunction=params.delayFunction,
                delayParameters=[params.delayArguments[ii][j] for ii in range(len(params.population_names))],
                synapseParameters=synapseParameters,
                synapsePositionArguments=synapsePositionArguments,
                extSynapseParameters=params.extSynapseParameters,
                nu_ext=1000. / params.netstim_interval,
                n_ext=n_ext[j],
                nu_X=mean_nu_X,
            )

            # make kernel predictions
            H_YX_pred[md5]['{}:{}'.format(Y, X)] = kernel.get_kernel(
                probes=[gauss_cyl_potential, current_dipole_moment],
                Vrest=Vrest, dt=dt, X=X, t_X=t_X, tau=tau,
                g_eff=g_eff)
            

In [None]:
# Compute reconstructed signals as the sum over convolutions
# phi(r, t) = sum_X sum_Y (nu_X*H_YX)(r, t)
# using kernels obtained either via the hybrid scheme and direct method
all_kernel_predictions = []
for j, (fname, ylabel, probe) in enumerate(zip(
    ['RecExtElectrode.h5', 'CurrentDipoleMoment.h5'],
    [r'$V_\mathrm{e}$', r'$\mathbf{P}$'],
    ['GaussCylinderPotential', 'CurrentDipoleMoment'],
    )):
       
    with h5py.File(os.path.join(OUTPUTPATH_REAL, fname), 'r') as f:
        data = f['data'][()]
                
    # compare biophysical variants using predicted kernels
    kernel_predictions = []  # container
    for k, pset in enumerate(PS2.iter_inner()):
        # sorted json dictionary
        js = json.dumps(pset, sort_keys=True).encode()
        md5 = hashlib.md5(js).hexdigest()

        label = ''
        for h, (key, value) in enumerate(pset.items()):
            if key.rfind('weight') >= 0 or key.rfind('n_ext') >= 0 or key.rfind('i_syn') >= 0 or key.rfind('t_E') >= 0 or key.rfind('t_I') >= 0 or key.rfind('perseg_Vrest') >= 0:
                continue
            if h > 5:
                label += '\n'
            label += '{}:{}'.format(key, value)

        prediction_label = r'$\sum_X \sum_Y \nu_X \ast \hat{H}_\mathrm{YX}$' + '\n' + label
        
        data = None
        for i, (X, N_X) in enumerate(zip(population_names,
                                         population_sizes)):
            for Y in population_names:
                if data is None:
                    data = np.zeros((H_YX_pred[md5]['{}:{}'.format(Y, X)][probe].shape[0],
                                     nu_X[X].size))
                for h, h_YX in enumerate(H_YX_pred[md5]['{}:{}'.format(Y, X)][probe]):
                    data[h, :] = data[h, :] + np.convolve(nu_X[X], h_YX, 'same')
        kernel_predictions.append((prediction_label, data))
    all_kernel_predictions.append(kernel_predictions)

## Plot

In [None]:
nyhead_file = os.path.join("sa_nyhead.mat")
head_data = h5py.File(nyhead_file, 'r')["sa"]
lead_field_normal = np.array(head_data["cortex75K"]["V_fem_normal"])
#lead_field = np.array(head_data["cortex75K"]["V_fem"])
cortex = np.array(head_data["cortex75K"]["vc"]) # Locations of every vertex in cortex
elecs = np.array(head_data["locs_3D"]) # 3D locations of electrodes
#elecs_2D = np.array(head_data["locs_2D"]) # 2D locations of electrodes
head_tri = np.array(head_data["head"]["tri"]).T - 1 # For 3D plotting
head_vc = np.array(head_data["head"]["vc"])

cortex_plt_idxs = np.array(head_data["cortex10K"]["in_from_cortex75K"], dtype=int)

cortex_tri = np.array(head_data["cortex10K"]["tri"]).T - 1 # For 3D plotting
x_ctx, y_ctx, z_ctx = cortex[:, cortex_plt_idxs[0] - 1]
x_h, y_h, z_h = head_vc[0, :], head_vc[1, :], head_vc[2, :]
num_eeg_elecs = elecs.shape[1]

upper_idxs = np.where(elecs[2, :] > 0)[0]
elecs = elecs[:, upper_idxs]

dipole_loc = np.array([-10., 0., 88.]) # x, y, z location in mm
vertex_idx = np.argmin(np.sqrt(np.sum((dipole_loc[:, None] - cortex)**2, axis=0)))


In [None]:
from mpl_toolkits.mplot3d import art3d
from matplotlib.patches import Circle
def rotation_matrix(d):
    """
    Calculates a rotation matrix given a vector d. The direction of d
    corresponds to the rotation axis. The length of d corresponds to
    the sin of the angle of rotation.

    Variant of: http://mail.scipy.org/pipermail/numpy-discussion/2009-March/040806.html
    """
    sin_angle = np.linalg.norm(d)

    if sin_angle == 0:
        return np.identity(3)

    d /= sin_angle

    eye = np.eye(3)
    ddt = np.outer(d, d)
    skew = np.array([[    0,  d[2],  -d[1]],
                  [-d[2],     0,  d[0]],
                  [d[1], -d[0],    0]], dtype=np.float64)

    M = ddt + np.sqrt(1 - sin_angle**2) * (eye - ddt) + sin_angle * skew
    return M

def pathpatch_2d_to_3d(pathpatch, z=0, normal='z'):
    """
    Transforms a 2D Patch to a 3D patch using the given normal vector.

    The patch is projected into they XY plane, rotated about the origin
    and finally translated by z.
    """
    if type(normal) is str: #Translate strings to normal vectors
        index = "xyz".index(normal)
        normal = np.roll((1.0,0,0), index)

    normal /= np.linalg.norm(normal) #Make sure the vector is normalised

    path = pathpatch.get_path() #Get the path and the associated transform
    trans = pathpatch.get_patch_transform()

    path = trans.transform_path(path) #Apply the transform

    pathpatch.__class__ = art3d.PathPatch3D #Change the class
    pathpatch._code3d = path.codes #Copy the codes
    pathpatch._facecolor3d = pathpatch.get_facecolor #Get the face color

    verts = path.vertices #Get the vertices in 2D

    d = np.cross(normal, (0, 0, 1)) #Obtain the rotation vector
    M = rotation_matrix(d) #Get the rotation matrix

    pathpatch._segment3d = np.array([np.dot(M, (x, y, 0)) + (0, 0, z) for x, y in verts])

def pathpatch_translate(pathpatch, delta):
    """
    Translates the 3D pathpatch by the amount delta.
    """
    pathpatch._segment3d += delta

In [None]:
T = [2000, 2200]

eeg_time = 2079

pathways = ['E:E', 'I:E', 'E:I', 'I:I']

filt_dict_lf = {'highpass_freq': None,
                 'lowpass_freq': 100,
                 'order': 4,
                 'filter_function': 'filtfilt',
                 'fs': 1 / dt * 1000,
                 'axis': -1}

fr_I_lf = elephant.signal_processing.butter(nu_X["I"], **filt_dict_lf)
fr_E_lf = elephant.signal_processing.butter(nu_X["E"], **filt_dict_lf)

spiketimes = {}
with h5py.File(os.path.join(OUTPUTPATH_REAL, 'spikes.h5'), 'r') as f:
    for i, X in enumerate(params.population_names):
        spiketimes[X] = [np.array(f[X]['gids']), np.array(f[X]['times'])]

with h5py.File(os.path.join(OUTPUTPATH_REAL, "CurrentDipoleMoment.h5"), 'r') as f:
    p_gt = np.array(f['data']["imem"][()])
    
p_kernels = all_kernel_predictions[1][0][1]

# Calculate EEG signal from lead field. 
eeg = np.zeros((num_eeg_elecs, len(nu_X["E"])))

eeg[:, :] = lead_field_normal[vertex_idx, :][:, None] * p_kernels[2, :] * 1e-3  # µV

eeg = eeg[upper_idxs, :]

kernel_t = np.arange(len(H_YX_pred[md5][pathways[0]]['CurrentDipoleMoment'][0, :])) * dt
kernel_t -= tau

sig_t = np.arange(len(nu_X["E"])) * dt

t0 = np.argmin(np.abs(sig_t - T[0]))
t1 = np.argmin(np.abs(sig_t - T[1]))

time_idx = np.argmin(np.abs(sig_t - eeg_time))

fr_I_lf -= np.mean(fr_I_lf)
fr_I_lf = fr_I_lf / np.std(fr_I_lf)

fr_E_lf -= np.mean(fr_E_lf)
fr_E_lf = fr_E_lf / np.std(fr_E_lf)

pz_gt = p_gt[2, :] - np.mean(p_gt[2, :])
pz_gt /= np.std(pz_gt)

pz_kernels = p_kernels[2, :] - np.mean(p_kernels[2, :])
pz_kernels /= np.std(pz_kernels)

fig = plt.figure(figsize=[6, 6])
ax_spikes = fig.add_axes([0.05, 0.75, 0.9, 0.2], xlim=T, 
                         frameon=False, xticks=[], yticks=[], rasterized=True,
                         title="spike times")

ax_fr = fig.add_axes([0.05, 0.6, 0.9, 0.1], xlim=T, yticks=[], ylim=[-2, 2.3],
                        frameon=False, xticks=[], 
                     title="firing rates (low-pass filtered)")

ax_pz = fig.add_axes([0.05, 0.28, 0.9, 0.1], xlim=T, 
                        frameon=False, xticks=[], title="$P_z$",
                     yticks=[], ylim=[-2.2, 2.2])

# Plot 3D head
ax_head = fig.add_axes([.02, 0.0, 0.25, 0.33], projection='3d', 
                       frame_on=False,
                          xticks=[], yticks=[], zticks=[],
                          xlim=[-70, 70], facecolor="none", rasterized=True,
                          ylim=[-70, 70], zlim=[-70, 70],
                          )

ax_geom = fig.add_axes([0.33, 0.02, 0.25, 0.28], aspect=1, 
                       frameon=False, 
                      xticks=[], yticks=[], rasterized=True)

ax_eeg = fig.add_axes([0.65, 0., 0.25, 0.33], xlim=[-110, 110], 
                       ylim=[-120, 110], aspect=1,
                       frameon=False, 
                      xticks=[], yticks=[])

ax_spikes.plot([T[1] - 20, T[1]], [-200, -200], lw=1, c='k', clip_on=False)
ax_spikes.text(T[1] - 10, -260, "20 ms", va="top", ha="center")

ax_fr.plot([T[1] - 20, T[1]], [-1, -1], lw=1, c='k')
ax_fr.text(T[1] - 10, -1.2, "20 ms", va="top", ha="center")

ax_pz.plot([T[1] - 20, T[1]], [-1, -1], lw=1, c='k')
ax_pz.text(T[1] - 10, -1.2, "20 ms", va="top", ha="center")

ax_spikes.axvline(sig_t[time_idx], ls="--", c='gray')
ax_fr.axvline(sig_t[time_idx], ls="--", c='gray')
ax_pz.axvline(sig_t[time_idx], ls="--", c='gray')

l_fr_E, = ax_fr.plot(sig_t[t0:t1], fr_E_lf[t0:t1], c='b', lw=1)
l_fr_I, = ax_fr.plot(sig_t[t0:t1], fr_I_lf[t0:t1], c='r', lw=1)

for X in ["E", "I"]:
    gids, spikes = spiketimes[X]
    for idx, gid in enumerate(gids):
        s_ = spikes[idx]
        ii = (s_ >= T[0]) & (s_ <= T[1])
        s_ = np.array(s_[ii])
        ax_spikes.plot(s_, np.ones(len(s_)) * gid, '.', 
                       c={"E": 'b', "I": "r"}[X], ms=2, zorder=0)   

l_kernels, = ax_pz.plot(sig_t[t0:t1], pz_kernels[t0:t1], c='gray', lw=1.5)
l_fr, = ax_pz.plot(sig_t[t0:t1] + 5, -fr_I_lf[t0:t1], c='r', lw=1)

ax_pz.legend([l_kernels, l_fr], 
             ["kernel method", "simple firing-rate proxy"], 
             frameon=False, loc=(0.7, 0.9))

ax_kernels = []

for p_idx, pathway in enumerate(pathways):
    p = H_YX_pred[md5][pathway]['CurrentDipoleMoment']
    ax_p = fig.add_subplot(5, 4, p_idx + 9, xlim=[0, 25], 
                           ylim=[np.min(p[2, :]) * 1.05, -np.min(p[2, :])*0.1],
                           frameon=False, xticks=[], yticks=[], title=pathway)
    ax_kernels.append(ax_p)
    ax_p.plot(kernel_t, p[2, :], c='k')
    
    ax_p.plot([26, 26], [np.min(p[2, :]), 0], c='k', clip_on=False)
    ax_p.text(25, np.min(p[2, :]) * 0.65, "{:1.2f}\nnAmm".format(-np.min(p[2, :])/2 * 1e-3), 
              ha='right', va="center")

    ax_p.plot([0, 5], [-np.min(p[2, :]) / 10] * 2, c='k')
    ax_p.text(2.5, -np.min(p[2, :]) / 10, "5 ms", va="bottom", ha="center")
    
ax_head.axis('off')
ax_head.plot_trisurf(x_ctx, y_ctx, z_ctx, triangles=cortex_tri,
                              color="pink", zorder=0)

ax_head.plot_trisurf(x_h, y_h, z_h, triangles=head_tri, 
                     color="#c87137", zorder=0, alpha=0.2)
all_patches = []                     
for elec_idx in range(len(elecs[0, :])):
    elec_normal = elecs[3:, elec_idx]
    elec_xyz = elecs[:3, elec_idx]
    p = Circle((0, 0), 5, facecolor='gray', zorder=elec_xyz[2],
               ) #Add a circle in the xy plane
    all_patches.append(p)
    ax_head.add_patch(p)
    pathpatch_2d_to_3d(p, z=0, normal=elec_normal)
    pathpatch_translate(p, elec_xyz)

ax_head.view_init(elev=90., azim=-90)    


cax = fig.add_axes([0.9, 0.05, 0.01, 0.2]) # This axis is just the colorbar

vmax = 50#np.floor(np.max(np.abs(eeg[:, time_idx])))
vmap = lambda v: cmap((v + vmax) / (2*vmax))
levels = np.linspace(-vmax, vmax, 60)

contourf_kwargs = dict(levels=levels,
                       cmap=cmap_v_e, 
                       vmax=vmax, 
                       vmin=-vmax,
                      extend="both")

# Plot 3D location EEG electrodes
img = ax_eeg.tricontourf(elecs[0], elecs[1], eeg[:, time_idx], **contourf_kwargs)
img2 = ax_eeg.tricontour(elecs[0], elecs[1], eeg[:, time_idx], **contourf_kwargs)

cbar = plt.colorbar(img, cax=cax)
cbar.set_label("µV", labelpad=-5)
cbar.set_ticks([-vmax, -vmax/2, 0, vmax/2, vmax])
    
# Plotting crossection of cortex around active region center
threshold = 1  # threshold in mm for including points in plot
xz_plane_idxs = np.where(np.abs(cortex[1, :] - 
                                dipole_loc[1]) < threshold)[0]

ax_geom.scatter(cortex[0, xz_plane_idxs], 
            cortex[2, xz_plane_idxs], s=1, c='0.9')

ax_geom.plot([-30, -40], [-60, -60], c='k', lw=1)
ax_geom.text(-35, -65, "20 mm", ha='center', va="top")
ax_geom.arrow(cortex[0, vertex_idx], cortex[2, vertex_idx] - 4, 0, 4, 
          color='k', head_width=2)


ax_eeg.plot(cortex[0, vertex_idx], cortex[1, vertex_idx], 'o', c='k', ms=6)
ax_head.plot([cortex[0, vertex_idx]], 
             [cortex[1, vertex_idx]], 
             [cortex[2, vertex_idx]], 'o', c='k', 
             ms=6, zorder=10000)

mark_subplots(ax_spikes, "A", xpos=-0.01, ypos=1.02)
mark_subplots(ax_fr, "B", xpos=-0.01, ypos=1.02)
mark_subplots(ax_kernels[0], "C", xpos=-0.4, ypos=1.02)
mark_subplots(ax_pz, "D", xpos=-0.01, ypos=1.02)

mark_subplots(ax_geom, "E", ypos=0.9, xpos=-1.15)

fig.savefig("firing_rate_EEG_compare.pdf")
