In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import os
from parameters import ParameterSpace, ParameterSet
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import h5py
import scipy.signal as ss
from example_network_parameters import (networkParameters, population_names)
from isyn_approximator import ISynApprox
import example_network_methods as methods
import example_network_parameters as params
import scipy.stats as st
from plotting import draw_lineplot, annotate_subplot, remove_axis_junk
import json
import hashlib
import scipy.optimize as so

In [None]:
plt.rcParams.update({
    'axes.xmargin': 0.01,
    'axes.ymargin': 0.01,
    'font.size': 14,
    'legend.fontsize': 12,
    'axes.titlesize': 14,
})

In [None]:
def ztransform(u, axis=-1):
    '''Return mean-subtracted input normalized by its standard deviations

    Parameters
    ----------
    u: ndarray
        1D or 2D array
    axis: int
        for 2D arrays, which axis to apply z-transform.
        Default axis=-1
    '''
    assert axis in [-1, 0, 1], 'axis not in [-1, 0, 1]'
    if u.ndim == 1:
        return (u - u.mean()) / u.std()
    elif u.ndim == 2:
        if axis in [-1, 1]:
            return ((u.T - u.mean(axis=axis)).T / u.std(axis=axis)).T
        else:
            return (u - u.mean(axis=axis)) / u.std(axis=axis)
    elif u.ndim > 2:
        raise Exception(f'can not transform array of shape {u.shape}')
    else:
        raise Exception

In [None]:
PS0 = ParameterSpace('PS0.txt')
PS1 = ParameterSpace('PS1.txt')
# PS2 = ParameterSpace('PS2.txt')

TRANSIENT = 2000
tstop = networkParameters['tstop']
dt = networkParameters['dt']
dt_proxy = 1.
tau = 50  # time lag relative to spike for kernel predictions
T = [2000, 2200]  # time segment for plots

time_proxy = np.arange(0, tstop + dt_proxy, dt_proxy)

In [None]:
# figure out which real LFP to compare with
for pset in PS1.iter_inner():
    weight_EE = pset['weight_EE']
    weight_IE = pset['weight_IE']
    weight_EI = pset['weight_EI']
    weight_II = pset['weight_II']
    pset_0 = ParameterSet(dict(weight_EE=weight_EE,
                               weight_IE=weight_IE,
                               weight_EI=weight_EI,
                               weight_II=weight_II,
                               weight_scaling=pset['weight_scaling'],
                               n_ext=PS0['n_ext'].value))
    js_0 = json.dumps(pset_0, sort_keys=True).encode()
    md5_0 = hashlib.md5(js_0).hexdigest()
    OUTPUTPATH_REAL = os.path.join('output', md5_0)

    break

In [None]:
# compute firing rate time series of "real" network at time res dt
nu_X = dict()
tstop = networkParameters['tstop']
bins = (np.arange(0, tstop / dt + 2)
        * dt - dt / 2)
with h5py.File(os.path.join(OUTPUTPATH_REAL, 'spikes.h5'), 'r') as f:
    for i, X in enumerate(population_names):
        hist = np.histogram(np.concatenate(f[X]['times']), bins=bins)[0]
        nu_X[X] = hist.astype(float)

In [None]:
# compute firing rate time series of "real" network at time res dt_proxy
nu_X_proxy = dict()
bins = np.arange(0, tstop / dt_proxy + 2) * dt_proxy - dt_proxy / 2
inds = (bins[:-1] >= T[0]) & (bins[:-1] <= T[1])
fig, ax = plt.subplots(1, 1, figsize=(16, 9))
with h5py.File(os.path.join(OUTPUTPATH_REAL, 'spikes.h5'), 'r') as f:
    for i, X in enumerate(population_names):
        hist = np.histogram(np.concatenate(f[X]['times']), bins=bins)[0]
        nu_X_proxy[X] = hist.astype(float)
        ax.step(bins[:-1][inds], nu_X_proxy[X][inds],
                label=r'$\nu_{%s}(t)$' % X)
ax.set_ylabel(r'$\nu_X$ (# spikes/$\Delta t$)')
ax.set_xlim(T)
ax.set_xlabel('$t$ (ms)')
ax.legend()

In [None]:
# mean/median somatic potentials
V_soma = dict()  # container
op = np.mean  # or np.median
fig, ax = plt.subplots(1, 1, figsize=(16, 9))
for j, Y in enumerate(params.population_names):
    with h5py.File(os.path.join(OUTPUTPATH_REAL, 'somav.h5'), 'r') as f:
        V_soma[Y] = op(f[Y], axis=0)

    inds = (time_proxy >= T[0]) & (time_proxy <= T[1])
    ax.plot(time_proxy[inds], V_soma[Y][inds], label=Y)

