In [None]:
import matplotlib.pyplot as plt
import nest
import numpy as np
import os
from pynestml.frontend.pynestml_frontend import generate_nest_target

In [None]:
def generate_code(neuron_model: str, models_path=""):
    """
    Generate NEST code for neuron model with gap junction support.
    Parameters
    ----------
    neuron_model : str
        Name of the neuron model to compile. This should correspond to a 
        .nestml file containing the neuron model definition.
    models_path : str, optional
        Path to the directory containing the NESTML model files.
        Default is empty string (current directory).
    """
    codegen_opts = {"gap_junctions": {"enable": True,
                                        "gap_current_port": "I_stim",
                                        "membrane_potential_variable": "V_m"}}

    files = os.path.join(models_path, neuron_model + ".nestml")
    generate_nest_target(input_path=files,
                            logging_level="WARNING",
                            module_name="nestml_gap_" + neuron_model + "_module",
                            suffix="_nestml",
                            codegen_opts=codegen_opts)

    return neuron_model

In [None]:
def initialize_hh():
    nest.Install("nestml_gap_hh_psc_alpha_neuron_module")
    neurons = nest.Create("hh_psc_alpha_neuron_nestml", 4)
    neurons.I_e = 20.0
    neurons[0].V_m = -60.0
    neurons[2].V_m = -60.0

    return neurons

In [None]:
def initialize_aeif():
    nest.Install("nestml_gap_aeif_cond_exp_neuron_module")
    neurons = nest.Create("aeif_cond_exp_neuron_nestml", 4)
    neurons.I_e = 20
    neurons[0].V_m = -60.0 
    neurons[2].V_m = -60.0 

    return neurons

In [None]:
def initilize_eglif():
    nest.Install("nestml_gap_eglif_cond_alpha_multisyn_module")
    neurons = nest.Create("eglif_cond_alpha_multisyn_nestml", 4)
    neurons.set({
        "V_m": -45,
        "E_L": -45,
        "C_m": 189,
        "tau_m": 11,
        "I_e": -18.101,
        "k_adap": 1.928,
        "k_1": 0.191,
        "k_2": 0.090909,
        "V_th": -35,
        })
    neurons[0].V_m = -46.0 
    neurons[2].V_m = -46.0 

    return neurons

In [None]:
def initialize_cells(model):
    if model == "hh":
        return initialize_hh()
    elif model == "aeif":
        return initialize_aeif()
    elif model == "eglif":
        return initilize_eglif()
        

In [None]:
def plot_vm(vm):
    vm_values = vm.events["V_m"]
    senders = vm.events["senders"]
    times = vm.events["times"]
    plt.figure(figsize=(10, 5))
    for cell_num in np.unique(vm.events["senders"]):
        plt.plot(
            times[np.where(senders == cell_num)], vm_values[np.where(senders == cell_num)],label=f"Neuron {cell_num}")
    plt.legend(loc='upper right')
    plt.xlabel("time (ms)")
    plt.ylabel("membrane potential (mV)")
    plt.show()

In [None]:
def plot_mm_stim(mm_stim):
    I_stim_values = mm_stim.events["I_stim_recordable"]
    senders_Istim = mm_stim.events["senders"]
    times_Istim = mm_stim.events["times"]

    plt.figure(figsize=(10, 5))
    for cell_num in np.unique(mm_stim.events["senders"]):
        plt.plot(
        times_Istim[np.where(senders_Istim == cell_num)], I_stim_values[np.where(senders_Istim == cell_num)],label=f"Neuron {cell_num}"
    )

    plt.legend(loc="upper right")
    plt.xlabel("time (ms)")
    plt.ylabel("I_stim (pA)")
    plt.title("Gap junction currents")
    plt.show()

In [None]:
def simulate_network(selected_model, dc_stim=False):
    nest.ResetKernel()
    nest.resolution = 0.05

    neurons = initialize_cells(selected_model)

    # Voltmeter connected to all neurons
    vm = nest.Create("voltmeter", params={"interval": 0.1})
    nest.Connect(vm, neurons, "all_to_all")

    # Optional DC stimulation for neuron 1
    if dc_stim:
        dc = nest.Create("dc_generator", params={"amplitude": 0.5})
        nest.Connect(dc, neurons[0], syn_spec={"weight": 1.0})
    
    # Create a multimeter for I_stim_recordable if present
    I_stim_recordable ="I_stim_recordable" in nest.GetDefaults(models.get(selected_model))["recordables"]
    if I_stim_recordable:
        mm_stim = nest.Create("multimeter", {"record_from": ["I_stim_recordable"] })
        nest.Connect(mm_stim, neurons)

    # Connect neurons 1 and 2
    nest.Connect(
            neurons[0], neurons[1], 
            {"rule": "one_to_one", "allow_autapses": False,"make_symmetric":True}, 
            {"synapse_model": "gap_junction", "weight": 100}
        )

    # Verify connections
    connections = nest.GetConnections()
    for connection in connections:
        print(connection)

    # Simulation
    nest.Simulate(1000.0)

    # Plots
    plot_vm(vm)
    if I_stim_recordable:
        plot_mm_stim(mm_stim)

    # V_m analysis
    vm_values = vm.events["V_m"]
    senders = vm.events["senders"]
    vm_per_cell = {}
    for cell_num in np.unique(vm.events["senders"]):
        vm_per_cell[cell_num] = vm_values[np.where(senders == cell_num)]

    print(f"V_m of cell 1 and cell 3 are equal: {np.array_equal(vm_per_cell[1], vm_per_cell[3])}")
    print(f"V_m of cell 2 and cell 4 are equal: {np.array_equal(vm_per_cell[2], vm_per_cell[4])}")
    

In [None]:
filenames = {
    "hh": "hh_psc_alpha_neuron",
    "eglif": "eglif_cond_alpha_multisyn",
    "aeif": "aeif_cond_exp_neuron", 
}
models = {
    "hh": "hh_psc_alpha_neuron_nestml",
    "aeif": "aeif_cond_exp_neuron_nestml", 
    "eglif": "eglif_cond_alpha_multisyn_nestml"
}
selected_model = "hh"
generate_model = False 

if generate_model:
    generate_code(neuron_model=filenames[selected_model], models_path="../nest_models")

simulate_network(selected_model, dc_stim=True)