From 368f28a95952e65f98c457e7d27a5e172a66da38 Mon Sep 17 00:00:00 2001 From: Jon Malkin Date: Tue, 5 Apr 2022 16:49:14 -0700 Subject: [PATCH] ensure k is power of 2 to match java --- quantiles/include/quantiles_sketch.hpp | 1 + quantiles/include/quantiles_sketch_impl.hpp | 15 +++++++-- quantiles/test/quantiles_sketch_test.cpp | 36 +++++++++++---------- 3 files changed, 32 insertions(+), 20 deletions(-) diff --git a/quantiles/include/quantiles_sketch.hpp b/quantiles/include/quantiles_sketch.hpp index 075dd512..ec3b4754 100644 --- a/quantiles/include/quantiles_sketch.hpp +++ b/quantiles/include/quantiles_sketch.hpp @@ -502,6 +502,7 @@ class quantiles_sketch { template static std::pair deserialize_array(const void* bytes, size_t size, uint32_t num_items, uint32_t capcacity, const SerDe& serde, const Allocator& allocator); + static void check_k(uint16_t k); static void check_serial_version(uint8_t serial_version); static void check_header_validity(uint8_t preamble_longs, uint8_t flags_byte, uint8_t serial_version); static void check_family_id(uint8_t family_id); diff --git a/quantiles/include/quantiles_sketch_impl.hpp b/quantiles/include/quantiles_sketch_impl.hpp index 84f28672..f42d02fa 100644 --- a/quantiles/include/quantiles_sketch_impl.hpp +++ b/quantiles/include/quantiles_sketch_impl.hpp @@ -43,9 +43,7 @@ min_value_(nullptr), max_value_(nullptr), is_sorted_(true) { - if (k < quantiles_constants::MIN_K || k > quantiles_constants::MAX_K) { - throw std::invalid_argument("K must be >= " + std::to_string(quantiles_constants::MIN_K) + " and <= " + std::to_string(quantiles_constants::MAX_K) + ": " + std::to_string(k)); - } + check_k(k_); base_buffer_.reserve(2 * std::min(quantiles_constants::MIN_K, k)); } @@ -268,6 +266,7 @@ auto quantiles_sketch::deserialize(std::istream &is, const SerDe& serde const auto k = read(is); read(is); // unused + check_k(k); check_serial_version(serial_version); // a little redundant with the next line, but explicit checks check_family_id(family_id); check_header_validity(preamble_longs, flags_byte, serial_version); @@ -376,6 +375,7 @@ auto quantiles_sketch::deserialize(const void* bytes, size_t size, cons uint16_t unused; ptr += copy_from_mem(ptr, unused); + check_k(k); check_serial_version(serial_version); // a little redundant with the next line, but explicit checks check_family_id(family_id); check_header_validity(preamble_longs, flags_byte, serial_version); @@ -768,6 +768,15 @@ uint8_t quantiles_sketch::compute_levels_needed(const uint16_t k, const return static_cast(64U) - count_leading_zeros_in_u64(n / (2 * k)); } +template +void quantiles_sketch::check_k(uint16_t k) { + if (k < quantiles_constants::MIN_K || k > quantiles_constants::MAX_K || (k & k - 1) != 0) { + throw std::invalid_argument("k must be a power of 2 that is >= " + + std::to_string(quantiles_constants::MIN_K) + " and <= " + + std::to_string(quantiles_constants::MAX_K) + ". Found: " + std::to_string(k)); + } +} + template void quantiles_sketch::check_serial_version(uint8_t serial_version) { if (serial_version == SERIAL_VERSION || serial_version == SERIAL_VERSION_1 || serial_version == SERIAL_VERSION_2) diff --git a/quantiles/test/quantiles_sketch_test.cpp b/quantiles/test/quantiles_sketch_test.cpp index 847a761f..3c17ef8f 100644 --- a/quantiles/test/quantiles_sketch_test.cpp +++ b/quantiles/test/quantiles_sketch_test.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include @@ -49,6 +50,7 @@ TEST_CASE("quantiles sketch", "[quantiles_sketch]") { quantiles_float_sketch sketch1(quantiles_constants::MIN_K, 0); // this should work quantiles_float_sketch sketch2(quantiles_constants::MAX_K, 0); // this should work REQUIRE_THROWS_AS(new quantiles_float_sketch(quantiles_constants::MIN_K - 1, 0), std::invalid_argument); + REQUIRE_THROWS_AS(new quantiles_float_sketch(40, 0), std::invalid_argument); // not power of 2 // MAX_K + 1 makes no sense because k is uint16_t } @@ -75,7 +77,7 @@ TEST_CASE("quantiles sketch", "[quantiles_sketch]") { } SECTION("get bad quantile") { - quantiles_float_sketch sketch(200, 0); + quantiles_float_sketch sketch(64, 0); sketch.update(0.0f); // has to be non-empty to reach the check REQUIRE_THROWS_AS(sketch.get_quantile(-1), std::invalid_argument); } @@ -108,7 +110,7 @@ TEST_CASE("quantiles sketch", "[quantiles_sketch]") { } SECTION("NaN") { - quantiles_float_sketch sketch(200, 0); + quantiles_float_sketch sketch(256, 0); sketch.update(std::numeric_limits::quiet_NaN()); REQUIRE(sketch.is_empty()); @@ -119,7 +121,7 @@ TEST_CASE("quantiles sketch", "[quantiles_sketch]") { SECTION("sampling mode") { - const uint16_t k = 10; + const uint16_t k = 8; const uint32_t n = 16 * (2 * k) + 1; quantiles_float_sketch sk(k, 0); for (uint32_t i = 0; i < n; ++i) { @@ -170,7 +172,7 @@ TEST_CASE("quantiles sketch", "[quantiles_sketch]") { } SECTION("10 items") { - quantiles_float_sketch sketch(200, 0); + quantiles_float_sketch sketch(128, 0); sketch.update(1.0f); sketch.update(2.0f); sketch.update(3.0f); @@ -188,7 +190,7 @@ TEST_CASE("quantiles sketch", "[quantiles_sketch]") { } SECTION("100 items") { - quantiles_float_sketch sketch(200, 0); + quantiles_float_sketch sketch(128, 0); for (int i = 0; i < 100; ++i) sketch.update(static_cast(i)); REQUIRE(sketch.get_quantile(0) == 0); REQUIRE(sketch.get_quantile(0.01) == 1); @@ -237,7 +239,7 @@ TEST_CASE("quantiles sketch", "[quantiles_sketch]") { previous_quantile = quantile; } - std::cout << sketch.to_string(); + //std::cout << sketch.to_string(); uint32_t count = 0; uint64_t total_weight = 0; @@ -250,7 +252,7 @@ TEST_CASE("quantiles sketch", "[quantiles_sketch]") { } SECTION("consistency between get_rank and get_PMF/CDF") { - quantiles_float_sketch sketch(200, 0); + quantiles_float_sketch sketch(64, 0); const int n = 1000; float values[n]; for (int i = 0; i < n; i++) { @@ -276,7 +278,7 @@ TEST_CASE("quantiles sketch", "[quantiles_sketch]") { } SECTION("stream serialize deserialize empty") { - quantiles_float_sketch sketch(200, 0); + quantiles_float_sketch sketch(128, 0); std::stringstream s(std::ios::in | std::ios::out | std::ios::binary); sketch.serialize(s); REQUIRE(static_cast(s.tellp()) == sketch.get_serialized_size_bytes()); @@ -294,7 +296,7 @@ TEST_CASE("quantiles sketch", "[quantiles_sketch]") { } SECTION("bytes serialize deserialize empty") { - quantiles_float_sketch sketch(200, 0); + quantiles_float_sketch sketch(256, 0); auto bytes = sketch.serialize(); auto sketch2 = quantiles_float_sketch::deserialize(bytes.data(), bytes.size(), serde(), 0); REQUIRE(bytes.size() == sketch.get_serialized_size_bytes()); @@ -309,7 +311,7 @@ TEST_CASE("quantiles sketch", "[quantiles_sketch]") { } SECTION("stream serialize deserialize one item") { - quantiles_float_sketch sketch(200, 0); + quantiles_float_sketch sketch(32, 0); sketch.update(1.0f); std::stringstream s(std::ios::in | std::ios::out | std::ios::binary); sketch.serialize(s); @@ -329,7 +331,7 @@ TEST_CASE("quantiles sketch", "[quantiles_sketch]") { } SECTION("bytes serialize deserialize one item") { - quantiles_float_sketch sketch(200, 0); + quantiles_float_sketch sketch(64, 0); sketch.update(1.0f); auto bytes = sketch.serialize(); REQUIRE(bytes.size() == sketch.get_serialized_size_bytes()); @@ -347,7 +349,7 @@ TEST_CASE("quantiles sketch", "[quantiles_sketch]") { } SECTION("stream serialize deserialize three items") { - quantiles_float_sketch sketch(200, 0); + quantiles_float_sketch sketch(128, 0); sketch.update(1.0f); sketch.update(2.0f); sketch.update(3.0f); @@ -366,7 +368,7 @@ TEST_CASE("quantiles sketch", "[quantiles_sketch]") { } SECTION("bytes serialize deserialize three items") { - quantiles_float_sketch sketch(200, 0); + quantiles_float_sketch sketch(128, 0); sketch.update(1.0f); sketch.update(2.0f); sketch.update(3.0f); @@ -383,7 +385,7 @@ TEST_CASE("quantiles sketch", "[quantiles_sketch]") { } SECTION("stream serialize deserialize many floats") { - quantiles_float_sketch sketch(200, 0); + quantiles_float_sketch sketch(128, 0); const int n = 1000; for (int i = 0; i < n; i++) sketch.update(static_cast(i)); std::stringstream s(std::ios::in | std::ios::out | std::ios::binary); @@ -405,7 +407,7 @@ TEST_CASE("quantiles sketch", "[quantiles_sketch]") { REQUIRE(sketch2.get_rank(static_cast(n)) == sketch.get_rank(static_cast(n))); } SECTION("bytes serialize deserialize many floats") { - quantiles_float_sketch sketch(200, 0); + quantiles_float_sketch sketch(128, 0); const int n = 1000; for (int i = 0; i < n; i++) sketch.update(static_cast(i)); auto bytes = sketch.serialize(); @@ -453,7 +455,7 @@ TEST_CASE("quantiles sketch", "[quantiles_sketch]") { } SECTION("out of order split points, float") { - quantiles_float_sketch sketch(200, 0); + quantiles_float_sketch sketch(256, 0); sketch.update(0.0f); // has too be non-empty to reach the check float split_points[2] = {1, 0}; REQUIRE_THROWS_AS(sketch.get_CDF(split_points, 2), std::invalid_argument); @@ -467,7 +469,7 @@ TEST_CASE("quantiles sketch", "[quantiles_sketch]") { } SECTION("NaN split point") { - quantiles_float_sketch sketch(200, 0); + quantiles_float_sketch sketch(512, 0); sketch.update(0.0f); // has too be non-empty to reach the check float split_points[1] = {std::numeric_limits::quiet_NaN()}; REQUIRE_THROWS_AS(sketch.get_CDF(split_points, 1), std::invalid_argument);