In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import os
from parameters import ParameterSpace, ParameterSet
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from matplotlib.ticker import MaxNLocator
import h5py
import scipy.signal as ss
from example_network_parameters import (networkParameters, population_names,
                                        population_sizes)
import example_network_parameters as params
import example_network_methods as methods
from lfpykernels import KernelApprox, GaussCylinderPotential
from lfpykit import CurrentDipoleMoment
from plotting import draw_lineplot, remove_axis_junk, annotate_subplot
import plotting

import scipy.stats as st
from copy import deepcopy
import json
import hashlib
from time import time
import h5py
import neuron
import nest
from pynestml.frontend.pynestml_frontend import to_nest, install_nest
from nest import raster_plot
import LIF_net

In [None]:
plt.rcParams.update(plotting.rcParams)
golden_ratio = plotting.golden_ratio
figwidth = plotting.figwidth

In [None]:
# recompile mod files if needed
mech_loaded = neuron.load_mechanisms('mod')
if not mech_loaded:
    os.system('cd mod && nrnivmodl && cd -')
    neuron.load_mechanisms('mod')

In [None]:
PS0 = ParameterSpace('PS0.txt')
PS1 = ParameterSpace('PS1.txt')
# PS2 = ParameterSpace('PS2.txt')

In [None]:
TRANSIENT = 2000
dt = networkParameters['dt']
tstop = networkParameters['tstop']
tau = 100  # time lag relative to spike for kernel predictions

In [None]:
# downsample signals for plots to resoultion dt * decimate_ratio
decimate_ratio = 4

In [None]:
# plt.mlab.psd/csd settings
Fs = 1000 / dt
NFFT = 2048
noverlap = 1536
detrend = 'constant'
cutoff = 200

In [None]:
# E and I colors
colors = ['tab:blue', 'tab:red']

In [None]:
# figure out which real LFP to compare with
for pset in PS1.iter_inner():
    weight_EE = pset['weight_EE']
    weight_IE = pset['weight_IE']
    weight_EI = pset['weight_EI']
    weight_II = pset['weight_II']
    weight_scaling = pset['weight_scaling']
    pset_0 = ParameterSet(dict(weight_EE=weight_EE,
                               weight_IE=weight_IE,
                               weight_EI=weight_EI,
                               weight_II=weight_II,
                               weight_scaling=weight_scaling,
                               n_ext=PS0['n_ext'].value))
    js_0 = json.dumps(pset_0, sort_keys=True).encode()
    md5_0 = hashlib.md5(js_0).hexdigest()
    OUTPUTPATH_REAL = os.path.join('output', md5_0)

    break
print(f'comparing with ground truth dataset: {OUTPUTPATH_REAL}')

In [None]:
# Install NESTML FIR_filter.nestml model
nestml_model_file = 'FIR_filter.nestml'
nestml_model_name = 'fir_filter_nestml'
target_path = '/tmp/fir-filter'
logging_level = 'INFO'
module_name = 'nestmlmodule'
store_log = False
suffix = '_nestml'
dev = True

'''
try:
    nest.set_verbosity("M_ALL")
    nest.Install(module_name)
except:
    # Generate the NEST code
    input_path = os.path.join(os.path.realpath(nestml_model_file))
    nest_path = nest.ll_api.sli_func("statusdict/prefix ::")
    to_nest(input_path, target_path, logging_level, module_name, store_log, suffix, dev)
    install_nest(target_path, nest_path)

    nest.set_verbosity("M_ALL")
    nest.Install(module_name)
'''
input_path = os.path.join(os.path.realpath(nestml_model_file))
nest_path = nest.ll_api.sli_func("statusdict/prefix ::")
to_nest(input_path, target_path, logging_level, module_name, store_log, suffix, dev)
install_nest(target_path, nest_path)

nest.set_verbosity("M_ALL")
nest.Install(module_name)

nest.ResetKernel()

In [None]:
pset

In [None]:
# create kernels from multicompartment neuron description
                
# kernel container
H_YX_pred = dict()

# sorted json dictionary
js = json.dumps(pset, sort_keys=True).encode()
md5 = hashlib.md5(js).hexdigest()

# parameters
weight_EE = pset['weight_EE']
weight_IE = pset['weight_IE']
weight_EI = pset['weight_EI']
weight_II = pset['weight_II']
biophys = pset['biophys']
n_ext = pset['n_ext']
g_eff = pset['g_eff']

