In [1]:
# Import necessary libraries
import nest  # NEST simulator for spiking neural networks
import matplotlib.pyplot as plt  # For plotting results
import ipywidgets as widgets  # For interactive sliders
import numpy as np
from ipywidgets import interact, FloatSlider, IntSlider

# Set NEST verbosity to suppress extra NEST logs
nest.set_verbosity("M_WARNING")


              -- N E S T --
  Copyright (C) 2004 The NEST Initiative

 Version: 3.8.0-post0.dev0
 Built: Nov  7 2024 15:09:56

 This program is provided AS IS and comes with
 NO WARRANTY. See the file LICENSE for details.

 Problems or suggestions?
   Visit https://www.nest-simulator.org

 Type 'nest.help()' to find out more about NEST.



Simulate 3 neurons:
    1. Only Neuron 1 receives external input.
    2. Each neuron has a self-excitatory connection.
    3. Neurons 1 and 2 excite each other.
    4. Neurons 2 and 3 excite each other.
    5. Neurons 1 and 3 inhibit each other.


In [None]:
def three_neurons_network(I_e, t_ref, sim_time, weight_self, weight_exci, weight_inh):
    # Reset the NEST kernel (to clear previous simulations)
    nest.ResetKernel()

    # Set the simulation resolution
    nest.SetKernelStatus({"resolution": 0.001})

    # Create three IAF neurons
    neuron1 = nest.Create('iaf_psc_alpha')
    neuron2 = nest.Create('iaf_psc_alpha')
    neuron3 = nest.Create('iaf_psc_alpha')

    # Set neuron parameters
    for neuron in [neuron1, neuron2, neuron3]:
        nest.SetStatus(neuron, {
            "C_m": 250,
            "tau_m": 20,
            "V_reset": -70,
            "V_th": -55,
            "V_min": -70,
            "t_ref": t_ref
        })

    # External input to neuron 1
    nest.SetStatus(neuron1, {"I_e": I_e})

    # Create devices for recording and stimulation
    spike_recorders = [nest.Create('spike_recorder') for _ in range(3)]
    multimeters = [nest.Create('multimeter') for _ in range(3)]

    # Configure the multimeters
    for mm in multimeters:
        nest.SetStatus(mm, {"record_from": ["V_m"]})
        nest.SetStatus(mm, {"interval": 0.001})

    # Connect devices
    for i, neuron in enumerate([neuron1, neuron2, neuron3]):
        nest.Connect(multimeters[i], neuron)
        nest.Connect(neuron, spike_recorders[i])

    # Self-connections
    for neuron in [neuron1, neuron2, neuron3]:
        nest.Connect(neuron, neuron, syn_spec={"weight": weight_self})

    # Excitatory connections between neurons
    nest.Connect(neuron1, neuron2, syn_spec={"weight": weight_exci})
    nest.Connect(neuron2, neuron1, syn_spec={"weight": weight_exci})
    nest.Connect(neuron2, neuron3, syn_spec={"weight": weight_exci})
    nest.Connect(neuron3, neuron2, syn_spec={"weight": weight_exci})

    # Inhibitory connections
    nest.Connect(neuron1, neuron3, syn_spec={"weight": weight_inh})
    nest.Connect(neuron3, neuron1, syn_spec={"weight": weight_inh})

    # Simulate for the specified time
    nest.Simulate(sim_time)

    # Retrieve and plot data for each neuron
    plt.figure(figsize=(10, 6))

    for i, (multimeter, spike_recorder) in enumerate(zip(multimeters, spike_recorders)):
        events = nest.GetStatus(multimeter, 'events')[0]
        time = events['times']
        V_m = events['V_m']
        spikes = nest.GetStatus(spike_recorder, 'events')[0]
        spike_times = spikes['times']

        # Print spike times
        print(f'Neuron {i+1} Spike Times: {", ".join(map(str, spike_times))}')

        # Calculate firing rates
        firing_rate = len(spike_times) / (sim_time / 1000.0)
        print(f'Neuron {i+1} Firing Rate: {firing_rate:.2f} Hz')

        # Plot membrane potentials
        plt.plot(time, V_m, label=f'Neuron {i+1} Membrane Potential (V_m)')

    # plt.xlabel('Time (ms)')
    # plt.ylabel('Membrane Potential (mV)')
    # plt.title('IAF Neurons Membrane Potential Over Time')
    plt.axhline(-55, color='gray', linestyle='--', label='Threshold (V_th)')
    plt.legend()
    # plt.grid()
    plt.show()

