# Example of population EEG study

### Note that executing this code takes a lot of time!

In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.collections import PatchCollection
import scipy.fftpack as ff
import sympy as sp
from lfpykit.eegmegcalc import FourSphereVolumeConductor
from brainsignals.plotting_convention import mark_subplots, simplify_axes, cmap_v_e
from brainsignals.neural_simulations import return_equidistal_xyz
import brainsignals.neural_simulations as ns

np.random.seed(12345)
head_colors = ["#ffb380", "#74abff", "#b3b3b3", "#c87137"]

radii = [89000., 90000., 95000., 100000.]  # (µm)
sigmas = [0.276, 1.65, 0.01, 0.465]  # (S/m)

eps = 1e-2

dt = 0.1
tstop = 1000
num_tsteps = int(tstop / dt)
tvec = np.arange(num_tsteps) * dt
cdm_amp = 1e6

sample_freq = ff.fftfreq(num_tsteps, d=dt/1000.)
pidxs = np.where(sample_freq >= 0)
freqs = sample_freq[pidxs]

pop_radial_loc = radii[0] - 1000
cortical_area = 4 * np.pi * radii[0] ** 2

pop_xs, pop_ys, pop_zs = return_equidistal_xyz(5000, pop_radial_loc)
num_populations = len(pop_xs)
top_pop_idx = np.argmax(pop_zs)

pop_center = np.array([0, 0, pop_radial_loc])

corr_phi_func = lambda corr_area: np.arccos(1 - corr_area / (2 * np.pi * radii[0]**2))

def return_corr_pop_idxs(corr_area):
    pop_phis = np.arccos(pop_zs / pop_radial_loc)
    corr_phi = corr_phi_func(corr_area)
    idxs = pop_phis <= corr_phi
    return idxs

elec_x, elec_y, elec_z = return_equidistal_xyz(500, radii[-1] - eps)

elec_x = np.r_[elec_x, 0]
elec_y = np.r_[elec_y, 0]
elec_z = np.r_[elec_z, radii[-1] - eps]

elec_clrs = lambda z: plt.cm.jet((z + radii[-1]) / (2*radii[-1]))

upper_idxs = np.where(elec_z > 0)[0]
num_elecs = len(elec_x)

r_elecs = np.array([elec_x, elec_y, elec_z]).T # (µm)
sphere_model = FourSphereVolumeConductor(r_elecs, radii, sigmas)

top_elec_idx = np.argmax(elec_z)

def plot_four_sphere_model(ax, radii, dipole_loc=None):
    
    for i in range(4):
        ax.add_patch(plt.Circle((0, 0), radius=radii[-1 - i],
                                   color=head_colors[-1-i],
                                   fill=True, ec='k', lw=.1))

    if not dipole_loc is None:
        ax.arrow(dipole_loc[0], dipole_loc[2] - 1200, 0, 500, 
                 color='k', head_width=300)

def plot_head_outline(ax, radius):
    circle_npts = 100
    head_x = radius * np.cos(np.linspace(0, 2 * np.pi, circle_npts))
    head_y = radius * np.sin(np.linspace(0, 2 * np.pi, circle_npts))
    patches = []
    right_ear = mpatches.FancyBboxPatch([radius + 5000, -15000], 
                                        3000, 30000,
        boxstyle=mpatches.BoxStyle("Round", pad=5000))
    patches.append(right_ear)

    left_ear = mpatches.FancyBboxPatch([-radius - 8000, -15000], 
                                       3000, 30000,
        boxstyle=mpatches.BoxStyle("Round", pad=5000))
    patches.append(left_ear)

    collection = PatchCollection(patches, facecolor='none', 
                                 edgecolor='k', alpha=1.0)
    ax.add_collection(collection)
    ax.plot(head_x, head_y, 'k', lw=1)
    ax.plot([radius], lw=1)

    ax.plot([-10000, 0, 10000], 
            [radius, radius + 10000, radius], 'k', lw=1)


In [None]:

def return_time_dependence():
    xi = np.sum([1 / f * np.sin(2 * np.pi * f * tvec / 1000 + np.random.uniform(0, 2 * np.pi)) 
                 for f in freqs[1:200]], axis=0)
    xi /= np.std(xi)
    return xi