t_X = TRANSIENT  # presynaptic activation time

# define biophysical membrane properties
if biophys == 'pas':
    set_biophys = [methods.set_passive_hay2011, methods.make_cell_uniform]
elif biophys == 'pas_v2':
    set_biophys = [methods.set_passive_hay2011_no_Ih, methods.make_cell_uniform]
elif biophys == 'lin':
    set_biophys = [methods.set_Ih_linearized_hay2011, methods.make_cell_uniform]
else:
    raise NotImplementedError

# synapse max. conductance (function, mean, st.dev., min.):
weights = [[weight_EE, weight_IE],
           [weight_EI, weight_II]]

# class RecExtElectrode/PointSourcePotential parameters:
electrodeParameters = params.electrodeParameters.copy()
for key in ['r', 'n', 'N', 'method']:
    del electrodeParameters[key]

# Not using RecExtElectrode class as we anyway average potential in
# space for each source element. This is to replaced by
# a closed form volumetric method (point source & volumetric contacts
# should result in same mappings as volumetric source & point contacs)

# Predictor assuming planar disk source elements convolved with Gaussian
# along z-axis
gauss_cyl_potential = GaussCylinderPotential(
    cell=None,
    z=electrodeParameters['z'],
    sigma=electrodeParameters['sigma'],
    R=params.populationParameters['pop_args']['radius'],
    sigma_z=params.populationParameters['pop_args']['scale'],
    )

# set up recording of current dipole moments.
current_dipole_moment = CurrentDipoleMoment(cell=None)

# Compute average firing rate of presynaptic populations X
mean_nu_X = methods.compute_mean_nu_X(params, OUTPUTPATH_REAL,
                                 TRANSIENT=TRANSIENT)

# tic toc
tic = time()

# kernel container
H_YX_pred = dict()

for i, (X, N_X) in enumerate(zip(params.population_names,
                                 params.population_sizes)):
    for j, (Y, N_Y, morphology) in enumerate(zip(params.population_names,
                                                 params.population_sizes,
                                                 params.morphologies)):

        # Extract median soma voltages from actual network simulation and
        # assume this value corresponds to Vrest.
        with h5py.File(os.path.join(OUTPUTPATH_REAL, 'somav.h5'
                                    ), 'r') as f:
            Vrest = np.median(f[Y][()][:, 200:])

        cellParameters = deepcopy(params.cellParameters)
        cellParameters.update(dict(
            morphology=morphology,
            custom_fun=set_biophys,
            custom_fun_args=[dict(Vrest=Vrest), dict(Vrest=Vrest)],
        ))

        # some inputs must be lists
        synapseParameters = [
            dict(weight=weights[ii][j],
                 syntype='Exp2Syn',
                 **params.synapseParameters[ii][j])
            for ii in range(len(params.population_names))]
        synapsePositionArguments = [
            params.synapsePositionArguments[ii][j]
            for ii in range(len(params.population_names))]

        # Create kernel approximator object
        kernel = KernelApprox(
            X=params.population_names,
            Y=Y,
            N_X=np.array(params.population_sizes),
            N_Y=N_Y,
            C_YX=np.array(params.connectionProbability[i]),
            cellParameters=cellParameters,
            populationParameters=params.populationParameters['pop_args'],
            multapseFunction=params.multapseFunction,
            multapseParameters=[params.multapseArguments[ii][j] for ii in range(len(params.population_names))],
            delayFunction=params.delayFunction,
            delayParameters=[params.delayArguments[ii][j] for ii in range(len(params.population_names))],
            synapseParameters=synapseParameters,
            synapsePositionArguments=synapsePositionArguments,
            extSynapseParameters=params.extSynapseParameters,
            nu_ext=1000. / params.netstim_interval,
            n_ext=n_ext[j],
            nu_X=mean_nu_X,
        )

        # make kernel predictions
        H_YX_pred['{}:{}'.format(Y, X)] = kernel.get_kernel(
            probes=[gauss_cyl_potential, current_dipole_moment],
            Vrest=Vrest, dt=dt, X=X, t_X=t_X, tau=tau,
            g_eff=g_eff,
        )

# tic toc
toc = time()

In [None]:
# print tic-toc summary
print(f'kernel simulations: {toc - tic} seconds')

In [None]:
# best fit params from fit_LIF_net
with h5py.File('Fit_LIF_net.h5', 'r') as f:
    res_x = f['X_opt'][()]

