Illustration of levels of neuronal network description and predictions 

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import os
import LFPy
import numpy as np
import h5py
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.gridspec import GridSpec
from matplotlib.collections import PolyCollection, PatchCollection
from plotting import annotate_subplot
import plotting
import example_network_parameters as params
import scipy.signal as ss

In [None]:
plt.xkcd(scale=0.5)
# from matplotlib import patheffects
# plt.rcParams.update({
#     'path.effects': [], # patheffects.withStroke(linewidth=4, foreground="w")]
# })

In [None]:
plt.rcParams.update(plotting.rcParams)
golden_ratio = plotting.golden_ratio
figwidth = plotting.figwidth

In [None]:
plt.rcParams.update({
    'font.family': plt.rcParamsDefault['font.family']
})

In [None]:
# E and I colors
colors = ['tab:blue', 'tab:red']

In [None]:
# Dataset
OUTPUTPATH = os.path.join('output', 'adb947bfb931a5a8d09ad078a6d256b0')
OUTPUTPATH_APPROX = os.path.join('output', '7c88fea99ae5f4fc3669292354655c2c')

In [None]:
SEED = 1234

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(figwidth, figwidth))
ax.axis('off')

# ax.axis('equal')
ax.axis((-500, 1500, -1000, 800))


# LFP data
f = h5py.File(os.path.join(OUTPUTPATH, 'RecExtElectrode.h5'), 'r')
data = f['data']['imem'][::2, -1601:]
data = (data.T - data.mean(axis=-1)).T
t = np.linspace(200, 500, data.shape[1])

# approximated data
f = h5py.File(os.path.join(OUTPUTPATH_APPROX, 'RecExtElectrode.h5'), 'r')
data_approx = f['data']['imem'][::2, -1601:]
data_approx = (data_approx.T - data_approx.mean(axis=-1)).T


###############################
# Draw MC+VC network
###############################
np.random.seed(SEED)
N_Y = [16, 8]
labels = ['E', 'I']
morphologies = ['BallAndSticks_short.hoc', 'BallAndSticks_I.hoc']
for i, (N, label, morphology) in enumerate(zip(N_Y, labels, morphologies)):
    # soma locations
    cell_xz = np.c_[(np.random.rand(N) - 0.5) * 500, (np.random.rand(N) - 0.5) * 200]

    for j, (x, z) in enumerate(cell_xz):
        morphology=morphology
        cellParameters = dict(
            morphology=morphology,
            templatefile='BallAndSticksTemplate.hoc',
            templatename='BallAndSticksTemplate',
            templateargs=None,
            delete_sections=True,
        )
        cell = LFPy.TemplateCell(**cellParameters)
        cell.set_pos(x, 0, z)

        # show morphology
        zips = []
        for x, z in cell.get_idx_polygons(projection=('x', 'z')):
            zips.append(list(zip(x, z)))
        polycol = PolyCollection(zips,
                                 edgecolors='k',
                                 facecolors=colors[i],
                                 linewidths=1,
                                 zorder=np.random.rand()-0.5,
                                 label=label if j == 0 else '__nolabel__')
        ax.add_collection(polycol)

# legend
ax.legend(loc='center', bbox_to_anchor=(0.45, 0.65))

# probe
el = [list(zip([-40, -40, 0, 40, 40, -40], 
               [500, -50, -100, -50, 600, 600]))]
elpoly = PolyCollection(el,
                        edgecolors='k',
                        facecolors='gray', zorder=0)
ax.add_collection(elpoly)
ax.plot(np.zeros(data.shape[0]), np.linspace(-50, 450, data.shape[0]), 'o', color='k', zorder=0)

# connections
style = "Simple, tail_width=3, head_width=10, head_length=6"
# E outgoing
kw = dict(arrowstyle=style, color="k")
arrow = mpatches.FancyArrowPatch((-0, 600), (380, 650),
                             connectionstyle="arc3,rad=-0.2", **kw)
ax.add_patch(arrow)



# draw extracellular traces
for i, x in enumerate(data):
    ax.plot(t + 200, x*1e2 - i * 50 + 700, 'k', lw=1)
ax.text(550, 750, r'$V_\mathrm{e}(\mathbf{R},t)$', fontsize=16, ha='center', va='center')


# draw approximated extracellular traces
for i, x in enumerate(data_approx):
    ax.plot(t + 950, x*1e2 - i * 50 - 450, 'k', lw=1)
ax.text(1300, -400, r'$\hat{V}_\mathrm{e}(\mathbf{R},t)$', fontsize=16, ha='center', va='center')


###############################
# Draw point-neuron network
###############################

np.random.seed(SEED)
N_Y = [16, 8]
morphologies = ['BallAndSticks_short.hoc', 'BallAndSticks_I.hoc']
coords = []
for i, (N, morphology, marker) in enumerate(zip(N_Y, morphologies, ['^', 'o'])):
    # soma locations
    cell_xz = np.c_[(np.random.rand(N) - 0.5) * 500, (np.random.rand(N) - 0.5) * 200 - 500]
    coords.append(cell_xz)
    
    for x, z in cell_xz:
        zorder = np.random.rand() - 0.5
        ax.plot(x, z, marker=marker, mec='k', mfc=colors[i], ms=15, zorder=zorder)
    

