Skip to content

Commit

Permalink
Use ParamVector in codegen (#1289)
Browse files Browse the repository at this point in the history
  • Loading branch information
JCGoran committed Jun 4, 2024
1 parent 751377d commit 3754a71
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 86 deletions.
66 changes: 34 additions & 32 deletions src/codegen/codegen_coreneuron_cpp_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ void CodegenCoreneuronCppVisitor::print_check_table_thread_function() {

printer->add_newline(2);
auto name = method_name("check_table_thread");
auto parameters = external_method_parameters(true);
auto parameters = get_parameter_str(external_method_parameters(true));

printer->fmt_push_block("static void {} ({})", name, parameters);
printer->add_line("setup_instance(nt, ml);");
Expand Down Expand Up @@ -695,45 +695,50 @@ void CodegenCoreneuronCppVisitor::add_variable_point_process(
}

std::string CodegenCoreneuronCppVisitor::internal_method_arguments() {
if (ion_variable_struct_required()) {
return "id, pnodecount, inst, ionvar, data, indexes, thread, nt, v";
}
return "id, pnodecount, inst, data, indexes, thread, nt, v";
return get_arg_str(internal_method_parameters());
}


/**
* @todo: figure out how to correctly handle qualifiers
*/
CodegenCoreneuronCppVisitor::ParamVector CodegenCoreneuronCppVisitor::internal_method_parameters() {
ParamVector params;
params.emplace_back("", "int", "", "id");
params.emplace_back("", "int", "", "pnodecount");
params.emplace_back("", fmt::format("{}*", instance_struct()), "", "inst");
ParamVector params = {{"", "int", "", "id"},
{"", "int", "", "pnodecount"},
{"", fmt::format("{}*", instance_struct()), "", "inst"}};
if (ion_variable_struct_required()) {
params.emplace_back("", "IonCurVar&", "", "ionvar");
}
params.emplace_back("", "double*", "", "data");
params.emplace_back("const ", "Datum*", "", "indexes");
params.emplace_back("", "ThreadDatum*", "", "thread");
params.emplace_back("", "NrnThread*", "", "nt");
params.emplace_back("", "double", "", "v");
ParamVector other_params = {{"", "double*", "", "data"},
{"const ", "Datum*", "", "indexes"},
{"", "ThreadDatum*", "", "thread"},
{"", "NrnThread*", "", "nt"},
{"", "double", "", "v"}};
params.insert(params.end(), other_params.begin(), other_params.end());
return params;
}


const char* CodegenCoreneuronCppVisitor::external_method_arguments() noexcept {
return "id, pnodecount, data, indexes, thread, nt, ml, v";
const std::string CodegenCoreneuronCppVisitor::external_method_arguments() noexcept {
return get_arg_str(external_method_parameters());
}


const char* CodegenCoreneuronCppVisitor::external_method_parameters(bool table) noexcept {
const CodegenCppVisitor::ParamVector CodegenCoreneuronCppVisitor::external_method_parameters(
bool table) noexcept {
ParamVector args = {{"", "int", "", "id"},
{"", "int", "", "pnodecount"},
{"", "double*", "", "data"},
{"", "Datum*", "", "indexes"},
{"", "ThreadDatum*", "", "thread"},
{"", "NrnThread*", "", "nt"},
{"", "Memb_list*", "", "ml"}};
if (table) {
return "int id, int pnodecount, double* data, Datum* indexes, "
"ThreadDatum* thread, NrnThread* nt, Memb_list* ml, int tml_id";
args.emplace_back("", "int", "", "tml_id");
} else {
args.emplace_back("", "double", "", "v");
}
return "int id, int pnodecount, double* data, Datum* indexes, "
"ThreadDatum* thread, NrnThread* nt, Memb_list* ml, double v";
return args;
}


Expand All @@ -750,10 +755,7 @@ std::string CodegenCoreneuronCppVisitor::nrn_thread_arguments() const {
* same mod file itself
*/
std::string CodegenCoreneuronCppVisitor::nrn_thread_internal_arguments() {
if (ion_variable_struct_required()) {
return "id, pnodecount, inst, ionvar, data, indexes, thread, nt, v";
}
return "id, pnodecount, inst, data, indexes, thread, nt, v";
return get_arg_str(internal_method_parameters());
}


Expand All @@ -779,7 +781,7 @@ std::string CodegenCoreneuronCppVisitor::replace_if_verbatim_variable(std::strin
}
}
if (name == naming::THREAD_ARGS_PROTO) {
name = external_method_parameters();
name = get_parameter_str(external_method_parameters());
}
return name;
}
Expand Down Expand Up @@ -2750,10 +2752,10 @@ void CodegenCoreneuronCppVisitor::print_net_receive() {
printing_net_receive = true;
if (!info.artificial_cell) {
const auto& name = method_name("net_receive");
ParamVector params;
params.emplace_back("", "Point_process*", "", "pnt");
params.emplace_back("", "int", "", "weight_index");
params.emplace_back("", "double", "", "flag");
ParamVector params = {
{"", "Point_process*", "", "pnt"},
{"", "int", "", "weight_index"},
{"", "double", "", "flag"}};
printer->add_newline(2);
printer->fmt_push_block("static void {}({})", name, get_parameter_str(params));
printer->add_line("NrnThread* nt = nrn_threads + pnt->_tid;");
Expand Down Expand Up @@ -2785,7 +2787,7 @@ void CodegenCoreneuronCppVisitor::print_net_receive() {
*/
void CodegenCoreneuronCppVisitor::print_derivimplicit_kernel(const Block& block) {
auto ext_args = external_method_arguments();
auto ext_params = external_method_parameters();
auto ext_params = get_parameter_str(external_method_parameters());
auto suffix = info.mod_suffix;
auto list_num = info.derivimplicit_list_num;
auto block_name = block.get_node_name();
Expand All @@ -2796,7 +2798,7 @@ void CodegenCoreneuronCppVisitor::print_derivimplicit_kernel(const Block& block)

printer->push_block("namespace");
printer->fmt_push_block("struct _newton_{}_{}", block_name, info.mod_suffix);
printer->fmt_push_block("int operator()({}) const", external_method_parameters());
printer->fmt_push_block("int operator()({}) const", get_parameter_str(external_method_parameters()));
auto const instance = fmt::format("auto* const inst = static_cast<{0}*>(ml->instance);",
instance_struct());
auto const slist1 = fmt::format("auto const& slist{} = {};",
Expand Down
6 changes: 3 additions & 3 deletions src/codegen/codegen_coreneuron_cpp_visitor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ class CodegenCoreneuronCppVisitor: public CodegenCppVisitor {
* Arguments for external functions called from generated code
* \return A string representing the arguments passed to an external function
*/
const char* external_method_arguments() noexcept override;
const std::string external_method_arguments() noexcept override;


/**
Expand All @@ -443,9 +443,9 @@ class CodegenCoreneuronCppVisitor: public CodegenCppVisitor {
* calling convention. This method generates the string representing the function parameters for
* these externally called functions.
* \param table
* \return A string representing the parameters of the function
* \return A ParamVector representing the parameters of the function
*/
const char* external_method_parameters(bool table = false) noexcept override;
const ParamVector external_method_parameters(bool table = false) noexcept override;


/**
Expand Down
28 changes: 15 additions & 13 deletions src/codegen/codegen_cpp_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,25 @@ bool CodegenCppVisitor::ion_variable_struct_required() const {
return optimize_ion_variable_copies() && info.ion_has_write_variable();
}

std::string CodegenCppVisitor::get_arg_str(const ParamVector& params) {
std::vector<std::string> variables;
for (const auto& param: params) {
variables.push_back(std::get<3>(param));
}
return fmt::format("{}", fmt::join(variables, ", "));
}


std::string CodegenCppVisitor::get_parameter_str(const ParamVector& params) {
std::string str;
bool is_first = true;
std::vector<std::string> variables;
for (const auto& param: params) {
if (is_first) {
is_first = false;
} else {
str += ", ";
}
str += fmt::format("{}{} {}{}",
std::get<0>(param),
std::get<1>(param),
std::get<2>(param),
std::get<3>(param));
variables.push_back(fmt::format("{}{} {}{}",
std::get<0>(param),
std::get<1>(param),
std::get<2>(param),
std::get<3>(param)));
}
return str;
return fmt::format("{}", fmt::join(variables, ", "));
}


Expand Down
16 changes: 14 additions & 2 deletions src/codegen/codegen_cpp_visitor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,18 @@ class CodegenCppVisitor: public visitor::ConstAstVisitor {
static std::string get_parameter_str(const ParamVector& params);


/**
* Generate the string representing the parameters in a function call
*
* The procedure parameters are stored in a vector of 4-tuples each representing a parameter.
*
* \param params The parameters that should be concatenated into the function parameter
* declaration
* \return The string representing the function call parameters
*/
static std::string get_arg_str(const ParamVector& params);


/**
* Check if function or procedure node has parameter with given name
*
Expand Down Expand Up @@ -894,7 +906,7 @@ class CodegenCppVisitor: public visitor::ConstAstVisitor {
* Arguments for external functions called from generated code
* \return A string representing the arguments passed to an external function
*/
virtual const char* external_method_arguments() noexcept = 0;
virtual const std::string external_method_arguments() noexcept = 0;


/**
Expand All @@ -906,7 +918,7 @@ class CodegenCppVisitor: public visitor::ConstAstVisitor {
* \param table
* \return A string representing the parameters of the function
*/
virtual const char* external_method_parameters(bool table = false) noexcept = 0;
virtual const ParamVector external_method_parameters(bool table = false) noexcept = 0;


/**
Expand Down
72 changes: 38 additions & 34 deletions src/codegen/codegen_neuron_cpp_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -456,30 +456,31 @@ void CodegenNeuronCppVisitor::add_variable_point_process(
}

std::string CodegenNeuronCppVisitor::internal_method_arguments() {
return "_ml, inst, id, _ppvar, _thread, _nt";
const auto& args = internal_method_parameters();
return get_arg_str(args);
}


CodegenNeuronCppVisitor::ParamVector CodegenNeuronCppVisitor::internal_method_parameters() {
ParamVector params;
params.emplace_back("", "_nrn_mechanism_cache_range*", "", "_ml");
params.emplace_back("", fmt::format("{}&", instance_struct()), "", "inst");
params.emplace_back("", "size_t", "", "id");
params.emplace_back("", "Datum*", "", "_ppvar");
params.emplace_back("", "Datum*", "", "_thread");
params.emplace_back("", "NrnThread*", "", "_nt");
CodegenCppVisitor::ParamVector CodegenNeuronCppVisitor::internal_method_parameters() {
ParamVector params = {{"", "_nrn_mechanism_cache_range*", "", "_ml"},
{"", fmt::format("{}&", instance_struct()), "", "inst"},
{"", "size_t", "", "id"},
{"", "Datum*", "", "_ppvar"},
{"", "Datum*", "", "_thread"},
{"", "NrnThread*", "", "_nt"}};
return params;
}


/// TODO: Edit for NEURON
const char* CodegenNeuronCppVisitor::external_method_arguments() noexcept {
const std::string CodegenNeuronCppVisitor::external_method_arguments() noexcept {
return {};
}


/// TODO: Edit for NEURON
const char* CodegenNeuronCppVisitor::external_method_parameters(bool table) noexcept {
const CodegenCppVisitor::ParamVector CodegenNeuronCppVisitor::external_method_parameters(
bool table) noexcept {
return {};
}

Expand Down Expand Up @@ -1369,12 +1370,11 @@ void CodegenNeuronCppVisitor::print_make_node_data() const {
node_data_struct(),
info.mod_suffix);

std::vector<std::string> make_node_data_args;
make_node_data_args.push_back("_ml_arg.nodeindices");
make_node_data_args.push_back("_nt.node_voltage_storage()");
make_node_data_args.push_back("_nt.node_d_storage()");
make_node_data_args.push_back("_nt.node_rhs_storage()");
make_node_data_args.push_back("_ml_arg.nodecount");
std::vector<std::string> make_node_data_args = {"_ml_arg.nodeindices",
"_nt.node_voltage_storage()",
"_nt.node_d_storage()",
"_nt.node_rhs_storage()",
"_ml_arg.nodecount"};

printer->fmt_push_block("return {}", node_data_struct());
printer->add_multi_line(fmt::format("{}", fmt::join(make_node_data_args, ",\n")));
Expand Down Expand Up @@ -1440,10 +1440,11 @@ void CodegenNeuronCppVisitor::print_initial_block(const InitialBlock* node) {
void CodegenNeuronCppVisitor::print_global_function_common_code(BlockType type,
const std::string& function_name) {
std::string method = function_name.empty() ? compute_method_name(type) : function_name;
std::string args =
"_nrn_model_sorted_token const& _sorted_token, NrnThread* _nt, Memb_list* _ml_arg, int "
"_type";
printer->fmt_push_block("void {}({})", method, args);
ParamVector args = {{"", "const _nrn_model_sorted_token&", "", "_sorted_token"},
{"", "NrnThread*", "", "_nt"},
{"", "Memb_list*", "", "_ml_arg"},
{"", "int", "", "_type"}};
printer->fmt_push_block("void {}({})", method, get_parameter_str(args));

printer->add_line("_nrn_mechanism_cache_range _lmr{_sorted_token, *_nt, *_ml_arg, _type};");
printer->fmt_line("auto inst = make_instance_{}(_lmr);", info.mod_suffix);
Expand Down Expand Up @@ -1483,10 +1484,15 @@ void CodegenNeuronCppVisitor::print_nrn_init(bool skip_init_check) {
void CodegenNeuronCppVisitor::print_nrn_jacob() {
printer->add_newline(2);

printer->fmt_push_block(
"static void {}(_nrn_model_sorted_token const& _sorted_token, NrnThread* "
"_nt, Memb_list* _ml_arg, int _type)",
method_name(naming::NRN_JACOB_METHOD)); // begin function
ParamVector args = {{"", "const _nrn_model_sorted_token&", "", "_sorted_token"},
{"", "NrnThread*", "", "_nt"},
{"", "Memb_list*", "", "_ml_arg"},
{"", "int", "", "_type"}};

printer->fmt_push_block("static void {}({})",
method_name(naming::NRN_JACOB_METHOD),
get_parameter_str(args)); // begin function


printer->add_multi_line(
"_nrn_mechanism_cache_range _lmr{_sorted_token, *_nt, *_ml_arg, _type};");
Expand Down Expand Up @@ -1692,11 +1698,10 @@ CodegenNeuronCppVisitor::ParamVector CodegenNeuronCppVisitor::nrn_current_parame
throw std::runtime_error("Not implemented.");
}

ParamVector params;
params.emplace_back("", "_nrn_mechanism_cache_range*", "", "_ml");
params.emplace_back("", "NrnThread*", "", "_nt");
params.emplace_back("", "Datum*", "", "_ppvar");
params.emplace_back("", "Datum*", "", "_thread");
ParamVector params = {{"", "_nrn_mechanism_cache_range*", "", "_ml"},
{"", "NrnThread*", "", "_nt"},
{"", "Datum*", "", "_ppvar"},
{"", "Datum*", "", "_thread"}};

if (info.thread_callback_register) {
auto type_name = fmt::format("{}&", thread_variables_struct());
Expand Down Expand Up @@ -2119,10 +2124,9 @@ void CodegenNeuronCppVisitor::print_net_receive() {
return;
}

ParamVector args;
args.emplace_back("", "Point_process*", "", "_pnt");
args.emplace_back("", "double*", "", "_args");
args.emplace_back("", "double", "", "flag");
ParamVector args = {{"", "Point_process*", "", "_pnt"},
{"", "double*", "", "_args"},
{"", "double", "", "flag"}};

printer->fmt_push_block("static void nrn_net_receive_{}({})",
info.mod_suffix,
Expand Down
4 changes: 2 additions & 2 deletions src/codegen/codegen_neuron_cpp_visitor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ class CodegenNeuronCppVisitor: public CodegenCppVisitor {
* Arguments for external functions called from generated code
* \return A string representing the arguments passed to an external function
*/
const char* external_method_arguments() noexcept override;
const std::string external_method_arguments() noexcept override;


/**
Expand All @@ -289,7 +289,7 @@ class CodegenNeuronCppVisitor: public CodegenCppVisitor {
* \param table
* \return A string representing the parameters of the function
*/
const char* external_method_parameters(bool table = false) noexcept override;
const ParamVector external_method_parameters(bool table = false) noexcept override;


/**
Expand Down

0 comments on commit 3754a71

Please sign in to comment.