In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import os
from parameters import ParameterSpace, ParameterSet
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import h5py
import scipy.signal as ss
from hay2011_network_parameters import (networkParameters, population_names,
                                        population_sizes)
from lfpykernels import KernelApprox, GaussCylinderPotential, KernelApproxCurrentDipoleMoment
import neuron
import example_network_methods as methods
import hay2011_network_parameters as params
from plotting import remove_axis_junk
import scipy.stats as st
from copy import deepcopy
from plotting import draw_lineplot, annotate_subplot
import plotting
from lfpykit import CurrentDipoleMoment
import json
import hashlib
import pandas as pd
import seaborn as sb
from time import time

In [None]:
plt.rcParams.update(plotting.rcParams)
golden_ratio = plotting.golden_ratio
figwidth = plotting.figwidth

In [None]:
PS0 = ParameterSpace('Hay2011_PS0.txt')
PS1 = ParameterSpace('Hay2011_PS1.txt')
PS2 = ParameterSpace('Hay2011_PS2.txt')

In [None]:
# load mod files
neuron.load_mechanisms('mod')

In [None]:
TRANSIENT = 2000
dt = networkParameters['dt']
tau = 100  # max time lag relative to spike for kernel predictions
tau_trunc = 50 # max time lag for shown in plot

In [None]:
# downsample signals for plots to resoultion dt * decimate_ratio
decimate_ratio = 4

In [None]:
# ss.welch/plt.mlab.psd/csd settings
Fs = 1000 / dt
NFFT = 1024 * 2
noverlap = 768 * 2
detrend = False

In [None]:
# low-pass filter settings
N = 2  # filter order
rp = 0.1  # ripple in passband (dB)
rs = 40.  # minimum attenuation required in the stop band (dB)
fc = 100.  # critical frequency (Hz)
btype = 'lp'  # filter type

# filter coefficients on 'sos' format
sos_ellip = ss.ellip(N=N, rp=rp, rs=rs, Wn=fc, btype=btype, fs=Fs, output='sos')

In [None]:
markers = 'od*p'

In [None]:
# figure out which real LFP to compare with
for pset in PS1.iter_inner():
    weight_EE = pset['weight_EE']
    weight_IE = pset['weight_IE']
    weight_EI = pset['weight_EI']
    weight_II = pset['weight_II']
    weight_scaling = pset['weight_scaling']
    pset_0 = ParameterSet(dict(weight_EE=weight_EE,
                               weight_IE=weight_IE,
                               weight_EI=weight_EI,
                               weight_II=weight_II,
                               weight_scaling=weight_scaling,
                               n_ext=PS0['n_ext'].value))
    js_0 = json.dumps(pset_0, sort_keys=True).encode()
    md5_0 = hashlib.md5(js_0).hexdigest()
    OUTPUTPATH_REAL = os.path.join('output', md5_0)

    break
print(f'comparing with ground truth dataset: {OUTPUTPATH_REAL}')

In [None]:
# compute firing rate time series of "real" network (as spikes per time bin of width dt)
nu_X = dict()
tstop = networkParameters['tstop']
bins = (np.arange(0, tstop / dt + 2)
        * dt - dt / 2)
with h5py.File(os.path.join(OUTPUTPATH_REAL, 'spikes.h5'), 'r') as f:
    for i, X in enumerate(params.population_names):
        hist = np.histogram(np.concatenate(f[X]['times']), bins=bins)[0]
        nu_X[X] = hist.astype(float)

In [None]:
# plot firing rate time series of "real" network
inds = (bins[:-1] >= 900) & (bins[:-1] <= 1000)
fig, axes = plt.subplots(2, 1, sharex=True, sharey=True, figsize=(16, 9))
for i, X in enumerate(population_names):
    axes[i].step(bins[:-1][inds], nu_X[X][inds])
    axes[i].set_title(r'$\nu_{%s}(t)$' % X)
    axes[i].set_ylabel(r'$\nu_X$ (# spikes/$\Delta t$)')
axes[1].set_xlabel('$t$ (ms)')

In [None]:
# firing rate spectra

# normalized smoothing filter coefficients
std = 1 # (ms) standard deviation of Gaussian filter
width = 100 # (ms) width of Gaussian filter
w = ss.windows.gaussian(M=int(width / dt) + 1, std=std / dt)
w /= w.sum()

cutoff = 200.  #

fig, ax = plt.subplots(1, 1, figsize=(8, 8))
for X in population_names:
    #nu = ss.convolve(nu_X[X], w, 'same')[int(TRANSIENT//dt):] / dt # 
    nu = nu_X[X] / dt
    freqs, Pxx = ss.welch(nu, fs=Fs, nperseg=NFFT, noverlap=noverlap)
    
    ax.semilogy(freqs[freqs <= cutoff], Pxx[freqs <= cutoff], label=X)
ax.set_xlabel('$f$ (Hz)', labelpad=0)
ax.set_ylabel(r'PSD$_\nu$ (s$^{-2}$/Hz)')
ax.legend()

In [None]:
# compute hybrid-scheme kernels