def calculate_all_xis():

    print("Preparing XIs")
    xis = np.zeros((num_populations,  num_tsteps))
    for p_idx in range(num_populations):
        xis[p_idx] = return_time_dependence()
    corr_sig = xis[top_pop_idx]

    return xis, corr_sig

xis, corr_sig = calculate_all_xis()

In [None]:
print("Preparing Fs")
Fs = np.zeros((num_populations, num_elecs))
for p_idx in range(num_populations):
    if p_idx % 500 == 0:
        print(p_idx, "/", num_populations)

    dipole_loc = np.array([pop_xs[p_idx], pop_ys[p_idx], pop_zs[p_idx]])
    radial_vec = dipole_loc / np.linalg.norm(dipole_loc)

    M = sphere_model.get_transformation_matrix(dipole_loc)
    Fs[p_idx] = 1000 * M @ radial_vec * cdm_amp
    
# np.save("Fs.npy", Fs)
#Fs = np.load("Fs.npy")

In [None]:
def calc_var_from_theory(Fs, corr_pop_idxs, pop_correlated):
    if pop_correlated:
        return np.sum(Fs[corr_pop_idxs, :], axis=0) **2 - np.sum(Fs[corr_pop_idxs, :]**2, axis=0)
    else:
        return np.sum(Fs[corr_pop_idxs, :]**2, axis=0)
    return var_eeg_corr


M_top = sphere_model.get_transformation_matrix(
    np.array([0, 0, pop_radial_loc]))[top_elec_idx]
F_0 = 1000 * M_top @ np.array([0, 0, 1]) * cdm_amp
rho_p = num_populations / cortical_area
a = 2.2 * 10 * 1e3

def calculate_analytic_var(corr_area, pop_correlated):
    phi_max = corr_phi_func(corr_area)

    if pop_correlated:
        # (sum F)**2
        numerator = (1 - np.exp(-phi_max * radii[0] / a) * (
            np.cos(phi_max) + (radii[0] / a) * np.sin(phi_max)))**2
        denominator = (1 + a**2 / radii[0]**2)**2
        analytic_var =  4 * np.pi**2 * a**4 * rho_p**2 * F_0**2 * numerator / denominator 
    else:
        # sum (F**2)
        numerator = 1 - np.exp(-2 * phi_max * radii[0] / a) * (
            np.cos(phi_max) + 2 * (radii[0] / a) * np.sin(phi_max))
        denominator = 1 + a**2 / (4 * radii[0]**2)
        analytic_var =  rho_p * F_0**2 * (np.pi * a**2 / 2) * numerator / denominator        
    return analytic_var



theta_, theta_m_, a_, r_p_, rho_, F_0_ = sp.symbols("theta theta_m a_ r_p rho F_0", positive=True)
integrand_ucorr = sp.exp(-2 * theta_ * r_p_ / a_) * sp.sin(theta_)
y_ucorr = 2 * sp.pi * r_p_**2 * rho_ * F_0_**2 * sp.integrate(integrand_ucorr, (theta_, 0, theta_m_))
y_ucorr = y_ucorr.simplify()
#print(y_ucorr)

integrand_corr = sp.exp(- theta_ * r_p_ / a_) * sp.sin(theta_)
y_corr = 4 * sp.pi**2 * r_p_**4 * rho_**2 * F_0_**2 * sp.integrate(integrand_corr, (theta_, 0, theta_m_))**2
y_corr = y_corr.simplify()
#print(y_corr)

param_dict = {F_0_: F_0, 
              a_: 2.2 * 10 * 1e3,
              r_p_: radii[0],
              rho_: rho_p}

def calculate_analytic_var_sympy(corr_area, pop_correlated):
    phi_max = corr_phi_func(corr_area)
    param_dict.update({theta_m: phi_max})
    if pop_correlated:
        # (sum F)**2
        analytic_var = y_corr.evalf(subs=param_dict)
    else:
        # sum (F**2)
        analytic_var = y_ucorr.evalf(subs=param_dict)
    return analytic_var


