# Contribution from dendritic calcium spikes to LFPs

In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import elephant
import LFPy
from brainsignals.plotting_convention import mark_subplots, simplify_axes, cmap_v_e
from brainsignals.neural_simulations import return_hay_cell, remove_active_mechanisms
import brainsignals.neural_simulations as ns

np.random.seed(1234)
dt = 2**-5
tstop = 100

num_tsteps = int(tstop / dt + 1)

ca_hotzone_range = [685, 885]  # distance from soma (Hay et al. 2011)

sigma = 0.3
# Create a grid of measurement locations, in (mum)
grid_x, grid_z = np.mgrid[-650:651:25, -650:1101:25]
grid_y = np.zeros(grid_x.shape)

# Define electrode parameters
grid_elec_params = {
    'sigma': sigma,      # extracellular conductivity
    'x': grid_x.flatten(),  
    'y': grid_y.flatten(),
    'z': grid_z.flatten()
}


num_elecs = 14
# Define electrode parameters
elec_params = {
    'sigma': sigma,      # extracellular conductivity
    'x': np.zeros(num_elecs),  # electrode positions
    'y': np.zeros(num_elecs),
    'z': np.linspace(-200, 1200, num_elecs),
    'method': 'root_as_point'
}
dz = np.abs(elec_params["z"][1] - elec_params["z"][0])


In [None]:

num_cells = 100
pop_radius = 100
height_sd = 100
jitter_sd = 10

rs = pop_radius * np.sqrt(np.random.rand(num_cells))
theta = np.random.uniform(0, 2 * np.pi, num_cells)
pop_xs = rs * np.cos(theta)
pop_ys = rs * np.sin(theta)
pop_zs = np.random.normal(0, height_sd, num_cells)
cell_rots = np.random.uniform(0, 2 * np.pi, num_cells)
cell_clrs = lambda idx: plt.cm.Greys(0.4 + idx / num_cells * 0.6)

# Just of plotting convenience:
pop_xs[0] = 0
pop_ys[0] = 50
pop_zs[0] = 0

cell_rots[0] = 0

t_shift = np.random.normal(0, jitter_sd, num_cells)

weight = 0.1

sim_names = ["control", "without Ca"]
remove_mech_dict = {"control": [],
                    "without Ca": ["CaDynamics_E2", 
                                   "Ca_LVAst", "Ca", 
                                   "Ca_HVA"]}
vmem_dict = {}
tvec = None
LFP_dict = {sim_name: np.zeros((num_elecs, num_tsteps)) for sim_name in sim_names}
LFP_dict_ufilt = {sim_name: np.zeros((num_elecs, num_tsteps)) for sim_name in sim_names}

grid_LFP_dict = {}

for sim_name in sim_names:
    cell = return_hay_cell(tstop=tstop, dt=dt, make_passive=False)
    ns.point_axon_down(cell)
    
    remove_active_mechanisms(remove_mech_dict[sim_name], cell)
    cell.set_pos(x=pop_xs[0], y=pop_ys[0], z=pop_zs[0])
    cell.set_rotation(z=cell_rots[0])
    
    plot_idxs = [cell.somaidx[0], cell.get_closest_idx(z=785)]
    
    idx_clr = {idx: ['b', 'orange'][num] for num, idx in enumerate(plot_idxs)}


    synapse = LFPy.Synapse(cell, idx=0,
                           syntype='Exp2Syn', weight=weight,
                           tau1=0.1, tau2=1.)
    synapse.set_spike_times(np.array([25, 30, 35]))

    cell.simulate(rec_imem=True, rec_vmem=True)
    print("MAX |I_mem(soma, apic)|: ", np.max(np.abs(cell.imem[plot_idxs]), axis=1))

    grid_electrode = LFPy.RecExtElectrode(cell, **grid_elec_params)
    grid_LFP = 1e3 * grid_electrode.get_transformation_matrix() @ cell.imem
    grid_LFP -= grid_LFP[:, 0, None]
    grid_LFP_dict[sim_name] = grid_LFP
    
    vmem_dict[sim_name] = cell.vmem.copy()
    tvec = cell.tvec.copy()
    elec_LFP = np.zeros((num_elecs, num_tsteps * 3))

    morph_data = []

    for cell_idx in range(num_cells):
        cell.set_pos(x=pop_xs[cell_idx], y=pop_ys[cell_idx], z=pop_zs[cell_idx])
        cell.set_rotation(z=cell_rots[cell_idx])
        elec = LFPy.RecExtElectrode(cell, **elec_params)

        t_shift_idx = int(t_shift[cell_idx] / dt)

        t0 = num_tsteps + t_shift_idx
        t1 = t0 + len(cell.tvec)
        elec_LFP[:, t0:t1] += 1e3 * elec.get_transformation_matrix() @ cell.imem

        morph_data.append({
            "cell_x": cell.x.copy(),
            "cell_z": cell.z.copy()
            })

    elec_LFP -= elec_LFP[:, 0, None]
    filt_dict_low_pass = {'highpass_freq': None,
                 'lowpass_freq': 300,
                 'order': 4,
                 'filter_function': 'filtfilt',
                 'fs': 1 / dt * 1000,
                 'axis': -1}
    
    LFP_dict_ufilt[sim_name] = elec_LFP[:, num_tsteps:(2*num_tsteps)].copy()
    elec_LFP = elephant.signal_processing.butter(elec_LFP, **filt_dict_low_pass)
    
    LFP_dict[sim_name] = elec_LFP[:, num_tsteps:(2*num_tsteps)]
    
    del cell
    del synapse
    del elec


