# Import

In [None]:
# add path to the current directory
%load_ext autoreload
%autoreload 2

import sys
import os
sys.path.append(os.getcwd())

# switch to the simulation directory
# Verify the current directory
print("Current directory:", os.getcwd())

%load_ext autoreload
%autoreload 2

import sys
# assert('neuron' not in sys.modules)
import os
nrn_options = "-nogui -NSTACK 100000 -NFRAME 20000"
os.environ["NEURON_MODULE_OPTIONS"] = nrn_options
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
import matplotlib.animation as animation
import time
import math

import numpy as np
np.random.seed(237)
import matplotlib.pyplot as plt
from skopt.plots import plot_gaussian_process
from functools import partial
from skopt.plots import plot_convergence
from skopt import gp_minimize
from skopt.space import Real


from brian2 import *
from utils import set_params_utils, eqs_utils, plotting_utils, obj_func_utils

# vanilla

## params

In [None]:
# populations
N = 700
N_E = int(N * 0.8)  # pyramidal neurons
N_I = int(N * 0.2)  # interneurons
f = 0.1
p = 3
C_ext = 800
DC_amp = 0

# external stimuli
# rate = 3 * Hz # in external noise
currents_to_track = ['I_syn', 'I_AMPA_ext', 'I_GABA_rec', 'I_AMPA_rec', 'I_NMDA_rec', 'I_DC1', 'I_DC2', 'I_sino1', 'I_sino2']
currents_to_plot = currents_to_track

## currents

In [None]:
DC_input_ts = 25 * ms
sino_input_ts = 0.1 * ms


DC_amp1 = 0.1  # in nA
DC_start_time1 = 0  # in ms
DC_duration1 = 1000  # in ms

DC_amp2 = -0.05  # in nA
DC_start_time2 = 500  # in ms
DC_duration2 = 2000  # in ms


sino_start_time1 = 0  # in ms
sino_duration1=3000  # in ms
sino_amp1=0.5  # in nA
sino_freq1=30  # in Hz

sino_start_time2 = 1000  # in ms
sino_duration2=2000  # in ms
sino_amp2=0.2  # in nA
sino_freq2=50  # in Hz



DC_input1 = set_params_utils.set_DC_input(DC_amp=DC_amp1, # in nA
                DC_duration=DC_duration1, # in ms
                DC_start_time=DC_start_time1, # in ms
                timestep=DC_input_ts
                )

DC_input2 = set_params_utils.set_DC_input(DC_amp=DC_amp2, # in nA
                DC_duration=DC_duration2, # in ms
                DC_start_time=DC_start_time2, # in ms
                timestep=DC_input_ts
                )

sino_input1 = set_params_utils.set_sino_input(sino_start_time=sino_start_time1, # in ms
                    sino_duration=sino_duration1, # in ms
                    sino_amp=sino_amp1, # in nA
                    sino_freq=sino_freq1, # in Hz
                    timestep=sino_input_ts
                    )

sino_input2 = set_params_utils.set_sino_input(sino_start_time=sino_start_time2, # in ms
                    sino_duration=sino_duration2, # in ms
                    sino_amp=sino_amp2, # in nA
                    sino_freq=sino_freq2, # in Hz
                    timestep=sino_input_ts
                    )

## run

In [None]:
start_scope()

N_sub = int(N_E * f)
N_non = int(N_E * (1. - f * p))


E_neuron_index = [0] # index of the neuron in the population
E_index_map = {0: 'nonselective'} # map the index in the monitor to population name
for i in range(p):
    E_neuron_index.append(N_non + i * N_sub)
    E_index_map[i+1] = f'selective {i}'


# voltage
V_L, V_thr, V_reset, V_E, V_I = set_params_utils.set_voltage()
# membrane capacitance and membrane leak
C_m_E, C_m_I, g_m_E, g_m_I = set_params_utils.set_membrane_params()

