diff --git a/README.md b/README.md index 5363713..975cd2a 100644 --- a/README.md +++ b/README.md @@ -67,7 +67,7 @@ Creates an interval where the borders are sorted so the lower border is the firs - [Example](#example-1) - [(const)iterator find_all(interval_type const& ival, OnFindFunctionT const& on_find, CompareFunctionT const& compare)](#constiterator-find_allinterval_type-const-ival-onfindfunctiont-const-on_find-comparefunctiont-const-compare) - [(const)iterator find_next_in_subtree(iterator from, interval_type const& ival)](#constiterator-find_next_in_subtreeiterator-from-interval_type-const-ival) - - [(const)iterator find_next(iterator from, interval_type const& ival, CompareFunctionT const& compare)](#constiterator-find_nextiterator-from-interval_type-const-ival-comparefunctiont-const-compare) + - [(const)iterator find_next_in_subtree(iterator from, interval_type const& ival, CompareFunctionT const& compare)](#constiterator-find_next_in_subtreeiterator-from-interval_type-const-ival-comparefunctiont-const-compare) - [(const)iterator overlap_find(interval_type const& ival, bool exclusive)](#constiterator-overlap_findinterval_type-const-ival-bool-exclusive) - [(const)iterator overlap_find_all(interval_type const& ival, OnFindFunctionT const& on_find, bool exclusive)](#constiterator-overlap_find_allinterval_type-const-ival-onfindfunctiont-const-on_find-bool-exclusive) - [Example](#example-2) @@ -127,11 +127,11 @@ Finds the first interval in the interval tree that has an exact match. --- ### (const)iterator find(interval_type const& ival, CompareFunctionT const& compare) -Finds the first interval in the interval tree that has the following statement evaluate to true: compare(ival, interval_in_tree); +Finds the first interval in the interval tree that has the following statement evaluate to true: compare(interval_in_tree, ival); Allows for propper float comparisons. #### Parameters * `ival` The interval to find. -* `compare` The compare function to compare intervals with. +* `compare` The compare function to compare intervals with. Function is called like so: compare(interval_in_tree, ival). **Returns**: An iterator to the found element, or std::end(tree). @@ -160,7 +160,7 @@ tree.find_all({3, 7}, [](auto iter) /* iter will be const_iterator if tree is co Find all intervals in the tree that the compare function returns true for. #### Parameters * `ival` The interval to find. -* `compare` The compare function to compare intervals with. +* `compare` The compare function to compare intervals with. Function is called like so: compare(interval_in_tree, ival). * `on_find` A function of type bool(iterator) that is called when an interval was found. Return true to continue, false to preemptively abort search. @@ -177,13 +177,13 @@ You cannot find all matches this way, use find_all for that. **Returns**: An iterator to the found element, or std::end(tree). --- -### (const)iterator find_next(iterator from, interval_type const& ival, CompareFunctionT const& compare) +### (const)iterator find_next_in_subtree(iterator from, interval_type const& ival, CompareFunctionT const& compare) Finds the next exact match EXCLUDING from in the subtree originating from "from". You cannot find all matches this way, use find_all for that. #### Parameters * `from` The iterator to start from (including this iterator!) * `ival` The interval to find. -* `compare` The compare function to compare intervals with. +* `compare` The compare function to compare intervals with. Function is called like so: compare(interval_in_tree, ival). **Returns**: An iterator to the found element, or std::end(tree). diff --git a/interval_tree.hpp b/interval_tree.hpp index 44ad751..a1ec4ce 100644 --- a/interval_tree.hpp +++ b/interval_tree.hpp @@ -37,6 +37,9 @@ namespace lib_interval_tree * Constructs an interval. low MUST be smaller than high. */ #ifndef INTERVAL_TREE_SAFE_INTERVALS +#if __cplusplus >= 201703L + constexpr +#endif interval(value_type low, value_type high) : low_{low} , high_{high} @@ -44,6 +47,9 @@ namespace lib_interval_tree assert(low <= high); } #else +#if __cplusplus >= 201703L + constexpr +#endif interval(value_type low, value_type high) : low_{std::min(low, high)} , high_{std::max(low, high)} @@ -173,6 +179,9 @@ namespace lib_interval_tree * Creates a safe interval that puts the lower bound left automatically. */ template +#if __cplusplus >= 201703L + constexpr +#endif interval make_safe_interval(numerical_type lhs, numerical_type rhs) { return interval {std::min(lhs, rhs), std::max(lhs, rhs)}; @@ -621,6 +630,7 @@ namespace lib_interval_tree using iterator = interval_tree_iterator ; using const_iterator = const_interval_tree_iterator ; using size_type = long long; + using this_type = interval_tree; public: friend const_interval_tree_iterator ; @@ -854,14 +864,14 @@ namespace lib_interval_tree { if (root_ == nullptr) return; - find_all_i(root_, ival, on_find, compare); + find_all_i(this, root_, ival, on_find, compare); } template void find_all(interval_type const& ival, FunctionT const& on_find, CompareFunctionT const& compare) const { if (root_ == nullptr) return; - find_all_i(root_, ival, on_find, compare); + find_all_i(this, root_, ival, on_find, compare); } template @@ -950,9 +960,9 @@ namespace lib_interval_tree if (root_ == nullptr) return; if (exclusive) - overlap_find_all_i(root_, ival, on_find); + overlap_find_all_i(this, root_, ival, on_find); else - overlap_find_all_i(root_, ival, on_find); + overlap_find_all_i(this, root_, ival, on_find); } template void overlap_find_all(interval_type const& ival, FunctionT const& on_find, bool exclusive = false) const @@ -960,9 +970,9 @@ namespace lib_interval_tree if (root_ == nullptr) return; if (exclusive) - overlap_find_all_i(root_, ival, on_find); + overlap_find_all_i(this, root_, ival, on_find); else - overlap_find_all_i(root_, ival, on_find); + overlap_find_all_i(this, root_, ival, on_find); } /** @@ -1114,36 +1124,43 @@ namespace lib_interval_tree return nullptr; }; - template - bool find_all_i(node_type* ptr, interval_type const& ival, FunctionT const& on_find, ComparatorFunctionT const& compare) + template + static bool find_all_i + ( + typename std::conditional::value, ThisType, ThisType const>::type* self, + node_type* ptr, + interval_type const& ival, + FunctionT const& on_find, + ComparatorFunctionT const& compare + ) { if (compare(ptr->interval(), ival)) { - if (!on_find(IteratorT{ptr, this})) + if (!on_find(IteratorT{ptr, self})) return false; } if (ptr->left_ && ival.high() <= ptr->left_->max()) { // no right? can only continue left if (!ptr->right_ || ival.low() > ptr->right_->max()) - return find_all_i(ptr->left_, ival, on_find, compare); + return find_all_i(self, ptr->left_, ival, on_find, compare); - if (!find_all_i(ptr->left_, ival, on_find, compare)) + if (!find_all_i(self, ptr->left_, ival, on_find, compare)) return false; } if (ptr->right_ && ival.high() <= ptr->right_->max()) { if (!ptr->left_ || ival.low() > ptr->left_->max()) - return find_all_i(ptr->right_, ival, on_find, compare); + return find_all_i(self, ptr->right_, ival, on_find, compare); - if (!find_all_i(ptr->right_, ival, on_find, compare)) + if (!find_all_i(self, ptr->right_, ival, on_find, compare)) return false; } return true; } template - node_type* find_i(node_type* ptr, interval_type const& ival, ComparatorFunctionT const& compare) + node_type* find_i(node_type* ptr, interval_type const& ival, ComparatorFunctionT const& compare) const { if (compare(ptr->interval(), ival)) return ptr; @@ -1153,7 +1170,7 @@ namespace lib_interval_tree // excludes ptr template - node_type* find_i_ex(node_type* ptr, interval_type const& ival, ComparatorFunctionT const& compare) + node_type* find_i_ex(node_type* ptr, interval_type const& ival, ComparatorFunctionT const& compare) const { if (ptr->left_ && ival.high() <= ptr->left_->max()) { @@ -1178,7 +1195,7 @@ namespace lib_interval_tree } template - node_type* overlap_find_i(node_type* ptr, interval_type const& ival) + node_type* overlap_find_i(node_type* ptr, interval_type const& ival) const { #if __cplusplus >= 201703L if constexpr (Exclusive) @@ -1198,8 +1215,14 @@ namespace lib_interval_tree return overlap_find_i_ex(ptr, ival); } - template - bool overlap_find_all_i(node_type* ptr, interval_type const& ival, FunctionT const& on_find) + template + static bool overlap_find_all_i + ( + typename std::conditional::value, ThisType, ThisType const>::type* self, + node_type* ptr, + interval_type const& ival, + FunctionT const& on_find + ) { #if __cplusplus >= 201703L if constexpr (Exclusive) @@ -1209,7 +1232,7 @@ namespace lib_interval_tree { if (ptr->interval().overlaps_exclusive(ival)) { - if (!on_find(IteratorT{ptr, this})) + if (!on_find(IteratorT{ptr, self})) { return false; } @@ -1219,7 +1242,7 @@ namespace lib_interval_tree { if (ptr->interval().overlaps(ival)) { - if (!on_find(IteratorT{ptr, this})) + if (!on_find(IteratorT{ptr, self})) { return false; } @@ -1230,17 +1253,17 @@ namespace lib_interval_tree // no right? can only continue left // or interval low is bigger than max of right branch. if (!ptr->right_ || ival.low() > ptr->right_->max()) - return overlap_find_all_i(ptr->left_, ival, on_find); + return overlap_find_all_i(self, ptr->left_, ival, on_find); - if (!overlap_find_all_i(ptr->left_, ival, on_find)) + if (!overlap_find_all_i(self, ptr->left_, ival, on_find)) return false; } if (ptr->right_ && ptr->right_->max() >= ival.low()) { if (!ptr->left_ || ival.low() > ptr->right_->max()) - return overlap_find_all_i(ptr->right_, ival, on_find); + return overlap_find_all_i(self, ptr->right_, ival, on_find); - if (!overlap_find_all_i(ptr->right_, ival, on_find)) + if (!overlap_find_all_i(self, ptr->right_, ival, on_find)) return false; } return true; @@ -1248,7 +1271,7 @@ namespace lib_interval_tree // excludes ptr template - node_type* overlap_find_i_ex(node_type* ptr, interval_type const& ival) + node_type* overlap_find_i_ex(node_type* ptr, interval_type const& ival) const { if (ptr->left_ && ptr->left_->max() >= ival.low()) { diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index b7b7462..a66e9b7 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -24,7 +24,7 @@ file(GLOB sources "*.cpp") # Add Executable add_executable(tree-tests ${sources}) -target_link_libraries(tree-tests gtest) +target_link_libraries(tree-tests gtest gmock) # Options if(DRAW_EXAMPLES) diff --git a/tests/find_tests.hpp b/tests/find_tests.hpp index 24b53a8..2624c69 100644 --- a/tests/find_tests.hpp +++ b/tests/find_tests.hpp @@ -32,6 +32,15 @@ TEST_F(FindTests, WillFindRoot) EXPECT_EQ(tree.find({0, 1}), std::begin(tree)); } +TEST_F(FindTests, WillFindRootOnConstTree) +{ + tree.insert({0, 1}); + [](auto const& tree) + { + EXPECT_EQ(tree.find({0, 1}), std::begin(tree)); + }(tree); +} + TEST_F(FindTests, WillFindInBiggerTree) { tree.insert({16, 21}); @@ -137,3 +146,23 @@ TEST_F(FindTests, CanFindAllElementsBackInStrictlyAscendingOverlappingIntervals) ASSERT_NE(tree.find(ival), std::end(tree)); } } + +TEST_F(FindTests, CanFindAllOnConstTree) +{ + const auto targetInterval = lib_interval_tree::make_safe_interval(16, 21); + tree.insert(targetInterval); + tree.insert({8, 9}); + tree.insert({25, 30}); + std::vector intervals; + auto findWithConstTree = [&intervals, &targetInterval](auto const& tree) + { + tree.find_all(targetInterval, [&intervals](auto const& iter) { + intervals.emplace_back(*iter); + return true; + }); + }; + findWithConstTree(tree); + + ASSERT_EQ(intervals.size(), 1); + EXPECT_EQ(intervals[0], targetInterval); +} \ No newline at end of file diff --git a/tests/interval_tests.hpp b/tests/interval_tests.hpp index 869bf8a..d8c016a 100644 --- a/tests/interval_tests.hpp +++ b/tests/interval_tests.hpp @@ -1,7 +1,5 @@ #pragma once -#include - #include class IntervalTests @@ -30,13 +28,13 @@ class DistanceTests { public: using types = IntervalTypes ; -}; +}; TEST_F(IntervalTests, FailBadBorders) { auto f = []() { - [[maybe_unused]] auto ival = types::interval_type{1 BOOST_PP_COMMA() 0}; + [[maybe_unused]] auto ival = types::interval_type{1, 0}; }; EXPECT_DEATH(f(), "low <= high"); diff --git a/tests/overlap_find_tests.hpp b/tests/overlap_find_tests.hpp index dd492bd..ee2c027 100644 --- a/tests/overlap_find_tests.hpp +++ b/tests/overlap_find_tests.hpp @@ -28,6 +28,14 @@ TEST_F(OverlapFindTests, WillFindOverlapWithRoot) EXPECT_EQ(tree.overlap_find({2, 7}), std::begin(tree)); } +TEST_F(OverlapFindTests, WillFindOverlapWithRootOnConstTree) +{ + tree.insert({2, 4}); + [](auto const& tree) { + EXPECT_EQ(tree.overlap_find({2, 7}), std::begin(tree)); + }(tree); +} + TEST_F(OverlapFindTests, WillFindOverlapWithRootIfMatchingExactly) { tree.insert({2, 7}); @@ -149,4 +157,24 @@ TEST_F(OverlapFindTests, WillFindSingleOverlapInBiggerTree) EXPECT_NE(iter, std::end(tree)); EXPECT_EQ(iter->low(), 1000); EXPECT_EQ(iter->high(), 2000); -} \ No newline at end of file +} + +TEST_F(FindTests, CanOverlapFindAllOnConstTree) +{ + const auto targetInterval = lib_interval_tree::make_safe_interval(16, 21); + tree.insert(targetInterval); + tree.insert({8, 9}); + tree.insert({25, 30}); + std::vector intervals; + auto findWithConstTree = [&intervals, &targetInterval](auto const& tree) + { + tree.overlap_find_all(targetInterval, [&intervals](auto const& iter) { + intervals.emplace_back(*iter); + return true; + }); + }; + findWithConstTree(tree); + + ASSERT_EQ(intervals.size(), 1); + EXPECT_EQ(intervals[0], targetInterval); +}