diff --git a/examples/dbscan/dbscan.cpp b/examples/dbscan/dbscan.cpp index b003a2233f..8070707b38 100644 --- a/examples/dbscan/dbscan.cpp +++ b/examples/dbscan/dbscan.cpp @@ -320,22 +320,24 @@ int main(int argc, char *argv[]) int max_num_points; int num_samples; std::string filename_labels; + std::string implementation; bpo::options_description desc("Allowed options"); // clang-format off desc.add_options() ( "help", "help message" ) - ( "filename", bpo::value(&filename), "filename containing data" ) ( "binary", bpo::bool_switch(&binary)->default_value(false), "binary file indicator") - ( "max-num-points", bpo::value(&max_num_points)->default_value(-1), "max number of points to read in") - ( "eps", bpo::value(&eps), "DBSCAN eps" ) ( "cluster-min-size", bpo::value(&cluster_min_size)->default_value(2), "minimum cluster size") ( "core-min-size", bpo::value(&core_min_size)->default_value(2), "DBSCAN min_pts") - ( "verify", bpo::bool_switch(&verify)->default_value(false), "verify connected components") - ( "samples", bpo::value(&num_samples)->default_value(-1), "number of samples" ) + ( "eps", bpo::value(&eps), "DBSCAN eps" ) + ( "filename", bpo::value(&filename), "filename containing data" ) + ( "impl", bpo::value(&implementation)->default_value("fdbscan"), "implementation (\"fdbscan\" or \"fdbscan-densebox\")") ( "labels", bpo::value(&filename_labels)->default_value(""), "clutering results output" ) - ( "print-dbscan-timers", bpo::bool_switch(&print_dbscan_timers)->default_value(false), "print dbscan timers") + ( "max-num-points", bpo::value(&max_num_points)->default_value(-1), "max number of points to read in") ( "output-sizes-and-centers", bpo::bool_switch(&print_sizes_centers)->default_value(false), "print cluster sizes and centers") + ( "print-dbscan-timers", bpo::bool_switch(&print_dbscan_timers)->default_value(false), "print dbscan timers") + ( "samples", bpo::value(&num_samples)->default_value(-1), "number of samples" ) + ( "verify", bpo::bool_switch(&verify)->default_value(false), "verify connected components") ; // clang-format on bpo::variables_map vm; @@ -348,6 +350,17 @@ int main(int argc, char *argv[]) return 1; } + if (implementation != "fdbscan" && implementation != "fdbscan-densebox") + { + std::cout << "Unknown implementation: \"" << implementation + << "\". Valid values are [\"fdbscan\", \"fdbscan-densebox\"]." + << std::endl; + return 2; + } + auto impl = (implementation == "fdbscan" + ? ArborX::DBSCAN::Implementation::FDBSCAN + : ArborX::DBSCAN::Implementation::FDBSCAN_DenseBox); + // Print out the runtime parameters printf("eps : %f\n", eps); printf("minpts : %d\n", core_min_size); @@ -355,6 +368,7 @@ int main(int argc, char *argv[]) printf("filename : %s [%s, max_pts = %d]\n", filename.c_str(), (binary ? "binary" : "text"), max_num_points); printf("filename [labels] : %s [binary]\n", filename_labels.c_str()); + printf("implementation : %s\n", implementation.c_str()); printf("samples : %d\n", num_samples); printf("verify : %s\n", (verify ? "true" : "false")); printf("print timers : %s\n", (print_dbscan_timers ? "true" : "false")); @@ -386,9 +400,10 @@ int main(int argc, char *argv[]) timer_start(timer_total); - auto labels = ArborX::dbscan( - exec_space, primitives, eps, core_min_size, - ArborX::DBSCAN::Parameters().setPrintTimers(print_dbscan_timers)); + auto labels = ArborX::dbscan(exec_space, primitives, eps, core_min_size, + ArborX::DBSCAN::Parameters() + .setPrintTimers(print_dbscan_timers) + .setImplementation(impl)); timer_start(timer); Kokkos::View cluster_indices("Testing::cluster_indices",