# illustrate some connections
for h, pre_xz in enumerate(coords):
    for k, post_xz in enumerate(coords):
        for i, (x_post, z_post) in enumerate(post_xz):
            z_syn = [z_post] * len(post_xz)
            for j, (x_pre, z_pre) in enumerate(pre_xz):
                if i != j:
                    if np.random.rand() > 0.9:
                        # draw line from presyn soma to synapse
                        ax.plot([x_pre, x_post + (-np.sign(x_post-x_pre)*15)], 
                                [z_pre,  z_syn[i]], 'o-', color=colors[h], 
                               path_effects=[],
                               lw=1)
                    continue
        
        
        
#########################
# Neural mass model illustration
#########################

z = -850
for i, (x, nverts, radius, label) in enumerate(zip([-150, 150],
                               [3, 100],
                               [100, 75],
                               labels)):
    polygon = mpatches.RegularPolygon((x, z), nverts, radius, orientation=0, 
                                      edgecolor=colors[i], facecolor='w', lw=3)
    ax.add_patch(polygon)
    ax.text(x, z, label, color=colors[i], size=32, ha='center', va='center')

    
# connections
style = "Simple, tail_width=3, head_width=10, head_length=6"
# E outgoing
kw = dict(arrowstyle=style, color=colors[0])
arrow = mpatches.FancyArrowPatch((-120, -800), (80, -800),
                             connectionstyle="arc3,rad=-0.5", **kw)
ax.add_patch(arrow)
arrow = mpatches.FancyArrowPatch((-180, -800), (-230, -880),
                             connectionstyle="arc3,rad=2.0", **kw)
ax.add_patch(arrow)

# I outgoing
kw = dict(arrowstyle=style, color=colors[1])
arrow = mpatches.FancyArrowPatch((100, -900), (-50, -900),
                             connectionstyle="arc3,rad=-0.5", **kw)
ax.add_patch(arrow)
arrow = mpatches.FancyArrowPatch((205, -900), (215, -800),
                             connectionstyle="arc3,rad=2", **kw)
ax.add_patch(arrow)


############
# Biophysical detail arrow
############
kw = dict(arrowstyle=style, color="k")
arrow = mpatches.FancyArrowPatch((-400, -800), (-400, 500),
                             connectionstyle="arc3,rad=-0.", **kw)
ax.add_patch(arrow)
ax.text(-410, -150, 'biophysical detail', rotation='vertical', ha='right', va='center', fontsize=18)


#################
# Bridging scales
##################
# MC<->Point neurons
arrow = mpatches.FancyArrowPatch((5, -375), (5, -250),
                             connectionstyle="arc3,rad=0.2", **kw)
ax.add_patch(arrow)
arrow = mpatches.FancyArrowPatch((-5, -250), (-5, -375),
                             connectionstyle="arc3,rad=0.2", **kw)
ax.add_patch(arrow)
ax.text(30, -312.5, 'parameter\nmapping', fontsize=16, ha='left', va='center')

# Point-neurons <-> Mass models
arrow = mpatches.FancyArrowPatch((5, -725), (5, -600),
                             connectionstyle="arc3,rad=0.2", **kw)
ax.add_patch(arrow)
arrow = mpatches.FancyArrowPatch((-5, -600), (-5, -725),
                             connectionstyle="arc3,rad=0.2", **kw)
ax.add_patch(arrow)
ax.text(30, -662.5, 'parameter\nmapping', fontsize=16, ha='left', va='center')



############################
# Predictions
############################
# MC + VC networks
x = 800
signals = [
    'extracellular potentials',
    'ECoG',
    'current dipole moment',
    'EEG',
    'MEG'
]
ax.text(x-50, len(signals) * 50 + 200, 'multicompartment-neuron\nnetworks:', fontsize=18, weight='bold')
for i, signal in enumerate(signals):
    ax.text(x, i * 50 + 200, '- ' + signal, fontsize=18, color='C2')

signals = [
    'spikes',
    'membrane potentials',
    'axial currents',
    'transmembrane currents',
]
for i, signal in enumerate(signals):
    ax.text(x, i * 50 + 0, '- ' + signal, fontsize=18, color='k')

    
# Point neuron networks
signals = [
    'spikes',
    "'soma' potentials",
]
ax.text(x-50, len(signals) * 50 - 400, 'point-neuron networks:', fontsize=18, weight='bold')
for i, signal in enumerate(signals):
    ax.text(x, i * 50 - 400, '- ' + signal, fontsize=18)

# Neural mass models
signals = [
    'averaged spike rates',
    "averaged 'soma' potentials",
]
ax.text(x-50, len(signals) * 50 - 950, 'population models:', fontsize=18, weight='bold')
for i, signal in enumerate(signals):
    ax.text(x, i * 50 - 950, '- ' + signal, fontsize=18)


########## 
# Signals
##########

