# Comparison of spikes from different neuron models

In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.collections import PolyCollection
import LFPy
import neuron
from neuron import h
import brainsignals.neural_simulations as ns
from brainsignals.plotting_convention import mark_subplots
from brainsignals.neural_simulations import return_hay_cell

ns.load_mechs_from_folder(ns.cell_models_folder)
np.random.seed(12345)

tstop = 150
dt = 2**-6

# Time window to extract spike from:
t0 = 118
t1 = 123

In [None]:
def insert_current_stimuli(cell):
    stim_params = {'amp': -0.4,
                   'idx': 0,
                   'pptype': "ISyn",
                   'dur': 1e9,
                   'delay': 0}
    synapse = LFPy.StimIntElectrode(cell, **stim_params)
    return synapse, cell

In [None]:
def return_electrode_grid():
    xmin, xmax = [-50, 70]
    zmin, zmax = [-50, 130]

    dx = 20
    dz = 20
    x_grid, z_grid = np.mgrid[xmin:xmax+dx:dx, zmin:zmax+dz:dz]
    num_elecs = len(x_grid.flatten())
    elec_grid_params = dict(
                sigma = 0.3,      # extracellular conductivity
                x = x_grid.flatten(),
                y = np.zeros(num_elecs),
                z = z_grid.flatten(),
                method = 'linesource',
            )
    return elec_grid_params



def plot_results(cell, figname, figtitle, subplot_marker):

    elec_grid_params = return_electrode_grid()

    elec = LFPy.RecExtElectrode(cell, **elec_grid_params)
    M_elec = elec.get_transformation_matrix()
    eaps = M_elec @ cell.imem * 1000

    xmin = np.min(elec_grid_params["x"])
    xmax = np.max(elec_grid_params["x"])
    zmin = np.min(elec_grid_params["z"])
    zmax = np.max(elec_grid_params["z"])

    eap_idxs = np.where((np.abs(elec_grid_params["z"] + 10) < 1e-9) &
                        (elec_grid_params["x"] > 0))[0]

    eap_clrs = {idx: plt.cm.Reds_r(num / (len(eap_idxs)))
                for num, idx in enumerate(eap_idxs)}

    fig = plt.figure(figsize=[2, 3.5])

    fig.suptitle(figtitle)
    ax_morph = fig.add_axes([0.05, 0.22, 0.9, 0.68], frameon=False, 
                            aspect=1,
                            xticks=[], yticks=[], 
                            xlim=[xmin - 5, xmax + 10],
                            ylim=[zmin - 10, zmax + 5])

    ax_eap = fig.add_axes([0.1, 0.01, 0.85, 0.2],
                          frameon=False, xticks=[], yticks=[],
                          ylim=[-1.05, 0.5])

    for n, elec_idx in enumerate(eap_idxs[::-1]):
        c = eap_clrs[elec_idx]
        eap_norm = (eaps[elec_idx] - eaps[elec_idx, 0]) / np.max(np.abs(eaps[elec_idx]))
        ls = '-'# if n == (len(eap_idxs) - 1) else '-'
        ax_eap.plot(cell.tvec, eap_norm, c=c, lw=1, ls=ls)
        x = int(elec_grid_params["x"][elec_idx])
        ax_eap.text(2.5, -1.05 + n * 0.25, "x={:d} µm".format(x), c=c)

    zips = []
    for x, z in cell.get_pt3d_polygons():
        zips.append(list(zip(x, z)))
    polycol = PolyCollection(zips, edgecolors='none',
                             facecolors='0.8', zorder=-1, rasterized=False)
    ax_morph.add_collection(polycol)

    dz = np.abs(np.diff(elec.z))[0]
    num_elecs = len(elec.x)
    eap_norm = dz * 0.9 / np.max(np.abs(eaps))
    t_norm = cell.tvec / cell.tvec[-1] * dz * 0.7
    for elec_idx in range(num_elecs):
        c = eap_clrs[elec_idx] if elec_idx in eap_idxs else 'k'
        x, z = elec.x[elec_idx], elec.z[elec_idx]
        ax_morph.plot(x, z, '.', c='k', ms=3)
        eap = eaps[elec_idx] * eap_norm
        ax_morph.plot(x + t_norm, z + eap, c=c, lw=1)

    ax_morph.plot([20, 40], [15, 15], c='gray', lw=1)
    ax_morph.text(30, 17, "20 µm", ha="center", c='gray')

    ax_morph.plot([82, 82], [10 - 300 * eap_norm / 2, 10 + 300 * eap_norm / 2], c='k', lw=1,
                  clip_on=False)
    ax_morph.text(79, 0, "300 µV", ha="right",
                  c='k', va="center")

    mark_subplots(ax_morph, subplot_marker, xpos=-0.04, ypos=1.11)
    
    ax_eap.text(-0.8, 0.25, "normalized spikes")
    ax_eap.plot([0., 1.], [-0.55, -0.55], c='k', lw=1)
    ax_eap.text(0.5, - 0.6, "1 ms", va="top", ha='center')

    fig.savefig("{}.pdf".format(figname))

## First we run original simulation with active conductances.  The somatic membrane potential from this simulation will be replayed into simpler models 