ax.axis(ax.axis('tight'))
ax.set_xlim(T)
ax.set_xlabel('$t$ (ms)')
ax.set_ylabel(r'%s$(V_\mathrm{soma})$' % op.__name__)
ax.legend()

In [None]:
# ground truth LFP (downsampled to dt_proxy)
if True:
    fname = 'RecExtElectrode.h5'
    unit = 'mv'
    vlimround = 2**-1

fig, ax = plt.subplots(1, 1, figsize=(16, 9))

with h5py.File(os.path.join(OUTPUTPATH_REAL, fname),
               'r') as f:
    data = ss.decimate(f['data'][()]['imem'], q=16,
                       zero_phase=True)

    label = 'real\n({})'.format(md5_0[:6])
    draw_lineplot(ax,
                  data,
                  dt=dt_proxy,
                  T=T,
                  scaling_factor=1.,
                  vlimround=vlimround,
                  label=label,
                  scalebar=True,
                  unit=unit,
                  ylabels=True,
                  color='k',
                  ztransform=True
                  )
    ax.set_title(label)



In [None]:

# synaptic currents per connection
##########################################################################
# Compute kernels mapping spikes to Isyn
##########################################################################

# synapse max. conductance (function, mean, st.dev., min.):
weights_YX = [[weight_EE, weight_IE],
              [weight_EI, weight_II]]

# Compute average firing rate of presynaptic populations X
nu_X_mean = methods.compute_mean_nu_X(params, OUTPUTPATH_REAL,
                                      TRANSIENT=TRANSIENT)

# conduction delay function
delayFunction = params.delayFunction
# if params.delayFunction == np.random.normal:
#     delayFunction = st.truncnorm
# else:
#     raise NotImplementedError

# kernel container
H_YX = dict()

# iterate over pre and postsynaptic units
for i, (X, N_X) in enumerate(zip(params.population_names,
                                 params.population_sizes)):

    for j, (Y, N_Y) in enumerate(zip(params.population_names,
                                     params.population_sizes)):
        # Extract median soma voltages from actual network simulation and
        # assume this value corresponds to Vrest.
        with h5py.File(os.path.join(OUTPUTPATH_REAL, 'somav.h5'
                                    ), 'r') as f:
            Vrest = np.median(f[Y][()][:, TRANSIENT:])

        # some inputs must be lists
        multapseParameters = [
            dict(loc=params.multapseArguments[ii][j]['loc'],
                 scale=params.multapseArguments[ii][j]['scale'])
            for ii in range(len(params.population_names))]
        synapseParameters = [
            dict(weight=weights_YX[ii][j],
                 syntype='Exp2Syn',
                 **params.synapseParameters[ii][j])
            for ii in range(len(params.population_names))]

        # Create kernel approximator object
        kernel = ISynApprox(
            X=params.population_names,
            Y=Y,
            N_X=np.array(params.population_sizes),
            N_Y=N_Y,
            C_YX=np.array(params.connectionProbability[i]),
            multapseParameters=multapseParameters,
            delayFunction=delayFunction,
            delayParameters=dict(
                # a=((params.mindelay - params.delayArguments[i][j]['loc'])
                #    / params.delayArguments[i][j]['scale']),
                # b=np.inf,
                **params.delayArguments[i][j]),
            synapseParameters=synapseParameters,
            nu_X=nu_X_mean,
        )

        H_YX['{}:{}'.format(Y, X)] = kernel.get_kernel(
            Vrest=Vrest, dt=dt, X=X, tau=tau,
        )

mean_Isyn_YX = np.zeros((2, 2))
for i, X in enumerate(params.population_names):
    for j, Y in enumerate(params.population_names):
        mean_Isyn_YX[i, j] = np.convolve(nu_X[X], H_YX['{}:{}'.format(Y, X)], 'same').mean() / params.population_sizes[j]