# AMPA (excitatory)
g_AMPA_ext_E, g_AMPA_rec_E, g_AMPA_ext_I, g_AMPA_rec_I, tau_AMPA = set_params_utils.set_AMPA_params(N_E)
# NMDA (excitatory)
g_NMDA_E, g_NMDA_I, tau_NMDA_rise, tau_NMDA_decay, alpha, Mg2 = set_params_utils.set_NMDA_params(N_E)
# GABAergic (inhibitory)
g_GABA_E, g_GABA_I, tau_GABA = set_params_utils.set_GABA_params(N_I)

# Write the equations for the target population (e.g., excitatory population P_E)
eqs_E = eqs_utils.write_eqs_E()
eqs_I = eqs_utils.write_eqs_I()
eqs_glut, eqs_pre_glut, eqs_pre_gaba = eqs_utils.write_other_eqs()

# neuron groups 
P_E, P_I = set_params_utils.set_neuron_groups(N_E, N_I, eqs_E, eqs_I, V_L)
# synapses
external_noise_rate = 3 * Hz
C_E, C_I, C_E_E, C_E_I, C_I_I, C_I_E, C_P_E, C_P_I = set_params_utils.set_synapses(P_E, P_I, N_E, N_I, N_sub, N_non, p, f, C_ext, external_noise_rate, eqs_glut, eqs_pre_glut, eqs_pre_gaba)


N_activity_plot = 15
DC_monitor_E, DC_monitor_I, sp_E_sels, sp_E, sp_I, r_E_sels, r_E, r_I = set_params_utils.set_monitors(N_activity_plot, N_non, N_sub, p, P_E, P_I, E_neuron_index=E_neuron_index, currents_to_track=currents_to_track)


## set external stimuli
# at 1s, select population 1
C_selection = int(f * C_ext)
rate_selection = 25 * Hz


stimuli1 = TimedArray(np.r_[np.zeros(40), np.ones(2), np.zeros(100)], dt=25 * ms)
input1 = PoissonInput(P_E[N_non:N_non + N_sub], 's_AMPA_ext', C_selection, rate_selection, 'stimuli1(t)')

# at 2s, select population 2
stimuli2 = TimedArray(np.r_[np.zeros(80), np.ones(2), np.zeros(100)], dt=25 * ms)
input2 = PoissonInput(P_E[N_non + N_sub:N_non + 2 * N_sub], 's_AMPA_ext', C_selection, rate_selection, 'stimuli2(t)')


# simulate, can be long >120s
net = Network(collect())
net.add(sp_E_sels)
net.add(r_E_sels)
net.add(P_E, P_I, C_E_E, C_E_I, C_I_I, C_I_E, C_P_E, C_P_I)

# Add the monitors to the network
net.add(DC_monitor_E)
net.add(DC_monitor_I)

net.store('initial')

net.run(3 * second, report='stdout')

plotting_utils.plot_firing_rate(r_E, r_I, r_E_sels)
plotting_utils.plot_raster(N_activity_plot, sp_E, sp_I, sp_E_sels, p)
plotting_utils.plot_currents(DC_monitor_E, DC_monitor_I, currents_to_plot, E_index_map)

## rerun 

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Parameters
duration = 1.0             # total duration in seconds
sampling_rate = 1000       # Hz (samples per second)
t = np.linspace(0, duration, int(duration * sampling_rate), endpoint=False)

# Custom square wave parameters
freq = 5  # Hz
period = 1 / freq  # seconds

# Positive phase
pos_amp = 10   # nA
pos_dur = 0.08  # in seconds (positive pulse width)

# Negative phase
neg_amp = -4   # nA
neg_dur = 0.12  # in seconds (negative pulse width)

# Build the custom waveform
waveform = np.zeros_like(t)
cycle_len = pos_dur + neg_dur
i = 0

