Skip to content

Commit

Permalink
Add testing for MST with mutual reachability metric (#802)
Browse files Browse the repository at this point in the history
* Add testing for MST with mutual reachability metric

Testing mrd based on the edges output is hard because many edges may
have the same weight. Instead, we just test the total weight of the MST,
and compare to the reference weight computed externally.

* Switch back to using relative tolerance to get better error message

* Add missing <map> header
  • Loading branch information
aprokop committed Dec 28, 2022
1 parent 4b17922 commit 9f1f2d3
Showing 1 changed file with 33 additions and 0 deletions.
33 changes: 33 additions & 0 deletions test/tstMinimumSpanningTreeGoldenTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#include <boost/tokenizer.hpp>

#include <fstream>
#include <map>
#include <numeric> // accumulate

namespace Test
{
Expand Down Expand Up @@ -103,4 +105,35 @@ BOOST_AUTO_TEST_CASE_TEMPLATE(minimum_spanning_tree_golden_test, DeviceType,
reinterpret_cast<Test::UndirectedEdge const *>(edges.data()),
edges.size())),
boost::test_tools::per_element());

// clang-format off
// Computed with the following Python code with variable k:
// import numpy as np
// import hdbscan
// clusterer = hdbscan.HDBSCAN(min_cluster_size=k-1, gen_min_span_tree=True)
// filename = "mst_golden_test_points.csv"
// points = np.loadtxt(filename, delimiter=",", dtype="double", comments="#")
// clusterer.fit(points)
// print(np.sum(clusterer.minimum_spanning_tree_._mst[:,2]))
// clang-format on
std::map<int, double> ref_total_weight;
ref_total_weight[5] = 102.68084503576422;
ref_total_weight[10] = 138.0244333174116;
ref_total_weight[15] = 162.51948793942978;

std::map<int, double> total_weight;
for (auto k : {5, 10, 15})
{
auto edges = Kokkos::create_mirror_view_and_copy(
Kokkos::HostSpace{},
MinimumSpanningTree<MemorySpace>(exec_space, points, k).edges);
total_weight[k] = std::accumulate(
edges.data(), edges.data() + edges.size(), 0.,
[](auto const &sum, auto const &b) { return sum + b.weight; });
}

namespace tt = boost::test_tools;
BOOST_TEST(total_weight[5] == ref_total_weight[5], tt::tolerance(1e-8));
BOOST_TEST(total_weight[10] == ref_total_weight[10], tt::tolerance(1e-8));
BOOST_TEST(total_weight[15] == ref_total_weight[15], tt::tolerance(1e-8));
}

0 comments on commit 9f1f2d3

Please sign in to comment.