In [None]:
grid_plot_times = [25.7, 38.5]

ylim = [-300, 1500]

fig = plt.figure(figsize=[6, 4])
fig.subplots_adjust(hspace=0.5, left=0.0, wspace=0.4, right=0.96,
                    top=0.97, bottom=0.17)

ax_m = fig.add_axes([0.0, 0.6, 0.12, 0.4], aspect=1, frameon=False,
                    xticks=[], yticks=[], ylim=ylim)
ax_pop = fig.add_axes([0.0, 0.08, 0.12, 0.4], aspect=1, frameon=False,
                    xticks=[], yticks=[], ylim=ylim)


ax_ca1_grid = fig.add_axes([0.63, 0.47, 0.17, 0.52], xlim=[-350, 350],
                           aspect=1, frameon=False, xticks=[], yticks=[])
ax_ca2_grid = fig.add_axes([0.75, 0.47, 0.17, 0.52], xlim=[-350, 350],
                           aspect=1, frameon=False, xticks=[], yticks=[])
ax_vm_ca = fig.add_axes([0.21, 0.6, 0.18, 0.35], xlabel="time (ms)", title="control",
                        ylim=[-80, 50], xlim=[0, 100])
ax_vm_nca = fig.add_axes([0.46, 0.6, 0.18, 0.35], xlabel="time (ms)", title="without Ca",
                         ylim=[-80, 50], xlim=[0, 100])

ax_ca1_grid.set_title("t1", y=0.93, color="gray")
ax_ca2_grid.set_title("t2", y=0.93, color="gray")

ax_ca = fig.add_axes([0.23, 0.08, 0.27, 0.37], ylim=ylim, title="control",
                     xlabel="time (ms)", ylabel="height (µm)", xlim=[20, 75])
ax_nca = fig.add_axes([0.6, 0.08, 0.27, 0.37], ylim=ylim, title="without Ca",
                      xlabel="time (ms)", ylabel="height (µm)", xlim=[20, 75])

t1_idx = np.argmin(np.abs(grid_plot_times[0] - tvec))
t2_idx = np.argmin(np.abs(grid_plot_times[1] - tvec))
grid_LFP_1 = grid_LFP_dict["control"][:, t1_idx].reshape(grid_x.shape)
grid_LFP_2 = grid_LFP_dict["control"][:, t2_idx].reshape(grid_x.shape)

vmax = 10
num = 11
levels = np.logspace(-1.5, 0, num=num)

levels_norm = vmax * np.concatenate((-levels[::-1], levels))

colors_from_map = [cmap_v_e(i / (len(levels_norm) - 2))
                   for i in range(len(levels_norm) - 1)]
colors_from_map[num - 1] = (1.0, 1.0, 1.0, 1.0)

ep_intervals = ax_ca1_grid.contourf(grid_x, grid_z, grid_LFP_1,
                               zorder=-2, colors=colors_from_map,
                               levels=levels_norm, extend='both')

ep_intervals = ax_ca2_grid.contourf(grid_x, grid_z, grid_LFP_2, colors=colors_from_map,
                                    levels=levels_norm,
                           zorder=-2,  extend='both')

cax = fig.add_axes([0.91, 0.53, 0.01, 0.4], frameon=False)
cbar = fig.colorbar(ep_intervals, cax=cax, label=r'$V_{\rm e}$ (µV)')
cbar.set_ticks([-vmax, -vmax/10, 0, vmax/10, vmax])
cbar.set_label(r'$V_{\rm e}$ (µV)', labelpad=2)

plot_cell_idx = 0
ax_m.plot(morph_data[plot_cell_idx]["cell_x"].T, 
              morph_data[plot_cell_idx]["cell_z"].T, 
              c='k', lw=1)
ax_ca1_grid.plot(morph_data[plot_cell_idx]["cell_x"].T, 
              morph_data[plot_cell_idx]["cell_z"].T, 
              c='k', lw=1)
ax_ca2_grid.plot(morph_data[plot_cell_idx]["cell_x"].T, 
              morph_data[plot_cell_idx]["cell_z"].T, 
              c='k', lw=1)

# Indicate Ca hot-zone:
ca_hotzone_z = [morph_data[plot_cell_idx]["cell_z"][0].mean() + ca_hotzone_range[0],
               morph_data[plot_cell_idx]["cell_z"][0].mean() + ca_hotzone_range[1]]
ax_m.plot([-50, -50], ca_hotzone_z, 'r')