while i < len(t):
    cycle_start_time = t[i]
    # Positive phase
    pos_end_time = cycle_start_time + pos_dur
    while i < len(t) and t[i] < pos_end_time:
        waveform[i] = pos_amp
        i += 1
    # Negative phase
    neg_end_time = cycle_start_time + cycle_len
    while i < len(t) and t[i] < neg_end_time:
        waveform[i] = neg_amp
        i += 1

# Plot the waveform
plt.plot(t * 1000, waveform)
plt.title("Custom Asymmetric Square Wave AC Current")
plt.xlabel("Time (ms)")
plt.ylabel("Current (nA)")
plt.grid(True)
plt.show()


### once

In [None]:
#title_prefix = f'sino_amp2 = {sino_amp2} nA: '
title_prefix = ''
net.restore('initial')


DC_input_ts = 25 * ms
sino_input_ts = 0.1 * ms



DC_amp1 = 0.1  # in nA
DC_start_time1 = 0  # in ms
DC_duration1 = 1000  # in ms

DC_amp2 = -0.05  # in nA
DC_start_time2 = 500  # in ms
DC_duration2 = 2000  # in ms


sino_start_time1 = 0  # in ms
sino_duration1=3000  # in ms
sino_amp1=0.1  # in nA
sino_freq1=20  # in Hz

sino_start_time2 = 1000  # in ms
sino_duration2=2000  # in ms
sino_amp2=0.05  # in nA
sino_freq2=10  # in Hz



DC_input1 = set_params_utils.set_DC_input(DC_amp=DC_amp1, # in nA
            DC_duration=DC_duration1, # in ms
            DC_start_time=DC_start_time1, # in ms
            timestep=DC_input_ts
            )

DC_input2 = set_params_utils.set_DC_input(DC_amp=DC_amp2, # in nA
            DC_duration=DC_duration2, # in ms
            DC_start_time=DC_start_time2, # in ms
            timestep=DC_input_ts
            )

sino_input1 = set_params_utils.set_sino_input(sino_start_time=sino_start_time1, # in ms
                sino_duration=sino_duration1, # in ms
                sino_amp=sino_amp1, # in nA
                sino_freq=sino_freq1, # in Hz
                timestep=sino_input_ts
                )

sino_input2 = set_params_utils.set_sino_input(sino_start_time=sino_start_time2, # in ms
                sino_duration=sino_duration2, # in ms
                sino_amp=sino_amp2, # in nA
                sino_freq=sino_freq2, # in Hz
                timestep=sino_input_ts
                )



net.run(3 * second, report='stdout')
plotting_utils.plot_firing_rate(r_E, r_I, r_E_sels, title_prefix=title_prefix)
plotting_utils.plot_currents(DC_monitor_E, DC_monitor_I, currents_to_plot, E_index_map, title_prefix=title_prefix)

### current params

In [None]:
plotting_utils.plot_currents(DC_monitor_E, DC_monitor_I, currents_to_plot, E_index_map, title_prefix=title_prefix)

In [None]:
DC_amp = -0.01 # in nA
DC_duration = 2000 # in ms
DC_start_time = 800 # in ms

sino_start_time = 0 # in nA
sino_duration = 3000 # in nA
sino_amp1 = 0.2 # in nA
# sino_amp2 = 0.1 # in nA
sino_freq1 = 20 # in Hz
sino_freq2 = 50 # in Hz


In [None]:
for sino_amp2 in np.arange(0.01, 0.15, 0.02):

    title_prefix = f'sino_amp2 = {sino_amp2} nA: '
    net.restore('initial')

    DC_input = set_params_utils.set_DC_input(DC_amp=DC_amp, # in nA
                    DC_duration=DC_duration, # in ms
                    DC_start_time=DC_start_time, # in ms
                    timestep=25 * ms
                    )

    sino_input = set_params_utils.set_sino_input(sino_start_time=sino_start_time, # in ms
                        sino_duration=sino_duration, # in ms
                        sino_amp1=sino_amp1, # in nA
                        sino_amp2=sino_amp2, # in nA
                        sino_freq1=sino_freq1, # in Hz
                        sino_freq2=sino_freq2, # in Hz
                        timestep=0.1 * ms
                        )

    net.run(3 * second, report='stdout')
    plotting_utils.plot_firing_rate(r_E, r_I, r_E_sels, title_prefix=title_prefix)
    plotting_utils.plot_currents(DC_monitor_E, DC_monitor_I, currents_to_plot, E_index_map, title_prefix=title_prefix)
    