# compute total synaptic current per synapse type by convolving
# firing rate with sum of postsynaptic kernels
ISyn = dict()
fig, axes = plt.subplots(2, 1, sharex=True, figsize=(16, 9))
for i, X in enumerate(params.population_names):
    for j, Y in enumerate(params.population_names[0]):
        if j == 0:
            H_X = H_YX['{}:{}'.format(Y, X)]
        else:
            H_X = H_X + H_YX['{}:{}'.format(Y, X)]
    ISyn[X] = ss.decimate(np.convolve(nu_X[X], H_X, 'same'), q=16,
                          zero_phase=True)

    inds = (time_proxy >= T[0]) & (time_proxy <= T[1])
    ax = axes[i]
    ax.plot(time_proxy[inds], ISyn[X][inds])
    ax.set_ylabel(r'$(I_{\mathrm{syn}%s})$ (nA)' % X)

    ax.axis(ax.axis('tight'))
    ax.set_xlim(T)
    ax.set_xlabel('$t$ (ms)')


In [None]:


################################################
# define and compare temporal loading functions g_proxy(t) after z-transforms:
################################################
# sum of firing rates
g_proxy_FR = ztransform(nu_X_proxy['E'][TRANSIENT:] +
                        nu_X_proxy['I'][TRANSIENT:])

# mean somatic voltages
g_proxy_Vm = ztransform(V_soma['E'][TRANSIENT:] + V_soma['I'][TRANSIENT:])

# excitatory synaptic current
g_proxy_ISynE = ztransform(ISyn['E'][TRANSIENT:])

# inhibitory synaptic current
g_proxy_ISynI = ztransform(ISyn['I'][TRANSIENT:])

# summed synaptic currents
g_proxy_I = ztransform(ISyn['E'][TRANSIENT:] + ISyn['I'][TRANSIENT:])

# absolute (difference) of summed synaptic currents (ISyn['I'] is always < 0)
g_proxy_I_abs = ztransform(ISyn['E'][TRANSIENT:] - ISyn['I'][TRANSIENT:])



In [None]:

def calc_g_proxy_WS(I_E, I_I, alpha=1, tau=0):
    '''
    Implements the temporal component of weighted sum proxy of
    Mazzoni2015 defined as:

    g_WS = ztransform( | I_E + alpha * convolve(I_I, delta(t - delay)) | )

    Parameters
    ----------
    I_E: ndarray
        excitatory synapse current as function of time
    I_I: ndarray
        inhibitory synapse current as function of time
    alpha: scalar
        relative weight of inhibitory synapse current
    tau: scalar
        fractional delay applied to inhibitory synapse current in units
        of samples

    Returns
    -------
    g_WS: ndarray
    '''
    # Fractional delay coefficients from
    # https://tomroelandts.com/articles/how-to-create-a-fractional-delay-filter
    # tau = 0.3  # Fractional delay [samples].
    N = 21     # Filter length.
    n = np.arange(N)
    # Compute sinc filter.
    h = np.sinc(n - (N - 1) / 2 - tau)
    # Multiply sinc filter by window
    h *= np.blackman(N)
    # Normalize to get unity gain.
    h /= np.sum(h)

    return ztransform(I_E - alpha * np.convolve(I_I, h, 'same'))


In [None]:


# optimized g_proxy_WS with regards to minimzing (1-R2)
def func(x):
    g = calc_g_proxy_WS(ISyn['E'][TRANSIENT:], ISyn['I'][TRANSIENT:],
                        alpha=x[0], tau=x[1])

    n_ch = data.shape[0]
    # compute mean R2
    X = data[:, TRANSIENT:].T
    # center data
    X = X - X.mean(axis=0)
    # concatenate
    X = np.c_[X, g]

    R2 = np.corrcoef(X.T)**2

    return 1 - np.mean(R2[n_ch, :n_ch])


In [None]:


# initial guess [alpha, tau]
x0 = [1., 0.]
res = so.minimize(func, x0)
print(res.x)

# optimized g_proxy
g_proxy_WS = calc_g_proxy_WS(ISyn['E'][TRANSIENT:], ISyn['I'][TRANSIENT:],
                             alpha=res.x[0], tau=res.x[1])

fig = plt.figure(figsize=(16, 9))
gs = GridSpec(7, 3)

axes = []
inds = (time_proxy >= T[0]) & (time_proxy <= T[1])