# Soma potentials
for j, (label, N, z_) in enumerate(zip(labels, [8, 2], [150, -50])):
    with h5py.File(os.path.join(OUTPUTPATH, 'somav.h5'), 'r') as f:
        V = f[label][:N, -1601:]
    
    for i, x in enumerate(V):
        ax.plot(t + 200, x - i * 25 + z_, colors[j], lw=1)
ax.text(550, 150, r'$V_{\mathrm{m}j}(t)$', fontsize=16, ha='center', va='center')
        
        
# spike trains
T = [params.networkParameters['tstop'] - 100, params.networkParameters['tstop']]
with h5py.File(os.path.join(OUTPUTPATH, 'spikes.h5'), 'r') as f:
    for i, (y, N_y, sgn) in enumerate(zip(labels, [900, 128], [1, -1])):
        times = []
        gids = []

        for g, t_ in zip(f[y]['gids'][()][:N_y], f[y]['times'][:N_y]):
            times = np.r_[times, t_]
            gids = np.r_[gids, np.zeros(t_.size) + g]

        gids = gids[times >= T[0]]
        times = times[times >= T[0]]

        ii = (times >= T[0]) & (times <= T[1])
        t_ = times[ii]
        t_ -= T[0]
        t_ *= 3
        ax.plot(t_ + 400, np.argsort(gids[ii])*sgn - 500, '|', 
                color=colors[i],
                ms=5,
                )
ax.text(550, -250, r'$s_j(t)$', fontsize=16, ha='center', va='center')

# population averaged Vsoma
for j, label in enumerate(labels):
    with h5py.File(os.path.join(OUTPUTPATH, 'somav.h5'), 'r') as f:
        V = np.mean(f[label][:, -101:], axis=0)
    V = (V.T - V.mean(axis=0)).T
    ax.plot(t[::16] + 200, V*20 - j * 50 - 900, colors[j], lw=1)
ax.text(550, -850, r'$V_{\mathrm{m}X}(t)$', fontsize=16, ha='center', va='center')

    
# spike rates
w = ss.windows.gaussian(161, 16)
w /= w.sum()
Delta_t = 2**-4
bins = np.linspace(T[0], T[1], int(np.diff(T) / Delta_t + 1))
with h5py.File(os.path.join(OUTPUTPATH, 'spikes.h5'), 'r') as f:
    for j, Y in enumerate(labels):
        times = []

        for t_ in f[Y]['times']:
            times = np.r_[times, t_]
        
        ii = (times >= T[0]) & (times <= T[1])
        hist, _ = np.histogram(times[ii], bins=bins)
        
        # smoothen rate profiles
        hist = np.convolve(hist, w, 'same')
        hist = hist - hist.mean()
        ax.plot(t[:-1] + 200, hist*50 - j * 50 - 700, color=colors[j], lw=1)
ax.text(550, -650, r'$\nu_X(t)$', fontsize=16, ha='center', va='center')


# BLACK BOX (LFPykernels)
blackbox = mpatches.Rectangle((840, -710), 170, 170, edgecolor='k', facecolor='k', lw=3)
ax.text(925, -625, 'black\nbox', fontsize=24, va='center', ha='center')
ax.add_patch(blackbox)

        
        
# MOAR arrows
# MC -> V_soma
arrow = mpatches.FancyArrowPatch((275, 0), (375, 0), connectionstyle="arc3,rad=0.", **kw)
ax.add_patch(arrow)
# Point -> V_soma
arrow = mpatches.FancyArrowPatch((275, -500), (375, -50), connectionstyle="arc3,rad=0.", **kw)
ax.add_patch(arrow)
# MC -> spikes
arrow = mpatches.FancyArrowPatch((275, 0), (375, -450), connectionstyle="arc3,rad=0.", **kw)
ax.add_patch(arrow)
# Point -> spikes
arrow = mpatches.FancyArrowPatch((275, -500), (375, -500), connectionstyle="arc3,rad=0.", **kw)
ax.add_patch(arrow)
# Pop -> rates
arrow = mpatches.FancyArrowPatch((325, -850), (375, -775), connectionstyle="arc3,rad=0.", **kw)
ax.add_patch(arrow)
# Pop -> V_soma
arrow = mpatches.FancyArrowPatch((325, -850), (375, -925), connectionstyle="arc3,rad=0.", **kw)
ax.add_patch(arrow)

# spikes->blackbox
arrow = mpatches.FancyArrowPatch((725, -500), (825, -615), connectionstyle="arc3,rad=0.2", **kw)
ax.add_patch(arrow)

# rates->blackbox
arrow = mpatches.FancyArrowPatch((725, -700), (825, -635), connectionstyle="arc3,rad=-0.2", **kw)
ax.add_patch(arrow)

# blackbox->approximations
arrow = mpatches.FancyArrowPatch((1025, -625), (1125, -625), connectionstyle="arc3,rad=0.", **kw)
ax.add_patch(arrow)
        
if not os.path.isdir('figures'):
    os.mkdir('figures')
fig.savefig(os.path.join('figures', 'figure01.pdf'), bbox_inches='tight')