### I_DC time

In [None]:
for DC_duration in [300, 500, 800, 1000, 1500]:
    net.restore('initial')
    DC_input = set_params_utils.set_DC_input(DC_amp = 0.5, # in nA
                 DC_duration= DC_duration, # in ms
                 DC_start_time = 0, # in ms
                 timestep = 25 * ms
                 )
    net.run(3 * second, report='stdout')
    plotting_utils.plot_firing_rate(r_E, r_I, r_E_sels, title_prefix=f'DC_duration = {DC_duration} ms: ')
    plotting_utils.plot_currents(DC_monitor_E, DC_monitor_I, currents_to_plot, E_index_map, title_prefix=f'DC_duration = {DC_duration} ms: ')

### stimuli rate

In [None]:
for rate in [1]:
    net.restore('initial')

    rate_selection = rate * Hz
    stimuli1 = TimedArray(np.r_[np.zeros(40), np.ones(2), np.zeros(100)], dt=25 * ms)
    input1 = PoissonInput(P_E[N_non:N_non + N_sub], 's_AMPA_ext', C_selection, rate_selection, 'stimuli1(t)')

    # at 2s, select population 2
    stimuli2 = TimedArray(np.r_[np.zeros(80), np.ones(2), np.zeros(100)], dt=25 * ms)
    input2 = PoissonInput(P_E[N_non + N_sub:N_non + 2 * N_sub], 's_AMPA_ext', C_selection, rate_selection, 'stimuli2(t)')

    net.run(3 * second, report='stdout')

    plotting_utils.plot_firing_rate(r_E, r_I, r_E_sels, title_prefix=f'rate_selection = {rate_selection}: ')
    # plotting_utils.plot_raster(N_activity_plot, sp_E, sp_I, sp_E_sels, p)
    # plotting_utils.plot_currents(DC_monitor_E, DC_monitor_I, currents_to_plot, E_index_map)


# Change Hz of poisson stimulation

In [None]:
# populations
N = 700
N_E = int(N * 0.8)  # pyramidal neurons
N_I = int(N * 0.2)  # interneurons
f = 0.1
p = 3
C_ext = 800

DC_input = set_params_utils.set_DC_input(DC_amp = 0.5, # in nA
                 DC_duration= 800, # in ms
                 DC_start_time = 0, # in ms
                 timestep = 25 * ms
                 )



