Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug in TABLE for POINT_PROCESS. #1342

Merged
merged 2 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/codegen/codegen_coreneuron_cpp_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ void CodegenCoreneuronCppVisitor::print_check_table_thread_function() {
printer->add_line("double v = 0;");

for (const auto& function: info.functions_with_table) {
auto method_name_str = method_name(table_function_prefix() + function->get_node_name());
auto method_name_str = table_update_function_name(function->get_node_name());
auto arguments = internal_method_arguments();
printer->fmt_line("{}({});", method_name_str, arguments);
}
Expand Down
9 changes: 4 additions & 5 deletions src/codegen/codegen_cpp_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ bool CodegenCppVisitor::has_parameter_of_name(const T& node, const std::string&
}


std::string CodegenCppVisitor::table_function_prefix() const {
return "lazy_update_";
std::string CodegenCppVisitor::table_update_function_name(const std::string& block_name) const {
return "update_table_" + method_name(block_name);
}


Expand Down Expand Up @@ -1524,9 +1524,8 @@ void CodegenCppVisitor::print_table_check_function(const Block& node) {
auto float_type = default_float_data_type();

printer->add_newline(2);
printer->fmt_push_block("void {}{}({})",
table_function_prefix(),
method_name(name),
printer->fmt_push_block("void {}({})",
table_update_function_name(name),
get_parameter_str(internal_params));
{
printer->fmt_push_block("if ({} == 0)", use_table_var);
Expand Down
8 changes: 5 additions & 3 deletions src/codegen/codegen_cpp_visitor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1112,10 +1112,12 @@ class CodegenCppVisitor: public visitor::ConstAstVisitor {
bool use_instance = true) const = 0;

/**
* Prefix used for the function that performs the lazy update
* The name of the function that updates the table value if the parameters
* changed.
*
* \param block_name The name of the block that contains the TABLE.
*/
std::string table_function_prefix() const;

std::string table_update_function_name(const std::string& block_name) const;

/**
* Return ion variable name and corresponding ion read variable name.
Expand Down
17 changes: 7 additions & 10 deletions src/codegen/codegen_neuron_cpp_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,8 @@
for (const auto& function: info.functions_with_table) {
auto name = function->get_node_name();
auto internal_params = internal_method_parameters();
printer->fmt_line("void {}{}({});",
table_function_prefix(),
method_name(name),
printer->fmt_line("void {}({});",
table_update_function_name(name),

Check warning on line 168 in src/codegen/codegen_neuron_cpp_visitor.cpp

View check run for this annotation

Codecov / codecov/patch

src/codegen/codegen_neuron_cpp_visitor.cpp#L167-L168

Added lines #L167 - L168 were not covered by tests
get_parameter_str(internal_params));
}

Expand All @@ -192,10 +191,9 @@
}

for (const auto& function: info.functions_with_table) {
auto method_name_str = function->get_node_name();
auto method_args_str = get_arg_str(internal_method_parameters());
printer->fmt_line(
"{}{}{}({});", table_function_prefix(), method_name_str, info.rsuffix, method_args_str);
auto method_name = function->get_node_name();
auto method_args = get_arg_str(internal_method_parameters());
printer->fmt_line("{}({});", table_update_function_name(method_name), method_args);

Check warning on line 196 in src/codegen/codegen_neuron_cpp_visitor.cpp

View check run for this annotation

Codecov / codecov/patch

src/codegen/codegen_neuron_cpp_visitor.cpp#L194-L196

Added lines #L194 - L196 were not covered by tests
}
printer->pop_block();
}
Expand Down Expand Up @@ -403,9 +401,8 @@
info.thread_var_thread_id);
}
if (info.function_uses_table(block_name)) {
printer->fmt_line("{}{}({});",
table_function_prefix(),
method_name(block_name),
printer->fmt_line("{}({});",
table_update_function_name(block_name),

Check warning on line 405 in src/codegen/codegen_neuron_cpp_visitor.cpp

View check run for this annotation

Codecov / codecov/patch

src/codegen/codegen_neuron_cpp_visitor.cpp#L404-L405

Added lines #L404 - L405 were not covered by tests
internal_method_arguments());
}
const auto get_func_call_str = [&]() {
Expand Down
53 changes: 36 additions & 17 deletions test/usecases/table/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ def check_solution(y_no_table, y_table, rtol):
), f"{y_no_table} == {y_table}"


def check_table(c1, c2, x, evaluate_table):
h.c1_tbl = 1
h.c2_tbl = 2
def check_table(c1, c2, x, mech_name, evaluate_table):
setattr(h, f"c1_{mech_name}", 1)
setattr(h, f"c2_{mech_name}", 2)

h.usetable_tbl = 0
setattr(h, f"usetable_{mech_name}", 0)
y_no_table = np.array([evaluate_table(i) for i in x])

h.usetable_tbl = 1
setattr(h, f"usetable_{mech_name}", 1)
y_table = np.array([evaluate_table(i) for i in x])

check_solution(y_table, y_no_table, rtol=1e-4)
Expand All @@ -31,34 +31,52 @@ def check_table(c1, c2, x, evaluate_table):
assert np.all(evaluate_table(x[-1] + 10) == y_table[-1])


def test_function():
def check_function(mech_name, make_instance):
s = h.Section()
s.insert("tbl")
obj = make_instance(s)

x = np.linspace(-3, 5, 18)
assert x[0] == -3.0
assert x[-1] == 5.0

check_table(1, 2, x, s(0.5).tbl.quadratic)
check_table(2, 2, x, s(0.5).tbl.quadratic)
check_table(2, 3, x, s(0.5).tbl.quadratic)
check_table(1, 2, x, mech_name, obj.quadratic)
check_table(2, 2, x, mech_name, obj.quadratic)
check_table(2, 3, x, mech_name, obj.quadratic)


def test_procedure():
s = h.Section()
def make_density_instance(s):
s.insert("tbl")
return s(0.5).tbl


def make_point_instance(s):
return h.tbl_point_process(s(0.5))


def test_function():
check_function("tbl", make_density_instance)
check_function("tbl_point_process", make_point_instance)


def check_procedure(mech_name, make_instance):
s = h.Section()
obj = make_instance(s)

def evaluate_table(x):
s(0.5).tbl.sinusoidal(x)
return np.array((s(0.5).tbl.v1, s(0.5).tbl.v2))
obj.sinusoidal(x)
return np.array((obj.v1, obj.v2))

x = np.linspace(-4, 6, 18)
assert x[0] == -4.0
assert x[-1] == 6.0

check_table(1, 2, x, evaluate_table)
check_table(2, 2, x, evaluate_table)
check_table(2, 3, x, evaluate_table)
check_table(1, 2, x, mech_name, evaluate_table)
check_table(2, 2, x, mech_name, evaluate_table)
check_table(2, 3, x, mech_name, evaluate_table)


def test_procedure():
check_procedure("tbl", make_density_instance)


def simulate():
Expand All @@ -83,4 +101,5 @@ def simulate():
if __name__ == "__main__":
test_function()
test_procedure()

simulate()
44 changes: 44 additions & 0 deletions test/usecases/table/table_point_process.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
NEURON {
POINT_PROCESS tbl_point_process
NONSPECIFIC_CURRENT i
RANGE g, v1, v2
GLOBAL k, d, c1, c2
}

PARAMETER {
k = .1
d = -50
c1 = 1
c2 = 2
}

ASSIGNED {
g
i
v
sig
v1
v2
}

BREAKPOINT {
sigmoidal(v)
g = 0.001 * sig
i = g*(v - 30.0)
}

PROCEDURE sigmoidal(v) {
TABLE sig DEPEND k, d FROM -127 TO 128 WITH 155
sig = 1/(1 + exp(k*(v - d)))
}

FUNCTION quadratic(x) {
TABLE DEPEND c1, c2 FROM -3 TO 5 WITH 500
quadratic = c1 * x * x + c2
}

PROCEDURE sinusoidal(x) {
TABLE v1, v2 DEPEND c1, c2 FROM -4 TO 6 WITH 800
v1 = sin(c1 * x) + 2
v2 = cos(c2 * x) + 2
}
Loading