# kernel container
H_YX_hybrid = dict()    
for k, pset in enumerate(PS2.iter_inner()):
    # sorted json dictionary
    js = json.dumps(pset, sort_keys=True).encode()
    md5 = hashlib.md5(js).hexdigest()
    
    H_YX_hybrid[md5] = dict()

    # sorted json dictionary
    js = json.dumps(pset, sort_keys=True).encode()
    md5 = hashlib.md5(js).hexdigest()

    OUTPUTPATH = os.path.join('output', md5)

    for i, (X, t, N_X) in enumerate(zip(population_names,
                                        [pset['t_E'], pset['t_I']],
                                        population_sizes)):
        inds = (np.arange(-tau // dt,
                          tau // dt + 1)
                + t // dt).astype(int)        
        for j, Y in enumerate(population_names):
            H_YX_probe = dict()
            for h, (unit, fname, label) in enumerate(
                    zip(['mV', 'nAµm'],
                        ['RecExtElectrode.h5', 'CurrentDipoleMoment.h5'], 
                        ['RecExtElectrode', 'CurrentDipoleMoment'])):
                
                with h5py.File(os.path.join(OUTPUTPATH, fname), 'r') as f:
                    H_YX = f['data'][Y][:, inds] / N_X
                    H_YX = (H_YX.T - H_YX[:, int(tau // dt)]).T
                    H_YX[:, :int(tau // dt)] = 0
                
                H_YX_probe[label] = H_YX
                
            H_YX_hybrid[md5]['{}:{}'.format(Y, X)] = H_YX_probe

In [None]:
# flag; if True, use the median membrane potential per compartment for kernel predictions 
perseg_Vrest = False

In [None]:
f = h5py.File(os.path.join(OUTPUTPATH_REAL, 'vmem.h5'), 'r')
plt.plot(f['E'][()][:, TRANSIENT:].mean(axis=-1))
plt.plot(np.median(f['E'][()][:, TRANSIENT:], axis=-1))
plt.axis('tight')
# plt.colorbar()

In [None]:
# Compute spike-LFP and spike-dipole moment kernel approximations using the KernelApprox class

# kernel container
H_YX_pred = dict()
H_YX_pred_times = pd.DataFrame(columns=['step', 'time_s'])
for k, pset in enumerate(PS2.iter_inner()):
    # sorted json dictionary
    js = json.dumps(pset, sort_keys=True).encode()
    md5 = hashlib.md5(js).hexdigest()

    # tic toc
    tic = time()
    
    # parameters
    weight_EE = pset['weight_EE']
    weight_IE = pset['weight_IE']
    weight_EI = pset['weight_EI']
    weight_II = pset['weight_II']
    weight_scaling = pset['weight_scaling']
    biophys = pset['biophys']
    n_ext = pset['n_ext']
    g_eff = pset['g_eff']

    t_X = TRANSIENT  # presynaptic activation time

    # define biophysical membrane properties
    if biophys == 'pas':
        custom_fun = [methods.set_pas_hay2011, methods.make_cell_uniform]
    elif biophys == 'frozen':
        custom_fun = [methods.set_frozen_hay2011, methods.make_cell_uniform]
    elif biophys == 'frozen_no_Ih':
        custom_fun = [methods.set_frozen_hay2011_no_Ih, methods.make_cell_uniform]
    elif biophys == 'lin':
        custom_fun = [methods.set_Ih_linearized_hay2011, methods.make_cell_uniform]
    else:
        raise NotImplementedError

    # synapse max. conductance (function, mean, st.dev., min.):
    weights = np.array([[weight_EE, weight_IE],
                        [weight_EI, weight_II]]) * weight_scaling

    # class RecExtElectrode/PointSourcePotential parameters:
    electrodeParameters = params.electrodeParameters.copy()
    for key in ['r', 'n', 'N', 'method']:
        del electrodeParameters[key]

    # Not using RecExtElectrode class as we anyway average potential in
    # space for each source element. 

    # Predictor assuming planar disk source elements convolved with Gaussian
    # along z-axis
    gauss_cyl_potential = GaussCylinderPotential(
        cell=None,
        z=electrodeParameters['z'],
        sigma=electrodeParameters['sigma'],
        R=params.populationParameters['pop_args']['radius'],
        sigma_z=params.populationParameters['pop_args']['scale'],
        )

    # set up recording of current dipole moments.
    current_dipole_moment = CurrentDipoleMoment(cell=None)

    # Compute average firing rate of presynaptic populations X
    mean_nu_X = methods.compute_mean_nu_X(params, OUTPUTPATH_REAL,
                                     TRANSIENT=TRANSIENT)

    # tic-tac
    tac = time()
    H_YX_pred_times_pset = pd.DataFrame(data={'step': 'setup', 'time_s': tac-tic}, index=[0])
        
    # kernel container
    H_YX_pred[md5] = dict()

    for i, (X, N_X) in enumerate(zip(params.population_names,
                                     params.population_sizes)):
        for j, (Y, N_Y) in enumerate(zip(params.population_names,
                                         params.population_sizes)):
            # tic tac
            tic = time()
            
            
            # Extract median soma voltages from actual network simulation and
            # assume this value corresponds to Vrest.
            if not perseg_Vrest:
                with h5py.File(os.path.join(OUTPUTPATH_REAL, 'somav.h5'
                                            ), 'r') as f:
                    Vrest = np.median(f[Y][()][:, TRANSIENT:])
            else:  # perseg_Vrest == True
                with h5py.File(os.path.join(OUTPUTPATH_REAL, 'vmem.h5'
                                            ), 'r') as f:
                    Vrest = np.mean(f[Y][()][:, TRANSIENT:], axis=-1)
                

            cellParameters = deepcopy(params.cellParameters[Y])
            if biophys == 'frozen':
                if Y == 'E':
                    cellParameters.update({
                        'templatefile': [
                            'L5bPCmodelsEH/models/L5PCbiophys3_frozen.hoc',
                            'L5bPCmodelsEH/models/L5PCtemplate_frozen.hoc'
                            ],
                        'templatename': 'L5PCtemplate_frozen',
                        'custom_fun': [
                            methods.set_V_R,
                            methods.make_cell_uniform
                            ],
                        'custom_fun_args': [dict(Vrest=Vrest)] * 2,
                    })
                elif Y == 'I':
                    cellParameters.update({
                        'custom_fun': [
                            methods.set_frozen_hay2011,
                            methods.make_cell_uniform
                            ],
                        'custom_fun_args': [dict(Vrest=Vrest)] * 2,
                    })
                else:
                    raise Exception(f'population {Y} not recognized')
            elif biophys == 'lin':
                if Y == 'E':
                    cellParameters.update({
                        'templatefile': [
                            'L5bPCmodelsEH/models/L5PCbiophys3_lin.hoc',
                            'L5bPCmodelsEH/models/L5PCtemplate_lin.hoc'
                            ],
                        'templatename': 'L5PCtemplate_lin',
                        'custom_fun': [
                            methods.set_V_R,
                            methods.make_cell_uniform
                            ],
                        'custom_fun_args': [dict(Vrest=Vrest)] * 2,
                    })
                elif Y == 'I':
                    cellParameters.update({
                        'custom_fun': [
                            methods.set_Ih_linearized_hay2011,
                            methods.make_cell_uniform
                            ],
                        'custom_fun_args': [dict(Vrest=Vrest)] * 2,
                    })
                else:
                    raise Exception(f'population {Y} not recognized')
            elif biophys == 'pas':
                if Y == 'E':
                    cellParameters.update({
                        'templatefile': [
                            'L5bPCmodelsEH/models/L5PCbiophys3_pas.hoc',
                            'L5bPCmodelsEH/models/L5PCtemplate_pas.hoc'
                            ],
                        'templatename': 'L5PCtemplate_pas',
                        'custom_fun': [
                            methods.make_cell_uniform
                            ],
                        'custom_fun_args': [dict(Vrest=Vrest)],
                    })
                elif Y == 'I':
                    cellParameters.update({
                        'custom_fun': [
                            methods.set_pas_hay2011,
                            methods.make_cell_uniform
                            ],
                        'custom_fun_args': [dict(Vrest=Vrest)] * 2,
                    })
                else:
                    raise Exception(f'population {Y} not recognized')
            else:
                raise NotImplementedError(f'biophys={biophys} not implemented')

            # population parameters
            populationParameters = deepcopy(params.populationParameters)
            populationParameters['rotation_args'] = deepcopy(params.rotation_args[Y])
            populationParameters['cell_args'] = cellParameters
            
            # some inputs must be lists
            synapseParameters = [
                dict(weight=weights[ii][j],
                     syntype='Exp2Syn',
                     **params.synapseParameters[ii][j])
                for ii in range(len(params.population_names))]
            synapsePositionArguments = [
                params.synapsePositionArguments[ii][j]
                for ii in range(len(params.population_names))]

            # Create kernel approximator object
            kernel = KernelApprox(
                X=params.population_names,
                Y=Y,
                N_X=np.array(params.population_sizes),
                N_Y=N_Y,
                C_YX=np.array(params.connectionProbability[i]),
                cellParameters=cellParameters,
                populationParameters=populationParameters['pop_args'],
                rotationParameters=params.rotation_args[Y],
                multapseFunction=params.multapseFunction,
                multapseParameters=[params.multapseArguments[ii][j] for ii in range(len(params.population_names))],
                delayFunction=params.delayFunction,
                delayParameters=[params.delayArguments[ii][j] for ii in range(len(params.population_names))],
                synapseParameters=synapseParameters,
                synapsePositionArguments=synapsePositionArguments,
                extSynapseParameters=params.extSynapseParameters,
                nu_ext=1000. / params.netstim_interval,
                n_ext=n_ext[j],
                nu_X=mean_nu_X,

            )
            
            # tic-tac
            tac = time()
            H_YX_pred_times_pset = H_YX_pred_times_pset.append(
                pd.DataFrame(data={'step': 'create', 'time_s': tac-tic}, index=[0]), ignore_index=True)

            # make kernel predictions
            H_YX_pred[md5]['{}:{}'.format(Y, X)] = kernel.get_kernel(
                probes=[gauss_cyl_potential, current_dipole_moment],
                Vrest=Vrest, dt=dt, X=X, t_X=t_X, tau=tau,
                g_eff=g_eff,
            )
            
            # tic-tac-toc
            toc = time()
            H_YX_pred_times_pset = H_YX_pred_times_pset.append(
                pd.DataFrame(data={'step': 'simulate', 'time_s': toc-tac}, index=[0]), ignore_index=True)

    H_YX_pred_times_pset = H_YX_pred_times_pset.groupby('step', as_index=False).agg({'time_s': 'sum'})
    H_YX_pred_times = H_YX_pred_times.append(H_YX_pred_times_pset, ignore_index=True)

In [None]:
# Plot spike-LFP and spike-dipole moment kernels
fig = plt.figure(figsize=(figwidth, figwidth))
# create subplots
gs = GridSpec(15, 4)

# Hybrid scheme kernels
axes = np.array([[None] * 4] * 2, dtype=object)
for i in range(2):
    for j in range(4):
        if i == 0:
            if j == 0:
                axes[i, j] = fig.add_subplot(gs[:6, j])
            else:
                axes[i, j] = fig.add_subplot(gs[:6, j], sharey=axes[0, 0], sharex=axes[0, 0])
        else:
            if j == 0:
                axes[i, j] = fig.add_subplot(gs[6, j], sharex=axes[0, 0])
            else:
                axes[i, j] = fig.add_subplot(gs[6, j], sharey=axes[1, 0], sharex=axes[0, 0])

vlims = np.zeros((2, 4))
for k, pset in enumerate(PS2.iter_inner()):
    # sorted json dictionary
    js = json.dumps(pset, sort_keys=True).encode()
    md5 = hashlib.md5(js).hexdigest()

    for i, (X, t, N_X) in enumerate(zip(population_names,
                                        [pset['t_E'], pset['t_I']],
                                        population_sizes)):
        inds = (np.arange(0, tau_trunc // dt + 1)
                + t // dt).astype(int)
        for j, Y in enumerate(population_names):
            for h, (unit, label) in enumerate(
                    zip(['mV', 'nAµm'],
                        ['RecExtElectrode', 'CurrentDipoleMoment'])):
                title = (
                    r'$H_\mathrm{%s %s}(\mathbf{R}, \tau)$'
                    % (Y, X)
                    )
                
                ax = axes[h, i * 2 + j]
                H_YX = H_YX_hybrid[md5]['{}:{}'.format(Y, X)][label][:, int(tau // dt):int((tau + tau_trunc) // dt)]

                # deal with current dipole moment
                if label == 'CurrentDipoleMoment':
                    scaling = 1E-4  # nAum --> nAcm unit conversion
                    unit = 'nAcm'
                    H_YX = H_YX[-1, :].reshape((1, H_YX.shape[1])) # show only z-component
                else:
                    scaling = 1
                
                vlims[h, i * 2 + j] = draw_lineplot(
                    ax,
                    H_YX * scaling,
                    dt=dt,
                    T=(0, tau),
                    scaling_factor=1.,
                    vlimround=(None
                               if vlims[h, i * 2 + j] == 0
                               else vlims[h, i * 2 + j]),
                    label=f"biophys:{pset['biophys']}",
                    scalebar=True if k == 0 else False,
                    unit=unit,
                    ylabels=True,
                    color=f'C{k + PS1.num_conditions()}',
                    ztransform=False
                    )
                if label == 'CurrentDipoleMoment':
                    ax.set_yticklabels(['$P_{}$'.format(x) for x in 'z'])
                if h == 0:
                    ax.set_title(title)
                    ax.set_xlabel('')
                    plt.setp(ax.get_xticklabels(), visible=False)
                if (i * 2 + j) == 0:
                    ax.set_ylabel('')
                else:
                    ax.set_ylabel('')
                    plt.setp(ax.get_yticklabels(), visible=False)
                # if h == 1:
                #     ax.set_xlabel(r'$\tau$ (ms)')
                #else:
                ax.set_xlabel('')
                plt.setp(ax.get_xticklabels(), visible=False)

axes[0, 0].legend(loc=1)
annotate_subplot(axes[0, 0], ncols=4, nrows=1, letter='A', linear_offset=0.02)


# predicted kernels
axes = np.array([[None] * 4] * 2, dtype=object)
for i in range(2):
    for j in range(4):
        if i == 0:
            if j == 0:
                axes[i, j] = fig.add_subplot(gs[8:14, j])
            else:
                axes[i, j] = fig.add_subplot(gs[8:14, j], sharey=axes[0, 0], sharex=axes[0, 0])
        else:
            if j == 0:
                axes[i, j] = fig.add_subplot(gs[14, j], sharex=axes[0, 0])
            else:
                axes[i, j] = fig.add_subplot(gs[14, j], sharey=axes[1, 0], sharex=axes[0, 0])

for k, pset in enumerate(PS2.iter_inner()):
    # sorted json dictionary
    js = json.dumps(pset, sort_keys=True).encode()
    md5 = hashlib.md5(js).hexdigest()

    for i, (X, t, N_X) in enumerate(zip(population_names,
                                        [pset['t_E'], pset['t_I']],
                                        population_sizes)):
        inds = (np.arange(0, tau_trunc // dt + 1)
                + t // dt).astype(int)
        for j, Y in enumerate(population_names):
            for h, (unit, label) in enumerate(
                    zip(['mV', 'nAµm'],
                        ['GaussCylinderPotential', 'CurrentDipoleMoment'])):
                title = (
                    r'$\hat{H}_\mathrm{%s %s}(\mathbf{R}, \tau)$'
                    % (Y, X)
                    )
                
                ax = axes[h, i * 2 + j]
                H_YX = H_YX_pred[md5]['{}:{}'.format(Y, X)][label][:, int(tau // dt):int((tau + tau_trunc) // dt)]

                # deal with current dipole moment
                if label == 'CurrentDipoleMoment':
                    scaling = 1E-4  # nAum --> nAcm unit conversion
                    unit = 'nAcm'
                    H_YX = H_YX[-1, :].reshape((1, H_YX.shape[1])) # show only z-component
                else:
                    scaling = 1
                
                draw_lineplot(
                    ax,
                    H_YX * scaling,
                    dt=dt,
                    T=(0, tau),
                    scaling_factor=1.,
                    vlimround=vlims[h, i * 2 + j],
                    label=f"biophys:{pset['biophys']}",
                    scalebar=True if k == 0 else False,
                    unit=unit,
                    ylabels=True,
                    color=f'C{k + PS1.num_conditions() + PS2.num_conditions()}',
                    ztransform=False
                    )
                if label == 'CurrentDipoleMoment':
                    ax.set_yticklabels(['$P_{}$'.format(x) for x in 'z'])
                if h == 0:
                    ax.set_title(title)
                    ax.set_xlabel('')
                    plt.setp(ax.get_xticklabels(), visible=False)
                if (i * 2 + j) == 0:
                    ax.set_ylabel('')
                else:
                    ax.set_ylabel('')
                    plt.setp(ax.get_yticklabels(), visible=False)
                if h == 1:
                    ax.set_xlabel(r'$\tau$ (ms)')
                else:
                    plt.setp(ax.get_xticklabels(), visible=False)

axes[0, 0].legend(loc=1)
annotate_subplot(axes[0, 0], ncols=4, nrows=1, letter='B', linear_offset=0.02)



if not os.path.isdir('figures'):
    os.mkdir('figures')
fig.savefig(os.path.join('figures', 'figure13.pdf'), bbox_inches='tight')

In [None]:
# Which contribution to the signal to visualize:
data_entry = 'E'  # or 'imem' or 'I'

In [None]:
# compare averaged (hybrid-scheme) spike-signal and deterministic spike-signal kernels
fig = plt.figure(figsize=(figwidth, figwidth / 2))
# create subplots
gs = GridSpec(7, 4)
axes = np.array([[None] * 4] * 2, dtype=object)
for i in range(2):
    for j in range(4):
        if i == 0:
            if j == 0:
                axes[i, j] = fig.add_subplot(gs[:6, j])
            else:
                axes[i, j] = fig.add_subplot(gs[:6, j], sharey=axes[0, 0], sharex=axes[0, 0])
        else:
            if j == 0:
                axes[i, j] = fig.add_subplot(gs[6, j], sharex=axes[0, 0])
            else:
                axes[i, j] = fig.add_subplot(gs[6, j], sharey=axes[1, 0], sharex=axes[0, 0])

# averaged kernels
for k, pset in enumerate(PS2.iter_inner()):
    # sorted json dictionary
    js = json.dumps(pset, sort_keys=True).encode()
    md5 = hashlib.md5(js).hexdigest()

    for i, X in enumerate(params.population_names):
        for j, Y in enumerate([data_entry]):
            # plot responses, iterate over probes
            for h, (unit, probe) in enumerate(
                    zip(['mV', 'nAµm'],
                        ['RecExtElectrode', 'CurrentDipoleMoment'])):
                # averaged kernels:
                title = (
                    r'$H_\mathrm{%s %s}(\mathbf{R}, \tau)$'
                    % (Y, X)
                    )
                
                H_YX = H_YX_hybrid[md5][f'{Y}:{X}'][probe][:, int(tau // dt):int((tau + tau_trunc) // dt)]

                if probe == 'CurrentDipoleMoment':
                    scaling = 1E-4  # nAum --> nAcm unit conversion
                    unit = 'nAcm'
                    H_YX = H_YX[-1, :].reshape((1, H_YX.shape[1])) # show only z-component

                else:
                    scaling = 1
                
                ax = axes[h, i + j]
                draw_lineplot(
                    ax,
                    H_YX * scaling,
                    dt=dt,
                    T=(0, tau),
                    scaling_factor=1.,
                    vlimround=vlims[h, i * 2 + j],
                    label=f"biophys:{pset['biophys']}",
                    scalebar=True if k == 0 else False,
                    unit=unit,
                    ylabels=True,
                    color=f'C{k + PS1.num_conditions()}',
                    ztransform=False
                    )
                if probe == 'CurrentDipoleMoment':
                    ax.set_yticklabels(['$P_z$'])
                if h == 0:
                    ax.set_title(title)
                    ax.set_xlabel('')
                    plt.setp(ax.get_xticklabels(), visible=False)
                if (i + j) == 0:
                    ax.set_ylabel('')
                else:
                    ax.set_ylabel('')
                    plt.setp(ax.get_yticklabels(), visible=False)
                if h == 1:
                    ax.set_xlabel(r'$\tau$ (ms)')
                else:
                    plt.setp(ax.get_xticklabels(), visible=False)                


                
for k, pset in enumerate(PS2.iter_inner()):
    # sorted json dictionary
    js = json.dumps(pset, sort_keys=True).encode()
    md5 = hashlib.md5(js).hexdigest()
    
    # deterministic kernels
    for i, X in enumerate(params.population_names):
        for j, Y in enumerate([data_entry]):
            # plot responses, iterate over probes
            for h, (unit, probe) in enumerate(
                    zip(['mV', 'nAµm'],
                        ['GaussCylinderPotential', 'KernelApproxCurrentDipoleMoment'])):
                title = (
                    r'$\hat{H}_\mathrm{%s %s}(\mathbf{R}, \tau)$'
                    % (Y, X)
                    )

                H_YX = H_YX_pred[md5]['{}:{}'.format(Y, X)][
                    probe.replace('KernelApproxCurrentDipoleMoment', 'CurrentDipoleMoment')][:, int(tau // dt):int((tau + tau_trunc) // dt)]

                if probe == 'KernelApproxCurrentDipoleMoment':
                    scaling = 1E-4  # nAum --> nAcm unit conversion
                    unit = 'nAcm'
                    H_YX = H_YX[-1, :].reshape((1, H_YX.shape[1])) # show only z-component
                else:
                    scaling = 1
                
                
                ax = axes[h, 2 + i + j]
                draw_lineplot(
                    ax,
                    H_YX * scaling,
                    dt=dt,
                    T=(0, tau),
                    scaling_factor=1.,
                    vlimround=(vlims[h, i * 2 + j]),
                    label=f"biophys:{pset['biophys']}",
                    scalebar=True if k == 0 else False,
                    unit=unit,
                    ylabels=True,
                    color=f'C{k + PS1.num_conditions() + PS2.num_conditions()}',
                    ztransform=False
                    )
                if h == 0:
                    ax.set_title(title)
                    ax.set_xlabel('')
                    plt.setp(ax.get_xticklabels(), visible=False)

                if probe == 'KernelApproxCurrentDipoleMoment':
                    ax.set_yticklabels(['$P_z$'])

                plt.setp(ax.get_yticklabels(), visible=False)
                ax.set_ylabel('')
                if h == 1:
                    ax.set_xlabel(r'$\tau$ (ms)')

axes[0, 0].legend(loc=1)
axes[0, 2].legend(loc=1)

if not os.path.isdir('figures'):
    os.mkdir('figures')
fig.savefig(os.path.join('figures', 'figure13.pdf'), bbox_inches='tight')

In [None]:
# Plot hybrid-scheme vs. ground truth data

# create figure    
fig = plt.figure(figsize=(figwidth, figwidth / 2))
fig.subplots_adjust(wspace=0.25)
# create subplots
ncols = PS1.num_conditions()
gs = GridSpec(7, ncols)
axes = np.array([[None] * ncols] * 2, dtype=object)
for i in range(2):
    for j in range(ncols):
        if i == 0:
            if j == 0:
                axes[i, j] = fig.add_subplot(gs[:6, j])
            else:
                axes[i, j] = fig.add_subplot(gs[:6, j], sharey=axes[0, 0], sharex=axes[0, 0])
        else:
            if j == 0:
                axes[i, j] = fig.add_subplot(gs[6, j], sharex=axes[0, 0])
            else:
                axes[i, j] = fig.add_subplot(gs[6, j], sharey=axes[1, 0], sharex=axes[0, 0])

for i, ax in enumerate(axes[0, :]):
    annotate_subplot(ax, ncols=ncols, nrows=3, letter='ABCDEFGHIJ'[i], linear_offset=0.02)
    
# compare summed extracellular signals
for j, (fname, ylabel, probe, unit, vlimround) in enumerate(zip(
    ['RecExtElectrode.h5', 'CurrentDipoleMoment.h5'],
    [r'$V_\mathrm{e}$', r'$\mathbf{P}$'],
    ['GaussCylinderPotential', 'CurrentDipoleMoment'],
    ['mV', 'nAµm'],
    [2**-1, 2**4])):

    if probe == 'CurrentDipoleMoment':
        scaling = 1E-4  # nAum --> nAcm unit conversion
        unit = 'nAcm'
    else:
        scaling = 1

    # ground truth
    with h5py.File(os.path.join(OUTPUTPATH_REAL, fname),
                   'r') as f:
        data = f['data'][()]
        if probe == 'CurrentDipoleMoment':
            data = data[-1, :].reshape((1, data.shape[1]))
        
        for ax in axes[j, :]:
            label = 'ground truth'
            draw_lineplot(ax,
                          ss.decimate(data[data_entry], q=decimate_ratio,
                                      zero_phase=True) * scaling,
                          dt=dt * decimate_ratio,
                          T=(TRANSIENT, TRANSIENT+200),
                          scaling_factor=1.,
                          vlimround=vlimround,
                          label=label,
                          scalebar=True,
                          unit=unit,
                          ylabels=True,
                          color='k',
                          ztransform=True
                          )
            if j == 0:
                # ax.set_title(label)
                ax.set_xlabel('')
                plt.setp(ax.get_xticklabels(), visible=False)
            # ax.set_ylabel(ylabel)
            ax.set_ylabel('')
        

    for i, pset in enumerate(PS1.iter_inner()):
        js = json.dumps(pset, sort_keys=True).encode()
        md5 = hashlib.md5(js).hexdigest()
        
        OUTPUTPATH = os.path.join('output', md5)
        with h5py.File(os.path.join(OUTPUTPATH, fname), 'r') as f:
            data = f['data'][()]

        if probe == 'CurrentDipoleMoment':
            data = data[-1, :].reshape((1, data.shape[1]))
            
        ax = axes[j, i]
        label = ''
        for h, (key, value) in enumerate(pset.items()):
            if key.rfind('weight') >= 0 or key.rfind('n_ext') >= 0 or key.rfind('i_syn') >= 0 or key.rfind('perseg_Vrest') >= 0:
                continue
            if h > 5:
                label += '\n'
            label += '{}:{}'.format(key, value)
        
        draw_lineplot(ax,
                      ss.decimate(data[data_entry], q=decimate_ratio,
                                  zero_phase=True) * scaling,
                      dt=dt * decimate_ratio,
                      T=(TRANSIENT, TRANSIENT+200),
                      scaling_factor=1.,
                      vlimround=vlimround,
                      label=label,
                      scalebar=False,
                      unit=unit,
                      ylabels=True if i // 2 > 0 else False,
                      color=f'C{i}',
                      ztransform=True
                      )
        if probe == 'CurrentDipoleMoment':
            ax.set_yticklabels(['$P_z$'])

        if j == 0:
            ax.set_xlabel('')
            plt.setp(ax.get_xticklabels(), visible=False)
            ax.legend(loc=2)
        if i > 0:
            plt.setp(ax.get_yticklabels(), visible=False)
        ax.set_ylabel('')

fig.savefig(os.path.join('figures', 'figure12.pdf'), bbox_inches='tight')

In [None]:
# Compute reconstructed signals as the sum over convolutions
# phi(r, t) = sum_X sum_Y (nu_X*H_YX)(r, t)
# using kernels obtained either via the hybrid scheme and direct method
all_kernel_predictions = []
for j, (fname, ylabel, probe) in enumerate(zip(
    ['RecExtElectrode.h5', 'CurrentDipoleMoment.h5'],
    [r'$V_\mathrm{e}$', r'$\mathbf{P}$'],
    ['GaussCylinderPotential', 'CurrentDipoleMoment'],
    )):
       
    with h5py.File(os.path.join(OUTPUTPATH_REAL, fname),
                   'r') as f:
        data = f['data'][()]
   
    # hybrid scheme kernels
    kernel_predictions = []  # container
    for k, pset in enumerate(PS2.iter_inner()):
        # sorted json dictionary
        js = json.dumps(pset, sort_keys=True).encode()
        md5 = hashlib.md5(js).hexdigest()

        label = ''
        for h, (key, value) in enumerate(pset.items()):
            if key.rfind('weight') >= 0 or key.rfind('n_ext') >= 0 or key.rfind('i_syn') >= 0 or key.rfind('t_E') >= 0 or key.rfind('t_I') >= 0 or key.rfind('perseg_Vrest') >= 0:
                continue
            if h > 5:
                label += '\n'
            label += '{}:{}'.format(key, value)

        prediction_label = r'$\sum_X \sum_Y \nu_X \ast H_\mathrm{YX}$' + '\n' + label
        data = None
        for i, (X, t, N_X) in enumerate(zip(population_names,
                                            [pset['t_E'], pset['t_I']],
                                            population_sizes)):
            for Y in [data_entry]:
                H_YX = H_YX_hybrid[md5][f'{Y}:{X}'][probe.replace('GaussCylinderPotential', 'RecExtElectrode')]
                if data is None:
                    data = np.zeros((H_YX.shape[0], nu_X[X].size))
                for h, h_YX in enumerate(H_YX):
                    data[h, :] = data[h, :] + np.convolve(nu_X[X], h_YX,
                                                          'same')
                        
        kernel_predictions.append((prediction_label, data))
        
        
    # compare biophysical variants using predicted kernels
    for k, pset in enumerate(PS2.iter_inner()):
        # sorted json dictionary
        js = json.dumps(pset, sort_keys=True).encode()
        md5 = hashlib.md5(js).hexdigest()

        label = ''
        for h, (key, value) in enumerate(pset.items()):
            if key.rfind('weight') >= 0 or key.rfind('n_ext') >= 0 or key.rfind('i_syn') >= 0 or key.rfind('t_E') >= 0 or key.rfind('t_I') >= 0 or key.rfind('perseg_Vrest') >= 0:
                continue
            if h > 5:
                label += '\n'
            label += '{}:{}'.format(key, value)

        prediction_label = r'$\sum_X \sum_Y \nu_X \ast \hat{H}_\mathrm{YX}$' + '\n' + label
        
        data = None
        for i, (X, N_X) in enumerate(zip(population_names,
                                         population_sizes)):
            for Y in [data_entry]:
                if data is None:
                    data = np.zeros((H_YX_pred[md5]['{}:{}'.format(Y, X)][probe].shape[0],
                                     nu_X[X].size))
                for h, h_YX in enumerate(H_YX_pred[md5]['{}:{}'.format(Y, X)][probe]):
                    data[h, :] = data[h, :] + np.convolve(nu_X[X], h_YX,
                                                          'same')

        kernel_predictions.append((prediction_label, data))

    all_kernel_predictions.append(kernel_predictions)

In [None]:
## plot kernel predictions

# create figure    
fig = plt.figure(figsize=(figwidth, figwidth / 2))
fig.subplots_adjust(wspace=0.15)
# create subplots
ncols = PS1.num_conditions()
gs = GridSpec(7, ncols, wspace=0.25)
axes = np.array([[None] * ncols] * 2, dtype=object)
for i in range(2):
    for j in range(ncols):
        if i == 0:
            if j == 0:
                axes[i, j] = fig.add_subplot(gs[:6, j])
            else:
                axes[i, j] = fig.add_subplot(gs[:6, j], sharey=axes[0, 0], sharex=axes[0, 0])
        else:
            if j == 0:
                axes[i, j] = fig.add_subplot(gs[6, j], sharex=axes[0, 0])
            else:
                axes[i, j] = fig.add_subplot(gs[6, j], sharey=axes[1, 0], sharex=axes[0, 0])

for i, ax in enumerate(axes[0, :]):
    annotate_subplot(ax, ncols=ncols, nrows=3, letter='EFGHIJ'[i], linear_offset=0.02)
    
# compare summed extracellular signals
for j, (fname, ylabel, probe, unit, vlimround) in enumerate(zip(
    ['RecExtElectrode.h5', 'CurrentDipoleMoment.h5'],
    [r'$V_\mathrm{e}$', r'$\mathbf{P}$'],
    ['GaussCylinderPotential', 'CurrentDipoleMoment'],
    ['mV', 'nAµm'],
    [2**-1, 2**4])):

    if probe == 'CurrentDipoleMoment':
        scaling = 1E-4  # nAum --> nAcm unit conversion
        unit = 'nAcm'
    else:
        scaling = 1
        
    with h5py.File(os.path.join(OUTPUTPATH_REAL, fname),
                   'r') as f:
        data = f['data'][()]
        if probe == 'CurrentDipoleMoment':
            data = data[-1, :].reshape((1, data.shape[1]))

        for ax in axes[j, :]:

            label = 'ground truth'
            draw_lineplot(ax,
                          ss.decimate(data['imem'], q=decimate_ratio,
                                      zero_phase=True) * scaling,
                          dt=dt * decimate_ratio,
                          T=(TRANSIENT, TRANSIENT+200),
                          scaling_factor=1.,
                          vlimround=vlimround,
                          label=label,
                          scalebar=True,
                          unit=unit,
                          ylabels=True,
                          color='k',
                          ztransform=True
                          )
            if j == 0:
                ax.set_title(label)
                ax.set_xlabel('')
                plt.setp(ax.get_xticklabels(), visible=False)
            # ax.set_ylabel(ylabel)
            ax.set_ylabel('')
        

    # compare biophysical variants using measured kernels
    for k, pset in enumerate(PS2.iter_inner()):
        # sorted json dictionary
        js = json.dumps(pset, sort_keys=True).encode()
        md5 = hashlib.md5(js).hexdigest()

        # data = None
        ax = axes[j, k]

        label = ''
        for h, (key, value) in enumerate(pset.items()):
            if key.rfind('weight') >= 0 or key.rfind('n_ext') >= 0 or key.rfind('i_syn') >= 0 or key.rfind('t_E') >= 0 or key.rfind('t_I') >= 0 or key.rfind('perseg_Vrest') >= 0:
                continue
            if h > 5:
                label += '\n'
            label += '{}:{}'.format(key, value)

        
        if probe == 'CurrentDipoleMoment':
            scaling = 1E-4  # nAum --> nAcm unit conversion
            unit = 'nAcm'
        else:
            scaling = 1
        
        title = 'ground truth vs.\n' + r'$\sum_X \sum_{Y=\{E\}} \left( \nu_X \ast H_\mathrm{YX} \right) (\mathbf{R}, t)$'

        prediction_label, data = all_kernel_predictions[j][k]

        if probe == 'CurrentDipoleMoment':
            data = data[-1, :].reshape((1, data.shape[1]))
        
        draw_lineplot(ax,
                      ss.decimate(data, q=decimate_ratio,
                                  zero_phase=True) * scaling,
                      dt=dt * decimate_ratio,
                      T=(TRANSIENT, TRANSIENT+200),
                      scaling_factor=1.,
                      vlimround=vlimround,
                      label=label,
                      scalebar=False,
                      unit=unit,
                      ylabels=True,
                      color=f'C{k + PS1.num_conditions()}',
                      ztransform=True
                      )

        if probe == 'CurrentDipoleMoment':
            ax.set_yticklabels(['$P_z$'])

        if j == 0:
            ax.set_title(title, va='bottom')
            ax.set_xlabel('')
            plt.setp(ax.get_xticklabels(), visible=False)
            ax.legend(loc=2)
        if k == 0:
            ax.set_ylabel(ylabel)
        else:
            plt.setp(ax.get_yticklabels(), visible=False)
        ax.set_ylabel('')
        
        
    # compare biophysical variants using predicted kernels
    for k, pset in enumerate(PS2.iter_inner()):

        # sorted json dictionary
        js = json.dumps(pset, sort_keys=True).encode()
        md5 = hashlib.md5(js).hexdigest()

        # data = None
        ax = axes[j, k + PS2.num_conditions()]

        label = ''
        for h, (key, value) in enumerate(pset.items()):
            if key.rfind('weight') >= 0 or key.rfind('n_ext') >= 0 or key.rfind('i_syn') >= 0 or key.rfind('t_E') >= 0 or key.rfind('t_I') >= 0 or key.rfind('perseg_Vrest') >= 0:
                continue
            if h > 5:
                label += '\n'
            label += '{}:{}'.format(key, value)
        
        title = 'ground truth vs.\n' + r'$\sum_X \sum_{Y=\{E\}} \left( \nu_X \ast \hat{H}_\mathrm{YX} \right) (\mathbf{R}, t)$'
        
        if probe == 'CurrentDipoleMoment':
            scaling = 1E-4  # nAum --> nAcm unit conversion
            unit = 'nAcm'
        else:
            scaling = 1

        prediction_label, data = all_kernel_predictions[j][k + PS2.num_conditions()]
        if probe == 'CurrentDipoleMoment':
            data = data[-1, :].reshape((1, data.shape[1]))

        draw_lineplot(ax,
                      ss.decimate(data, q=decimate_ratio,
                                  zero_phase=True) * scaling,
                      dt=dt * decimate_ratio,
                      T=(TRANSIENT, TRANSIENT+200),
                      scaling_factor=1.,
                      vlimround=vlimround,
                      label=label,
                      scalebar=False,
                      unit=unit,
                      ylabels=True,
                      color=f'C{k + PS1.num_conditions() + PS2.num_conditions()}',
                      ztransform=True
                      )

        if probe == 'CurrentDipoleMoment':
            ax.set_yticklabels(['$P_z$'])

        plt.setp(ax.get_yticklabels(), visible=False)
        ax.set_ylabel('')
        if j == 0:
            ax.set_title(title, va='bottom')
            ax.set_xlabel('')
            plt.setp(ax.get_xticklabels(), visible=False)
            ax.legend(loc=2)

fig.savefig(os.path.join('figures', 'figure14.pdf'), bbox_inches='tight')

In [None]:
# correlation coefficients between approximations and GT,
# as well as r_STD (scaling)

# compare summed extracellular potentials
if True:
    fname = 'RecExtElectrode.h5'
    probe = 'GaussCylinderPotential'
    unit = 'mv'
    vlimround = 2**-1
    kernel_predictions = all_kernel_predictions[0]
else:
    fname = 'CurrentDipoleMoment.h5'
    probe = 'CurrentDipoleMoment'
    unit = 'nAµm'
    vlimround = 2**18
    kernel_predictions = all_kernel_predictions[1]

    
with h5py.File(os.path.join(OUTPUTPATH_REAL, fname), 'r') as f:
    # raw data
    data_gt = f['data'][()][:, int(TRANSIENT // dt):]
    # low pass filtered data
    data_gt_lp = ss.sosfiltfilt(sos_ellip, f['data'][()][data_entry])[:, int(TRANSIENT // dt):]

    
    
# container
df = pd.DataFrame(columns=['R2', 'STD/STD', 'label', 'signal', 'channel', 'color', 'marker'])

for i, pset in enumerate(PS1.iter_inner()):
    js = json.dumps(pset, sort_keys=True).encode()
    md5 = hashlib.md5(js).hexdigest()

    OUTPUTPATH = os.path.join('output', md5)

    title = ''
    for j, (key, value) in enumerate(pset.items()):
        if key.rfind('weight') >= 0 or key.rfind('n_ext') >= 0 or key.rfind('i_syn') >= 0 or key.rfind('perseg_Vrest') >= 0:
            continue
        title += '{}:{}\n'.format(key, value)
    title = title.removesuffix('\n')

    with h5py.File(os.path.join(OUTPUTPATH, fname),
                   'r') as f:
        data = f['data'][()][:, int(TRANSIENT // dt):]
        data_lp = ss.sosfiltfilt(sos_ellip, f['data'][()][data_entry])[:, int(TRANSIENT // dt):]

    # Pearson correlation coefficients
    n_ch = data.shape[0]
    Pcc = np.corrcoef(data_gt[data_entry],
                      data[data_entry])[n_ch:, :n_ch].diagonal()
    Pcc_lp = np.corrcoef(data_gt_lp,
                         data_lp)[n_ch:, :n_ch].diagonal()

    # STD(y) / STD(x)
    scaling = data[data_entry].std(axis=-1) / data_gt[data_entry].std(axis=-1)
    scaling_lp = data_lp.std(axis=-1) / data_gt_lp.std(axis=-1)

    for ch in range(n_ch):
        df = df.append([
                pd.DataFrame(
                    data={'R2': Pcc[ch]**2, 'STD/STD': scaling[ch], 
                          'label': title, 'signal': 'raw', 'channel': (ch + 1),
                          'color': f'C{i}', 'marker': markers[i // 2]}, 
                    index=[0]
                ),
                pd.DataFrame(
                    data={'R2': Pcc_lp[ch]**2, 'STD/STD': scaling_lp[ch], 
                          'label': title, 'signal': 'LP', 'channel': (ch + 1),
                          'color': f'C{i}', 'marker': markers[i // 2]}, 
                    index=[0]
                ),
                ],
            ignore_index=True)
        
for j, (title, data) in enumerate(kernel_predictions):
    data_lp = ss.sosfiltfilt(sos_ellip, data)[:, int(TRANSIENT // dt):]
    data = data[:, int(TRANSIENT // dt):]
    

    # Pearson correlation coefficients
    Pcc = np.corrcoef(data_gt[data_entry],
                      data)[n_ch:, :n_ch].diagonal()
    Pcc_lp = np.corrcoef(data_gt_lp,
                         data)[n_ch:, :n_ch].diagonal()

    # STD(y) / STD(x)
    scaling = data.std(axis=-1) / data_gt[data_entry].std(axis=-1)
    scaling_lp = data_lp.std(axis=-1) / data_gt_lp.std(axis=-1)

    for ch in range(n_ch):
        df = df.append([
                pd.DataFrame(
                    data={'R2': Pcc[ch]**2, 'STD/STD': scaling[ch], 
                          'label': title, 'signal': 'raw', 'channel': (ch + 1),
                          'color': f'C{j + PS1.num_conditions()}', 'marker': markers[j // 2 + 2]}, 
                    index=[0]
                ),
                pd.DataFrame(
                    data={'R2': Pcc_lp[ch]**2, 'STD/STD': scaling_lp[ch], 
                          'label': title, 'signal': 'LP', 'channel': (ch + 1),
                          'color': f'C{j + PS1.num_conditions()}', 'marker': markers[j // 2 + 2]}, 
                    index=[0]
                ),
                ],
            ignore_index=True)

In [None]:
df

In [None]:
# aggregate mean and standard deviations of R2 and STD ratio
df_agg = df.groupby(['signal', 'label', 'color', 'marker'], as_index=False
                    ).agg({'R2': ['mean', 'median', methods.quant10, methods.quant90], 
                           'STD/STD': ['mean', 'median', methods.quant10, methods.quant90]})
df_agg

In [None]:
# R2(P_z) vs. STD(P_z^approx) / STD(P_z)
fname = 'CurrentDipoleMoment.h5'
kernel_predictions = all_kernel_predictions[1]

skipfirst = len(PS2['biophys'])

# container
df_Pz = pd.DataFrame(columns=['R2', 'STD/STD', 'label', 'signal', 'color', 'marker'])

with h5py.File(os.path.join(OUTPUTPATH_REAL, fname), 'r') as f:
    data_gt = f['data'][()][:, int(TRANSIENT // dt):]
    data_gt_lp = ss.sosfiltfilt(sos_ellip, f['data'][()][data_entry])[:, int(TRANSIENT // dt):]

for i, pset in enumerate(PS1.iter_inner()):
    js = json.dumps(pset, sort_keys=True).encode()
    md5 = hashlib.md5(js).hexdigest()

    OUTPUTPATH = os.path.join('output', md5)

    title = ''
    for j, (key, value) in enumerate(pset.items()):
        if key.rfind('weight') >= 0 or key.rfind('n_ext') >= 0 or key.rfind('i_syn') >= 0:
            continue
        title += '{}:{}\n'.format(key, value)
    title = title.removesuffix('\n')

    with h5py.File(os.path.join(OUTPUTPATH, fname),
                   'r') as f:
        data = f['data'][()][:, int(TRANSIENT // dt):]
        data_lp = ss.sosfiltfilt(sos_ellip, f['data'][()][data_entry])[:, int(TRANSIENT // dt):]

    # Pearson correlation coefficients
    n_ch = data.shape[0]
    Pcc = np.corrcoef(data_gt[data_entry],
                      data[data_entry])[n_ch:, :n_ch].diagonal()
    Pcc_lp = np.corrcoef(data_gt_lp,
                         data_lp)[n_ch:, :n_ch].diagonal()
    
    scaling = data[data_entry].std(axis=-1) / data_gt[data_entry].std(axis=-1) 
    scaling_lp = data_lp.std(axis=-1) / data_gt_lp.std(axis=-1) 

    df_Pz = df_Pz.append([
            pd.DataFrame(
                data={'R2': Pcc[2]**2, 'STD/STD': scaling[2], 'label': title, 'signal': 'raw',
                      'color': f'C{i}', 'marker': markers[i // 2]}, 
                index=[0]
            ),
            pd.DataFrame(
                data={'R2': Pcc_lp[2]**2, 'STD/STD': scaling_lp[2], 'label': title, 'signal': 'LP',
                      'color': f'C{i}', 'marker': markers[i // 2]}, 
                index=[0]
            ),
            ],
        ignore_index=True)

for j, (title, data) in enumerate(kernel_predictions):
    data_lp = ss.sosfiltfilt(sos_ellip, data)[:, int(TRANSIENT // dt):]
    data = data[:, int(TRANSIENT // dt):]

    # Pearson correlation coefficients
    Pcc = np.corrcoef(data_gt[data_entry],
                      data)[n_ch:, :n_ch].diagonal()
    Pcc_lp = np.corrcoef(data_gt_lp,
                         data_lp)[n_ch:, :n_ch].diagonal()
    
    scaling = data.std(axis=-1) / data_gt[data_entry].std(axis=-1)
    scaling_lp = data_lp.std(axis=-1) / data_gt_lp.std(axis=-1) 
    
    df_Pz = df_Pz.append([
            pd.DataFrame(
                data={'R2': Pcc[2]**2, 'STD/STD': scaling[2], 'label': title, 'signal': 'raw',
                      'color': f'C{j + PS1.num_conditions()}', 'marker': markers[j // 2 + 2]}, 
                index=[0]
            ),
            pd.DataFrame(
                data={'R2': Pcc_lp[2]**2, 'STD/STD': scaling_lp[2], 'label': title, 'signal': 'LP',
                      'color': f'C{j + PS1.num_conditions()}', 'marker': markers[j // 2 + 2]}, 
                index=[0]
            ),
            ],
        ignore_index=True)

In [None]:
df_Pz

In [None]:
# combine R2 and r_STD into one figure
gs = GridSpec(4, len([1 for i in PS1.iter_inner()]) + len([1 for i in PS2.iter_inner()] * 2),
              wspace=0.3, hspace=0.4)
gs1 = GridSpec(4, 4, wspace=0.4, hspace=0.4)

markersize = 8

fig = plt.figure(figsize=(figwidth, figwidth))
# fig.subplots_adjust(hspace=0.5)

axes = []

titles = []
for i, pset in enumerate(PS1.iter_inner()):
    js = json.dumps(pset, sort_keys=True).encode()
    md5 = hashlib.md5(js).hexdigest()

    OUTPUTPATH = os.path.join('output', md5)

    title = ''
    for j, (key, value) in enumerate(pset.items()):
        if key.rfind('weight') >= 0 or key.rfind('n_ext') >= 0 or key.rfind('i_syn') >= 0 or key.rfind('perseg_Vrest') >= 0:
            continue
        title += '{}:{}\n'.format(key, value)
    title = title.removesuffix('\n')
    titles += [title]
    
titles += [kp[0] for kp in kernel_predictions]

for i, title in enumerate(titles):
    for j, (x, xlabel) in enumerate(zip(['R2', 'STD/STD'], [r'$R^2$ (-)', r'$r_\mathrm{STD}$ (-)'])):
        if i == 0:
            ax = fig.add_subplot(gs[j, i])
            axes.append(ax)
        else:
            ax = fig.add_subplot(gs[j, i], sharex=axes[j])
        ax.invert_yaxis()
        if j == 0:
            ax.set_title(title)
        for signal in ['raw', 'LP']:
            
            d = df[(df['label'] == title) & (df['signal'] == signal)][x].values
            channels = df[(df['label'] == title) & (df['signal'] == signal)]['channel'].values
            color, marker = df[(df['label'] == title) & (df['signal'] == signal)][['color', 'marker']].values[0]
            
            ax.plot(d, channels, 
                    marker=marker,
                    color=color,
                    mec=color,
                    mfc=color if signal=='raw' else 'w',
                    ms=markersize, 
                    label=signal,
                    clip_on=False)
            
        if j == 0:
            ax.legend()
            ax.set_xlim(0, 1)
        ax.set_yticks(channels)

        if i == 0:
            ax.set_yticklabels(['ch.{}'.format(c) for c in channels])
        else:
            ax.set_yticklabels([])
        ax.set_xlabel(xlabel)
        remove_axis_junk(ax)
        if i == 0:
            annotate_subplot(ax, ncols=8, nrows=4, letter='AB'[j], linear_offset=0.02)

            
# aggregate numbers
axes = []
for i in range(4):
    ax = fig.add_subplot(gs1[2, i])
    axes.append(ax)

for i, ax in enumerate(axes):
    annotate_subplot(ax, ncols=4, nrows=4, letter='CDEFGHIJ'[i], linear_offset=0.02)
labels = labels, index = np.unique(df['label'], return_index=True)
labels = labels[np.argsort(index)]
ax = axes[0]
for i, label in enumerate(labels):
    for j, signal in enumerate(['raw', 'LP']):
        
        d = df_agg[(df_agg['label'] == label) & (df_agg['signal'] == signal)]
        color, marker = d[['color', 'marker']].values[0]
        
        ax.errorbar(x=d['R2']['median'], 
                    y=d['STD/STD']['median'],
                    xerr=np.c_[d['R2']['median'] - d['R2']['quant10'], 
                               d['R2']['quant90'] - d['R2']['median']].T,
                    yerr=np.c_[d['STD/STD']['median'] - d['STD/STD']['quant10'], 
                               d['STD/STD']['quant90'] - d['STD/STD']['median']].T,
                    fmt=marker, 
                    ms=markersize,
                    ecolor=color,
                    mec=color,
                    mfc='w' if signal == 'LP' else color,
                    label=label if j == 0 else '_nolegend_',
                    clip_on=False)

        
ax = axes[1]
for i, label in enumerate(labels):
    if i < 2:
        pass
    else:
        for j, signal in enumerate(['raw', 'LP']):
            d = df_agg[(df_agg['label'] == label) & (df_agg['signal'] == signal)]
            color, marker = d[['color', 'marker']].values[0]
        
            ax.errorbar(x=d['R2']['median'], 
                        y=d['STD/STD']['median'],
                        xerr=np.c_[d['R2']['median'] - d['R2']['quant10'], 
                                   d['R2']['quant90'] - d['R2']['median']].T,
                        yerr=np.c_[d['STD/STD']['median'] - d['STD/STD']['quant10'], 
                                   d['STD/STD']['quant90'] - d['STD/STD']['median']].T,
                        fmt=marker, 
                        ms=markersize,
                        ecolor=color,
                        mec=color,
                        mfc='w' if signal == 'LP' else color,
                        label=label if j == 0 else '_nolegend_',
                        clip_on=False)

            
ax = axes[2]
for j, signal in enumerate(['raw', 'LP']):
    for i, (r2, covvar, label, _, color, marker) in enumerate(df_Pz[df_Pz['signal'] == signal].values):
        ax.plot(r2, covvar, marker, 
                label=r'' + label, ms=markersize, 
                mfc=color if j == 0 else 'w', 
                mec=color,
                clip_on=False)

        
        
        
ax = axes[3]
for j, signal in enumerate(['raw', 'LP']):
    for i, (r2, covvar, label, _, color, marker) in enumerate(df_Pz[df_Pz['signal'] == signal].values):
        if label.rfind('g_eff:True') > 0:
            ax.plot(r2, covvar, marker, 
                    label=r'' + label, ms=markersize, 
                    mfc=color if j == 0 else 'w', 
                    mec=color,
                    clip_on=False)
for ax in axes:
    ax.set_xlabel('$R^2$ (-)')
    remove_axis_junk(ax)

axes[0].set_ylabel(r'$r_\mathrm{STD}$ (-)')


# create axes for PSD and coherence
axes = []
for i in range(2):
    ax = fig.add_subplot(gs1[3, i])
    axes.append(ax)
    remove_axis_junk(ax)
    annotate_subplot(ax, ncols=4, nrows=4, letter='GHIJ'[i], linear_offset=0.02)


# GT PSD vs approx (P_z, P_z_approx)
ax = axes[0]
f_max = 1000

with h5py.File(os.path.join(OUTPUTPATH_REAL, fname), 'r') as f:
    data_gt = f['data'][()][:, int(TRANSIENT // dt):]
    
    P_xx, freqs = methods.csd(data_gt[data_entry][2], y=None, Fs=Fs, NFFT=NFFT, noverlap=noverlap, detrend=detrend, library='scipy')

f_inds = freqs <= f_max
ax.loglog(freqs[f_inds], P_xx[f_inds], 'k')


for i, pset in enumerate(PS1.iter_inner()):
    js = json.dumps(pset, sort_keys=True).encode()
    md5 = hashlib.md5(js).hexdigest()

    OUTPUTPATH = os.path.join('output', md5)

    title = ''
    for j, (key, value) in enumerate(pset.items()):
        if key.rfind('weight') >= 0 or key.rfind('n_ext') >= 0:
            continue
        title += '{}:{}\n'.format(key, value)
    # title += '({})'.format(md5[:6])
    title = title.removesuffix('\n')

    with h5py.File(os.path.join(OUTPUTPATH, fname),
                   'r') as f:
        data = f['data'][()][:, int(TRANSIENT // dt):]

        P_xx, _ = methods.csd(data['imem'][2], y=None, Fs=Fs, NFFT=NFFT, noverlap=noverlap, detrend=detrend, library='scipy')
        ax.loglog(freqs[f_inds], P_xx[f_inds], f'C{i}')
        
for j, (title, data) in enumerate(kernel_predictions):
    data = data[:, int(TRANSIENT // dt):]
    P_xx, _ = methods.csd(data[2], y=None, Fs=Fs, NFFT=NFFT, noverlap=noverlap, detrend=detrend, library='scipy')
    ax.loglog(freqs[f_inds], P_xx[f_inds], f'C{j + PS1.num_conditions()}')
        

axis = ax.axis(ax.axis('tight'))
ax.set_xlabel(r'$f$ (Hz)') 
ax.set_ylabel('$S_{P_z P_z}(f)$ ($(\mathrm{nAcm})^2/\mathrm{Hz}$)')
    

    
# Normalized coherence(P_z, P_z_approx)

ax = axes[1]
with h5py.File(os.path.join(OUTPUTPATH_REAL, fname), 'r') as f:
    data_gt = f['data'][()][:, int(TRANSIENT // dt):]
    
for i, pset in enumerate(PS1.iter_inner()):
    js = json.dumps(pset, sort_keys=True).encode()
    md5 = hashlib.md5(js).hexdigest()

    OUTPUTPATH = os.path.join('output', md5)

    title = ''
    for j, (key, value) in enumerate(pset.items()):
        if key.rfind('weight') >= 0 or key.rfind('n_ext') >= 0:
            continue
        title += '{}:{}\n'.format(key, value)
    # title += '({})'.format(md5[:6])
    title = title.removesuffix('\n')

    with h5py.File(os.path.join(OUTPUTPATH, fname),
                   'r') as f:
        data = f['data'][()][:, int(TRANSIENT // dt):]
    
    # compute coherence
    gamma_xy, freqs = methods.coherence(data_gt[data_entry][2], data['imem'][2], 
                                        Fs=Fs, NFFT=NFFT, noverlap=noverlap, 
                                        detrend=detrend,
                                        library='scipy')
    f_inds = freqs <= f_max
    ax.semilogx(freqs[f_inds], gamma_xy[f_inds], label=title)
                         
for j, (title, data) in enumerate(kernel_predictions):

    data = data[:, int(TRANSIENT // dt):]

    # compute coherence
    gamma_xy, _ = methods.coherence(data_gt[data_entry][2], data[2], 
                                    Fs=Fs, NFFT=NFFT, noverlap=noverlap, 
                                    detrend=detrend, 
                                    library='scipy')
    ax.semilogx(freqs[f_inds], gamma_xy[f_inds], label=title)
                                       
                                       
axis = ax.axis(ax.axis('tight'))
ax.set_xlabel(r'$f$ (Hz)')
ax.set_ylabel(r'$C_{P_z\hat{P_z}}(f)$ (-)')

fig.savefig(os.path.join('figures', 'figure15.pdf'), bbox_inches='tight')

In [None]:
# Cross correlation functions
fname = 'CurrentDipoleMoment.h5'
kernel_predictions = all_kernel_predictions[1]

fig, axes = plt.subplots(1, 2, figsize=(16, 8))

ax = axes[0]
with h5py.File(os.path.join(OUTPUTPATH_REAL, fname), 'r') as f:
    data_gt = f['data'][()][:, int(TRANSIENT // dt):]

# set x-axis limits
max_lag = 20
lag = (np.arange(data.shape[1]) - data.shape[1] // 2) * params.networkParameters['dt']
lag_inds = (lag >= -max_lag) & (lag <= max_lag)

for i, pset in enumerate(PS1.iter_inner()):
    js = json.dumps(pset, sort_keys=True).encode()
    md5 = hashlib.md5(js).hexdigest()

    OUTPUTPATH = os.path.join('output', md5)

    title = ''
    for j, (key, value) in enumerate(pset.items()):
        if key.rfind('weight') >= 0 or key.rfind('n_ext') >= 0:
            continue
        title += '{}:{}\n'.format(key, value)
    # title += '({})'.format(md5[:6])
    title = title.removesuffix('\n')

    with h5py.File(os.path.join(OUTPUTPATH, fname),
                   'r') as f:
        data = f['data'][()][:, int(TRANSIENT // dt):]

    # cross correlation function
    xcorr = np.correlate(methods.zscore(data_gt[data_entry][2]), 
                         methods.zscore(data[data_entry][2]), 'same') / data_gt[data_entry][2].size
    
    ax.plot(lag[lag_inds], xcorr[lag_inds], label=title)
                         
for j, (title, data) in enumerate(kernel_predictions):

    data = data[:, int(TRANSIENT // dt):]

    # cross correlation function
    xcorr = np.correlate(methods.zscore(data_gt[data_entry][2]), 
                         methods.zscore(data[2]), 'same') / data_gt[data_entry][2].size
    
    ax.plot(lag[lag_inds], xcorr[lag_inds], label=title)

ax.legend(ncol=2)
axis = ax.axis(ax.axis('tight'))
ax.vlines(0, axis[2], axis[3], ls=':', color='k')
ax.set_xlabel(r'$\tau$ (ms)', labelpad=0)
ax.set_ylabel(r'$\rho_{\psi\hat{\psi}}(\tau)$', labelpad=0)


# Normalized coherence(P_z, P_z_approx)
f_max = 1000

ax = axes[1]
with h5py.File(os.path.join(OUTPUTPATH_REAL, fname), 'r') as f:
    data_gt = f['data'][()][:, int(TRANSIENT // dt):]

for i, pset in enumerate(PS1.iter_inner()):
    js = json.dumps(pset, sort_keys=True).encode()
    md5 = hashlib.md5(js).hexdigest()

    OUTPUTPATH = os.path.join('output', md5)

    title = ''
    for j, (key, value) in enumerate(pset.items()):
        if key.rfind('weight') >= 0 or key.rfind('n_ext') >= 0:
            continue
        title += '{}:{}\n'.format(key, value)
    # title += '({})'.format(md5[:6])
    title = title.removesuffix('\n')
    
    with h5py.File(os.path.join(OUTPUTPATH, fname),
                   'r') as f:
        data = f['data'][()][:, int(TRANSIENT // dt):]
    
    # compute coherence
    gamma_xy, freqs = methods.coherence(data_gt[data_entry][2], data[data_entry][2], 
                                        Fs=Fs, NFFT=NFFT, noverlap=noverlap,
                                        library='mpl')
    f_inds = freqs <= f_max
    ax.plot(freqs[f_inds], gamma_xy[f_inds], label=title)
                         
for j, (title, data) in enumerate(kernel_predictions):

    data = data[:, int(TRANSIENT // dt):]

    # compute coherence
    gamma_xy, _ = methods.coherence(data_gt[data_entry][2], data[2], 
                                    Fs=Fs, NFFT=NFFT, noverlap=noverlap,
                                    library='mpl')
    ax.plot(freqs[f_inds], gamma_xy[f_inds], label=title)
                                       
                                       
axis = ax.axis(ax.axis('tight'))
# ax.legend(ncol=2)
ax.set_xlabel(r'$f$ (Hz)', labelpad=0)
ax.set_ylabel(r'$\gamma_{\psi\hat{\psi}}^2(f)$ (-)', labelpad=0)


# fig.savefig('figures/Hay2011-crosscorr_Pz.pdf', bbox_inches='tight')

In [None]:
# swarmplots of setup, simulation times
ntasks = 1024

fig, axes = plt.subplots(1, 4, sharey=True, sharex=False, figsize=(16, 9))
for j, (ax, PS, title) in enumerate(zip(axes,
                         [PS0, PS1, PS2],
                         ['network', 'hybrid', 'kernel'])):
    for i, pset in enumerate(PS.iter_inner()):
        js = json.dumps(pset, sort_keys=True).encode()
        md5 = hashlib.md5(js).hexdigest()
        OUTPUTPATH = os.path.join('output', md5)
        if i == 0:
            df = pd.read_csv(os.path.join(OUTPUTPATH, 'tic_tac.txt'), sep=' ')
        else:
            df = df.append(
                pd.read_csv(os.path.join(OUTPUTPATH, 'tic_tac.txt'), sep=' '),
                ignore_index=True)

    
    df.time_s = df.time_s * ntasks
    
    sb.swarmplot(x='step', y='time_s', data=df, ax=ax)
    ax.set_title(title)
    ax.set_xticklabels(ax.get_xticklabels(), rotation=30)
    ax.semilogy(base=10)
    if j == 0:
        
        ax.set_ylabel('core-seconds (s)')
    else:
        ax.set_ylabel('')
    ax.set_xlabel('')

ax = axes[-1]
sb.swarmplot(x='step', y='time_s', data=H_YX_pred_times, ax=ax)
ax.set_title('rate kernels')
ax.set_xticklabels(ax.get_xticklabels(), rotation=30)
ax.semilogy(base=10)
ax.set_ylabel('')
ax.set_xlabel('')
    
# fig.savefig('figures/Hay2011_simulation_times.pdf', bbox_inches='tight')

In [None]:
for Y in params.population_names:
    with h5py.File(os.path.join(OUTPUTPATH_REAL, 'somav.h5'
                                ), 'r') as f:
        Vrest = np.median(f[Y][()][:, TRANSIENT:])
        print(Y, Vrest)

In [None]:
for PS in [PS0, PS1, PS2]:
    for i, pset in enumerate(PS.iter_inner()):
        js = json.dumps(pset, sort_keys=True).encode()
        md5 = hashlib.md5(js).hexdigest()
        OUTPUTPATH = os.path.join('output', md5)
        if i == 0:
            df = pd.read_csv(os.path.join(OUTPUTPATH, 'tic_tac.txt'), sep=' ')
        else:
            df = df.append(
                pd.read_csv(os.path.join(OUTPUTPATH, 'tic_tac.txt'), sep=' '),
                ignore_index=True)
    
    print(df.groupby('step', as_index=False).agg({'time_s': 'max'}))
    print(df.groupby('step', as_index=False).agg({'time_s': 'mean'}))