In [None]:
cell = return_hay_cell(tstop=tstop, dt=dt, make_passive=False)
ns.point_axon_down(cell)
syn, cell = insert_current_stimuli(cell)
cell.simulate(rec_imem=True, rec_vmem=True)
t0_idx = np.argmin(np.abs(cell.tvec - t0))
t1_idx = np.argmin(np.abs(cell.tvec - t1))

cell.vmem = cell.vmem[:, t0_idx:t1_idx]
cell.imem = cell.imem[:, t0_idx:t1_idx]
cell.tvec = cell.tvec[t0_idx:t1_idx] - cell.tvec[t0_idx]

np.save("somatic_vmem.npy", [cell.tvec, cell.vmem[0, :]])
np.save("imem_orig.npy", cell.imem)
np.save("vmem_orig.npy", cell.vmem)

fig = plt.figure(figsize=[1.1, 1.5])
fig.subplots_adjust(left=0.45, top=0.8, right=0.95, bottom=0.05)

ax_vm = fig.add_subplot(111, title="membrane\npotential",
                       frameon=False, xticks=[], xlim=[0, 5])
l, = ax_vm.plot(cell.tvec, cell.vmem[0, :], c='k', lw=1)
ax_vm.plot([-0.3, -0.3], [-40, 10], c='k', lw=1, clip_on=False)
ax_vm.text(-0.5, -20, "50 mV", ha='right')
mark_subplots(ax_vm, "A", xpos=-0.2)

v0 = int(cell.vmem[0, 0])

ax_vm.plot([1.5, 2.5], [v0 - 1, v0 - 1], c='k', lw=1)
ax_vm.text(1.5, v0 - 2, "1 ms", va="top", ha='center')
ax_vm.set_yticks([v0])
ax_vm.set_yticklabels(["{:d} mV".format(v0)])
plt.savefig("somatic_vmem.pdf")
plt.close("all")

plot_results(cell, "fig_hay_orig_spike", "original active", 'A')
cell.__del__()

## Then we replay the simulated somatic membrane potential into the somas of different models. We start with the passive hay-model

In [None]:
soma_t, soma_vmem = np.load("somatic_vmem.npy")

cell = return_hay_cell(tstop=soma_t[-1], dt=dt, make_passive=False)
ns.point_axon_down(cell)
remove_list = ["Nap_Et2", "NaTa_t", "NaTs2_t", "SKv3_1",
               "SK_E2", "K_Tst", "K_Pst",
               "Im", "Ih", "CaDynamics_E2", "Ca_LVAst", "Ca", "Ca_HVA"]
cell = ns.remove_active_mechanisms(remove_list, cell)
h.dt = dt

for sec in neuron.h.allsec():
    if "soma" in sec.name():
        print("g_pas: {}, e_pas: {}, cm: {}, "
              "Ra: {}, soma_diam: {}, soma_L: {}".format(sec.g_pas,
                                                         sec.e_pas, sec.cm,
                                                         sec.Ra, sec.diam,
                                                         sec.L))
        print("Inserting vclamp")
        vclamp = h.SEClamp_i(sec(0.5))
        vclamp.dur1 = 1e9
        vclamp.rs = 1e-9
        vmem_to_insert = h.Vector(soma_vmem[1:])
        vmem_to_insert.play(vclamp._ref_amp1, h.dt)

cell.simulate(rec_imem=True, rec_vmem=True)

plot_results(cell, "fig_hay_replay_spike", "reconstructed neuron", "A")
cell.__del__()

## Now we try a ball and stick neuron

In [None]:
cell = ns.return_ball_and_stick_cell(soma_t[-1], dt, apic_diam=4)
for sec in neuron.h.allsec():
    # Insert same passive params as Hay model
    sec.g_pas = 3.38e-05
    sec.e_pas = -90
    sec.cm = 1.0
    sec.Ra = 100
h.dt = dt

for sec in neuron.h.allsec():
    if "soma" in sec.name():
        print("Inserting vclamp")
        vclamp = h.SEClamp_i(sec(0.5))
        vclamp.dur1 = 1e9
        vclamp.rs = 1e-9
        vmem_to_insert = h.Vector(soma_vmem)
        vmem_to_insert.play(vclamp._ref_amp1, h.dt)

cell.simulate(rec_imem=True, rec_vmem=True)
plot_results(cell, "fig_ball_and_stick_replay_spike", "ball-and-stick", "B")
cell.__del__()

## Finally, we try two-compartment neuron

In [None]:
cell = ns.return_two_comp_cell(soma_t[-1], dt)
h.dt = dt
for sec in neuron.h.allsec():
    # Insert same passive params as Hay model
    sec.g_pas = 3.38e-05
    sec.e_pas = -90
    sec.cm = 1.0
    sec.Ra = 100
for sec in neuron.h.allsec():
    if "soma" in sec.name():
        print("Inserting vclamp")
        vclamp = h.SEClamp_i(sec(0.5))
        vclamp.dur1 = 1e9
        vclamp.rs = 1e-9
        vmem_to_insert = h.Vector(soma_vmem)
        vmem_to_insert.play(vclamp._ref_amp1, h.dt)

cell.simulate(rec_imem=True, rec_vmem=True)

plot_results(cell, "fig_two_comp_replay_spike", "two-compartment", "C")
cell.__del__()