In [None]:
import matplotlib.pyplot as plt
import nest
import numpy
import os
from pynestml.frontend.pynestml_frontend import generate_nest_target

In [None]:
def generate_code(neuron_model: str, models_path=""):
    """
    Generate NEST code for neuron model with gap junction support.
    Parameters
    ----------
    neuron_model : str
        Name of the neuron model to compile. This should correspond to a 
        .nestml file containing the neuron model definition.
    models_path : str, optional
        Path to the directory containing the NESTML model files.
        Default is empty string (current directory).
    """
    codegen_opts = {"gap_junctions": {"enable": True,
                                        "gap_current_port": "I_stim",
                                        "membrane_potential_variable": "V_m"}}

    files = os.path.join(models_path, neuron_model + ".nestml")
    generate_nest_target(input_path=files,
                            logging_level="WARNING",
                            module_name="nestml_gap_" + neuron_model + "_module",
                            suffix="_nestml",
                            codegen_opts=codegen_opts)

    return neuron_model

generate_code(neuron_model="aeif_cond_exp_neuron", models_path="../nest_models")

In [None]:

nest.ResetKernel()
nest.Install("nestml_gap_aeif_cond_exp_neuron_module")
nest.resolution = 0.05

In [None]:


neuron = nest.Create("aeif_cond_exp_neuron_nestml", 2)
neuron.I_e = 650.0
neuron[0].V_m = -10.0 

In [None]:
vm = nest.Create("voltmeter", params={"interval": 0.1})
nest.Connect(vm, neuron, "all_to_all")

In [None]:
with_gaps = True
if with_gaps:
    nest.Connect(
        neuron, neuron, 
        {"rule": "all_to_all", "allow_autapses": False}, 
        {"synapse_model": "gap_junction", "weight": 5}
    )

In [None]:
nest.Simulate(5000.0)

senders = vm.events["senders"]
times = vm.events["times"]
v_m_values = vm.events["V_m"]


plt.figure(figsize=(10, 5))
plt.plot(
    times[numpy.where(senders == 1)], v_m_values[numpy.where(senders == 1)], "r-",
    label="Neuron 1 (V_m = -10.0)")
plt.plot(
    times[numpy.where(senders == 2)], v_m_values[numpy.where(senders == 2)], "g-", label="Neuron 2 (V_m = -65.0)")
plt.legend(loc='upper right')
plt.xlabel("time (ms)")
plt.ylabel("membrane potential (mV)")
plt.savefig(f"images/aeif_cond_exp_{'with' if with_gaps else 'without'}")
plt.show()