def run_simulation(corr_pop_idxs, pop_correlated):
    eeg = np.zeros((num_elecs, num_tsteps))
    
    for p_idx in np.where(corr_pop_idxs)[0]:
        #if corr_pop_idxs[p_idx]:
        if pop_correlated:
            xi = corr_sig
        else:
            xi = xis[p_idx]
        eeg += Fs[p_idx][None, :].T * xi[None, :] # (uV)
    return eeg

        

In [None]:
def plot_eeg_colorplot(eeg, vmax, elec_x, elec_y, ax):

    levels = np.linspace(-vmax, vmax, 60)
    contourf_kwargs = dict(levels=levels, 
                           cmap=cmap_v_e, 
                           vmax=vmax, 
                           vmin=-vmax,
                          extend="both")

    img = ax.tricontourf(elec_x, elec_y, eeg, **contourf_kwargs)
    ax.tricontour(elec_x, elec_y, eeg, **contourf_kwargs)

    plot_head_outline(ax, radii[-1])
    return img

def plot_eeg_var_colorplot(eeg, vmax, elec_x, elec_y, ax):

    vmap = lambda v: cmap(v / vmax)
    levels = np.linspace(0, vmax, 21)
    contourf_kwargs = dict(levels=levels, 
                           cmap="hot", 
                           vmax=vmax, 
                           vmin=0,
                          extend="both")

    img = ax.tricontourf(elec_x, elec_y, eeg, **contourf_kwargs)
    plot_head_outline(ax, radii[-1])
    return img

