# Spatial variation of spikes from Hay model

In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.collections import PolyCollection
import LFPy
from brainsignals.plotting_convention import mark_subplots, simplify_axes
from brainsignals.neural_simulations import return_hay_cell
import brainsignals.neural_simulations as ns

np.random.seed(12345)

In [None]:
def insert_current_stimuli(cell):
    stim_params = {'amp': -0.4,
                   'idx': 0,
                   'pptype': "ISyn",
                   'dur': 1e9,
                   'delay': 0}
    ns.load_mechs_from_folder(ns.cell_models_folder)
    synapse = LFPy.StimIntElectrode(cell, **stim_params)
    return synapse, cell

In [None]:
tstop = 150
dt = 2**-6
cell = return_hay_cell(tstop=tstop, dt=dt, make_passive=False)

# Original axon is awkwardly pointing up and left. This is fixed for visual reasons
ns.point_axon_down(cell)  
syn, cell = insert_current_stimuli(cell)
cell.simulate(rec_imem=True, rec_vmem=True)

In [None]:
xmin, xmax = [-50, 70]
zmin, zmax = [-50, 130]

dx = 20
dz = 20
x_grid, z_grid = np.mgrid[xmin:xmax+dx:dx, zmin:zmax+dz:dz]
num_elecs = len(x_grid.flatten())
elec_grid_params = dict(
            sigma = 0.3,      # extracellular conductivity
            x = x_grid.flatten(),
            y = np.zeros(num_elecs),
            z = z_grid.flatten(),
            method = 'root_as_point',
        )

eap_idxs = np.where((np.abs(elec_grid_params["z"] - 10) < 1e-9) & (elec_grid_params["x"] > 0))[0]

eap_clrs = {idx: plt.cm.Reds_r(num / (len(eap_idxs))) for num, idx in enumerate(eap_idxs)}

# Time window to extract spike from:
t0 = 118
t1 = 123
t0_idx = np.argmin(np.abs(cell.tvec - t0))
t1_idx = np.argmin(np.abs(cell.tvec - t1))


elec = LFPy.RecExtElectrode(cell, **elec_grid_params)
M_elec = elec.get_transformation_matrix()
eaps = M_elec @ cell.imem[:, t0_idx:t1_idx] * 1000

vmem = cell.vmem[:, t0_idx:t1_idx]
tvec = cell.tvec[t0_idx:t1_idx] - cell.tvec[t0_idx]

In [None]:
plt.close("all")
fig = plt.figure(figsize=[6, 5])


ax_morph = fig.add_axes([0.01, 0.01, 0.65, 0.98], frameon=False, aspect=1,
                        xticks=[], yticks=[], xlim=[xmin - 5, xmax + 15],
                        ylim=[zmin - 10, zmax + 5])

ax_vm = fig.add_axes([0.69, 0.51, 0.3, 0.4], title="membrane\npotential",
                      frameon=False, xticks=[])

ax_eap = fig.add_axes([0.69, 0.02, 0.3, 0.4], title="normalized spikes", 
                      frameon=False, xticks=[], yticks=[])

for n, elec_idx in enumerate(eap_idxs[::-1]):
    c = eap_clrs[elec_idx]
    eap_ = eaps[elec_idx] - eaps[elec_idx, 0]
    eap_norm = eap_ / np.max(np.abs(eap_))
    ax_eap.plot(tvec, eap_norm, c=c, lw=1.5)
    x = int(elec_grid_params["x"][elec_idx])
    ax_eap.text(2.5, -0.7 + n * 0.12, "x={:d} µm".format(x), c=c)
zips = []
for x, z in cell.get_pt3d_polygons():
    zips.append(list(zip(x, z)))
polycol = PolyCollection(zips, edgecolors='none',
                         facecolors='0.8', zorder=-1, rasterized=False)
ax_morph.add_collection(polycol)

lines = []
line_names = []

#print(np.max(np.abs(eaps)))
eap_norm = dz * 0.9 / np.max(np.abs(eaps))
t_norm = tvec / tvec[-1] * dz * 0.7
for elec_idx in range(num_elecs):
    c = eap_clrs[elec_idx] if elec_idx in eap_idxs else 'k'
    x, z = elec.x[elec_idx], elec.z[elec_idx]
    ax_morph.plot(x, z, '.', c='k', ms=7)
    eap = (eaps[elec_idx] - eaps[elec_idx, 0]) * eap_norm
    ax_morph.plot(x + t_norm, z + eap, c=c, lw=1.5)

l, = ax_vm.plot(tvec, vmem[0, :], c='k', lw=1.5)

ax_vm.plot([3, 3], [-50, 0], c='k', lw=2)
ax_vm.text(3.2, -25, "50\nmV")

v0 = int(vmem[0, 0])

ax_morph.plot([20, 40], [15, 15], c='gray', lw=2)
ax_morph.text(30, 17, "20 µm", ha="center", c='gray')

ax_morph.plot([82, 82], [-10 - 300 * eap_norm, -10], c='k', lw=2,
              clip_on=False)
ax_morph.text(79, -10 - 300 * eap_norm / 2, "300 µV", ha="right",
              c='k', va="center")

ax_vm.plot([1.5, 2.5], [v0, v0], c='k', lw=2)
ax_vm.text(2., v0 - 2, "1 ms", va="top", ha='center')

ax_eap.plot([2.5, 3.5], [-0.85, -0.85], c='k', lw=2)
ax_eap.text(3., - 0.9, "1 ms", va="top", ha='center')

ax_vm.set_yticks([v0])
ax_vm.set_yticklabels(["{:d}\nmV".format(v0)])

fig.legend(lines, line_names, loc="lower center", frameon=False, ncol=2)
mark_subplots(fig.axes[0], 'A', ypos=0.98, xpos=0.0)
mark_subplots(fig.axes[1:], 'BCDE', xpos=0.05, ypos=1.15)
simplify_axes(fig.axes)

fig.savefig("fig_hay_eap_simpler.pdf")