# Calculation of shape function for Hay model

In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
from elephant.spike_train_generation import homogeneous_poisson_process
from quantities import Hz, ms
import LFPy
import brainsignals.neural_simulations as ns
from brainsignals.plotting_convention import mark_subplots, simplify_axes

ns.load_mechs_from_folder(ns.cell_models_folder)
np.random.seed(12345)

sigma = 0.3  # S/m
dists = np.logspace(1, 4, 50)
angles = np.linspace(0, 2 * np.pi, 61)[:-1]
heights = np.linspace(-800, 1200, 6)

elec_x = []
elec_y = []
elec_z = []

for d, height in enumerate(heights):
    for i, dist in enumerate(dists):
        elec_x.extend(dist * np.cos(angles))
        elec_y.extend(dist * np.sin(angles))
        elec_z.extend(np.ones(len(angles)) * height)

elec_x = np.array(elec_x)
elec_y = np.array(elec_y)
elec_z = np.array(elec_z)

dists_idxs = {}
for d, height in enumerate(heights):
    dists_idxs[height] = []
    for i, dist in enumerate(dists):
        dist_idxs = np.where((np.abs(np.sqrt(elec_x**2 + elec_y**2) - dist) <= 1e-9) 
                             & (np.abs(elec_z - height) <= 1e-9))[0]
        dists_idxs[height].append(dist_idxs)


In [None]:

# Define electrode parameters
grid_elec_params = {
    'sigma': sigma,      # extracellular conductivity
    'x': elec_x,  # electrode positions
    'y': elec_y,
    'z': elec_z,
    'method': 'linesource'
}

def insert_synaptic_input(idx, cell, spiketimes):

    synapse_parameters = {
                          'weight': 0.1, #  synapse weight
                          'record_current': False, # record synapse current
                          'syntype': 'ExpSynI',
                          'tau': 1, #Time constant, rise
                          }
    synapse_parameters['idx'] = idx
    synapse = LFPy.Synapse(cell, **synapse_parameters)
    synapse.set_spike_times(spiketimes)
    return synapse, cell



### Running all simulations

In [None]:
tstop = 1200
dt = 2**-4
num_syns = 1000
synrate = 5  # Hz
sig_cutoff = 200

num_tsteps = int(tstop / dt + 1)
tvec = np.arange(num_tsteps) * dt

spiketimes = []

for syn_idx in range(num_syns):
    times = homogeneous_poisson_process(synrate * Hz, t_stop=tvec[-1] * ms)
    spiketimes.append(np.array(times))

cell = ns.return_hay_cell(tstop=tstop, dt=dt, make_passive=True)

syn_ranges = {"basal": [np.min(cell.z), np.min(cell.z) + 500],
              "apical": [np.max(cell.z) - 500, np.max(cell.z)],
              "uniform": [np.min(cell.z), np.max(cell.z)],
              }

input_region_clrs = {"basal": 'b',
                     "apical": 'r',
                     "uniform": 'gray'}

syn_idxs_dict = {}
for input_region in syn_ranges.keys():
    syn_idxs_dict[input_region] = cell.get_rand_idx_area_norm(
        section='allsec', z_max=syn_ranges[input_region][1],
        z_min=syn_ranges[input_region][0], 
        nidx=num_syns)

cell.__del__()

grid_electrode = LFPy.RecExtElectrode(cell, **grid_elec_params)
LFPs = {}
shape_functions = {}
t0 = np.argmin(np.abs(tvec - sig_cutoff))

for i, input_region in enumerate(syn_ranges.keys()):
    syn_idxs = syn_idxs_dict[input_region]
    cell = ns.return_hay_cell(tstop=tstop, dt=dt, make_passive=True)
    for idx, syn_idx in enumerate(syn_idxs):
        syn, cell = insert_synaptic_input(syn_idx, cell, spiketimes[idx])
    cell.simulate(rec_imem=True)
    
    LFP = 1000 * grid_electrode.get_transformation_matrix() @ cell.imem[:, t0:]
    LFP_var = np.std(LFP, axis=1)

    shape_functions[input_region] = {}
    for d, height in enumerate(heights):
        shape_function = []
        for i, dist in enumerate(dists):
            shape_function.append(np.mean(LFP_var[dists_idxs[height][i]]))
        shape_functions[input_region][height] = shape_function
    
    cell.__del__()
    