params = dict(
    X = ['E', 'I'],
    N_X=[8192, 1024], 
    C_m_X=[res_x[9], res_x[10]],
    tau_m_X=[10., 10.],
    E_L_X=[-65., -65.],
    C_YX=[[0.5, 0.5], [0.5, 0.5]],
    J_YX=[[res_x[0], res_x[1]], [res_x[2], res_x[3]]],
    delay_YX = [[res_x[5], res_x[6]], [res_x[7], res_x[8]]],
    tau_syn_YX = [[0.5, 0.5], [0.5, 0.5]],
    n_ext=n_ext,
    nu_ext=40.,
    J_ext=res_x[4],
    model='iaf_psc_exp',
    dt=2**-4,
)
params

In [None]:
H_YX_pred['E:E']['GaussCylinderPotential'].shape

In [None]:
H_YX_pred['E:E'].keys()

In [None]:
# test Network class
tic = time()
net = LIF_net.Network(**params)
net.create_fir_filters(H_YX=H_YX_pred)
tac = time()

In [None]:
# run simulation
net.simulate(tstop=tstop)
toc = time()

In [None]:
# print a little tic-toc summary
print(f'network create: {tac - tic} seconds')
print(f'network simulate: {toc - tac} seconds')
print(f'total: {toc - tic} seconds')

In [None]:
net.multimeters

In [None]:
# reset matplotlib settings
%matplotlib inline

In [None]:
plt.rcParams.update(plotting.rcParams)
golden_ratio = plotting.golden_ratio
figwidth = plotting.figwidth

In [None]:
# check if multimeter recorded data
fig, axes = plt.subplots(2, 2, sharex=True, sharey=False)
for i, X in enumerate(net.X):
    for j, Y in enumerate(net.X):        
        [data] = nest.GetStatus(net.multimeters['GaussCylinderPotential'][f'{X}:{Y}'], 'events')
        for sender in np.unique(data['senders']):
            inds = data['senders'] == sender
            axes[i, j].plot(data['times'][inds][-10000:], data['y'][inds][-10000:])

In [None]:
# mean firing rates of "real" network populations
lif_mean_nu_X = dict()  # mean spike rates
lif_nu_X = dict()  # binned firing rate
lif_psd_X = dict()
for i, X in enumerate(population_names):
    times = nest.GetStatus(net.spike_recorders[X])[0]['events']['times']
    times = times[times >= TRANSIENT]

    lif_mean_nu_X[X] = LIF_net.get_mean_spike_rate(times, TRANSIENT=TRANSIENT, tstop=tstop) / population_sizes[i]
    _, lif_nu_X[X] = LIF_net.get_spike_rate(times, TRANSIENT=TRANSIENT, tstop=tstop, dt=dt)
    _, lif_psd_X[X] = LIF_net.get_psd(lif_nu_X[X], Fs=Fs, NFFT=NFFT, noverlap=noverlap, detrend=detrend, cutoff=cutoff)
        
lif_mean_nu_X

In [None]:
# plot LIF-network spikes, spike rates, signal predictions
fig = plt.figure(figsize=(figwidth, figwidth / golden_ratio))
gs = GridSpec(8, 6, wspace=0.5)

ax = fig.add_subplot(gs[:-1, 0])
remove_axis_junk(ax)
annotate_subplot(ax, ncols=8, nrows=1, letter='A', linear_offset=0.02)

T = [4000, 4100]
for i, Y in enumerate(net.X):
    times = nest.GetStatus(net.spike_recorders[Y])[0]['events']['times']
    gids = nest.GetStatus(net.spike_recorders[Y])[0]['events']['senders']

    gids = gids[times >= TRANSIENT]
    times = times[times >= TRANSIENT]

    ii = (times >= T[0]) & (times <= T[1])
    ax.plot(times[ii], gids[ii], '.',
            mfc=colors[i],
            mec='none',
            ms=2,
            label=r'$\langle \nu_\mathrm{%s} \rangle =%.2f$ s$^{-1}$' % (
                Y, lif_mean_nu_X[Y])
           )
ax.legend(loc=1, markerscale=5)
ax.axis('tight')
ax.set_xticklabels([])
ax.set_ylabel('gid')
# annotate_subplot(ax, ncols=7, nrows=5 / 4, letter='A', linear_offset=0.02)


#####
# Rates
####
ax = fig.add_subplot(gs[-1, 0])
remove_axis_junk(ax)

