# GLM model for tectum

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import scipy
import pandas as pd

In [2]:
params = scipy.io.loadmat("./optimized_params.mat")['params'].squeeze()

In [3]:
wc = params['wc'].flatten()[0].astype(np.float32) # neuron locations
frame_dur = np.float32(params['frame_dur'])
steps_per_frame = np.float32(params['steps_per_frame'])
ex_sigma = np.float32(params['ex_sigma'])
in_sigma = np.float32(params['in_sigma'])
ex_tau_sec = np.float32(params['ex_tau_sec'])
in_tau_sec = np.float32(params['in_tau_sec'])
ex_gain = np.float32(params['ex_gain'])
in_gain = np.float32(params['in_gain'])
run_steps = np.float32(params['run_steps'])
hemisphere_correction = np.float32(params['hemisphere_correction'])
dc = np.float32(params['dc'])
random_seed = np.float32(params['random_seed'])
transient_steps = np.float32(params['transient_steps'])
dtSp = frame_dur/steps_per_frame
ex_tau_steps = ex_tau_sec/dtSp
in_tau_steps = in_tau_sec/dtSp
inhibition_downsample = 100
iht_ex=np.arange(0, ex_tau_steps*5, dtype=np.float32)
iht_in=np.arange(0, in_tau_steps*5, inhibition_downsample, dtype=np.float32)
ihcpl_ex = ex_gain*np.exp(-iht_ex/ex_tau_steps)
ihcpl_in = in_gain*np.exp(-iht_in/in_tau_steps)
n_neuron = wc.shape[0]

In [4]:
slen = int(run_steps+transient_steps)
Istm = np.zeros((slen, n_neuron), dtype=np.float32)
initial_inhibition = (10*(-np.exp(-(np.arange(1, transient_steps+1))/(transient_steps/5)))).T
Istm[:int(transient_steps), :] = np.repeat(initial_inhibition[:, None], axis=1, repeats=n_neuron)
Istm = Istm + dc # add dc

In [5]:
# generate weight masks
left_inds = wc[:, 1] > 309
right_inds = wc[:, 1] <= 309
w_ex = np.zeros((n_neuron, n_neuron))
w_in = np.zeros((n_neuron, n_neuron))

for jj in range(n_neuron):
    cl1 = wc[jj, :]
    dists = np.sqrt((cl1[0]-wc[:, 0])**2 + (cl1[1]-wc[:, 1])**2 + (cl1[2]-wc[:, 2])**2)
    if left_inds[jj]:
        same_hemisphere = left_inds
    else:
        same_hemisphere = right_inds
    hc = np.zeros(n_neuron)
    hc[same_hemisphere] = 1
    hc[~same_hemisphere] = hemisphere_correction

    w_ex[jj, :] = np.exp(-dists**2 / (2 * ex_sigma**2)) * hc
    w_in[jj, :] = np.exp(-dists**2 / (2 * in_sigma**2)) * hc

## why this is so slow? 
# hemisphere_mask = np.where(wc[:, 1] > 309, 1, -1)
# hemisphere_mask = hemisphere_mask[:, None] @ hemisphere_mask[None, :]
# hemisphere_mask = np.where(hemisphere_mask == 1, 1, hemisphere_correction)
# pairwise_dist = scipy.spatial.distance.squareform(scipy.spatial.distance.pdist(wc))
# w_ex2 = np.exp(-pairwise_dist**2 / (2 * ex_sigma**2)) * hemisphere_mask
# w_in2 = np.exp(-pairwise_dist**2 / (2 * in_sigma**2)) * hemisphere_mask

In [6]:
ex_filt = ihcpl_ex[None, :]
in_filt = ihcpl_in[None, :]
membrane_potential = np.zeros((n_neuron, 1), dtype=np.float32) # current membrane potential for each neuron
I_ex = Istm.T
rlen_in = int(I_ex.shape[1]/inhibition_downsample)
I_in = np.zeros((I_ex.shape[0], rlen_in+2), dtype=np.float32)
linear_drive = np.zeros((I_ex.shape[0], slen), dtype=np.float32)
n_spikes_neuron = np.zeros((n_neuron, 1), dtype=np.float32)
n_spikes_time = np.zeros((1, slen), dtype=np.float32)
dt = dtSp
next_spike_time = np.random.exponential(1, size=(n_neuron, 1))
spike_times = [[] for i in range(n_neuron)]