In [None]:
def plot_results(eeg, corr_area, corr_pop_idxs, corr_sig, pop_correlated):
    max_time_idx = np.argmax(np.abs(eeg[top_elec_idx]))

    var_eeg_theory = calc_var_from_theory(Fs, corr_pop_idxs, pop_correlated)
    

    var_eeg_observed = np.var(eeg, axis=1)
    
    
    max_var = np.ceil(np.max([var_eeg_observed, var_eeg_theory]))
    min_var = np.floor(np.min([var_eeg_observed, var_eeg_theory]))
    
    freq, eeg_psd = ns.return_freq_and_psd(tvec, eeg)
    max_psd = np.max(eeg_psd)

    plt.close("all")
    fig = plt.figure(figsize=[6, 4])
    fig.subplots_adjust(bottom=0.1, right=0.94, left=0.05, 
                        top=0.96, wspace=.4, hspace=0.3)
    ax_geom = fig.add_subplot(241, frameon=False, aspect=1,
                             xticks=[], yticks=[], 
                              title="population area:\n{:1.0f} cm² ({:1.1f} $\%$)".format(
                                  corr_area * 1e-8, corr_area / cortical_area * 100),
                             xlim=[-120000, 120000],
                             ylim=[-120000, 120000])

    ax_elecs = fig.add_axes([0.75, 0.35, 0.1, 0.1], frameon=False, aspect=1,
                             xticks=[], yticks=[], 
                             xlim=[-101000, 101000],
                             ylim=[-101000, 101000], zorder=1000)
    ax_elecs_var1 = fig.add_axes([0.8, 0.45, 0.1, 0.1], frameon=False, aspect=1,
                             xticks=[], yticks=[], 
                             xlim=[-120000, 120000],
                             ylim=[-120000, 120000])
    ax_elecs_var2 = fig.add_axes([0.9, 0.45, 0.1, 0.1], frameon=False, aspect=1,
                             xticks=[], yticks=[], 
                             xlim=[-120000, 120000],
                             ylim=[-120000, 120000])

    img = plot_eeg_var_colorplot(var_eeg_theory[upper_idxs], 
                                 max_var, 
                                 elec_x[upper_idxs], 
                                 elec_y[upper_idxs], ax_elecs_var1)
    plot_eeg_var_colorplot(var_eeg_observed[upper_idxs], 
                           max_var, 
                           elec_x[upper_idxs], elec_y[upper_idxs], ax_elecs_var2)
    
    cax = fig.add_axes([0.8, 0.65, 0.17,  0.01])
    cbar = plt.colorbar(img, cax=cax, label="µV²", orientation="horizontal")
    cbar.set_ticks([0, max_var/2, max_var])
    
    
    xlim_freqs = [freqs[1], 200]
    ax_eeg_top = fig.add_axes([0.08, 0.1, 0.25, 0.35], xlabel="time (ms)",
                              ylabel="EEG (µV)")
    twin_ax_pz = ax_eeg_top.twinx()
    twin_ax_pz.yaxis.label.set_color("r")
    twin_ax_pz.tick_params(axis='y', colors="r")
    twin_ax_pz.set_ylabel("$P_z$ (nAm)")
    
    ax_eeg_psd = fig.add_axes([0.48, 0.1, 0.17, 0.35], xlabel="frequency (Hz)", 
                              ylabel="EEG PSD (µV²/Hz)", 
                              ylim=[max_psd * 1e-6, max_psd * 2], xlim=xlim_freqs)

    ax_var = fig.add_axes([0.74, 0.1, 0.25, 0.35], xlabel="var(EEG) observed (µV²)", 
                             ylabel="var(EEG) theory (µV²)", aspect=1, 
                          xlim=[min_var - 1, max_var + 1], 
                             ylim=[min_var - 1, max_var + 1])

    plot_four_sphere_model(ax_geom, radii)
    plot_four_sphere_model(ax_elecs, radii)

    ax_var.plot([0, max_var], [0, max_var], lw=0.5, c='gray', ls="--")

    ax_geom.plot(pop_xs, pop_zs, 'o', c='gray', ms=0.3)
    ax_geom.plot(pop_xs[corr_pop_idxs], pop_zs[corr_pop_idxs], 'o', c='r', ms=0.3)
    
    
    for elec in range(num_elecs):
        ax_elecs.plot(elec_x[elec], elec_z[elec], 'o', c=elec_clrs(elec_z[elec]), ms=1)

        ax_var.plot(var_eeg_observed[elec], var_eeg_theory[elec], 'o', 
                    c=elec_clrs(elec_z[elec]), ms=2, mec='k', mew=0.1)
        ax_eeg_psd.loglog(freq, eeg_psd[elec], c=elec_clrs(elec_z[elec]), lw=0.7)
        
        l_other, = ax_eeg_top.plot(tvec, eeg[elec], c='0.7', lw=0.5)

    l_eeg, = ax_eeg_top.plot(tvec, eeg[top_elec_idx], c='k')
    l_pz, = twin_ax_pz.plot(tvec, corr_sig[:] * cdm_amp * 1e-6, c='r', lw=0.7)

    ax_eeg_top.legend([l_eeg, l_other, l_pz], ["top electrode", "other electrodes", 
                                      r"$P_z$ of correlated pop"], ncol=1, 
                      frameon=False, loc=(0., 1.0))

    vmax = np.ceil(np.max(np.abs(eeg)))
    time_idxs = np.linspace(0, num_tsteps - 1, 5, dtype=int)[:-1]
    plot_pop_r = pop_radial_loc * np.sin(corr_phi_func(corr_area))
    
    for i, time_idx in enumerate(time_idxs):
        ax_eeg = fig.add_subplot(2,5,2 + i, frameon=False, aspect=1,
                                 xticks=[], yticks=[], 
                                 title="t={:1.0f} ms".format(tvec[time_idx]),
                                 xlim=[-120000, 120000],
                                 ylim=[-120000, 120000])
        img = plot_eeg_colorplot(eeg[upper_idxs, time_idx], vmax, 
                                 elec_x[upper_idxs], elec_y[upper_idxs], ax_eeg)

        ax_eeg.plot(plot_pop_r * np.cos(np.linspace(0,2 * np.pi, 20)), 
                   plot_pop_r * np.sin(np.linspace(0,2 * np.pi, 20)), 
                    c='gray', lw=1, ls='--', zorder=1000)
        ax_eeg_top.axvline(tvec[time_idx], ls='--', c='gray')

        if i == 0:
            mark_subplots(ax_eeg, "B", xpos=0.1)

    cax = fig.add_axes([0.25, 0.62, 0.5,  0.01])
    cbar = plt.colorbar(img, cax=cax, label="µV", orientation="horizontal")
    cbar.set_ticks([-vmax, -vmax/2, 0, vmax/2, vmax])

    simplify_axes([ax_var, ax_eeg_psd])
    
    ax_var.set_xticks(ax_var.get_yticks())
    ax_var.set_yticks(ax_var.get_yticks())

    mark_subplots(ax_geom, "A", ypos=1.2)
    mark_subplots(ax_eeg_top, "C")
    mark_subplots(ax_var, "D")
    
    corr_name = "corr" if pop_correlated else "ucorr"
    plt.savefig("population_eeg_no_bg_{:s}_{:1.0f}cm2.png".format(
        corr_name, corr_area * 1e-8), facecolor="w", dpi=300)
    