[ax_vm_ca.axvline(tvec[t_idx], zorder=10, c='gray', ls='--') for t_idx in [t1_idx, t2_idx]]
ax_vm_ca.text(tvec[t1_idx] - 3, 45, "t1", ha="right", c='gray')
ax_vm_ca.text(tvec[t2_idx] + 3, 45, "t2", ha="left", c='gray')

for cell_idx in range(num_cells)[:]:
    #if cell_idx > 10:
    #    break
    ax_pop.plot(morph_data[cell_idx]["cell_x"].T, 
              morph_data[cell_idx]["cell_z"].T, 
              c=cell_clrs(cell_idx), lw=1, 
              zorder=np.random.randint(100),
              rasterized=True, clip_on=False)

ax_pop.plot(elec_params["x"], elec_params["z"], 'o', c='lightseagreen', ms=4, zorder=200)

vmax = np.max(np.abs(LFP_dict["control"]))

img = ax_ca.imshow(LFP_dict_ufilt["control"], origin="lower", vmax=vmax, vmin=-vmax, rasterized=True,
           cmap=cmap_v_e, extent=[0, tvec[-1], 
                                  elec_params["z"][0], 
                                  elec_params["z"][-1]])
img2 = ax_nca.imshow(LFP_dict_ufilt["without Ca"], origin="lower", vmax=vmax, vmin=-vmax, rasterized=True,
           cmap=cmap_v_e, extent=[0, tvec[-1], 
                                  elec_params["z"][0], 
                                  elec_params["z"][-1]])

ax_ca.axis("auto")
ax_nca.axis("auto")
print("Max LFP: ", vmax)

cax = fig.add_axes([0.88, 0.07, 0.01, 0.37], frameon=False)
cbar = fig.colorbar(img, cax=cax, label=r'V$_{\rm e}$ (µV)')
cbar.set_ticks([-150, -100, -50, 0, 50, 100, 150])
cbar.set_label(r'$V_{\rm e}$ (µV)', labelpad=2)

ax_vm_ca.set_ylabel(r"$V_{\rm m}$ (mV)", labelpad=-1)
ax_vm_ca.text(65, -5, "Ca$^{2+}$\nspike", fontsize=11, color='orange')
ax_vm_ca.arrow(70, -10, -7, -10, color='orange', head_width=4)

mark_subplots(ax_m, "A", xpos=0.05, ypos=0.95)
mark_subplots(ax_vm_ca, "B", xpos=-0.3, ypos=1.07)
mark_subplots(ax_vm_nca, "C", xpos=-0.3, ypos=1.07)
mark_subplots(ax_ca1_grid, "D", xpos=0.07, ypos=0.97)
mark_subplots(ax_ca2_grid, "E", xpos=0.07, ypos=0.97)
mark_subplots(ax_pop, "F", xpos=0.05, ypos=1.00)
mark_subplots(ax_ca, "G", xpos=-0.1, ypos=1.07)
mark_subplots(ax_nca, "H", xpos=-0.1, ypos=1.07)

for idx in plot_idxs: 
    ax_vm_ca.plot(tvec, vmem_dict["control"][idx], c=idx_clr[idx])
    ax_vm_nca.plot(tvec, vmem_dict["without Ca"][idx], c=idx_clr[idx])
    ax_m.plot(morph_data[0]["cell_x"][idx].mean(), morph_data[0]["cell_z"][idx].mean(), 'o', c=idx_clr[idx])

simplify_axes([ax_vm_ca, ax_vm_nca, ax_ca, ax_nca])

plt.savefig('fig_ca_spike_100_cells_ufilt.pdf')


In [None]:
divide_into_welch = 1
welch_dict = {'Fs': 1000 / dt,
              'NFFT': int(num_tsteps/divide_into_welch),
              'noverlap': int(num_tsteps/divide_into_welch/2.),
              'detrend': 'mean',
              'scale_by_freq': True,
              }

fig = plt.figure()
ax1 = fig.add_subplot(121, title="soma region", xlabel="frequency (Hz)", ylabel="PSD (µV²/Hz)", xlim=[1, 500])
ax2 = fig.add_subplot(122, title="apical region", xlabel="frequency (Hz)", ylabel="PSD (µV²/Hz)", xlim=[1, 500])

freqs_wca, psd_lfp_wca = ns.return_freq_and_psd_welch(np.pad(LFP_dict["without Ca"], pad_width=((0, 0), (2000, 2000))), welch_dict)
freqs_ctr, psd_lfp_ctr = ns.return_freq_and_psd_welch(np.pad(LFP_dict["control"], pad_width=((0, 0), (2000, 2000))), welch_dict)

elec_idxs = [2, 8]
clrs = ['k', 'r']

ax1.loglog(freqs_ctr[1:], psd_lfp_ctr[2][1:], c='k')
ax1.loglog(freqs_wca[1:], psd_lfp_wca[2][1:], c='r')

l_ctr, = ax2.loglog(freqs_ctr[1:], psd_lfp_ctr[8][1:], c='k')
l_wca, = ax2.loglog(freqs_wca[1:], psd_lfp_wca[8][1:], c='r')

ax2.legend([l_ctr, l_wca], ["control", "without Ca"], frameon=False)

simplify_axes([ax1, ax2])