In [7]:
from tqdm import tqdm
tbar = tqdm(range(slen), desc='Simulating', total=slen)

left_edge = 0
nbinsPerEval = 100 
while left_edge < slen: # do the simulation in chunks
    chunk_inds = np.arange(left_edge, min(left_edge+nbinsPerEval, slen)) # indices of the current chunk
    next_weight = left_edge/inhibition_downsample - left_edge//inhibition_downsample
    interp_inhibition = (1-next_weight) * I_in[:, left_edge//inhibition_downsample] + next_weight * I_in[:, left_edge//inhibition_downsample + 1]
    total_input = I_ex[:, chunk_inds] - interp_inhibition[:, None] # linear drive
    linear_drive[:, chunk_inds] = total_input
    instantaneous_potential = np.exp(total_input) # instantaneous firing rate (membrane potential)
    cumulative_potential = membrane_potential + np.cumsum(instantaneous_potential, axis=1)*dt  # Cumulative intensity

    next_spike_time_expanded = np.repeat(next_spike_time, axis=1, repeats=len(chunk_inds)) # expand to the size of the cumulative potential, along the time axis
    cell_idx, time_idx = np.where(cumulative_potential - next_spike_time_expanded > 0) # time to next spike
    if len(cell_idx) <= 0:
        left_edge = chunk_inds[-1]+1 # skip the whole chunk
        membrane_potential = cumulative_potential[:, [-1]]
    else: # have spike
        spike_cells = np.unique(cell_idx[time_idx == min(time_idx)]) # cell inds with the mimimum time to next spike
        actual_spike_time = chunk_inds[min(time_idx)] # actual spike time
        membrane_potential = cumulative_potential[:, [min(time_idx)]]

        if actual_spike_time >= slen:
            left_edge = actual_spike_time
            break
        # deal with the current due to this spike
        excitatory_border = min(slen, actual_spike_time+ex_filt.shape[1])
        actual_spike_time_inhibitory = actual_spike_time // inhibition_downsample
        inhibitory_border = min(slen, actual_spike_time_inhibitory+in_filt.shape[1])

        # use the formula I_spatial @ Spikes @ I_temporal, (N, #spike_cells) @ (#spike_cells, 1) @ (1, #time_steps)
        outgoing_total_excitatory = w_ex[:, spike_cells] @ np.ones((len(spike_cells),1), dtype=np.float32) @ ex_filt
        outgoing_total_inhibitory = w_in[:, spike_cells] @ np.ones((len(spike_cells),1), dtype=np.float32) @ in_filt
        I_ex[:, (actual_spike_time+1):(excitatory_border+1)] += outgoing_total_excitatory[:, 0:(excitatory_border-actual_spike_time)]
        next_weight = actual_spike_time / inhibition_downsample - actual_spike_time_inhibitory
        I_in[:, (actual_spike_time_inhibitory+1): (inhibitory_border+1)] += (1-next_weight)*outgoing_total_inhibitory[:, 0: (inhibitory_border-actual_spike_time_inhibitory)]        
        I_in[:, (actual_spike_time_inhibitory+2): (inhibitory_border+2)] += (1-next_weight)*outgoing_total_inhibitory[:, 0: (inhibitory_border-actual_spike_time_inhibitory)]
        
        # collect spike times
        for i in spike_cells:
            spike_times[i].append(actual_spike_time*dt)
        
        # reset neuron potential
        membrane_potential[spike_cells] = 0
        n_spikes_neuron[spike_cells] += 1
        next_spike_time[spike_cells] = np.random.exponential(1, size=(len(spike_cells), 1))
        n_spikes_time[:, actual_spike_time] = len(spike_cells)
        left_edge = actual_spike_time+1

    muISI = left_edge/(sum(n_spikes_neuron))
    nbinsPerEval = max(20, np.round(1.5*muISI))
    tbar.update(1)
    tbar.set_description(f'Bin: {left_edge} Spiking cells {len(spike_cells)}')


Simulating:   0%|          | 0/20000 [00:00<?, ?it/s]

  instantaneous_potential = np.exp(total_input) # instantaneous firing rate (membrane potential)
Bin: 1978 Spiking cells 664:   4%|▎         | 730/20000 [01:05<20:33, 15.63it/s]  

In [15]:
n_spikes_time.shape

(1, 20000)