Delta_t = dt
bins = np.linspace(T[0], T[1], int(np.diff(T) / Delta_t + 1))

for i, Y in enumerate(net.X):
    times = nest.GetStatus(net.spike_recorders[Y])[0]['events']['times']

    ii = (times >= T[0]) & (times <= T[1])
    ax.hist(times[ii], bins=bins, histtype='step', color=colors[i])

ax.yaxis.set_major_locator(MaxNLocator(nbins=5, integer=True))
ax.axis('tight')
ax.set_xlabel('t (ms)')
ax.set_ylabel(r'$\nu_X$' + '\n(spikes/' + r'$\Delta t$)')

# annotate_subplot(ax, ncols=7, nrows=5, letter='B', linear_offset=0.02)


# contributions by each connection:
for k, (ylabel, probe, unit, vlimround) in enumerate(zip(
    [r'$V_\mathrm{e}$', r'$\mathbf{P}$'],
    ['GaussCylinderPotential', 'CurrentDipoleMoment'],
    ['mV', 'nAµm'],
    [2**-1, 2**4])):

    if probe == 'CurrentDipoleMoment':
        scaling = 1E-4  # nAum --> nAcm unit conversion
        unit = 'nAcm'
    else:
        scaling = 1
    
    data = None
    data_mm = None
            
    for i, X in enumerate(population_names):
        for j, Y in enumerate(population_names):
            if k == 0:
                ax = fig.add_subplot(gs[:-1, i * 2  + j + 1])
                annotate_subplot(ax, ncols=6, nrows=1, letter='ABCDE'[i * 2 + j + 1], linear_offset=0.02)
            else:
                ax = fig.add_subplot(gs[-1, i * 2 + j + 1])
            

            data_YX = np.zeros((H_YX_pred['{}:{}'.format(Y, X)][probe].shape[0],
                                lif_nu_X[X].size))
            if data is None:
                data = np.zeros((H_YX_pred['{}:{}'.format(Y, X)][probe].shape[0],
                                 lif_nu_X[X].size))
            for h, h_YX in enumerate(H_YX_pred['{}:{}'.format(Y, X)][probe]):
                data_YX[h, :] = np.convolve(lif_nu_X[X], h_YX, 'same')


            data = data + data_YX
            
            # FIR filter responses in NEST:
            [mm_YX] = nest.GetStatus(net.multimeters[probe][f'{Y}:{X}'], 'events')
            for ii, sender in enumerate(np.unique(mm_YX['senders'])):
                inds = mm_YX['senders'] == sender
                if ii == 0:
                    d = ss.decimate(mm_YX['y'][inds][int(TRANSIENT / dt):], 
                                    q=decimate_ratio, zero_phase=True)
                else:
                    d = np.row_stack((d, ss.decimate(mm_YX['y'][inds][int(TRANSIENT / dt):], 
                                                     q=decimate_ratio, zero_phase=True)))
            if data_mm is None:
                data_mm = d
            else:
                data_mm += d
            
            draw_lineplot(ax,
                      d[-1, :].reshape((1, -1)) * scaling if probe == 'CurrentDipoleMoment' else d * scaling,
                      dt=dt * decimate_ratio,
                      T=T,
                      scaling_factor=1.,
                      # vlimround=vlimround,
                      label='',
                      scalebar=True,
                      unit=unit,
                      ylabels=True,
                      color='k',
                      ztransform=True
                      )

            if probe == 'CurrentDipoleMoment':
                ax.set_yticklabels(['$P_z$'])

            if i * 2  + j > 0:
                plt.setp(ax.get_yticklabels(), visible=False)
            if k == 0:
                ax.set_title(r'$\nu_{%s} \ast \hat{H}_\mathrm{%s%s}$' % (X, Y, X))
                ax.set_xlabel('')
                plt.setp(ax.get_xticklabels(), visible=False)
            ax.set_ylabel('')
                
    # sum
    if k == 0:
        ax = fig.add_subplot(gs[:-1, -1])
        annotate_subplot(ax, ncols=6, nrows=1, letter='F', linear_offset=0.02)

    else:
        ax = fig.add_subplot(gs[-1, -1])
    

    draw_lineplot(ax,
                  data_mm[-1, :].reshape((1, -1)) * scaling if probe == 'CurrentDipoleMoment' else data_mm * scaling,
                  dt=dt * decimate_ratio,
                  T=T,
                  scaling_factor=1.,
                  label='',
                  scalebar=True,
                  unit=unit,
                  ylabels=True,
                  color='k',
                  ztransform=True
                  )
    
    plt.setp(ax.get_yticklabels(), visible=False)
    if k == 0:
        ax.set_title(r'$\sum_X \sum_Y \nu_X \ast \hat{H}_\mathrm{YX}$')
        ax.set_xlabel('')
        plt.setp(ax.get_xticklabels(), visible=False)
    ax.set_ylabel('')