In [None]:
for rate in range(8, 18):

    #for rate in range(1, 25, 3):
    start_scope()

    N_sub = int(N_E * f)
    N_non = int(N_E * (1. - f * p))

    # voltage
    V_L, V_thr, V_reset, V_E, V_I = set_params_utils.set_voltage()
    # membrane capacitance and membrane leak
    C_m_E, C_m_I, g_m_E, g_m_I = set_params_utils.set_membrane_params()

    # AMPA (excitatory)
    g_AMPA_ext_E, g_AMPA_rec_E, g_AMPA_ext_I, g_AMPA_rec_I, tau_AMPA = set_params_utils.set_AMPA_params(N_E)
    # NMDA (excitatory)
    g_NMDA_E, g_NMDA_I, tau_NMDA_rise, tau_NMDA_decay, alpha, Mg2 = set_params_utils.set_NMDA_params(N_E)
    # GABAergic (inhibitory)
    g_GABA_E, g_GABA_I, tau_GABA = set_params_utils.set_GABA_params(N_I)

    # Write the equations for the target population (e.g., excitatory population P_E)
    eqs_E = eqs_utils.write_eqs_E()
    eqs_I = eqs_utils.write_eqs_I()
    eqs_glut, eqs_pre_glut, eqs_pre_gaba = eqs_utils.write_other_eqs()

    # neuron groups 
    P_E, P_I = set_params_utils.set_neuron_groups(N_E, N_I, eqs_E, eqs_I, V_L)
    # synapses
    external_noise_rate = 3 * Hz
    C_E, C_I, C_E_E, C_E_I, C_I_I, C_I_E, C_P_E, C_P_I = set_params_utils.set_synapses(P_E, P_I, N_E, N_I, N_sub, N_non, p, f, C_ext, external_noise_rate, eqs_glut, eqs_pre_glut, eqs_pre_gaba)


    N_activity_plot = 15
    DC_monitor_E, DC_monitor_I, sp_E_sels, sp_E, sp_I, r_E_sels, r_E, r_I = set_params_utils.set_monitors(N_activity_plot, N_non, N_sub, p, P_E, P_I, E_neuron_index=E_neuron_index)


    ## set external stimuli
    # at 1s, select population 1
    C_selection = int(f * C_ext)
    rate_selection = rate * Hz

    stimuli1 = TimedArray(np.r_[np.zeros(40), np.ones(2), np.zeros(100)], dt=25 * ms)
    input1 = PoissonInput(P_E[N_non:N_non + N_sub], 's_AMPA_ext', C_selection, rate_selection, 'stimuli1(t)')

    # at 2s, select population 2
    stimuli2 = TimedArray(np.r_[np.zeros(80), np.ones(2), np.zeros(100)], dt=25 * ms)
    input2 = PoissonInput(P_E[N_non + N_sub:N_non + 2 * N_sub], 's_AMPA_ext', C_selection, rate_selection, 'stimuli2(t)')

    # # at 4s, reset selection
    # stimuli_reset = TimedArray(np.r_[np.zeros(120), np.ones(2), np.zeros(100)], dt=25 * ms)
    # input_reset_I = PoissonInput(P_E, 's_AMPA_ext', C_ext, rate_selection, 'stimuli_reset(t)')
    # input_reset_E = PoissonInput(P_I, 's_AMPA_ext', C_ext, rate_selection, 'stimuli_reset(t)')

    net = Network(collect())
    net.add(sp_E_sels)
    net.add(r_E_sels)
    net.add(P_E, P_I, C_E_E, C_E_I, C_I_I, C_I_E, C_P_E, C_P_I)

    # Add the monitors to the network
    net.add(DC_monitor_E)
    net.add(DC_monitor_I)

    net.add(input1)
    net.add(input2)
    net.store('initial')

    net.run(3 * second, report='stdout')

    for i in range(1, len(net.objects)):
        if 'Poisson' in str(list(net.objects)[i]):
            poisson_index.append(i)
            print(list(net.objects)[i])
            print(list(net.objects)[i].rate)


    plotting_utils.plot_firing_rate(r_E, r_I, r_E_sels, title_prefix=f'rate_selection = {rate_selection}: ')
    plotting_utils.plot_raster(N_activity_plot, sp_E, sp_I, sp_E_sels, p, title_prefix=f'rate_selection = {rate_selection}: ')
    plotting_utils.plot_currents(DC_monitor_E, DC_monitor_I, currents_to_plot, E_index_map, title_prefix=f'rate_selection = {rate_selection}: ')

# Other potential waveforms

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Parameters
duration = 1.0             # total duration in seconds
sampling_rate = 1000       # Hz
t = np.linspace(0, duration, int(duration * sampling_rate), endpoint=False)

# Phases
pos_amp = 10   # nA
pos_dur = 0.05  # seconds

neg_amp = -5   # nA
neg_dur = 0.05  # seconds

rest_dur = 0.05  # seconds (zero amplitude)

# Full cycle duration
cycle_dur = pos_dur + neg_dur + rest_dur