# Interactive widgets for input parameters
I_e = widgets.FloatSlider(value=9200.00, min=0.0, max=100000.0, step=0.01, description='I_e (pA):')
t_ref = widgets.FloatSlider(value=8.0, min=0.0, max=100.0, step=1.0, description='t_ref (ms):')
sim_time = widgets.FloatSlider(value=250.0, min=1.0, max=60000.0, step=1.0, description='Sim Time (ms):')
weight_self = widgets.FloatSlider(value=690.0, min=0.0, max=100000.0, step=0.1, description='Weight Self:')
weight_exci = widgets.FloatSlider(value=566.0, min=0.0, max=100000.0, step=0.1, description='Weight Exci:')
weight_inh = widgets.FloatSlider(value=-3.0, min=-100000.0, max=0.0, step=0.1, description='Weight Inh:')

widgets.interact(three_neurons_network, 
I_e=I_e, t_ref=t_ref, sim_time=sim_time, weight_self=weight_self, weight_exci=weight_exci, weight_inh=weight_inh)


interactive(children=(FloatSlider(value=9200.0, description='I_e (pA):', max=100000.0, step=0.01), FloatSlider…

<function __main__.three_neurons_network(I_e, t_ref, sim_time, weight_self, weight_exci, weight_inh)>

Now, reomve the current after certain time

In [3]:
def three_neurons_network(I_e, t_ref, sim_time, weight_self, weight_exci, weight_inh, I_e_start, I_e_stop):
    # Reset the NEST kernel (to clear previous simulations)
    nest.ResetKernel()

    # Set the simulation resolution
    nest.SetKernelStatus({"resolution": 0.001})

    # Create three IAF neurons
    neuron1 = nest.Create('iaf_psc_alpha')
    neuron2 = nest.Create('iaf_psc_alpha')
    neuron3 = nest.Create('iaf_psc_alpha')

    # Set neuron parameters
    for neuron in [neuron1, neuron2, neuron3]:
        nest.SetStatus(neuron, {
            "C_m": 250.0,
            "tau_m": 20.0,
            "V_reset": -70.0,
            "V_th": -55.0,
            "V_min": -70.0,
            "t_ref": t_ref
        })

    # External input to neuron 1
    dc_gen = nest.Create("dc_generator", params={
        "amplitude": I_e,
        "start": I_e_start,
        "stop": I_e_stop
    })
    nest.Connect(dc_gen, neuron1)

    # Create devices for recording
    spike_recorders = [nest.Create('spike_recorder') for _ in range(3)]
    multimeters = [nest.Create('multimeter') for _ in range(3)]

    # Configure the multimeters
    for mm in multimeters:
        nest.SetStatus(mm, {"record_from": ["V_m"], "interval": 0.001})

    # Connect devices (multimeter and spike_recorder) to each neuron
    for i, n in enumerate([neuron1, neuron2, neuron3]):
        nest.Connect(multimeters[i], n)
        nest.Connect(n, spike_recorders[i])

    # Self-connections
    for n in [neuron1, neuron2, neuron3]:
        nest.Connect(n, n, syn_spec={"weight": weight_self})

    # Excitatory connections between neurons
    nest.Connect(neuron1, neuron2, syn_spec={"weight": weight_exci})
    nest.Connect(neuron2, neuron1, syn_spec={"weight": weight_exci})
    nest.Connect(neuron2, neuron3, syn_spec={"weight": weight_exci})
    nest.Connect(neuron3, neuron2, syn_spec={"weight": weight_exci})

    # Inhibitory connections
    nest.Connect(neuron1, neuron3, syn_spec={"weight": weight_inh})
    nest.Connect(neuron3, neuron1, syn_spec={"weight": weight_inh})

    # Simulate for the specified time
    nest.Simulate(sim_time)

    # ----- Membrane Potential Plot -----
    plt.figure(figsize=(10, 6))
    for i, mm in enumerate(multimeters):
        mm_data = nest.GetStatus(mm, 'events')[0]
        time = mm_data['times']
        V_m = mm_data['V_m']
        plt.plot(time, V_m, label=f'Neuron {i+1} V_m')

    plt.xlabel('Time (ms)')
    plt.ylabel('Membrane Potential (mV)')
    plt.title('Membrane Potentials of IAF Neurons')
    plt.axhline(-55, color='gray', linestyle='--', label='Threshold (V_th)')
    # plt.legend()
    plt.grid()
    plt.show()

    # ----- Spike Raster Plot -----
    firing_rates = []
    
    plt.figure(figsize=(10, 6))
    for i, sr in enumerate(spike_recorders):
        spike_data = nest.GetStatus(sr, 'events')[0]
        spike_times = spike_data['times']
        spike_senders = spike_data['senders']

        # Calculate and store firing rates
        firing_rate = len(spike_times) / (sim_time / 1000.0)
        firing_rates.append(firing_rate)
        print(firing_rate)
        # Plot each neuron's spikes in the raster
        plt.scatter(spike_times, spike_senders, s=2, label=f'Neuron {i+1}')

    plt.xlabel('Time (ms)')
    plt.ylabel('Sender ID')
    plt.title('Combined Spike Raster Plot')
    plt.xlim(0, sim_time)
    # plt.legend()
    plt.show()

    # ----- Firing Rate Bar Plot -----
    plt.figure(figsize=(6, 4))
    neuron_ids = range(1, 4)  # 1, 2, 3
    plt.bar(neuron_ids, firing_rates, color=['C0', 'C1', 'C2'])
    plt.xlabel('Neuron')
    plt.ylabel('Firing Rate (Hz)')
    plt.title('Average Firing Rate of Each Neuron')
    plt.xticks(neuron_ids, [f'N{i}' for i in neuron_ids])
    plt.ylim(0, max(firing_rates)*1.2 if max(firing_rates) > 0 else 1)
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.show()

# --- Widgets for Interactive Control (Jupyter Notebook only) ---
I_e = widgets.FloatSlider(value=9200.0, min=0.0, max=100000.0, step=0.01, description='I_e (pA):')
t_ref = widgets.FloatSlider(value=8.0, min=0.0, max=100.0, step=1.0, description='t_ref (ms):')
sim_time = widgets.FloatSlider(value=10000.0, min=1.0, max=60000.0, step=1.0, description='Sim Time (ms):')
weight_self = widgets.FloatSlider(value=5242.0, min=0.0, max=100000.0, step=0.1, description='Weight Self:')
weight_exci = widgets.FloatSlider(value=566.0, min=0.0, max=100000.0, step=0.1, description='Weight Exci:')
weight_inh = widgets.FloatSlider(value=-3.0, min=-100000.0, max=0.0, step=0.1, description='Weight Inh:')
I_e_start = widgets.FloatSlider(value=0.0, min=0.0, max=60000.0, step=1, description='I_e_start:')
I_e_stop = widgets.FloatSlider(value=20.0, min=0.0, max=60000, step=1, description='I_e_stop:')

widgets.interact(
    three_neurons_network, 
    I_e=I_e, 
    t_ref=t_ref, 
    sim_time=sim_time, 
    weight_self=weight_self, 
    weight_exci=weight_exci, 
    weight_inh=weight_inh, 
    I_e_start=I_e_start, 
    I_e_stop=I_e_stop
)


interactive(children=(FloatSlider(value=9200.0, description='I_e (pA):', max=100000.0, step=0.01), FloatSlider…

<function __main__.three_neurons_network(I_e, t_ref, sim_time, weight_self, weight_exci, weight_inh, I_e_start, I_e_stop)>