In [None]:
corr_areas = np.array([0.5, 1, 5, 10, 20, 50, 100, 200, 300, 400, 
                       500, 600, 700, 800, 900, cortical_area * 1e-8
                      ]) * 1e8

num_trials = 50
pops_correlated = [True, False]
data_dict = {}
eeg_vars = np.zeros((2, num_trials, len(corr_areas)))

for corr_type, pop_correlated in enumerate(pops_correlated):
    print("Correlated population: ", pop_correlated)
    data_dict[corr_type] = {}
    for trial in range(num_trials):
        np.random.seed(1234 + trial * 10)
        print("trial number ", trial)
        xis, corr_sig = calculate_all_xis()
        for c_idx, corr_area in enumerate(corr_areas):

            print("correlated area: ", corr_area * 1e-8, "cm²")
            corr_pop_idxs = return_corr_pop_idxs(corr_area)
            eeg = run_simulation(corr_pop_idxs, pop_correlated)
            eeg_vars[corr_type, trial, c_idx] = np.var(eeg[top_elec_idx])

            if trial == 0:
                data_dict[corr_type][corr_area] = {"pop_correlated": pop_correlated}
                data_dict[corr_type][corr_area]["eeg"] = eeg
                data_dict[corr_type][corr_area]["corr_pop_idxs"] = corr_pop_idxs
                data_dict[corr_type][corr_area]["corr_sig"] = corr_sig
                plot_results(corr_area=corr_area, **data_dict[corr_type][corr_area])