tvec = tvec[t0:] - tvec[t0]


### Plotting results

In [None]:

plt.close("all")
fig = plt.figure(figsize=[6, 4])
fig.subplots_adjust(bottom=0.1, top=0.9, right=0.98,
                    left=0.4, wspace=0.3, hspace=0.3)

ax0 = fig.add_axes([0.0, 0.05, 0.3, 0.9], aspect=1, frameon=False, xticks=[], yticks=[])
#ax1.grid(True)

ax0.plot(cell.x.T, cell.z.T, c='k', lw=1)

ax0.plot([-200, 0], [-250, -250], c='k', lw=1)
ax0.text(-100, -270, "200 µm", va="top", ha="center")

for d, height in enumerate(heights):
    ax0.plot([dists[0], dists[0]+500], [height, height], c='cyan', lw=1, ls='--')
    ax0.text(200, height, "{:d} µm".format(int(height)), va="bottom")
    
for i, input_region in enumerate(syn_ranges.keys()):    
    ax0.plot([-200 - i * 25, -200 - i * 25], syn_ranges[input_region], 
             c=input_region_clrs[input_region])

h_axes = []
for d, height in enumerate(heights[::-1]):    
    ax = fig.add_subplot(3, 2, d + 1, ylim=[1e-6, 2e0])
    ax.set_title("z={:d} µm".format(int(height)), y=0.9)
    h_axes.append(ax)
    for i, input_region in enumerate(syn_ranges.keys()):
        ax.loglog(dists, shape_functions[input_region][height], 
                   c=input_region_clrs[input_region], lw=1.5)

    if height == -400:
        ax.plot([10, 300], [1e-2, 1e-2], c='k', lw=2, ls=':')
        ax.plot([300, 1e4], [1e-2 * 300**2 / 300**2, 
                            1e-2 * 300**2 / (1e4)**2], c='k', lw=2, ls=':')
        ax.axvline(300, ls=':', lw=1, c='k')
    if height == 0:
        ax.plot([10, 60], [np.sqrt(10) / np.sqrt(10), np.sqrt(10)/np.sqrt(60)], c='k', lw=2, ls=':')
        ax.plot([60, 1e4], [np.sqrt(10)/np.sqrt(60) * (60)**2 / (60)**2, 
                            np.sqrt(10)/np.sqrt(60) * (60)**2 /(1e4)**2], c='k', lw=2, ls=':')
        ax.axvline(60, ls=':', lw=1, c='k')
        

    ax.plot([1e3, 1e4], [5e-3, 5e-5], lw=2, c='k', ls='--')
    ax.text(3e3, 10e-4, "1/r$^2$")

    ax.set_xticks([10, 100, 1000, 10000])
    ax.set_yticks([1e-5, 1e-3, 1e-1])

    locmin = matplotlib.ticker.LogLocator(base=10.0, subs=(1, 2, 3, 4, 5, 6,7,8,9,10), numticks=12)
    ax.xaxis.set_minor_locator(locmin)
    if d < 4:
        ax.set_xticklabels("")
    else:
        ax.set_xlabel("radial distance (µm)")
        
    if d % 2 == 0:
        ax.set_ylabel("shape function $F(r)$")
    ax.grid(True)

simplify_axes(fig.axes)
mark_subplots(fig.axes[1], "B", ypos=1.08, xpos=-0.1)
mark_subplots(fig.axes[0], "A", ypos=1.01, xpos=0.1)

fig.savefig("fig_LFP_spatial_decay.pdf")