# reshape data
X = data[:, TRANSIENT:].T
# center data
X = X - X.mean(axis=0)
for i, (g_proxy, label) in enumerate(zip(
        [g_proxy_FR,
         g_proxy_Vm,
         g_proxy_ISynE,
         g_proxy_ISynI,
         g_proxy_I,
         g_proxy_I_abs,
         g_proxy_WS],
        [r'$g_{\sum \nu_X}(t)$',
         r'$g_{V_\mathrm{m}}(t)$',
         r'$g_{\sum I_E}(t)$',
         r'$g_{\sum I_I}(t)$',
         r'$g_{\sum I}(t)$',
         r'$g_{\sum {|I|}}(t)$',
         r'$g_\mathrm{WS}(t)$'])):
    if i == 0:
        ax = fig.add_subplot(gs[i, 0])
        annotate_subplot(ax, ncols=3, nrows=7, letter='A', linear_offset=0.025)
    else:
        ax = fig.add_subplot(gs[i, 0], sharex=axes[-1])
    if i != 6:
        plt.setp(ax.get_xticklabels(), visible=False)
    axes.append(ax)
    ax.plot(time_proxy[inds], g_proxy[inds[TRANSIENT:]], 'C{}'.format(i),
            label=label)
    ax.set_ylabel(label, labelpad=0)
    ax.axis('tight')
    ax.set_xlim(T)

ax.set_xlabel('$t$ (ms)', labelpad=0)

for i, ax in enumerate(axes):
    remove_axis_junk(ax)
    if i != 6:
        plt.setp(ax.get_xticklabels(), visible=False)



# R2 == CC^2 per channel
ax = fig.add_subplot(gs[:, 1])
remove_axis_junk(ax)
ax.invert_yaxis()
X = np.c_[X, np.c_[g_proxy_FR, g_proxy_Vm, g_proxy_ISynE,
                   g_proxy_ISynI, g_proxy_I, g_proxy_I_abs,
                   g_proxy_WS]]
R2 = np.corrcoef(X.T)**2
n_ch = data.shape[0]
channels = np.arange(n_ch) + 1
for i, label in enumerate([
        r'$g_{\sum \nu_X}(t)$',
        r'$g_{V_\mathrm{m}}(t)$',
        r'$g_{\sum I_E}(t)$',
        r'$g_{\sum I_I}(t)$',
        r'$g_{\sum I}(t)$',
        r'$g_{\sum {|I|}}(t)$',
        r'$g_\mathrm{WS}(t)$']):
    ax.plot(R2[i + n_ch, :n_ch], channels, '-' + 'o^>v<sd'[i], color='C{}'.format(i), lw=2,
            label=(label + '\n' +
                   r'$R^2_\mathrm{max}$=%.2f' % R2[i + n_ch, :n_ch].max()))
ax.set_yticks(channels)
ax.set_yticklabels(['ch.{}'.format(ch) for ch in channels])
ax.set_xlabel('$R^2$ (-)', labelpad=0)
ax.axis(ax.axis('tight'))

box = ax.get_position()
ax.set_position([box.x0, box.y0 + box.height * 0.4, box.width, box.height * 0.6])

ax.legend(loc=2, ncol=2, bbox_to_anchor=(0.0, -0.15))


annotate_subplot(ax, ncols=3, nrows=2, letter='B', linear_offset=0.025)



ax = fig.add_subplot(gs[:, 2], sharey=ax)
remove_axis_junk(ax)
timelags = (np.arange(X.shape[0]) - X.shape[0] // 2) * dt_proxy
# include max of squared normalized correlation functions
for i, x in enumerate([g_proxy_FR, g_proxy_Vm, g_proxy_ISynE,
                       g_proxy_ISynI, g_proxy_I, g_proxy_I_abs,
                       g_proxy_WS]):
    R2_delta = []
    lags = []
    for ch in range(n_ch):
        cc2 = (np.correlate(X[:, ch], x, 'same') /
               (X[:, ch].std() * x.std()) / x.size)**2
        R2_delta += [cc2.max()]
        lags += [timelags[cc2 == cc2.max()]]

    lags = np.array(lags)

    ax.plot(R2_delta, channels, '-' + 'o^>v<sd'[i], color='C{}'.format(i), lw=2,
            label=(r'$\langle \tau \rangle ={%.2f}$ ms' % lags.mean() +
                   '\n' +
                   r'$\sigma_\tau={%.2f}$ ms' % lags.std() +
                   '\n' +
                   r'$R^2_\mathrm{max}$=%.2f' % np.max(R2_delta)))

plt.setp(ax.get_yticklabels(), visible=False)
ax.set_xlabel('$R^2$ (-)', labelpad=0)
ax.axis(ax.axis('tight'))

box = ax.get_position()
ax.set_position([box.x0, box.y0 + box.height * 0.4, box.width, box.height * 0.6])

ax.legend(loc=2, ncol=2, bbox_to_anchor=(0.0, -0.15))

annotate_subplot(ax, ncols=3, nrows=2, letter='C', linear_offset=0.025)


fig.savefig('Figure_8.pdf', bbox_inches='tight', pad_inches=0)