Skip to content

Commit

Permalink
Directly parse DBSCAN implementation cmd line option into an enum
Browse files Browse the repository at this point in the history
  • Loading branch information
aprokop committed May 7, 2021
1 parent 0e85a30 commit c3a6d3d
Showing 1 changed file with 46 additions and 14 deletions.
60 changes: 46 additions & 14 deletions examples/dbscan/dbscan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,45 @@ void printClusterSizesAndCenters(ExecutionSpace const &exec_space,
}
}
}
// The type resolution in Boost may require shift operators to be defined in the
// std namespace.
namespace std
{
// This function is required for Boost program_options to be able to use the
// Implementation enum.
std::istream &operator>>(std::istream &in,
ArborX::DBSCAN::Implementation &implementation)
{
std::string impl_string;
in >> impl_string;

if (impl_string == "fdbscan")
implementation = ArborX::DBSCAN::Implementation::FDBSCAN;
else if (impl_string == "fdbscan-densebox")
implementation = ArborX::DBSCAN::Implementation::FDBSCAN_DenseBox;
else
in.setstate(std::ios_base::failbit);

return in;
}

// This function is required for Boost program_options to use Implementation
// enum as the default_value().
std::ostream &operator<<(std::ostream &out,
ArborX::DBSCAN::Implementation const &implementation)
{
switch (implementation)
{
case ArborX::DBSCAN::Implementation::FDBSCAN:
out << "fdbscan";
break;
case ArborX::DBSCAN::Implementation::FDBSCAN_DenseBox:
out << "fdbscan-densebox";
break;
}
return out;
}
} // namespace std

int main(int argc, char *argv[])
{
Expand All @@ -308,6 +347,7 @@ int main(int argc, char *argv[])
std::cout << "ArborX hash : " << ArborX::gitCommitHash() << std::endl;

namespace bpo = boost::program_options;
using ArborX::DBSCAN::Implementation;

std::string filename;
bool binary;
Expand All @@ -320,7 +360,7 @@ int main(int argc, char *argv[])
int max_num_points;
int num_samples;
std::string filename_labels;
std::string implementation;
Implementation implementation;

bpo::options_description desc("Allowed options");
// clang-format off
Expand All @@ -331,7 +371,7 @@ int main(int argc, char *argv[])
( "core-min-size", bpo::value<int>(&core_min_size)->default_value(2), "DBSCAN min_pts")
( "eps", bpo::value<float>(&eps), "DBSCAN eps" )
( "filename", bpo::value<std::string>(&filename), "filename containing data" )
( "impl", bpo::value<std::string>(&implementation)->default_value("fdbscan"), "implementation (\"fdbscan\" or \"fdbscan-densebox\")")
( "impl", bpo::value<Implementation>(&implementation)->default_value(Implementation::FDBSCAN), "implementation (\"fdbscan\" or \"fdbscan-densebox\")")
( "labels", bpo::value<std::string>(&filename_labels)->default_value(""), "clutering results output" )
( "max-num-points", bpo::value<int>(&max_num_points)->default_value(-1), "max number of points to read in")
( "output-sizes-and-centers", bpo::bool_switch(&print_sizes_centers)->default_value(false), "print cluster sizes and centers")
Expand All @@ -350,16 +390,8 @@ int main(int argc, char *argv[])
return 1;
}

if (implementation != "fdbscan" && implementation != "fdbscan-densebox")
{
std::cout << "Unknown implementation: \"" << implementation
<< "\". Valid values are [\"fdbscan\", \"fdbscan-densebox\"]."
<< std::endl;
return 2;
}
auto impl = (implementation == "fdbscan"
? ArborX::DBSCAN::Implementation::FDBSCAN
: ArborX::DBSCAN::Implementation::FDBSCAN_DenseBox);
std::stringstream ss;
ss << implementation;

// Print out the runtime parameters
printf("eps : %f\n", eps);
Expand All @@ -368,7 +400,7 @@ int main(int argc, char *argv[])
printf("filename : %s [%s, max_pts = %d]\n", filename.c_str(),
(binary ? "binary" : "text"), max_num_points);
printf("filename [labels] : %s [binary]\n", filename_labels.c_str());
printf("implementation : %s\n", implementation.c_str());
printf("implementation : %s\n", ss.str().c_str());
printf("samples : %d\n", num_samples);
printf("verify : %s\n", (verify ? "true" : "false"));
printf("print timers : %s\n", (print_dbscan_timers ? "true" : "false"));
Expand Down Expand Up @@ -403,7 +435,7 @@ int main(int argc, char *argv[])
auto labels = ArborX::dbscan(exec_space, primitives, eps, core_min_size,
ArborX::DBSCAN::Parameters()
.setPrintTimers(print_dbscan_timers)
.setImplementation(impl));
.setImplementation(implementation));

timer_start(timer);
Kokkos::View<int *, MemorySpace> cluster_indices("Testing::cluster_indices",
Expand Down

0 comments on commit c3a6d3d

Please sign in to comment.