fig.savefig(os.path.join('figures', 'figure16.pdf'), bbox_inches='tight')

In [None]:
# mean firing rates, rates, rate spectras of "real" network populations
mean_nu_X = dict()
nu_X = dict()
psd_X = dict()
with h5py.File(os.path.join(OUTPUTPATH_REAL, 'spikes.h5'), 'r') as f:
    for i, X in enumerate(population_names):
        times = np.concatenate(f[X]['times'])
        mean_nu_X[X] = LIF_net.get_mean_spike_rate(times, TRANSIENT=TRANSIENT, tstop=tstop) / population_sizes[i]
        bins, nu_X[X] = LIF_net.get_spike_rate(times, TRANSIENT=TRANSIENT, tstop=tstop, dt=dt)
        freqs, psd_X[X] = LIF_net.get_psd(nu_X[X], Fs=Fs, NFFT=NFFT, noverlap=noverlap, detrend=detrend, cutoff=cutoff)

mean_nu_X

In [None]:
lif_psd_X = {}
lif_mean_nu_X_ = {}
for i, X in enumerate(population_names):
    times = nest.GetStatus(net.spike_recorders[X])[0]['events']['times']
    times = times[times >= TRANSIENT]
    lif_mean_nu_X_[X] = LIF_net.get_mean_spike_rate(times, TRANSIENT=TRANSIENT, tstop=tstop)  / population_sizes[i]
    _, lif_psd_X[X] = LIF_net.get_psd(lif_nu_X[X], Fs=Fs, NFFT=NFFT, noverlap=noverlap, detrend=detrend, cutoff=cutoff)

In [None]:
lif_psd_X

In [None]:
# compare rate spectra
fig, ax = plt.subplots(1, 1, figsize=(8, 8))
for i, X in enumerate(population_names):
    ax.semilogy(freqs, psd_X[X], label=f'{X} mc {mean_nu_X[X]} s-1')
    ax.semilogy(freqs, lif_psd_X[X], label=f'{X} lif {lif_mean_nu_X_[X]} s-1')
ax.set_xlabel('$f$ (Hz)')
ax.set_ylabel(r'PSD$_\nu$ (s$^{-2}$/Hz)')
ax.legend()

In [None]:
# current dipole moment spectra
fname = 'CurrentDipoleMoment.h5'
probe = 'CurrentDipoleMoment'

scaling = 1E-4  # nAum --> nAcm unit conversion
unit = 'nAcm'
        
with h5py.File(os.path.join(OUTPUTPATH_REAL, fname),
           'r') as f:
    data = f['data'][()]
    data = data['imem']
    data = data[2, int(TRANSIENT * dt):]
    data -= data.mean()
    freqs, psd_gt = ss.welch(data * scaling, fs=Fs, nperseg=NFFT, noverlap=noverlap)

    
# FIR filter responses in NEST:
data_mm = None
for X in params['X']:
    for Y in params['X']:
        [mm_YX] = nest.GetStatus(net.multimeters[probe][f'{Y}:{X}'], 'events')
        for ii, sender in enumerate(np.unique(mm_YX['senders'])):
            inds = mm_YX['senders'] == sender
            if ii == 0:
                d = mm_YX['y'][inds][int(TRANSIENT / dt):]
            else:
                d = np.row_stack((d, mm_YX['y'][inds][int(TRANSIENT / dt):]))
        if data_mm is None:
            data_mm = d
        else:
            data_mm += d

data_mm = (data_mm.T - data_mm.mean(axis=-1)).T
_, psd_mm = ss.welch(data_mm[2, ] * scaling, fs=Fs, nperseg=NFFT, noverlap=noverlap)
    
    
fig, ax = plt.subplots(1, 1, figsize=(8, 8))
ax.semilogy(freqs[freqs <= 200], psd_gt[freqs <= 200], label=f'ground truth')
ax.semilogy(freqs[freqs <= 200], psd_mm[freqs <= 200], label=f'LIFnet')
ax.legend()