diff --git a/src/codegen/codegen_neuron_cpp_visitor.cpp b/src/codegen/codegen_neuron_cpp_visitor.cpp index c013f72df..ba9da2372 100644 --- a/src/codegen/codegen_neuron_cpp_visitor.cpp +++ b/src/codegen/codegen_neuron_cpp_visitor.cpp @@ -181,6 +181,11 @@ void CodegenNeuronCppVisitor::print_check_table_function_prototypes() { printer->push_block(); printer->add_line("_nrn_mechanism_cache_range _lmc{_sorted_token, *_nt, *_ml, _type};"); printer->fmt_line("auto inst = make_instance_{}(_lmc);", info.mod_suffix); + if (!codegen_thread_variables.empty()) { + printer->fmt_line("auto _thread_vars = {}(_thread[{}].get());", + thread_variables_struct(), + info.thread_var_thread_id); + } for (const auto& function: info.functions_with_table) { auto method_name_str = function->get_node_name(); @@ -388,6 +393,11 @@ void CodegenNeuronCppVisitor::print_hoc_py_wrapper_function_body( )CODE"); } printer->fmt_line("auto inst = make_instance_{}(_lmc);", info.mod_suffix); + if (!codegen_thread_variables.empty()) { + printer->fmt_line("auto _thread_vars = {}(_thread[{}].get());", + thread_variables_struct(), + info.thread_var_thread_id); + } if (info.function_uses_table(block_name)) { printer->fmt_line("{}{}({});", table_function_prefix(), @@ -458,12 +468,16 @@ std::string CodegenNeuronCppVisitor::internal_method_arguments() { CodegenCppVisitor::ParamVector CodegenNeuronCppVisitor::internal_method_parameters() { - ParamVector params = {{"", "_nrn_mechanism_cache_range&", "", "_lmc"}, - {"", fmt::format("{}&", instance_struct()), "", "inst"}, - {"", "size_t", "", "id"}, - {"", "Datum*", "", "_ppvar"}, - {"", "Datum*", "", "_thread"}, - {"", "NrnThread*", "", "_nt"}}; + ParamVector params; + params.emplace_back("", "_nrn_mechanism_cache_range&", "", "_lmc"); + params.emplace_back("", fmt::format("{}&", instance_struct()), "", "inst"); + params.emplace_back("", "size_t", "", "id"); + params.emplace_back("", "Datum*", "", "_ppvar"); + params.emplace_back("", "Datum*", "", "_thread"); + if (!codegen_thread_variables.empty()) { + params.emplace_back("", fmt::format("{}&", thread_variables_struct()), "", "_thread_vars"); + } + params.emplace_back("", "NrnThread*", "", "_nt"); return params; } diff --git a/test/usecases/global/simulate.py b/test/usecases/global/simulate.py index d8bf22a6e..8de8df686 100644 --- a/test/usecases/global/simulate.py +++ b/test/usecases/global/simulate.py @@ -68,7 +68,7 @@ assert np.all(y0[i0] == g_w_init) assert np.all(y0[i1] == g_w1) assert np.all(y0[i2] == 33.3) -assert np.all(y0[i3] == z0) +assert np.all(y0[i3] == z0 + z0**2) # The values on thread 1: assert y1[i0] == g_w_init @@ -76,4 +76,4 @@ # `g_w` from Python. # assert np.all(y0[i1] == g_w1) assert np.all(y1[i2] == 34.3) -assert np.all(y1[i3] == z1) +assert np.all(y1[i3] == z1 + z1**2) diff --git a/test/usecases/global/thread_variable.mod b/test/usecases/global/thread_variable.mod index e5da59305..19478428a 100644 --- a/test/usecases/global/thread_variable.mod +++ b/test/usecases/global/thread_variable.mod @@ -2,7 +2,7 @@ NEURON { SUFFIX shared_global NONSPECIFIC_CURRENT il RANGE y, z - GLOBAL g_w, g_arr + GLOBAL g_v1, g_w, g_arr THREADSAFE } @@ -16,6 +16,7 @@ ASSIGNED { INITIAL { g_w = 48.0 + g_v1 = 0.0 g_arr[0] = 10.0 + z g_arr[1] = 10.1 g_arr[2] = 10.2 @@ -24,12 +25,27 @@ INITIAL { BREAKPOINT { if(t > 0.33) { - g_w = g_arr[0] + g_arr[1] + g_arr[2] + g_w = sum_arr() } if(t > 0.66) { - g_w = z + set_g_w(z) + compute_g_v1(z) } - y = g_w + + y = g_w + g_v1 il = 0.0000001 * (v - 10.0) } + +FUNCTION sum_arr() { + sum_arr = g_arr[0] + g_arr[1] + g_arr[2] +} + +PROCEDURE set_g_w(zz) { + g_w = zz +} + +PROCEDURE compute_g_v1(zz) { + TABLE g_v1 FROM 3 TO 4 WITH 8 + g_v1 = zz * zz +}