In [None]:
def make_summary_figure(corr_type):
    pop_correlated = pops_correlated[corr_type]

    show_corr_areas = np.array([1, 10, 100]) * 1e8
    plt.close("all")
    fig = plt.figure(figsize=[3, 1.7 * len(show_corr_areas)])
    fig.subplots_adjust(bottom=0.1, right=0.94, left=0.05,
                        top=0.96, wspace=.6, hspace=0.3)

    ax_h = 0.45 / len(show_corr_areas)
    ax_w = 0.2
    trial_idx = 0
    for c_idx, corr_area in enumerate(show_corr_areas):

        ax_y0 = .8 - 1.6 * ax_h * (c_idx)

        fig.text(0.05, ax_y0 + ax_h / 2, "{:1.0f}\ncm²".format(
            corr_area * 1e-8), va="center", ha="center")

        corr_pop_idxs = data_dict[corr_type][corr_area]["corr_pop_idxs"]
        eeg = data_dict[corr_type][corr_area]["eeg"]
        max_time_idx = np.argmax(np.abs(eeg[top_elec_idx]))
        var_eeg_theory = calc_var_from_theory(Fs, corr_pop_idxs, pop_correlated)
        var_eeg_observed = np.var(eeg, axis=1)

        max_var = np.ceil(np.max([var_eeg_observed, var_eeg_theory]))
        min_var = np.floor(np.min([var_eeg_observed, var_eeg_theory]))

        ax_elecs_var1 = fig.add_axes([0.56, ax_y0, ax_w, ax_h], 
                                     frameon=False, aspect=1,
                                 xticks=[], yticks=[], title="theory",
                                 xlim=[-120000, 120000],
                                 ylim=[-120000, 120000])
        ax_elecs_var2 = fig.add_axes([0.56 + 1.1*ax_w, ax_y0, ax_w, ax_h], 
                                     frameon=False, aspect=1,
                                 xticks=[], yticks=[], title="simulated",
                                 xlim=[-120000, 120000],
                                 ylim=[-120000, 120000])

        mark_subplots(ax_elecs_var1, ["B", "D", "F"][c_idx])
        img = plot_eeg_var_colorplot(var_eeg_theory[upper_idxs], 
                                     max_var, 
                                     elec_x[upper_idxs], 
                                     elec_y[upper_idxs], ax_elecs_var1)
        plot_eeg_var_colorplot(var_eeg_observed[upper_idxs], 
                               max_var, 
                               elec_x[upper_idxs], elec_y[upper_idxs], ax_elecs_var2)

        cax = fig.add_axes([0.57, ax_y0 - 0.01, 0.4, 0.01])
        cbar = plt.colorbar(img, cax=cax, label="µV²", orientation="horizontal")
        cbar.set_ticks([0, int(max_var/2), int(max_var)])


        max_var = np.max([var_eeg_observed, var_eeg_theory])
        min_var = np.min([var_eeg_observed, var_eeg_theory])

        vmax = np.ceil(np.max(np.abs(eeg)))
        time_idxs = np.linspace(0, num_tsteps - 1, 5, dtype=int)[:-1][2:3]
        
        plot_pop_r = pop_radial_loc * np.sin(corr_phi_func(corr_area))

        for i, time_idx in enumerate(time_idxs):

            ax_x0 = 0.17 + ax_w * 1.1 * i
            ax_eeg = fig.add_axes([ax_x0, ax_y0, ax_w, ax_h], 
                                  frameon=False, aspect=1,
                                     xticks=[], yticks=[], 
                                     xlim=[-120000, 120000],
                                     ylim=[-120000, 120000])
            img = plot_eeg_colorplot(eeg[upper_idxs, time_idx], vmax, 
                                     elec_x[upper_idxs], elec_y[upper_idxs], ax_eeg)

            ax_eeg.plot(plot_pop_r * np.cos(np.linspace(0,2 * np.pi, 50)), 
                       plot_pop_r * np.sin(np.linspace(0,2 * np.pi, 50)), 
                        c='gray', lw=0.5, ls='--', zorder=1000)

            if i == 0:
                mark_subplots(ax_eeg, ["A", "C", "E"][c_idx], xpos=0.1)

        cax = fig.add_axes([0.07, ax_y0 - 0.01, 0.4, 0.01])
        cbar = plt.colorbar(img, cax=cax, label="µV", orientation="horizontal")
        cbar.set_ticks([-int(vmax), -int(vmax/2), 0, int(vmax/2), int(vmax)])

    ax_var_corr = fig.add_axes([0.17, 0.07, 0.5, 0.1], 
                               xlabel="$A_P$ (cm²)",
                              ylabel="EEG variance (µV²)")

    vars_theory = np.zeros(len(corr_areas))
    vars_analytic = np.zeros(len(corr_areas))    
    for c_idx, corr_area in enumerate(corr_areas):
        corr_pop_idxs = data_dict[corr_type][corr_area]["corr_pop_idxs"]
        vars_theory[c_idx] =  calc_var_from_theory(Fs, corr_pop_idxs, pop_correlated)[top_elec_idx]
        vars_analytic[c_idx] = calculate_analytic_var(corr_area, pop_correlated) 

    l1 = plt.errorbar(corr_areas * 1e-8, np.mean(eeg_vars[corr_type], axis=0), 
                      np.std(eeg_vars[corr_type], axis=0), c='k', zorder=1)   
    l2, = ax_var_corr.plot(corr_areas * 1e-8, vars_theory, 'gray', ls='--', zorder=2)    
    l3, = ax_var_corr.plot(corr_areas * 1e-8, vars_analytic, 'blue', zorder=3)    

    fig.legend([l1, l2, l3], ["simulated", "theory", "analytic"], 
               frameon=False, ncol=1, loc=(0.68, 0.1))
    mark_subplots(ax_var_corr, "G", xpos=-0.05)
    simplify_axes([ax_var_corr])
    
    corr_name = "corr" if pop_correlated else "ucorr"
    plt.savefig("population_eeg_summary_no_BG_%s.pdf" % corr_name, facecolor="w")
    
make_summary_figure(0)
make_summary_figure(1)