# Initialize signal
waveform = np.zeros_like(t)

# Fill waveform
i = 0
while i < len(t):
    start_time = t[i]
    
    # Positive phase
    pos_end = start_time + pos_dur
    while i < len(t) and t[i] < pos_end:
        waveform[i] = pos_amp
        i += 1

    # Negative phase
    neg_end = start_time + pos_dur + neg_dur
    while i < len(t) and t[i] < neg_end:
        waveform[i] = neg_amp
        i += 1

    # Rest phase (already 0s, just skip ahead)
    rest_end = start_time + cycle_dur
    while i < len(t) and t[i] < rest_end:
        # waveform[i] = 0 (already zero)
        i += 1

# Plot it
plt.plot(t * 1000, waveform)
plt.title("Asymmetric Square Wave with Rest Phase")
plt.xlabel("Time (ms)")
plt.ylabel("Current (nA)")
plt.grid(True)
plt.show()


In [None]:
from brian2 import *
stimulus = TimedArray(waveform * nA, dt=1*ms)

# Other brian2 examples

## selective stimulation

In [None]:
start_scope()
amp_range = [0.5, 1.0, 2.0, 4.0]
freq_range = [5, 10, 20, 30]
tau = 5*ms
eqs = '''
dv/dt = (I-v)/tau : 1
I = I_recorded(t) : 1
'''
E_neurons = NeuronGroup(1, eqs, threshold='v>1', reset='v=0', method='exact')
M_state = StateMonitor(E_neurons, variables=True, record=True)
M_spike = SpikeMonitor(E_neurons)
store()

t_array = []
output_v = []
output_i = []
output_rates = []

## see network objects

In [None]:
for obj in net.objects:
    print(obj)

# BO

In [None]:
import ipdb
%pdb on

In [None]:
%pdb off

In [None]:
amp1_range = [2, 10]
amp2_range = [2, 10]
freq1_range = [1, 20]
freq2_range = [0.1, 10]
space = [Real(amp1_range[0], amp1_range[1], name='amp1'),
            Real(amp2_range[0], amp2_range[1], name='amp2'),
            Real(freq1_range[0], freq1_range[1], name='freq1'),
            Real(freq2_range[0], freq2_range[1], name='freq2')]

objective_with_factor = partial(my_utils.run_again2,  M_state=M_state, M_spike=M_spike, E_neurons=E_neurons, tau=tau)


In [None]:

res = gp_minimize(objective_with_factor, # the function to minimize
                  space,
                  n_initial_points= 2,
                  acq_func="EI",      # the acquisition function
                  n_calls=15,         # the number of evaluations of f
                  n_random_starts=5,  # the number of random initialization points
                  noise=0.1**2,       # the noise level (optional)
                  random_state=1234)   # the random seed

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from skopt import gp_minimize

# Create a grid of x1, x2 values
x1_vals = np.linspace(amp1_range[0], amp1_range[1], 10)
x2_vals = np.linspace(amp2_range[0], amp2_range[1], 10)
X1, X2 = np.meshgrid(x1_vals, x2_vals)

# Evaluate the objective function over the grid
Z = np.array([[objective_with_factor((x1, x2, res.x[2], res.x[3])) for x1, x2 in zip(row_x1, row_x2)] for row_x1, row_x2 in zip(X1, X2)])

# Plot the objective function as a contour plot
plt.contourf(X1, X2, Z, levels=50, cmap="viridis")
plt.colorbar(label="Objective Value")

# Plot the sampled points
x_iters = np.array(res.x_iters)  # Extract sampled points
plt.scatter(x_iters[:, 0], x_iters[:, 1], color="red", marker="x", label="Sampled Points", s=50, zorder=5)

# Additional plot formatting
plt.xlabel("x1")
plt.ylabel("x2")
plt.title("Contour Plot of the Objective Function with Sampled Points")
plt.legend()
plt.show()


In [None]:
res