Skip to content

Commit

Permalink
Fix thread_variables, in functions and tables.
Browse files Browse the repository at this point in the history
  • Loading branch information
1uc committed Jun 12, 2024
1 parent 13550d7 commit 55cea68
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 12 deletions.
26 changes: 20 additions & 6 deletions src/codegen/codegen_neuron_cpp_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,11 @@ void CodegenNeuronCppVisitor::print_check_table_function_prototypes() {
printer->push_block();
printer->add_line("_nrn_mechanism_cache_range _lmc{_sorted_token, *_nt, *_ml, _type};");
printer->fmt_line("auto inst = make_instance_{}(_lmc);", info.mod_suffix);
if (!codegen_thread_variables.empty()) {
printer->fmt_line("auto _thread_vars = {}(_thread[{}].get<double*>());",
thread_variables_struct(),
info.thread_var_thread_id);

Check warning on line 187 in src/codegen/codegen_neuron_cpp_visitor.cpp

View check run for this annotation

Codecov / codecov/patch

src/codegen/codegen_neuron_cpp_visitor.cpp#L184-L187

Added lines #L184 - L187 were not covered by tests
}

for (const auto& function: info.functions_with_table) {
auto method_name_str = function->get_node_name();
Expand Down Expand Up @@ -388,6 +393,11 @@ void CodegenNeuronCppVisitor::print_hoc_py_wrapper_function_body(
)CODE");
}
printer->fmt_line("auto inst = make_instance_{}(_lmc);", info.mod_suffix);
if (!codegen_thread_variables.empty()) {
printer->fmt_line("auto _thread_vars = {}(_thread[{}].get<double*>());",
thread_variables_struct(),
info.thread_var_thread_id);

Check warning on line 399 in src/codegen/codegen_neuron_cpp_visitor.cpp

View check run for this annotation

Codecov / codecov/patch

src/codegen/codegen_neuron_cpp_visitor.cpp#L397-L399

Added lines #L397 - L399 were not covered by tests
}
if (info.function_uses_table(block_name)) {
printer->fmt_line("{}{}({});",
table_function_prefix(),
Expand Down Expand Up @@ -458,12 +468,16 @@ std::string CodegenNeuronCppVisitor::internal_method_arguments() {


CodegenCppVisitor::ParamVector CodegenNeuronCppVisitor::internal_method_parameters() {
ParamVector params = {{"", "_nrn_mechanism_cache_range&", "", "_lmc"},
{"", fmt::format("{}&", instance_struct()), "", "inst"},
{"", "size_t", "", "id"},
{"", "Datum*", "", "_ppvar"},
{"", "Datum*", "", "_thread"},
{"", "NrnThread*", "", "_nt"}};
ParamVector params;
params.emplace_back("", "_nrn_mechanism_cache_range&", "", "_lmc");
params.emplace_back("", fmt::format("{}&", instance_struct()), "", "inst");
params.emplace_back("", "size_t", "", "id");
params.emplace_back("", "Datum*", "", "_ppvar");
params.emplace_back("", "Datum*", "", "_thread");
if (!codegen_thread_variables.empty()) {
params.emplace_back("", fmt::format("{}&", thread_variables_struct()), "", "_thread_vars");

Check warning on line 478 in src/codegen/codegen_neuron_cpp_visitor.cpp

View check run for this annotation

Codecov / codecov/patch

src/codegen/codegen_neuron_cpp_visitor.cpp#L478

Added line #L478 was not covered by tests
}
params.emplace_back("", "NrnThread*", "", "_nt");
return params;
}

Expand Down
4 changes: 2 additions & 2 deletions test/usecases/global/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,12 @@
assert np.all(y0[i0] == g_w_init)
assert np.all(y0[i1] == g_w1)
assert np.all(y0[i2] == 33.3)
assert np.all(y0[i3] == z0)
assert np.all(y0[i3] == z0 + z0**2)

# The values on thread 1:
assert y1[i0] == g_w_init
# The value of `g_w` on thread 1 is not specified, after setting
# `g_w` from Python.
# assert np.all(y0[i1] == g_w1)
assert np.all(y1[i2] == 34.3)
assert np.all(y1[i3] == z1)
assert np.all(y1[i3] == z1 + z1**2)
24 changes: 20 additions & 4 deletions test/usecases/global/thread_variable.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ NEURON {
SUFFIX shared_global
NONSPECIFIC_CURRENT il
RANGE y, z
GLOBAL g_w, g_arr
GLOBAL g_v1, g_w, g_arr
THREADSAFE
}

Expand All @@ -16,6 +16,7 @@ ASSIGNED {

INITIAL {
g_w = 48.0
g_v1 = 0.0
g_arr[0] = 10.0 + z
g_arr[1] = 10.1
g_arr[2] = 10.2
Expand All @@ -24,12 +25,27 @@ INITIAL {

BREAKPOINT {
if(t > 0.33) {
g_w = g_arr[0] + g_arr[1] + g_arr[2]
g_w = sum_arr()
}

if(t > 0.66) {
g_w = z
set_g_w(z)
compute_g_v1(z)
}
y = g_w

y = g_w + g_v1
il = 0.0000001 * (v - 10.0)
}

FUNCTION sum_arr() {
sum_arr = g_arr[0] + g_arr[1] + g_arr[2]
}

PROCEDURE set_g_w(zz) {
g_w = zz
}

PROCEDURE compute_g_v1(zz) {
TABLE g_v1 FROM 3 TO 4 WITH 8
g_v1 = zz * zz
}

0 comments on commit 55cea68

Please sign in to comment.