diff --git a/Cargo.lock b/Cargo.lock index 2fc4da2..a849360 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -412,19 +412,17 @@ dependencies = [ [[package]] name = "asap_sketchlib" version = "0.1.0" -source = "git+https://github.com/ProjectASAP/asap_sketchlib?branch=refactor%2Fadopt-sketch-core-modules#d84ff152c7ac7c90b97bf2fbe0d88f28c147d7a6" +source = "git+https://github.com/ProjectASAP/asap_sketchlib#19035220b7d999d1e12ca927574557fb702c2741" dependencies = [ "bytes", "prost", "prost-build", - "protoc-bin-vendored", "rand 0.9.4", "rmp-serde", "serde", "serde-big-array", "smallvec", "twox-hash 2.1.2", - "xxhash-rust", ] [[package]] @@ -579,6 +577,12 @@ dependencies = [ "tracing", ] +[[package]] +name = "base64" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" + [[package]] name = "base64" version = "0.21.7" @@ -665,6 +669,17 @@ dependencies = [ "alloc-stdlib", ] +[[package]] +name = "bstr" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba3569f383e8f1598449f1a423e72e99569137b47740b1da11ef19af3d5c3223" +dependencies = [ + "lazy_static", + "memchr", + "regex-automata 0.1.10", +] + [[package]] name = "bumpalo" version = "3.20.2" @@ -863,6 +878,17 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c8d4a3bb8b1e0c1050499d1815f5ab16d04f0959b233085fb31653fbfc9d98f9" +[[package]] +name = "codespan-reporting" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af491d569909a7e4dee0ad7db7f5341fef5c614d5b8ec8cf765732aba3cff681" +dependencies = [ + "serde", + "termcolor", + "unicode-width 0.2.2", +] + [[package]] name = "colorchoice" version = "1.0.5" @@ -1074,6 +1100,78 @@ dependencies = [ "memchr", ] +[[package]] +name = "ctor" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a2785755761f3ddc1492979ce1e48d2c00d09311c39e4466429188f3dd6501" +dependencies = [ + "quote", + "syn 2.0.117", +] + +[[package]] +name = "cxx" +version = "1.0.194" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "747d8437319e3a2f43d93b341c137927ca70c0f5dabeea7a005a73665e247c7e" +dependencies = [ + "cc", + "cxx-build", + "cxxbridge-cmd", + "cxxbridge-flags", + "cxxbridge-macro", + "foldhash 0.2.0", + "link-cplusplus", +] + +[[package]] +name = "cxx-build" +version = "1.0.194" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0f4697d190a142477b16aef7da8a99bfdc41e7e8b1687583c0d23a79c7afc1e" +dependencies = [ + "cc", + "codespan-reporting", + "indexmap 2.13.1", + "proc-macro2", + "quote", + "scratch", + "syn 2.0.117", +] + +[[package]] +name = "cxxbridge-cmd" +version = "1.0.194" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0956799fa8678d4c50eed028f2de1c0552ae183c76e976cf7ca8c4e36a7c328" +dependencies = [ + "clap 4.6.0", + "codespan-reporting", + "indexmap 2.13.1", + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "cxxbridge-flags" +version = "1.0.194" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23384a836ab4f0ad98ace7e3955ad2de39de42378ab487dc28d3990392cb283a" + +[[package]] +name = "cxxbridge-macro" +version = "1.0.194" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6acc6b5822b9526adfb4fc377b67128fdd60aac757cc4a741a6278603f763cf" +dependencies = [ + "indexmap 2.13.1", + "proc-macro2", + "quote", + "syn 2.0.117", +] + [[package]] name = "dashmap" version = "5.5.3" @@ -1554,6 +1652,22 @@ dependencies = [ "syn 2.0.117", ] +[[package]] +name = "dsrs" +version = "0.6.1" +source = "git+https://github.com/ProjectASAP/datasketches-rs?rev=d748ec75c80fff21f7b24897244dd1c895df2e9a#d748ec75c80fff21f7b24897244dd1c895df2e9a" +dependencies = [ + "base64 0.13.1", + "bstr", + "cxx", + "cxx-build", + "memchr", + "rmp-serde", + "serde", + "structopt", + "thin-dst", +] + [[package]] name = "either" version = "1.15.0" @@ -1679,6 +1793,12 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" +[[package]] +name = "foldhash" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb" + [[package]] name = "foreign-types" version = "0.3.2" @@ -1924,7 +2044,7 @@ version = "0.15.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" dependencies = [ - "foldhash", + "foldhash 0.1.5", ] [[package]] @@ -2551,6 +2671,15 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "link-cplusplus" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f78c730aaa7d0b9336a299029ea49f9ee53b0ed06e9202e8cb7db9bae7b8c82" +dependencies = [ + "cc", +] + [[package]] name = "linux-raw-sys" version = "0.12.1" @@ -2658,7 +2787,7 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9" dependencies = [ - "regex-automata", + "regex-automata 0.4.14", ] [[package]] @@ -2739,7 +2868,7 @@ version = "0.50.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" dependencies = [ - "windows-sys 0.61.2", + "windows-sys 0.59.0", ] [[package]] @@ -3413,70 +3542,6 @@ version = "2.28.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "106dd99e98437432fed6519dedecfade6a06a73bb7b2a1e019fdd2bee5778d94" -[[package]] -name = "protoc-bin-vendored" -version = "3.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1c381df33c98266b5f08186583660090a4ffa0889e76c7e9a5e175f645a67fa" -dependencies = [ - "protoc-bin-vendored-linux-aarch_64", - "protoc-bin-vendored-linux-ppcle_64", - "protoc-bin-vendored-linux-s390_64", - "protoc-bin-vendored-linux-x86_32", - "protoc-bin-vendored-linux-x86_64", - "protoc-bin-vendored-macos-aarch_64", - "protoc-bin-vendored-macos-x86_64", - "protoc-bin-vendored-win32", -] - -[[package]] -name = "protoc-bin-vendored-linux-aarch_64" -version = "3.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c350df4d49b5b9e3ca79f7e646fde2377b199e13cfa87320308397e1f37e1a4c" - -[[package]] -name = "protoc-bin-vendored-linux-ppcle_64" -version = "3.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a55a63e6c7244f19b5c6393f025017eb5d793fd5467823a099740a7a4222440c" - -[[package]] -name = "protoc-bin-vendored-linux-s390_64" -version = "3.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1dba5565db4288e935d5330a07c264a4ee8e4a5b4a4e6f4e83fad824cc32f3b0" - -[[package]] -name = "protoc-bin-vendored-linux-x86_32" -version = "3.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8854774b24ee28b7868cd71dccaae8e02a2365e67a4a87a6cd11ee6cdbdf9cf5" - -[[package]] -name = "protoc-bin-vendored-linux-x86_64" -version = "3.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b38b07546580df720fa464ce124c4b03630a6fb83e05c336fea2a241df7e5d78" - -[[package]] -name = "protoc-bin-vendored-macos-aarch_64" -version = "3.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89278a9926ce312e51f1d999fee8825d324d603213344a9a706daa009f1d8092" - -[[package]] -name = "protoc-bin-vendored-macos-x86_64" -version = "3.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81745feda7ccfb9471d7a4de888f0652e806d5795b61480605d4943176299756" - -[[package]] -name = "protoc-bin-vendored-win32" -version = "3.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95067976aca6421a523e491fce939a3e65249bac4b977adee0ee9771568e8aa3" - [[package]] name = "psm" version = "0.1.30" @@ -3505,9 +3570,11 @@ dependencies = [ "clap 4.6.0", "criterion", "csv", + "ctor", "dashmap 5.5.3", "datafusion", "datafusion_summary_library", + "dsrs", "elastic_dsl_utilities", "flate2", "form_urlencoded", @@ -3527,6 +3594,7 @@ dependencies = [ "serde", "serde_json", "serde_yaml", + "sketch-core", "snap", "sql_utilities", "sqlparser 0.59.0", @@ -3721,10 +3789,16 @@ checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276" dependencies = [ "aho-corasick", "memchr", - "regex-automata", + "regex-automata 0.4.14", "regex-syntax", ] +[[package]] +name = "regex-automata" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" + [[package]] name = "regex-automata" version = "0.4.14" @@ -3882,6 +3956,12 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "scratch" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d68f2ec51b097e4c1a75b681a8bec621909b5e91f15bb7b840c4f2f7b01148b2" + [[package]] name = "security-framework" version = "3.7.0" @@ -4053,6 +4133,19 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b2aa850e253778c88a04c3d7323b043aeda9d3e30d5971937c1855769763678e" +[[package]] +name = "sketch-core" +version = "0.3.0" +dependencies = [ + "asap_sketchlib", + "clap 4.6.0", + "ctor", + "dsrs", + "rmp-serde", + "serde", + "xxhash-rust", +] + [[package]] name = "slab" version = "0.4.12" @@ -4181,7 +4274,6 @@ dependencies = [ "cfg-if", "libc", "psm", - "windows-sys 0.52.0", "windows-sys 0.59.0", ] @@ -4334,6 +4426,15 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "termcolor" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755" +dependencies = [ + "winapi-util", +] + [[package]] name = "textwrap" version = "0.11.0" @@ -4343,6 +4444,12 @@ dependencies = [ "unicode-width 0.1.14", ] +[[package]] +name = "thin-dst" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db3c46be180f1af9673ebb27bc1235396f61ef6965b3fe0dbb2e624deb604f0e" + [[package]] name = "thiserror" version = "1.0.69" @@ -4711,7 +4818,7 @@ dependencies = [ "matchers", "nu-ansi-term", "once_cell", - "regex-automata", + "regex-automata 0.4.14", "sharded-slab", "smallvec", "thread_local", diff --git a/Cargo.toml b/Cargo.toml index c5b0afe..8e7b744 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,7 @@ [workspace] resolver = "2" members = [ + "asap-common/sketch-core", "asap-common/dependencies/rs/promql_utilities", "asap-common/dependencies/rs/sql_utilities", "asap-common/dependencies/rs/elastic_dsl_utilities", @@ -31,6 +32,7 @@ promql-parser = "0.5.0" tokio = { version = "1.0", features = ["full"] } # Internal crates +sketch-core = { path = "asap-common/sketch-core" } promql_utilities = { path = "asap-common/dependencies/rs/promql_utilities" } sql_utilities = { path = "asap-common/dependencies/rs/sql_utilities" } asap_types = { path = "asap-common/dependencies/rs/asap_types" } diff --git a/asap-common/sketch-core/Cargo.toml b/asap-common/sketch-core/Cargo.toml new file mode 100644 index 0000000..fe9cf24 --- /dev/null +++ b/asap-common/sketch-core/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "sketch-core" +version.workspace = true +edition.workspace = true + +[dependencies] +serde.workspace = true +# sketch-core-specific, keep pinned +rmp-serde = "1.1" +xxhash-rust = { version = "0.8", features = ["xxh32"] } +dsrs = { git = "https://github.com/ProjectASAP/datasketches-rs", rev = "d748ec75c80fff21f7b24897244dd1c895df2e9a" } +asap_sketchlib = { git = "https://github.com/ProjectASAP/asap_sketchlib" } +clap = { version = "4.0", features = ["derive"] } + +[dev-dependencies] +ctor = "0.2" diff --git a/asap-common/sketch-core/report.md b/asap-common/sketch-core/report.md new file mode 100644 index 0000000..9dfde74 --- /dev/null +++ b/asap-common/sketch-core/report.md @@ -0,0 +1,125 @@ +# Report + +Compares the **legacy** sketch implementations in `sketch-core` vs the **asap_sketchlib** backends (Count-Min Sketch, Count-Min-With-Heap, KLL, HydraKLL). + +## Fidelity harness + +The fidelity binary selects backends via CLI flags (`--cms-impl`, `--kll-impl`, `--cmwh-impl`). + +| Goal | Command | +|--------------------------|--------------------------------------------------------------------------------------------------------------| +| Default (all sketchlib) | `cargo run -p sketch-core --bin sketchlib_fidelity` | +| All legacy | `cargo run -p sketch-core --bin sketchlib_fidelity -- --cms-impl legacy --kll-impl legacy --cmwh-impl legacy` | +| Legacy KLL only | `cargo run -p sketch-core --bin sketchlib_fidelity -- --cms-impl sketchlib --kll-impl legacy --cmwh-impl sketchlib` | +| CMS sketchlib only | `cargo run -p sketch-core --bin sketchlib_fidelity -- --cms-impl sketchlib` | +| CMS legacy only | `cargo run -p sketch-core --bin sketchlib_fidelity -- --cms-impl legacy` | + +## Unit tests + +Unit tests always run with **legacy** backends enabled (the test ctor calls +`force_legacy_mode_for_tests()`), so you only need: + +```bash +cargo test -p sketch-core +``` + +## Results + +### CountMinSketch (accuracy vs exact counts) + +#### depth=3 + +| width | n | domain | Mode | Pearson corr | MAPE (%) | RMSE (%) | +|-------|--------|--------|----------------|----------------|----------|----------| +| 1024 | 100000 | 1000 | Legacy | 0.9998451189 | 24.48 | 52.76 | +| 1024 | 100000 | 1000 | asap_sketchlib | 0.9998387103 | 24.36 | 54.11 | + +#### depth=5 + +| width | n | domain | Mode | Pearson corr | MAPE (%) | RMSE (%) | +|-------|--------|--------|----------------|----------------|----------|----------| +| 2048 | 200000 | 2000 | Legacy | 0.9999733814 | 8.75 | 29.94 | +| 2048 | 200000 | 2000 | asap_sketchlib | 0.9999744627 | 8.37 | 28.84 | +| 2048 | 50000 | 500 | Legacy | 1.0000000000 | 0.00 | 0.00 | +| 2048 | 50000 | 500 | asap_sketchlib | 1.0000000000 | 0.00 | 0.00 | + +#### depth=7 + +| width | n | domain | Mode | Pearson corr | MAPE (%) | RMSE (%) | +|-------|--------|--------|----------------|----------------|----------|----------| +| 4096 | 200000 | 2000 | Legacy | 0.9999993694 | 0.20 | 3.69 | +| 4096 | 200000 | 2000 | asap_sketchlib | 0.9999993499 | 0.21 | 4.27 | + +--- + +### CountMinSketchWithHeap (top-k + CMS accuracy on exact top-k) + +The heap is maintained by local updates; recall is measured against the **true** top-k at the end of the stream. + +#### depth=3 + +| width | n | domain | heap_size | Mode | Top-k recall | Pearson (top-k) | MAPE (%) | RMSE (%) | +|-------|--------|--------|-----------|----------------|--------------|-----------------|----------|----------| +| 1024 | 100000 | 1000 | 10 | Legacy | 0.40 | 0.9571 | 0.174 | 0.319 | +| 1024 | 100000 | 1000 | 10 | asap_sketchlib | 0.80 | 1.0000 | 0.000 | 0.000 | + +#### depth=5 + +| width | n | domain | heap_size | Mode | Top-k recall | Pearson (top-k) | MAPE (%) | RMSE (%) | +|-------|--------|--------|-----------|----------------|--------------|-----------------|----------|----------| +| 2048 | 200000 | 2000 | 20 | Legacy | 0.60 | 0.9964 | 0.045 | 0.101 | +| 2048 | 200000 | 2000 | 20 | asap_sketchlib | 1.00 | 0.9982 | 0.021 | 0.067 | +| 2048 | 200000 | 2000 | 50 | Legacy | 0.40 | 0.9999983 | 5.60 | 16.49 | +| 2048 | 200000 | 2000 | 50 | asap_sketchlib | 0.48 | 0.9999990 | 3.90 | 12.95 | + +--- + +### KllSketch (quantiles, absolute rank error) + +For each quantile \(q\), we compute the sketch estimate `est_value`, then: +`abs_rank_error = |rank_fraction(exact_sorted_values, est_value) - q|`. + +#### k=20 + +| n_updates | Mode | q=0.5 | q=0.9 | q=0.99 | +|-----------|----------------|---------|---------|---------| +| 200000 | Legacy | 0.0104 | 0.0145 | 0.0028 | +| 200000 | asap_sketchlib | 0.0275 | 0.0470 | 0.0061 | +| 50000 | Legacy | 0.0131 | 0.0091 | 0.0054 | +| 50000 | asap_sketchlib | 0.0110 | 0.0116 | 0.0031 | + +#### k=50 + +| n_updates | Mode | q=0.5 | q=0.9 | q=0.99 | +|-----------|----------------|---------|---------|---------| +| 200000 | Legacy | 0.0013 | 0.0021 | 0.0012 | +| 200000 | asap_sketchlib | 0.0101 | 0.0044 | 0.0074 | + +#### k=200 + +| n_updates | Mode | q=0.5 | q=0.9 | q=0.99 | +|-----------|----------------|---------|---------|---------| +| 200000 | Legacy | 0.0021 | 0.0036 | 0.0000 | +| 200000 | asap_sketchlib | 0.0015 | 0.0001 | 0.0002 | + +--- + +### HydraKllSketch (per-key quantiles, mean/max absolute rank error across 50 keys) + +#### rows=2, cols=64 + +| k | n | domain | Mode | q=0.5 (mean / max) | q=0.9 (mean / max) | +|-----|--------|--------|----------------|--------------------|--------------------| +| 20 | 200000 | 200 | Legacy | 0.0170 / 0.0546 | 0.0165 / 0.0452 | +| 20 | 200000 | 200 | asap_sketchlib | 0.0254 / 0.0629 | 0.0546 / 0.0942 | + +#### rows=3, cols=128 + +| k | n | domain | Mode | q=0.5 (mean / max) | q=0.9 (mean / max) | +|-----|--------|--------|----------------|--------------------|--------------------| +| 20 | 200000 | 200 | Legacy | 0.0166 / 0.0591 | 0.0114 / 0.0304 | +| 20 | 200000 | 200 | asap_sketchlib | 0.0216 / 0.0534 | 0.0238 / 0.1087 | +| 50 | 200000 | 200 | Legacy | 0.0099 / 0.0352 | 0.0087 / 0.0330 | +| 50 | 200000 | 200 | asap_sketchlib | 0.0119 / 0.0458 | 0.0119 / 0.0296 | +| 20 | 100000 | 100 | Legacy | 0.0141 / 0.0574 | 0.0149 / 0.0471 | +| 20 | 100000 | 100 | asap_sketchlib | 0.0202 / 0.0621 | 0.0287 / 0.0779 | diff --git a/asap-common/sketch-core/src/bin/sketchlib_fidelity.rs b/asap-common/sketch-core/src/bin/sketchlib_fidelity.rs new file mode 100644 index 0000000..ca95cb6 --- /dev/null +++ b/asap-common/sketch-core/src/bin/sketchlib_fidelity.rs @@ -0,0 +1,496 @@ +// Fidelity benchmarks comparing legacy vs sketchlib implementations across sketch types. +#![allow(dead_code)] + +use std::collections::HashMap; + +use clap::Parser; +use sketch_core::config::{self, ImplMode}; +use sketch_core::count_min::CountMinSketch; +use sketch_core::count_min_with_heap::CountMinSketchWithHeap; +use sketch_core::hydra_kll::HydraKllSketch; +use sketch_core::kll::KllSketch; + +#[derive(Clone)] +struct Lcg64 { + state: u64, +} + +impl Lcg64 { + fn new(seed: u64) -> Self { + Self { state: seed } + } + + fn next_u64(&mut self) -> u64 { + self.state = self + .state + .wrapping_mul(6364136223846793005) + .wrapping_add(1442695040888963407); + self.state + } + + fn next_f64_0_1(&mut self) -> f64 { + let x = self.next_u64() >> 11; + (x as f64) / ((1u64 << 53) as f64) + } +} + +fn pearson_corr(exact: &[f64], est: &[f64]) -> f64 { + let n = exact.len().min(est.len()); + if n == 0 { + return f64::NAN; + } + let (mut sum_x, mut sum_y) = (0.0, 0.0); + for i in 0..n { + sum_x += exact[i]; + sum_y += est[i]; + } + let mean_x = sum_x / (n as f64); + let mean_y = sum_y / (n as f64); + let (mut num, mut den_x, mut den_y) = (0.0, 0.0, 0.0); + for i in 0..n { + let dx = exact[i] - mean_x; + let dy = est[i] - mean_y; + num += dx * dy; + den_x += dx * dx; + den_y += dy * dy; + } + if den_x == 0.0 || den_y == 0.0 { + return f64::NAN; + } + num / (den_x.sqrt() * den_y.sqrt()) +} + +fn mape(exact: &[f64], est: &[f64]) -> f64 { + let n = exact.len().min(est.len()); + let mut num = 0.0; + let mut denom = 0.0; + for i in 0..n { + if exact[i] == 0.0 { + continue; + } + num += ((exact[i] - est[i]) / exact[i]).abs(); + denom += 1.0; + } + if denom == 0.0 { + return if exact == est { 0.0 } else { f64::INFINITY }; + } + (num / denom) * 100.0 +} + +fn rmse_percentage(exact: &[f64], est: &[f64]) -> f64 { + let n = exact.len().min(est.len()); + let mut sum_sq = 0.0; + let mut denom = 0.0; + for i in 0..n { + if exact[i] == 0.0 { + continue; + } + let rel = (exact[i] - est[i]) / exact[i]; + sum_sq += rel * rel; + denom += 1.0; + } + if denom == 0.0 { + return if exact == est { 0.0 } else { f64::INFINITY }; + } + (sum_sq / denom).sqrt() * 100.0 +} + +#[derive(Parser)] +struct Args { + #[arg(long, value_enum, default_value_t = sketch_core::config::DEFAULT_CMS_IMPL)] + cms_impl: ImplMode, + #[arg(long, value_enum, default_value_t = sketch_core::config::DEFAULT_KLL_IMPL)] + kll_impl: ImplMode, + #[arg(long, value_enum, default_value_t = sketch_core::config::DEFAULT_CMWH_IMPL)] + cmwh_impl: ImplMode, +} + +fn rank_fraction(sorted: &[f64], x: f64) -> f64 { + if sorted.is_empty() { + return 0.0; + } + let idx = sorted.partition_point(|v| *v <= x); + (idx as f64) / (sorted.len() as f64) +} + +// --- CountMinSketch parameter sets and runner --- + +struct CmsParams { + depth: usize, + width: usize, + n: usize, + domain: usize, +} + +struct CmsResult { + pearson: f64, + mape: f64, + rmse: f64, +} + +fn run_countmin_once(seed: u64, p: &CmsParams) -> CmsResult { + let mut rng = Lcg64::new(seed); + let mut exact: Vec = vec![0.0; p.domain]; + let mut cms = CountMinSketch::new(p.depth, p.width); + + for _ in 0..p.n { + let r = rng.next_u64(); + let key_id = if (r & 0xFF) < 200 { + (r as usize) % 20 + } else { + (r as usize) % p.domain + }; + let key = format!("k{key_id}"); + cms.update(&key, 1.0); + exact[key_id] += 1.0; + } + + let mut est: Vec = Vec::with_capacity(p.domain); + for key_id in 0..p.domain { + let key = format!("k{key_id}"); + est.push(cms.query_key(&key)); + } + + CmsResult { + pearson: pearson_corr(&exact, &est), + mape: mape(&exact, &est), + rmse: rmse_percentage(&exact, &est), + } +} + +// --- CountMinSketchWithHeap --- + +struct CmwhParams { + depth: usize, + width: usize, + n: usize, + domain: usize, + heap_size: usize, +} + +struct CmwhResult { + topk_recall: f64, + pearson: f64, + mape: f64, + rmse: f64, +} + +fn run_countmin_with_heap_once(seed: u64, p: &CmwhParams) -> CmwhResult { + let mut rng = Lcg64::new(seed ^ 0xA5A5_A5A5); + let mut exact: Vec = vec![0.0; p.domain]; + let mut cms = CountMinSketchWithHeap::new(p.depth, p.width, p.heap_size); + + for _ in 0..p.n { + let r = rng.next_u64(); + let key_id = if (r & 0xFF) < 200 { + (r as usize) % 20 + } else { + (r as usize) % p.domain + }; + let key = format!("k{key_id}"); + cms.update(&key, 1.0); + exact[key_id] += 1.0; + } + + let mut exact_pairs: Vec<(usize, f64)> = exact.iter().copied().enumerate().collect(); + exact_pairs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); + exact_pairs.truncate(p.heap_size); + + let exact_top: HashMap = exact_pairs + .into_iter() + .map(|(k, v)| (format!("k{k}"), v)) + .collect(); + + let mut est_vals = Vec::with_capacity(exact_top.len()); + let mut exact_vals = Vec::with_capacity(exact_top.len()); + let mut hit = 0usize; + for item in cms.topk_heap_items() { + if exact_top.contains_key(&item.key) { + hit += 1; + } + } + for (k, v) in &exact_top { + exact_vals.push(*v); + est_vals.push(cms.query_key(k)); + } + + CmwhResult { + topk_recall: (hit as f64) / (p.heap_size as f64), + pearson: pearson_corr(&exact_vals, &est_vals), + mape: mape(&exact_vals, &est_vals), + rmse: rmse_percentage(&exact_vals, &est_vals), + } +} + +// --- KllSketch --- + +struct KllParams { + k: u16, + n: usize, +} + +struct KllResult { + rank_err_50: f64, + rank_err_90: f64, + rank_err_99: f64, +} + +fn run_kll_once(seed: u64, p: &KllParams) -> KllResult { + let mut rng = Lcg64::new(seed ^ 0x1234_5678); + let mut values: Vec = Vec::with_capacity(p.n); + let mut sk = KllSketch::new(p.k); + + for _ in 0..p.n { + let v = rng.next_f64_0_1() * 1_000_000.0; + values.push(v); + sk.update(v); + } + + values.sort_by(f64::total_cmp); + let qs = [0.5, 0.9, 0.99]; + let rank_err = |q: f64| (rank_fraction(&values, sk.get_quantile(q)) - q).abs(); + + KllResult { + rank_err_50: rank_err(qs[0]), + rank_err_90: rank_err(qs[1]), + rank_err_99: rank_err(qs[2]), + } +} + +// --- HydraKllSketch --- + +struct HydraKllParams { + rows: usize, + cols: usize, + k: u16, + n: usize, + domain: usize, + eval_keys: usize, +} + +struct HydraKllResult { + mean_50: f64, + max_50: f64, + mean_90: f64, + max_90: f64, +} + +fn run_hydra_kll_once(seed: u64, p: &HydraKllParams) -> HydraKllResult { + let mut rng = Lcg64::new(seed ^ 0xDEAD_BEEF); + let mut hydra = HydraKllSketch::new(p.rows, p.cols, p.k); + let mut exact: HashMap> = HashMap::new(); + + for _ in 0..p.n { + let r = rng.next_u64(); + let key_id = if (r & 0xFF) < 200 { + (r as usize) % 20 + } else { + (r as usize) % p.domain + }; + let key = format!("k{key_id}"); + let v = rng.next_f64_0_1() * 1_000_000.0; + hydra.update(&key, v); + exact.entry(key).or_default().push(v); + } + + let mut keys: Vec = exact.keys().cloned().collect(); + keys.sort(); + keys.truncate(p.eval_keys); + + let mut mean_50 = 0.0f64; + let mut max_50 = 0.0f64; + let mut mean_90 = 0.0f64; + let mut max_90 = 0.0f64; + let nk = keys.len() as f64; + for key in &keys { + let mut vals = exact.get(key).cloned().unwrap_or_default(); + vals.sort_by(f64::total_cmp); + for (q, mean_ref, max_ref) in [ + (0.5, &mut mean_50, &mut max_50), + (0.9, &mut mean_90, &mut max_90), + ] { + let est = hydra.query(key, q); + let err = (rank_fraction(&vals, est) - q).abs(); + *mean_ref += err; + if err > *max_ref { + *max_ref = err; + } + } + } + mean_50 /= nk; + mean_90 /= nk; + + HydraKllResult { + mean_50, + max_50, + mean_90, + max_90, + } +} + +fn main() { + let args = Args::parse(); + config::configure(args.cms_impl, args.kll_impl, args.cmwh_impl) + .expect("sketch backend already initialised"); + + let seed = 0xC0FFEE_u64; + let cms_mode = if matches!(args.cms_impl, ImplMode::Legacy) { + "Legacy" + } else { + "asap_sketchlib" + }; + let cmwh_mode = if matches!(args.cmwh_impl, ImplMode::Legacy) { + "Legacy" + } else { + "asap_sketchlib" + }; + let kll_mode = if matches!(args.kll_impl, ImplMode::Legacy) { + "Legacy" + } else { + "asap_sketchlib" + }; + + // CountMinSketch: multiple (depth, width, n, domain) + let cms_param_sets: Vec = vec![ + CmsParams { + depth: 3, + width: 1024, + n: 100_000, + domain: 1000, + }, + CmsParams { + depth: 5, + width: 2048, + n: 200_000, + domain: 2000, + }, + CmsParams { + depth: 7, + width: 4096, + n: 200_000, + domain: 2000, + }, + CmsParams { + depth: 5, + width: 2048, + n: 50_000, + domain: 500, + }, + ]; + + println!("## CountMinSketch ({cms_mode})"); + println!("| depth | width | n_updates | domain | Pearson corr | MAPE (%) | RMSE (%) |"); + println!("|-------|-------|------------|--------|--------------|----------|----------|"); + for p in &cms_param_sets { + let r = run_countmin_once(seed, p); + println!( + "| {} | {} | {} | {} | {:.10} | {:.6} | {:.6} |", + p.depth, p.width, p.n, p.domain, r.pearson, r.mape, r.rmse + ); + } + + // CountMinSketchWithHeap + let cmwh_param_sets: Vec = vec![ + CmwhParams { + depth: 3, + width: 1024, + n: 100_000, + domain: 1000, + heap_size: 10, + }, + CmwhParams { + depth: 5, + width: 2048, + n: 200_000, + domain: 2000, + heap_size: 20, + }, + CmwhParams { + depth: 5, + width: 2048, + n: 200_000, + domain: 2000, + heap_size: 50, + }, + ]; + + println!("\n## CountMinSketchWithHeap ({cmwh_mode})"); + println!("| depth | width | n | domain | heap_size | Top-k recall | Pearson (top-k) | MAPE (%) | RMSE (%) |"); + println!("|-------|-------|-----|--------|-----------|--------------|-----------------|----------|----------|"); + for p in &cmwh_param_sets { + let r = run_countmin_with_heap_once(seed, p); + println!( + "| {} | {} | {} | {} | {} | {:.4} | {:.10} | {:.6} | {:.6} |", + p.depth, p.width, p.n, p.domain, p.heap_size, r.topk_recall, r.pearson, r.mape, r.rmse + ); + } + // KllSketch + let kll_param_sets: Vec = vec![ + KllParams { k: 20, n: 200_000 }, + KllParams { k: 50, n: 200_000 }, + KllParams { k: 200, n: 200_000 }, + KllParams { k: 20, n: 50_000 }, + ]; + + println!("\n## KllSketch ({kll_mode})"); + println!( + "| k | n_updates | q=0.5 abs_rank_error | q=0.9 abs_rank_error | q=0.99 abs_rank_error |" + ); + println!( + "|---|-----------|----------------------|----------------------|-----------------------|" + ); + for p in &kll_param_sets { + let r = run_kll_once(seed, p); + println!( + "| {} | {} | {:.6} | {:.6} | {:.6} |", + p.k, p.n, r.rank_err_50, r.rank_err_90, r.rank_err_99 + ); + } + + // HydraKllSketch + let hydra_param_sets: Vec = vec![ + HydraKllParams { + rows: 2, + cols: 64, + k: 20, + n: 200_000, + domain: 200, + eval_keys: 50, + }, + HydraKllParams { + rows: 3, + cols: 128, + k: 20, + n: 200_000, + domain: 200, + eval_keys: 50, + }, + HydraKllParams { + rows: 3, + cols: 128, + k: 50, + n: 200_000, + domain: 200, + eval_keys: 50, + }, + HydraKllParams { + rows: 3, + cols: 128, + k: 20, + n: 100_000, + domain: 100, + eval_keys: 50, + }, + ]; + + println!("\n## HydraKllSketch ({kll_mode})"); + println!("| rows | cols | k | n | domain | q=0.5 mean/max | q=0.9 mean/max |"); + println!("|------|------|---|-----|--------|----------------|----------------|"); + for p in &hydra_param_sets { + let r = run_hydra_kll_once(seed, p); + println!( + "| {} | {} | {} | {} | {} | {:.5} / {:.5} | {:.5} / {:.5} |", + p.rows, p.cols, p.k, p.n, p.domain, r.mean_50, r.max_50, r.mean_90, r.max_90 + ); + } +} diff --git a/asap-common/sketch-core/src/config.rs b/asap-common/sketch-core/src/config.rs new file mode 100644 index 0000000..b7c6abc --- /dev/null +++ b/asap-common/sketch-core/src/config.rs @@ -0,0 +1,83 @@ +use std::sync::OnceLock; + +/// Implementation mode for sketch-core internals. +#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)] +pub enum ImplMode { + /// Use the original hand-written implementations. + Legacy, + /// Use asap_sketchlib backed implementations. + Sketchlib, +} + +/// Global default when impl mode is not explicitly configured (e.g. env var parsing). +pub const DEFAULT_IMPL_MODE: ImplMode = ImplMode::Sketchlib; + +/// Per-backend defaults. Used when configure() has not been called. +pub const DEFAULT_CMS_IMPL: ImplMode = ImplMode::Sketchlib; +pub const DEFAULT_KLL_IMPL: ImplMode = ImplMode::Sketchlib; +pub const DEFAULT_CMWH_IMPL: ImplMode = ImplMode::Sketchlib; + +static COUNTMIN_MODE: OnceLock = OnceLock::new(); + +/// Returns true if Count-Min operations should use asap_sketchlib internally. +pub fn use_sketchlib_for_count_min() -> bool { + *COUNTMIN_MODE.get_or_init(|| DEFAULT_CMS_IMPL) == ImplMode::Sketchlib +} + +static KLL_MODE: OnceLock = OnceLock::new(); + +/// Returns true if KLL operations should use asap_sketchlib internally. +pub fn use_sketchlib_for_kll() -> bool { + *KLL_MODE.get_or_init(|| DEFAULT_KLL_IMPL) == ImplMode::Sketchlib +} + +static COUNTMIN_WITH_HEAP_MODE: OnceLock = OnceLock::new(); + +/// Returns true if Count-Min-With-Heap operations should use asap_sketchlib internally for the +/// Count-Min portion. +pub fn use_sketchlib_for_count_min_with_heap() -> bool { + *COUNTMIN_WITH_HEAP_MODE.get_or_init(|| DEFAULT_CMWH_IMPL) == ImplMode::Sketchlib +} + +/// Set backend modes for all sketch types. Call once at process startup, +/// before any sketch operation. Returns Err if any OnceLock was already set. +pub fn configure(cms: ImplMode, kll: ImplMode, cmwh: ImplMode) -> Result<(), &'static str> { + let a = COUNTMIN_MODE.set(cms); + let b = KLL_MODE.set(kll); + let c = COUNTMIN_WITH_HEAP_MODE.set(cmwh); + if a.is_err() || b.is_err() || c.is_err() { + Err("configure() called after sketch backends were already initialised") + } else { + Ok(()) + } +} + +pub fn force_legacy_mode_for_tests() { + let _ = COUNTMIN_MODE.set(ImplMode::Legacy); + let _ = KLL_MODE.set(ImplMode::Legacy); + let _ = COUNTMIN_WITH_HEAP_MODE.set(ImplMode::Legacy); +} + +/// Helper used by UDF templates and documentation examples to parse implementation mode +/// from environment variables in a robust way. This is not used in the hot path. +pub fn parse_mode(var: Result) -> ImplMode { + match var { + Ok(v) => match v.to_ascii_lowercase().as_str() { + "legacy" => ImplMode::Legacy, + "sketchlib" => ImplMode::Sketchlib, + other => { + eprintln!( + "sketch-core: unrecognised IMPL value {other:?}, defaulting to {DEFAULT_IMPL_MODE:?}" + ); + DEFAULT_IMPL_MODE + } + }, + Err(std::env::VarError::NotPresent) => DEFAULT_IMPL_MODE, + Err(std::env::VarError::NotUnicode(v)) => { + eprintln!( + "sketch-core: IMPL env var has invalid UTF-8 ({v:?}), defaulting to {DEFAULT_IMPL_MODE:?}" + ); + DEFAULT_IMPL_MODE + } + } +} diff --git a/asap-common/sketch-core/src/count_min.rs b/asap-common/sketch-core/src/count_min.rs new file mode 100644 index 0000000..d0ebc13 --- /dev/null +++ b/asap-common/sketch-core/src/count_min.rs @@ -0,0 +1,423 @@ +// Adapted from QueryEngineRust/src/precompute_operators/count_min_sketch_accumulator.rs +// Changes: +// - Renamed CountMinSketchAccumulator -> CountMinSketch +// - _update(&KeyByLabelValues) -> pub update(&str) (caller does key-to-string conversion) +// - query_key(&KeyByLabelValues) -> query_key(&str) +// - serialize_to_bytes (trait) -> serialize_msgpack (inherent method) +// - deserialize_from_bytes_arroyo -> deserialize_msgpack +// - merge_accumulators -> merge +// - Removed: deserialize_from_json, deserialize_from_bytes (legacy QE formats, stay in QE) +// - Removed: merge_multiple (QE trait-object helper, stays in QE) +// - Removed: AggregateCore, SerializableToSink, MergeableAccumulator, MultipleSubpopulationAggregate impls +// - Added: aggregate_count() / aggregate_sum() one-shot helpers for Arroyo call pattern + +use serde::{Deserialize, Serialize}; +use xxhash_rust::xxh32::xxh32; + +use crate::config::use_sketchlib_for_count_min; +use crate::count_min_sketchlib::{ + matrix_from_sketchlib_cms, new_sketchlib_cms, sketchlib_cms_from_matrix, sketchlib_cms_query, + sketchlib_cms_update, SketchlibCms, +}; + +#[derive(Serialize, Deserialize)] +struct WireFormat { + sketch: Vec>, + row_num: usize, + col_num: usize, +} + +/// Backend implementation for Count-Min Sketch. Only one is active at a time. +#[derive(Debug, Clone)] +pub enum CountMinBackend { + /// Original hand-written matrix implementation. + Legacy(Vec>), + /// asap_sketchlib backed implementation. + Sketchlib(SketchlibCms), +} + +/// Count-Min Sketch probabilistic data structure for frequency counting. +/// Provides approximate frequency counts with error bounds. +/// This is the canonical shared implementation; the msgpack wire format is the +/// contract between Arroyo UDAFs (producers) and QueryEngineRust (consumer). +#[derive(Debug, Clone)] +pub struct CountMinSketch { + pub row_num: usize, + pub col_num: usize, + pub backend: CountMinBackend, +} + +impl CountMinSketch { + pub fn new(row_num: usize, col_num: usize) -> Self { + let backend = if use_sketchlib_for_count_min() { + CountMinBackend::Sketchlib(new_sketchlib_cms(row_num, col_num)) + } else { + CountMinBackend::Legacy(vec![vec![0.0; col_num]; row_num]) + }; + Self { + row_num, + col_num, + backend, + } + } + + /// Returns the sketch matrix (for wire format, serialization, tests). + pub fn sketch(&self) -> Vec> { + match &self.backend { + CountMinBackend::Legacy(m) => m.clone(), + CountMinBackend::Sketchlib(s) => matrix_from_sketchlib_cms(s), + } + } + + /// Mutable access to the matrix. Only `Some` for Legacy backend. + pub fn sketch_mut(&mut self) -> Option<&mut Vec>> { + match &mut self.backend { + CountMinBackend::Legacy(m) => Some(m), + CountMinBackend::Sketchlib(_) => None, + } + } + + /// Construct from a legacy matrix (used by deserialization and query engine). + pub fn from_legacy_matrix(sketch: Vec>, row_num: usize, col_num: usize) -> Self { + let backend = if use_sketchlib_for_count_min() { + CountMinBackend::Sketchlib(sketchlib_cms_from_matrix(row_num, col_num, &sketch)) + } else { + CountMinBackend::Legacy(sketch) + }; + Self { + row_num, + col_num, + backend, + } + } + + pub fn update(&mut self, key: &str, value: f64) { + match &mut self.backend { + CountMinBackend::Legacy(sketch) => { + let key_bytes = key.as_bytes(); + for (i, row) in sketch.iter_mut().enumerate().take(self.row_num) { + let hash_value = xxh32(key_bytes, i as u32); + let col_index = (hash_value as usize) % self.col_num; + row[col_index] += value; + } + } + CountMinBackend::Sketchlib(s) => { + sketchlib_cms_update(s, key, value); + } + } + } + + pub fn query_key(&self, key: &str) -> f64 { + match &self.backend { + CountMinBackend::Legacy(sketch) => { + let key_bytes = key.as_bytes(); + let mut min_value = f64::MAX; + for (i, row) in sketch.iter().enumerate().take(self.row_num) { + let hash_value = xxh32(key_bytes, i as u32); + let col_index = (hash_value as usize) % self.col_num; + min_value = min_value.min(row[col_index]); + } + min_value + } + CountMinBackend::Sketchlib(s) => sketchlib_cms_query(s, key), + } + } + + pub fn merge( + accumulators: Vec, + ) -> Result> { + if accumulators.is_empty() { + return Err("No accumulators to merge".into()); + } + + if accumulators.len() == 1 { + return Ok(accumulators.into_iter().next().unwrap()); + } + + // Check that all accumulators have the same dimensions + let row_num = accumulators[0].row_num; + let col_num = accumulators[0].col_num; + + for acc in &accumulators { + if acc.row_num != row_num || acc.col_num != col_num { + return Err( + "Cannot merge CountMinSketch accumulators with different dimensions".into(), + ); + } + } + + if use_sketchlib_for_count_min() { + let mut sketchlib_inners: Vec = Vec::with_capacity(accumulators.len()); + for acc in accumulators { + let matrix = acc.sketch(); + let inner = sketchlib_cms_from_matrix(acc.row_num, acc.col_num, &matrix); + sketchlib_inners.push(inner); + } + let merged_sketchlib = sketchlib_inners + .into_iter() + .reduce(|mut lhs: SketchlibCms, rhs: SketchlibCms| { + lhs.merge(&rhs); + lhs + }) + .ok_or("No accumulators to merge")?; + + let sketch = matrix_from_sketchlib_cms(&merged_sketchlib); + let row_num = sketch.len(); + let col_num = sketch.first().map(|r| r.len()).unwrap_or(0); + + Ok(Self { + row_num, + col_num, + backend: CountMinBackend::Sketchlib(merged_sketchlib), + }) + } else { + let mut merged = accumulators[0].clone(); + for acc in &accumulators[1..] { + let acc_matrix = acc.sketch(); + if let CountMinBackend::Legacy(merged_matrix) = &mut merged.backend { + for (merged_row, acc_row) in merged_matrix.iter_mut().zip(acc_matrix.iter()) { + for (m_cell, a_cell) in merged_row.iter_mut().zip(acc_row.iter()) { + *m_cell += *a_cell; + } + } + } + } + Ok(merged) + } + } + + /// Merge from references, allocating only the output — no input clones. + pub fn merge_refs( + accumulators: &[&Self], + ) -> Result> { + if accumulators.is_empty() { + return Err("No accumulators to merge".into()); + } + + let row_num = accumulators[0].row_num; + let col_num = accumulators[0].col_num; + + for acc in accumulators { + if acc.row_num != row_num || acc.col_num != col_num { + return Err( + "Cannot merge CountMinSketch accumulators with different dimensions".into(), + ); + } + } + + if use_sketchlib_for_count_min() { + let mut sketchlib_inners: Vec = Vec::with_capacity(accumulators.len()); + for acc in accumulators { + let acc_matrix = acc.sketch(); + let matrix_has_values = acc_matrix + .iter() + .any(|row: &Vec| row.iter().any(|&v| v != 0.0)); + + let inner = if matrix_has_values { + sketchlib_cms_from_matrix(acc.row_num, acc.col_num, &acc_matrix) + } else if let CountMinBackend::Sketchlib(s) = &acc.backend { + s.clone() + } else { + sketchlib_cms_from_matrix(acc.row_num, acc.col_num, &acc_matrix) + }; + + sketchlib_inners.push(inner); + } + + let merged_sketchlib = sketchlib_inners + .into_iter() + .reduce(|mut lhs: SketchlibCms, rhs: SketchlibCms| { + lhs.merge(&rhs); + lhs + }) + .ok_or("No accumulators to merge")?; + + let sketch = matrix_from_sketchlib_cms(&merged_sketchlib); + let r = sketch.len(); + let c = sketch.first().map(|row| row.len()).unwrap_or(0); + + Ok(Self { + row_num: r, + col_num: c, + backend: CountMinBackend::Sketchlib(merged_sketchlib), + }) + } else { + let mut merged = Self::new(row_num, col_num); + if let CountMinBackend::Legacy(ref mut merged_sketch) = merged.backend { + for acc in accumulators { + let acc_matrix = acc.sketch(); + for (merged_row, acc_row) in merged_sketch.iter_mut().zip(acc_matrix.iter()) { + for (m_cell, a_cell) in merged_row.iter_mut().zip(acc_row.iter()) { + *m_cell += *a_cell; + } + } + } + } + Ok(merged) + } + } + + /// Serialize to MessagePack — matches the Arroyo UDF wire format exactly. + pub fn serialize_msgpack(&self) -> Vec { + let sketch = self.sketch(); + let wire = WireFormat { + sketch, + row_num: self.row_num, + col_num: self.col_num, + }; + + let mut buf = Vec::new(); + wire.serialize(&mut rmp_serde::Serializer::new(&mut buf)) + .unwrap(); + buf + } + + /// Deserialize from MessagePack produced by the Arroyo UDF. + pub fn deserialize_msgpack(buffer: &[u8]) -> Result> { + let wire: WireFormat = + rmp_serde::from_slice(buffer).map_err(|e| -> Box { + format!("Failed to deserialize CountMinSketch from MessagePack: {e}").into() + })?; + + let backend = if use_sketchlib_for_count_min() { + CountMinBackend::Sketchlib(sketchlib_cms_from_matrix( + wire.row_num, + wire.col_num, + &wire.sketch, + )) + } else { + CountMinBackend::Legacy(wire.sketch) + }; + + Ok(Self { + row_num: wire.row_num, + col_num: wire.col_num, + backend, + }) + } + + /// One-shot aggregation for the Arroyo UDAF call pattern: build a sketch from + /// parallel key/value slices and return the msgpack bytes. + pub fn aggregate_count( + depth: usize, + width: usize, + keys: &[&str], + values: &[f64], + ) -> Option> { + if keys.is_empty() { + return None; + } + let mut sketch = Self::new(depth, width); + for (key, &value) in keys.iter().zip(values.iter()) { + sketch.update(key, value); + } + Some(sketch.serialize_msgpack()) + } + + /// Same as aggregate_count — CMS accumulates sums by construction. + pub fn aggregate_sum( + depth: usize, + width: usize, + keys: &[&str], + values: &[f64], + ) -> Option> { + Self::aggregate_count(depth, width, keys, values) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_count_min_sketch_creation() { + let cms = CountMinSketch::new(4, 1000); + assert_eq!(cms.row_num, 4); + assert_eq!(cms.col_num, 1000); + let sketch = cms.sketch(); + assert_eq!(sketch.len(), 4); + assert_eq!(sketch[0].len(), 1000); + + // Check all values are initialized to 0 + for row in &sketch { + for &value in row { + assert_eq!(value, 0.0); + } + } + } + + #[test] + fn test_count_min_sketch_update() { + let mut cms = CountMinSketch::new(2, 10); + cms.update("key1", 1.0); + // Query should return at least the updated value + let result = cms.query_key("key1"); + assert!(result >= 1.0); + } + + #[test] + fn test_count_min_sketch_query_empty() { + let cms = CountMinSketch::new(2, 10); + assert_eq!(cms.query_key("anything"), 0.0); + } + + #[test] + fn test_count_min_sketch_merge() { + // Use from_legacy_matrix so the test works regardless of sketchlib/legacy config + let mut sketch1 = vec![vec![0.0; 3]; 2]; + sketch1[0][0] = 5.0; + sketch1[1][2] = 10.0; + let cms1 = CountMinSketch::from_legacy_matrix(sketch1, 2, 3); + + let mut sketch2 = vec![vec![0.0; 3]; 2]; + sketch2[0][0] = 3.0; + sketch2[0][1] = 7.0; + let cms2 = CountMinSketch::from_legacy_matrix(sketch2, 2, 3); + + let merged = CountMinSketch::merge(vec![cms1, cms2]).unwrap(); + let merged_sketch = merged.sketch(); + + assert_eq!(merged_sketch[0][0], 8.0); // 5 + 3 + assert_eq!(merged_sketch[0][1], 7.0); // 0 + 7 + assert_eq!(merged_sketch[1][2], 10.0); // 10 + 0 + } + + #[test] + fn test_count_min_sketch_merge_dimension_mismatch() { + let cms1 = CountMinSketch::new(2, 3); + let cms2 = CountMinSketch::new(3, 3); + assert!(CountMinSketch::merge(vec![cms1, cms2]).is_err()); + } + + #[test] + fn test_count_min_sketch_msgpack_round_trip() { + let mut cms = CountMinSketch::new(4, 256); + cms.update("apple", 5.0); + cms.update("banana", 3.0); + cms.update("apple", 2.0); // total "apple" = 7 + + let bytes = cms.serialize_msgpack(); + let deserialized = CountMinSketch::deserialize_msgpack(&bytes).unwrap(); + + assert_eq!(deserialized.row_num, 4); + assert_eq!(deserialized.col_num, 256); + assert!(deserialized.query_key("apple") >= 7.0); + assert!(deserialized.query_key("banana") >= 3.0); + } + + #[test] + fn test_aggregate_count() { + let keys = ["a", "b", "a"]; + let values = [1.0, 2.0, 3.0]; + let bytes = CountMinSketch::aggregate_count(4, 100, &keys, &values).unwrap(); + let cms = CountMinSketch::deserialize_msgpack(&bytes).unwrap(); + // "a" was updated twice (1.0 + 3.0 = 4.0), "b" once (2.0) + assert!(cms.query_key("a") >= 4.0); + assert!(cms.query_key("b") >= 2.0); + } + + #[test] + fn test_aggregate_count_empty() { + assert!(CountMinSketch::aggregate_count(4, 100, &[], &[]).is_none()); + } +} diff --git a/asap-common/sketch-core/src/count_min_sketchlib.rs b/asap-common/sketch-core/src/count_min_sketchlib.rs new file mode 100644 index 0000000..6950580 --- /dev/null +++ b/asap-common/sketch-core/src/count_min_sketchlib.rs @@ -0,0 +1,59 @@ +use asap_sketchlib::{CountMin, DataInput, RegularPath, Vector2D}; + +/// Concrete Count-Min type from asap_sketchlib when sketchlib backend is enabled. +/// Uses f64 counters (Vector2D) for weighted updates without integer rounding. +pub type SketchlibCms = CountMin, RegularPath>; + +/// Creates a fresh sketchlib Count-Min sketch with the given dimensions. +pub fn new_sketchlib_cms(row_num: usize, col_num: usize) -> SketchlibCms { + SketchlibCms::with_dimensions(row_num, col_num) +} + +/// Builds a sketchlib Count-Min sketch from an existing `sketch` matrix. +pub fn sketchlib_cms_from_matrix( + row_num: usize, + col_num: usize, + sketch: &[Vec], +) -> SketchlibCms { + let matrix = Vector2D::from_fn(row_num, col_num, |r, c| { + sketch + .get(r) + .and_then(|row| row.get(c)) + .copied() + .unwrap_or(0.0) + }); + SketchlibCms::from_storage(matrix) +} + +/// Converts a sketchlib Count-Min sketch into the legacy `Vec>` matrix. +pub fn matrix_from_sketchlib_cms(inner: &SketchlibCms) -> Vec> { + let storage: &Vector2D = inner.as_storage(); + let rows = storage.rows(); + let cols = storage.cols(); + let mut sketch = vec![vec![0.0; cols]; rows]; + + for (r, row) in sketch.iter_mut().enumerate().take(rows) { + for (c, cell) in row.iter_mut().enumerate().take(cols) { + if let Some(v) = storage.get(r, c) { + *cell = *v; + } + } + } + + sketch +} + +/// Helper to update a sketchlib Count-Min with a weighted key. +pub fn sketchlib_cms_update(inner: &mut SketchlibCms, key: &str, value: f64) { + if value <= 0.0 { + return; + } + let input = DataInput::String(key.to_owned()); + inner.insert_many(&input, value); +} + +/// Helper to query a sketchlib Count-Min for a key, returning f64. +pub fn sketchlib_cms_query(inner: &SketchlibCms, key: &str) -> f64 { + let input = DataInput::String(key.to_owned()); + inner.estimate(&input) +} diff --git a/asap-common/sketch-core/src/count_min_with_heap.rs b/asap-common/sketch-core/src/count_min_with_heap.rs new file mode 100644 index 0000000..f02efc7 --- /dev/null +++ b/asap-common/sketch-core/src/count_min_with_heap.rs @@ -0,0 +1,597 @@ +// Adapted from QueryEngineRust/src/precompute_operators/count_min_sketch_with_heap_accumulator.rs +// Changes: +// - Renamed CountMinSketchWithHeapAccumulator -> CountMinSketchWithHeap +// - Inner CmsData helper renamed to avoid name collision with count_min::CountMinSketch +// - update() takes &str instead of &KeyByLabelValues +// - query_key() takes &str +// - serialize_to_bytes (trait) -> serialize_msgpack (inherent method) +// - deserialize_from_bytes_arroyo -> deserialize_msgpack +// - merge_accumulators -> merge +// - Removed: deserialize_from_json, deserialize_from_bytes (legacy QE formats, stay in QE) +// - Removed: AggregateCore, SerializableToSink, MergeableAccumulator, MultipleSubpopulationAggregate impls +// - Removed: get_topk_keys (returns KeyByLabelValues — QE-specific) +// - Added: insert_or_update_heap helper, aggregate_topk() one-shot helper +// - Refactored to enum-based backend (Legacy vs Sketchlib) +// +// NOTE (bug, do not fix): QueryEngineRust uses xxhash-rust::xxh32; the Arroyo template uses +// twox-hash::XxHash32. Bucket assignments differ, so query results will be wrong until the +// hash crate mismatch is resolved. Tracked separately. + +use serde::{Deserialize, Serialize}; +use std::collections::HashSet; +use xxhash_rust::xxh32::xxh32; + +use crate::config::use_sketchlib_for_count_min_with_heap; +use crate::count_min_with_heap_sketchlib::{ + heap_to_wire, matrix_from_sketchlib_cms_heap, new_sketchlib_cms_heap, + sketchlib_cms_heap_from_matrix_and_heap, sketchlib_cms_heap_query, sketchlib_cms_heap_update, + SketchlibCMSHeap, WireHeapItem, +}; + +/// Item in the top-k heap representing a key-value pair. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct HeapItem { + pub key: String, + pub value: f64, +} + +/// Helper struct matching Arroyo's nested serialization format (inner CMS). +#[derive(Debug, Clone, Serialize, Deserialize)] +struct CmsData { + sketch: Vec>, + row_num: usize, + col_num: usize, +} + +/// Helper struct matching Arroyo's serialization format (outer wrapper). +#[derive(Debug, Clone, Serialize, Deserialize)] +struct CountMinSketchWithHeapSerialized { + sketch: CmsData, + topk_heap: Vec, + heap_size: usize, +} + +/// Backend implementation for Count-Min Sketch with Heap. Only one is active at a time. +pub enum CountMinWithHeapBackend { + /// Legacy implementation: matrix + local heap. + Legacy { + sketch: Vec>, + heap: Vec, + }, + /// asap_sketchlib CMSHeap implementation. + Sketchlib(SketchlibCMSHeap), +} + +impl std::fmt::Debug for CountMinWithHeapBackend { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + CountMinWithHeapBackend::Legacy { sketch, heap } => f + .debug_struct("Legacy") + .field("sketch", sketch) + .field("heap", heap) + .finish(), + CountMinWithHeapBackend::Sketchlib(_) => write!(f, "Sketchlib(..)"), + } + } +} + +/// Count-Min Sketch with Heap for top-k tracking. +/// Combines probabilistic frequency counting with efficient top-k maintenance. +pub struct CountMinSketchWithHeap { + pub row_num: usize, + pub col_num: usize, + pub heap_size: usize, + pub backend: CountMinWithHeapBackend, +} + +impl std::fmt::Debug for CountMinSketchWithHeap { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CountMinSketchWithHeap") + .field("row_num", &self.row_num) + .field("col_num", &self.col_num) + .field("heap_size", &self.heap_size) + .field("backend", &self.backend) + .finish() + } +} + +impl Clone for CountMinSketchWithHeap { + fn clone(&self) -> Self { + let backend = match &self.backend { + CountMinWithHeapBackend::Legacy { sketch, heap } => CountMinWithHeapBackend::Legacy { + sketch: sketch.clone(), + heap: heap.clone(), + }, + CountMinWithHeapBackend::Sketchlib(cms_heap) => { + let sketch = matrix_from_sketchlib_cms_heap(cms_heap); + let heap_items: Vec = heap_to_wire(cms_heap) + .into_iter() + .map(|w| HeapItem { + key: w.key, + value: w.value, + }) + .collect(); + let wire_ref: Vec = heap_items + .iter() + .map(|h| WireHeapItem { + key: h.key.clone(), + value: h.value, + }) + .collect(); + CountMinWithHeapBackend::Sketchlib(sketchlib_cms_heap_from_matrix_and_heap( + self.row_num, + self.col_num, + self.heap_size, + &sketch, + &wire_ref, + )) + } + }; + Self { + row_num: self.row_num, + col_num: self.col_num, + heap_size: self.heap_size, + backend, + } + } +} + +impl CountMinSketchWithHeap { + pub fn new(row_num: usize, col_num: usize, heap_size: usize) -> Self { + let backend = if use_sketchlib_for_count_min_with_heap() { + CountMinWithHeapBackend::Sketchlib(new_sketchlib_cms_heap(row_num, col_num, heap_size)) + } else { + CountMinWithHeapBackend::Legacy { + sketch: vec![vec![0.0; col_num]; row_num], + heap: Vec::new(), + } + }; + Self { + row_num, + col_num, + heap_size, + backend, + } + } + + /// Create from legacy matrix and heap (e.g. from JSON deserialization). + pub fn from_legacy_matrix( + sketch: Vec>, + topk_heap: Vec, + row_num: usize, + col_num: usize, + heap_size: usize, + ) -> Self { + Self { + row_num, + col_num, + heap_size, + backend: CountMinWithHeapBackend::Legacy { + sketch, + heap: topk_heap, + }, + } + } + + /// Mutable reference to the sketch matrix. Only valid for Legacy backend. + pub fn sketch_mut(&mut self) -> Option<&mut Vec>> { + match &mut self.backend { + CountMinWithHeapBackend::Legacy { sketch, .. } => Some(sketch), + CountMinWithHeapBackend::Sketchlib(_) => None, + } + } + + /// Get the top-k heap items (works for both backends). + pub fn topk_heap_items(&self) -> Vec { + match &self.backend { + CountMinWithHeapBackend::Legacy { heap, .. } => heap.clone(), + CountMinWithHeapBackend::Sketchlib(cms_heap) => heap_to_wire(cms_heap) + .into_iter() + .map(|w| HeapItem { + key: w.key, + value: w.value, + }) + .collect(), + } + } + + /// Get the sketch matrix (works for both backends). + pub fn sketch_matrix(&self) -> Vec> { + match &self.backend { + CountMinWithHeapBackend::Legacy { sketch, .. } => sketch.clone(), + CountMinWithHeapBackend::Sketchlib(cms_heap) => { + matrix_from_sketchlib_cms_heap(cms_heap) + } + } + } + + pub fn update(&mut self, key: &str, value: f64) { + match &mut self.backend { + CountMinWithHeapBackend::Legacy { sketch, heap } => { + let key_bytes = key.as_bytes(); + for (i, row) in sketch.iter_mut().enumerate().take(self.row_num) { + let hash_value = xxh32(key_bytes, i as u32); + let col_index = (hash_value as usize) % self.col_num; + row[col_index] += value; + } + Self::insert_or_update_heap_inline(heap, key, value, self.heap_size); + } + CountMinWithHeapBackend::Sketchlib(cms_heap) => { + sketchlib_cms_heap_update(cms_heap, key, value); + } + } + } + + fn insert_or_update_heap_inline( + heap: &mut Vec, + key: &str, + value: f64, + heap_size: usize, + ) { + if let Some(item) = heap.iter_mut().find(|i| i.key == key) { + item.value += value; + } else if heap.len() < heap_size { + heap.push(HeapItem { + key: key.to_string(), + value, + }); + } else if let Some(min_item) = heap.iter_mut().min_by(|a, b| { + a.value + .partial_cmp(&b.value) + .unwrap_or(std::cmp::Ordering::Equal) + }) { + if value > min_item.value { + *min_item = HeapItem { + key: key.to_string(), + value, + }; + } + } + } + + pub fn query_key(&self, key: &str) -> f64 { + match &self.backend { + CountMinWithHeapBackend::Legacy { sketch, .. } => { + let key_bytes = key.as_bytes(); + let mut min_value = f64::MAX; + for (i, row) in sketch.iter().enumerate().take(self.row_num) { + let hash_value = xxh32(key_bytes, i as u32); + let col_index = (hash_value as usize) % self.col_num; + min_value = min_value.min(row[col_index]); + } + min_value + } + CountMinWithHeapBackend::Sketchlib(cms_heap) => sketchlib_cms_heap_query(cms_heap, key), + } + } + + pub fn merge( + accumulators: Vec, + ) -> Result> { + if accumulators.is_empty() { + return Err("No accumulators to merge".into()); + } + + if accumulators.len() == 1 { + return Ok(accumulators.into_iter().next().unwrap()); + } + + let row_num = accumulators[0].row_num; + let col_num = accumulators[0].col_num; + + for acc in &accumulators { + if acc.row_num != row_num || acc.col_num != col_num { + return Err( + "Cannot merge CountMinSketchWithHeap accumulators with different dimensions" + .into(), + ); + } + } + + let min_heap_size = accumulators + .iter() + .map(|acc| acc.heap_size) + .min() + .unwrap_or(0); + + let mut all_keys: HashSet = HashSet::new(); + for acc in &accumulators { + for item in acc.topk_heap_items() { + all_keys.insert(item.key); + } + } + + match &accumulators[0].backend { + CountMinWithHeapBackend::Sketchlib(_) => { + let mut sketchlib_cms_heaps: Vec = + Vec::with_capacity(accumulators.len()); + for acc in accumulators { + let (sketch, heap) = match &acc.backend { + CountMinWithHeapBackend::Legacy { sketch, heap } => { + (sketch.clone(), heap.clone()) + } + CountMinWithHeapBackend::Sketchlib(cms_heap) => ( + matrix_from_sketchlib_cms_heap(cms_heap), + heap_to_wire(cms_heap) + .into_iter() + .map(|w| HeapItem { + key: w.key, + value: w.value, + }) + .collect(), + ), + }; + let wire_heap: Vec = heap + .iter() + .map(|h| WireHeapItem { + key: h.key.clone(), + value: h.value, + }) + .collect(); + sketchlib_cms_heaps.push(sketchlib_cms_heap_from_matrix_and_heap( + acc.row_num, + acc.col_num, + acc.heap_size, + &sketch, + &wire_heap, + )); + } + + let merged_sketchlib = sketchlib_cms_heaps + .into_iter() + .reduce(|mut lhs, rhs| { + lhs.merge(&rhs); + lhs + }) + .ok_or("No accumulators to merge")?; + + let _merged_sketch = matrix_from_sketchlib_cms_heap(&merged_sketchlib); + let _heap_items: Vec = heap_to_wire(&merged_sketchlib) + .into_iter() + .map(|w| HeapItem { + key: w.key, + value: w.value, + }) + .collect(); + + Ok(CountMinSketchWithHeap { + row_num, + col_num, + heap_size: min_heap_size, + backend: CountMinWithHeapBackend::Sketchlib(merged_sketchlib), + }) + } + CountMinWithHeapBackend::Legacy { .. } => { + let mut merged_sketch = vec![vec![0.0; col_num]; row_num]; + for acc in &accumulators { + let sketch = match &acc.backend { + CountMinWithHeapBackend::Legacy { sketch, .. } => sketch, + CountMinWithHeapBackend::Sketchlib(_) => { + return Err( + "Cannot mix Legacy and Sketchlib backends when merging".into() + ); + } + }; + for (i, row) in merged_sketch.iter_mut().enumerate() { + for (j, cell) in row.iter_mut().enumerate() { + *cell += sketch[i][j]; + } + } + } + + let temp_merged = Self::from_legacy_matrix( + merged_sketch.clone(), + Vec::new(), + row_num, + col_num, + min_heap_size, + ); + + let mut heap_items: Vec = all_keys + .into_iter() + .map(|key_str| { + let frequency = temp_merged.query_key(&key_str); + HeapItem { + key: key_str, + value: frequency, + } + }) + .collect(); + + heap_items.sort_by(|a, b| b.value.partial_cmp(&a.value).unwrap()); + heap_items.truncate(min_heap_size); + + Ok(CountMinSketchWithHeap { + row_num, + col_num, + heap_size: min_heap_size, + backend: CountMinWithHeapBackend::Legacy { + sketch: merged_sketch, + heap: heap_items, + }, + }) + } + } + } + + pub fn serialize_msgpack(&self) -> Vec { + let (sketch, topk_heap) = (self.sketch_matrix(), self.topk_heap_items()); + + let serialized = CountMinSketchWithHeapSerialized { + sketch: CmsData { + sketch, + row_num: self.row_num, + col_num: self.col_num, + }, + topk_heap, + heap_size: self.heap_size, + }; + + let mut buf = Vec::new(); + serialized + .serialize(&mut rmp_serde::Serializer::new(&mut buf)) + .unwrap(); + buf + } + + pub fn deserialize_msgpack(buffer: &[u8]) -> Result> { + let serialized: CountMinSketchWithHeapSerialized = + rmp_serde::from_slice(buffer).map_err(|e| { + format!("Failed to deserialize CountMinSketchWithHeap from MessagePack: {e}") + })?; + + let mut sorted_topk_heap = serialized.topk_heap; + sorted_topk_heap.sort_by(|a, b| b.value.partial_cmp(&a.value).unwrap()); + + let backend = if use_sketchlib_for_count_min_with_heap() { + let wire_heap: Vec = sorted_topk_heap + .iter() + .map(|h| WireHeapItem { + key: h.key.clone(), + value: h.value, + }) + .collect(); + CountMinWithHeapBackend::Sketchlib(sketchlib_cms_heap_from_matrix_and_heap( + serialized.sketch.row_num, + serialized.sketch.col_num, + serialized.heap_size, + &serialized.sketch.sketch, + &wire_heap, + )) + } else { + CountMinWithHeapBackend::Legacy { + sketch: serialized.sketch.sketch, + heap: sorted_topk_heap, + } + }; + + Ok(Self { + row_num: serialized.sketch.row_num, + col_num: serialized.sketch.col_num, + heap_size: serialized.heap_size, + backend, + }) + } + + pub fn aggregate_topk( + row_num: usize, + col_num: usize, + heap_size: usize, + keys: &[&str], + values: &[f64], + ) -> Option> { + if keys.is_empty() { + return None; + } + let mut sketch = Self::new(row_num, col_num, heap_size); + for (key, &value) in keys.iter().zip(values.iter()) { + sketch.update(key, value); + } + Some(sketch.serialize_msgpack()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_creation() { + let cms = CountMinSketchWithHeap::new(4, 1000, 20); + assert_eq!(cms.row_num, 4); + assert_eq!(cms.col_num, 1000); + assert_eq!(cms.heap_size, 20); + assert_eq!(cms.sketch_matrix().len(), 4); + assert_eq!(cms.sketch_matrix()[0].len(), 1000); + assert_eq!(cms.topk_heap_items().len(), 0); + } + + #[test] + fn test_query_empty() { + let cms = CountMinSketchWithHeap::new(2, 10, 5); + assert_eq!(cms.query_key("anything"), 0.0); + } + + #[test] + fn test_merge() { + let mut cms1 = CountMinSketchWithHeap::new(2, 10, 5); + let mut cms2 = CountMinSketchWithHeap::new(2, 10, 3); + + if let Some(sketch) = cms1.sketch_mut() { + sketch[0][0] = 10.0; + sketch[1][1] = 20.0; + } + if let Some(sketch) = cms2.sketch_mut() { + sketch[0][0] = 5.0; + sketch[1][1] = 15.0; + } + if let CountMinWithHeapBackend::Legacy { heap, .. } = &mut cms1.backend { + heap.push(HeapItem { + key: "key1".to_string(), + value: 100.0, + }); + heap.push(HeapItem { + key: "key2".to_string(), + value: 50.0, + }); + } + if let CountMinWithHeapBackend::Legacy { heap, .. } = &mut cms2.backend { + heap.push(HeapItem { + key: "key3".to_string(), + value: 75.0, + }); + heap.push(HeapItem { + key: "key1".to_string(), + value: 80.0, + }); + } + + let merged = CountMinSketchWithHeap::merge(vec![cms1, cms2]).unwrap(); + + assert_eq!(merged.sketch_matrix()[0][0], 15.0); + assert_eq!(merged.sketch_matrix()[1][1], 35.0); + assert_eq!(merged.heap_size, 3); + assert!(merged.topk_heap_items().len() <= 3); + } + + #[test] + fn test_merge_dimension_mismatch() { + let cms1 = CountMinSketchWithHeap::new(2, 10, 5); + let cms2 = CountMinSketchWithHeap::new(3, 10, 5); + assert!(CountMinSketchWithHeap::merge(vec![cms1, cms2]).is_err()); + } + + #[test] + fn test_msgpack_round_trip() { + let mut cms = CountMinSketchWithHeap::new(4, 128, 3); + cms.update("hot", 100.0); + cms.update("cold", 1.0); + + let bytes = cms.serialize_msgpack(); + let deserialized = CountMinSketchWithHeap::deserialize_msgpack(&bytes).unwrap(); + + assert_eq!(deserialized.row_num, 4); + assert_eq!(deserialized.col_num, 128); + assert_eq!(deserialized.heap_size, 3); + assert!(!deserialized.topk_heap_items().is_empty()); + assert_eq!(deserialized.topk_heap_items()[0].key, "hot"); + assert!(deserialized.topk_heap_items()[0].value >= 100.0); + assert!(deserialized.query_key("hot") >= 100.0); + assert!(deserialized.query_key("cold") >= 1.0); + } + + #[test] + fn test_aggregate_topk() { + let keys = ["a", "b", "a", "c"]; + let values = [1.0, 2.0, 3.0, 0.5]; + let bytes = CountMinSketchWithHeap::aggregate_topk(4, 100, 2, &keys, &values).unwrap(); + let cms = CountMinSketchWithHeap::deserialize_msgpack(&bytes).unwrap(); + assert_eq!(cms.heap_size, 2); + assert!(cms.topk_heap_items().len() <= 2); + } + + #[test] + fn test_aggregate_topk_empty() { + assert!(CountMinSketchWithHeap::aggregate_topk(4, 100, 10, &[], &[]).is_none()); + } +} diff --git a/asap-common/sketch-core/src/count_min_with_heap_sketchlib.rs b/asap-common/sketch-core/src/count_min_with_heap_sketchlib.rs new file mode 100644 index 0000000..78c1415 --- /dev/null +++ b/asap-common/sketch-core/src/count_min_with_heap_sketchlib.rs @@ -0,0 +1,109 @@ +//! asap_sketchlib CMSHeap integration for CountMinSketchWithHeap. +//! +//! Uses CMSHeap (CountMin + HHHeap) from asap_sketchlib instead of CountMin + local heap, +//! providing automatic top-k tracking during insert and merge. + +use asap_sketchlib::RegularPath; +use asap_sketchlib::{CMSHeap, DataInput, Vector2D}; + +/// Wire-format heap item (key, value) to avoid circular dependency with count_min_with_heap. +pub struct WireHeapItem { + pub key: String, + pub value: f64, +} + +/// Concrete Count-Min-with-Heap type from asap_sketchlib (CMS + HHHeap). +pub type SketchlibCMSHeap = CMSHeap, RegularPath>; + +/// Creates a fresh CMSHeap with the given dimensions and heap capacity. +pub fn new_sketchlib_cms_heap( + row_num: usize, + col_num: usize, + heap_size: usize, +) -> SketchlibCMSHeap { + CMSHeap::new(row_num, col_num, heap_size) +} + +/// Builds a CMSHeap from an existing sketch matrix and optional heap items. +/// Used when deserializing or when ensuring sketchlib from legacy state. +pub fn sketchlib_cms_heap_from_matrix_and_heap( + row_num: usize, + col_num: usize, + heap_size: usize, + sketch: &[Vec], + topk_heap: &[WireHeapItem], +) -> SketchlibCMSHeap { + let matrix = Vector2D::from_fn(row_num, col_num, |r, c| { + sketch + .get(r) + .and_then(|row| row.get(c)) + .copied() + .unwrap_or(0.0) + .round() as i64 + }); + let mut cms_heap = CMSHeap::from_storage(matrix, heap_size); + + // Populate the heap from wire-format topk_heap + for item in topk_heap { + let count = item.value.round() as i64; + if count > 0 { + let input = DataInput::Str(&item.key); + cms_heap.heap_mut().update(&input, count); + } + } + + cms_heap +} + +/// Converts a CMSHeap's storage into the legacy `Vec>` matrix. +pub fn matrix_from_sketchlib_cms_heap(cms_heap: &SketchlibCMSHeap) -> Vec> { + let storage = cms_heap.cms().as_storage(); + let rows = storage.rows(); + let cols = storage.cols(); + let mut sketch = vec![vec![0.0; cols]; rows]; + + for (r, row) in sketch.iter_mut().enumerate().take(rows) { + for (c, cell) in row.iter_mut().enumerate().take(cols) { + if let Some(v) = storage.get(r, c) { + *cell = *v as f64; + } + } + } + + sketch +} + +/// Converts sketchlib HHHeap items to wire-format (key, value) pairs. +pub fn heap_to_wire(cms_heap: &SketchlibCMSHeap) -> Vec { + cms_heap + .heap() + .heap() + .iter() + .map(|hh_item| { + let key = match &hh_item.key { + asap_sketchlib::HeapItem::String(s) => s.clone(), + other => format!("{:?}", other), + }; + WireHeapItem { + key, + value: hh_item.count as f64, + } + }) + .collect() +} + +/// Updates a CMSHeap with a weighted key. Automatically updates the heap. +pub fn sketchlib_cms_heap_update(cms_heap: &mut SketchlibCMSHeap, key: &str, value: f64) { + let many = value.round() as i64; + if many <= 0 { + return; + } + let input = DataInput::String(key.to_owned()); + cms_heap.insert_many(&input, many); +} + +/// Queries a CMSHeap for a key's frequency estimate. +pub fn sketchlib_cms_heap_query(cms_heap: &SketchlibCMSHeap, key: &str) -> f64 { + let input = DataInput::String(key.to_owned()); + cms_heap.estimate(&input) as f64 +} diff --git a/asap-common/sketch-core/src/delta_set_aggregator.rs b/asap-common/sketch-core/src/delta_set_aggregator.rs new file mode 100644 index 0000000..c086e2a --- /dev/null +++ b/asap-common/sketch-core/src/delta_set_aggregator.rs @@ -0,0 +1,71 @@ +// Adapted from QueryEngineRust/src/precompute_operators/delta_set_aggregator_accumulator.rs +// Changes: +// - Only the wire format (DeltaResult) and its serialize/deserialize functions are extracted. +// - The Arroyo side uses lazy_static for stateful window tracking — that streaming logic +// stays in the Arroyo template and does NOT belong in sketch-core. +// - DeltaResult made pub (was private inline struct in QE). +// - serialize_msgpack / deserialize_msgpack are module-level free functions +// (not methods on DeltaSetAggregatorAccumulator, which stays in QE). +// - Removed: all QE accumulator struct/impls (stay in QE) + +use serde::{Deserialize, Serialize}; +use std::collections::HashSet; + +/// Wire format for the delta set aggregator — shared between Arroyo and QueryEngineRust. +/// Both sides agree on `{ added: HashSet, removed: HashSet }` in msgpack. +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct DeltaResult { + pub added: HashSet, + pub removed: HashSet, +} + +/// Serialize a delta result to MessagePack. +pub fn serialize_msgpack(added: &HashSet, removed: &HashSet) -> Vec { + let result = DeltaResult { + added: added.clone(), + removed: removed.clone(), + }; + let mut buf = Vec::new(); + rmp_serde::encode::write(&mut buf, &result).unwrap(); + buf +} + +/// Deserialize a delta result from MessagePack produced by the Arroyo UDF. +pub fn deserialize_msgpack(buffer: &[u8]) -> Result> { + rmp_serde::from_slice(buffer) + .map_err(|e| format!("Failed to deserialize DeltaResult from MessagePack: {e}").into()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_msgpack_round_trip() { + let mut added = HashSet::new(); + added.insert("web".to_string()); + added.insert("api".to_string()); + + let mut removed = HashSet::new(); + removed.insert("db".to_string()); + + let bytes = serialize_msgpack(&added, &removed); + let result = deserialize_msgpack(&bytes).unwrap(); + + assert_eq!(result.added.len(), 2); + assert!(result.added.contains("web")); + assert!(result.added.contains("api")); + assert_eq!(result.removed.len(), 1); + assert!(result.removed.contains("db")); + } + + #[test] + fn test_empty_sets() { + let added = HashSet::new(); + let removed = HashSet::new(); + let bytes = serialize_msgpack(&added, &removed); + let result = deserialize_msgpack(&bytes).unwrap(); + assert!(result.added.is_empty()); + assert!(result.removed.is_empty()); + } +} diff --git a/asap-common/sketch-core/src/hydra_kll.rs b/asap-common/sketch-core/src/hydra_kll.rs new file mode 100644 index 0000000..e6888d5 --- /dev/null +++ b/asap-common/sketch-core/src/hydra_kll.rs @@ -0,0 +1,295 @@ +// Adapted from QueryEngineRust/src/precompute_operators/hydra_kll_accumulator.rs +// Changes: +// - Renamed HydraKllSketchAccumulator -> HydraKllSketch +// - KllSketchData import replaced by crate::kll::{KllSketch, KllSketchData} +// - Inner cells are KllSketch instead of DatasketchesKLLAccumulator +// - update() takes &str instead of &KeyByLabelValues +// - query_key() takes &str; renamed to query() +// - serialize_to_bytes (trait) -> serialize_msgpack (inherent method) +// - deserialize_from_bytes_arroyo -> deserialize_msgpack +// - merge_accumulators -> merge +// - Removed: deserialize_from_bytes (stub, stays in QE) +// - Removed: AggregateCore, SerializableToSink, MergeableAccumulator, MultipleSubpopulationAggregate impls +// - Removed: base64, serde_json imports (QE-specific) +// - Added: aggregate_hydrakll() one-shot helper + +use crate::kll::{KllSketch, KllSketchData}; +use serde::{Deserialize, Serialize}; +use std::cmp::Ordering; +use xxhash_rust::xxh32::xxh32; + +#[derive(Serialize, Deserialize)] +struct HydraKllSketchData { + row_num: usize, + col_num: usize, + sketches: Vec>, +} + +#[derive(Debug, Clone)] +pub struct HydraKllSketch { + pub sketch: Vec>, + pub row_num: usize, + pub col_num: usize, +} + +impl HydraKllSketch { + pub fn new(row_num: usize, col_num: usize, k: u16) -> Self { + let sketch = vec![vec![KllSketch::new(k); col_num]; row_num]; + Self { + sketch, + row_num, + col_num, + } + } + + pub fn update(&mut self, key: &str, value: f64) { + let key_bytes = key.as_bytes(); + // Update each row using different hash functions + for i in 0..self.row_num { + let hash_value = xxh32(key_bytes, i as u32); + let col_index = (hash_value as usize) % self.col_num; + self.sketch[i][col_index].update(value); + } + } + + pub fn query(&self, key: &str, quantile: f64) -> f64 { + let key_bytes = key.as_bytes(); + let mut quantiles = Vec::with_capacity(self.row_num); + + for i in 0..self.row_num { + let hash_value = xxh32(key_bytes, i as u32); + let col_index = (hash_value as usize) % self.col_num; + quantiles.push(self.sketch[i][col_index].get_quantile(quantile)); + } + + if quantiles.is_empty() { + return 0.0; + } + + quantiles.sort_by(|a, b| match a.partial_cmp(b) { + Some(ordering) => ordering, + None => Ordering::Equal, + }); + + let mid = quantiles.len() / 2; + if quantiles.len() % 2 == 0 { + (quantiles[mid - 1] + quantiles[mid]) / 2.0 + } else { + quantiles[mid] + } + } + + pub fn merge( + accumulators: Vec, + ) -> Result> { + if accumulators.is_empty() { + return Err("No accumulators to merge".into()); + } + + // Check dimensions match + let row_num = accumulators[0].row_num; + let col_num = accumulators[0].col_num; + for acc in &accumulators { + if acc.row_num != row_num || acc.col_num != col_num { + return Err( + "Cannot merge HydraKllSketch accumulators with different dimensions".into(), + ); + } + } + + // Transpose Vec into Vec>> indexed [row][col][acc], + // consuming the owned accumulators so no per-cell clones are needed. + let mut by_cell: Vec>> = (0..row_num) + .map(|_| (0..col_num).map(|_| Vec::new()).collect()) + .collect(); + for acc in accumulators { + for (i, row) in acc.sketch.into_iter().enumerate() { + for (j, cell) in row.into_iter().enumerate() { + by_cell[i][j].push(cell); + } + } + } + + // Merge each cell independently + let mut merged_sketch = Vec::with_capacity(row_num); + for row in by_cell { + let mut merged_row = Vec::with_capacity(col_num); + for cells in row { + merged_row.push(KllSketch::merge(cells)?); + } + merged_sketch.push(merged_row); + } + + Ok(HydraKllSketch { + sketch: merged_sketch, + row_num, + col_num, + }) + } + + /// Serialize to MessagePack — matches the Arroyo UDF wire format exactly. + pub fn serialize_msgpack(&self) -> Vec { + let mut sketches = Vec::with_capacity(self.row_num); + for row in &self.sketch { + let mut row_data = Vec::with_capacity(self.col_num); + for cell in row { + // Serialize each KllSketch to KllSketchData + let cell_bytes = cell.serialize_msgpack(); + let kll_data: KllSketchData = rmp_serde::from_slice(&cell_bytes) + .expect("Failed to deserialize KllSketchData from cell"); + row_data.push(kll_data); + } + sketches.push(row_data); + } + + let serialized = HydraKllSketchData { + row_num: self.row_num, + col_num: self.col_num, + sketches, + }; + + let mut buf = Vec::new(); + rmp_serde::encode::write(&mut buf, &serialized).unwrap(); + buf + } + + /// Deserialize from MessagePack produced by the Arroyo UDF. + pub fn deserialize_msgpack(buffer: &[u8]) -> Result> { + let deserialized_sketch_data: HydraKllSketchData = rmp_serde::from_slice(buffer) + .map_err(|e| format!("Failed to deserialize HydraKLL from MessagePack: {e}"))?; + + if deserialized_sketch_data.sketches.len() != deserialized_sketch_data.row_num { + return Err(format!( + "HydraKLL row count mismatch: expected {}, got {}", + deserialized_sketch_data.row_num, + deserialized_sketch_data.sketches.len() + ) + .into()); + } + + let mut sketch: Vec> = Vec::with_capacity(deserialized_sketch_data.row_num); + + for (row_idx, row) in deserialized_sketch_data.sketches.into_iter().enumerate() { + if row.len() != deserialized_sketch_data.col_num { + return Err(format!( + "HydraKLL column count mismatch in row {}: expected {}, got {}", + row_idx, + deserialized_sketch_data.col_num, + row.len() + ) + .into()); + } + + let mut accum_row: Vec = + Vec::with_capacity(deserialized_sketch_data.col_num); + for cell in row { + let cell_bytes = rmp_serde::to_vec(&cell) + .map_err(|e| format!("Failed to serialize nested KLL sketch: {e}"))?; + let kll = KllSketch::deserialize_msgpack(&cell_bytes)?; + accum_row.push(kll); + } + + sketch.push(accum_row); + } + + Ok(Self { + sketch, + row_num: deserialized_sketch_data.row_num, + col_num: deserialized_sketch_data.col_num, + }) + } + + /// One-shot aggregation for the Arroyo UDAF call pattern. + pub fn aggregate_hydrakll( + row_num: usize, + col_num: usize, + k: u16, + keys: &[&str], + values: &[f64], + ) -> Option> { + if keys.is_empty() { + return None; + } + let mut sketch = Self::new(row_num, col_num, k); + for (key, &value) in keys.iter().zip(values.iter()) { + sketch.update(key, value); + } + Some(sketch.serialize_msgpack()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_creation() { + let h = HydraKllSketch::new(2, 3, 200); + assert_eq!(h.row_num, 2); + assert_eq!(h.col_num, 3); + assert_eq!(h.sketch.len(), 2); + assert_eq!(h.sketch[0].len(), 3); + } + + #[test] + fn test_update_and_query() { + let mut h = HydraKllSketch::new(2, 10, 200); + h.update("key1", 5.0); + h.update("key1", 10.0); + // With 2 values, median quantile should be between them + let q = h.query("key1", 0.5); + assert!(q >= 0.0); + } + + #[test] + fn test_merge() { + let mut h1 = HydraKllSketch::new(2, 5, 200); + let mut h2 = HydraKllSketch::new(2, 5, 200); + + for i in 1..=5 { + h1.update("key1", i as f64); + } + for i in 6..=10 { + h2.update("key1", i as f64); + } + + let merged = HydraKllSketch::merge(vec![h1, h2]).unwrap(); + assert_eq!(merged.row_num, 2); + assert_eq!(merged.col_num, 5); + } + + #[test] + fn test_merge_dimension_mismatch() { + let h1 = HydraKllSketch::new(2, 5, 200); + let h2 = HydraKllSketch::new(3, 5, 200); + assert!(HydraKllSketch::merge(vec![h1, h2]).is_err()); + } + + #[test] + fn test_msgpack_round_trip() { + let mut h = HydraKllSketch::new(2, 3, 200); + h.update("key1", 5.0); + h.update("key2", 10.0); + + let bytes = h.serialize_msgpack(); + let deserialized = HydraKllSketch::deserialize_msgpack(&bytes).unwrap(); + + assert_eq!(deserialized.row_num, 2); + assert_eq!(deserialized.col_num, 3); + } + + #[test] + fn test_aggregate_hydrakll() { + let keys = ["a", "b", "a"]; + let values = [1.0, 2.0, 3.0]; + let bytes = HydraKllSketch::aggregate_hydrakll(2, 5, 200, &keys, &values).unwrap(); + let h = HydraKllSketch::deserialize_msgpack(&bytes).unwrap(); + assert_eq!(h.row_num, 2); + assert_eq!(h.col_num, 5); + } + + #[test] + fn test_aggregate_hydrakll_empty() { + assert!(HydraKllSketch::aggregate_hydrakll(2, 5, 200, &[], &[]).is_none()); + } +} diff --git a/asap-common/sketch-core/src/kll.rs b/asap-common/sketch-core/src/kll.rs new file mode 100644 index 0000000..5751cb2 --- /dev/null +++ b/asap-common/sketch-core/src/kll.rs @@ -0,0 +1,366 @@ +// Adapted from QueryEngineRust/src/precompute_operators/datasketches_kll_accumulator.rs +// Changes: +// - Renamed DatasketchesKLLAccumulator -> KllSketch +// - KllSketchData made pub (used by hydra_kll) +// - _update -> pub update +// - serialize_to_bytes (trait) -> serialize_msgpack (inherent method) +// - deserialize_from_bytes_arroyo -> deserialize_msgpack +// - merge_accumulators -> merge +// - Removed: deserialize_from_json, deserialize_from_bytes (legacy QE formats, stay in QE) +// - Removed: merge_multiple (QE trait-object helper, stays in QE) +// - Removed: AggregateCore, SerializableToSink, MergeableAccumulator, SingleSubpopulationAggregate impls +// - Removed: base64, serde_json, tracing imports (QE-specific) +// - Added: aggregate_kll() one-shot helper + +use core::panic; +use dsrs::KllDoubleSketch; +use serde::{Deserialize, Serialize}; + +use crate::config::use_sketchlib_for_kll; +use crate::kll_sketchlib::{ + bytes_from_sketchlib_kll, new_sketchlib_kll, sketchlib_kll_from_bytes, sketchlib_kll_merge, + sketchlib_kll_quantile, sketchlib_kll_update, SketchlibKll, +}; + +/// Wire format used in MessagePack serialization (matches Arroyo UDF output). +#[derive(Deserialize, Serialize)] +pub struct KllSketchData { + pub k: u16, + pub sketch_bytes: Vec, +} + +/// Backend implementation for KLL Sketch. Only one is active at a time. +pub enum KllBackend { + /// dsrs (DataSketches) implementation. + Legacy(KllDoubleSketch), + /// asap_sketchlib backed implementation. + Sketchlib(SketchlibKll), +} + +impl std::fmt::Debug for KllBackend { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + KllBackend::Legacy(_) => write!(f, "Legacy(..)"), + KllBackend::Sketchlib(_) => write!(f, "Sketchlib(..)"), + } + } +} + +impl Clone for KllBackend { + fn clone(&self) -> Self { + match self { + KllBackend::Legacy(s) => { + if s.get_n() == 0 { + KllBackend::Legacy(KllDoubleSketch::with_k(200)) // k will be overwritten by KllSketch + } else { + let bytes = s.serialize(); + KllBackend::Legacy(KllDoubleSketch::deserialize(bytes.as_ref()).unwrap()) + } + } + KllBackend::Sketchlib(s) => KllBackend::Sketchlib(s.clone()), + } + } +} + +pub struct KllSketch { + pub k: u16, + pub backend: KllBackend, +} + +impl KllSketch { + pub fn new(k: u16) -> Self { + let backend = if use_sketchlib_for_kll() { + KllBackend::Sketchlib(new_sketchlib_kll(k)) + } else { + KllBackend::Legacy(KllDoubleSketch::with_k(k)) + }; + Self { k, backend } + } + + /// Returns the raw sketch bytes (for JSON serialization, etc.). + pub fn sketch_bytes(&self) -> Vec { + match &self.backend { + KllBackend::Legacy(s) => s.serialize().as_ref().to_vec(), + KllBackend::Sketchlib(s) => bytes_from_sketchlib_kll(s), + } + } + + pub fn update(&mut self, value: f64) { + match &mut self.backend { + KllBackend::Legacy(s) => s.update(value), + KllBackend::Sketchlib(s) => sketchlib_kll_update(s, value), + } + } + + pub fn count(&self) -> u64 { + match &self.backend { + KllBackend::Legacy(s) => s.get_n(), + KllBackend::Sketchlib(s) => s.count() as u64, + } + } + + pub fn get_quantile(&self, quantile: f64) -> f64 { + if self.count() == 0 { + return 0.0; + } + match &self.backend { + KllBackend::Legacy(s) => s.get_quantile(quantile), + KllBackend::Sketchlib(s) => sketchlib_kll_quantile(s, quantile), + } + } + + pub fn merge( + accumulators: Vec, + ) -> Result> { + if accumulators.is_empty() { + return Err("No accumulators to merge".into()); + } + + let k = accumulators[0].k; + for acc in &accumulators { + if acc.k != k { + return Err("Cannot merge KllSketch with different k values".into()); + } + } + + let mut merged = KllSketch::new(k); + match &mut merged.backend { + KllBackend::Legacy(merged_legacy) => { + for acc in accumulators { + if let KllBackend::Legacy(acc_legacy) = acc.backend { + merged_legacy.merge(&acc_legacy); + } else { + return Err("Cannot merge Legacy with Sketchlib KLL".into()); + } + } + } + KllBackend::Sketchlib(merged_sketchlib) => { + for acc in accumulators { + if let KllBackend::Sketchlib(acc_sketchlib) = &acc.backend { + sketchlib_kll_merge(merged_sketchlib, acc_sketchlib); + } else { + return Err("Cannot merge Sketchlib with Legacy KLL".into()); + } + } + } + } + + Ok(merged) + } + + /// Serialize to MessagePack — matches the Arroyo UDF wire format exactly. + pub fn serialize_msgpack(&self) -> Vec { + let sketch_bytes = self.sketch_bytes(); + let serialized = KllSketchData { + k: self.k, + sketch_bytes, + }; + + let mut buf = Vec::new(); + match rmp_serde::encode::write(&mut buf, &serialized) { + Ok(_) => buf, + Err(_) => { + panic!("Failed to serialize KllSketchData to MessagePack"); + } + } + } + + /// Deserialize from MessagePack produced by the Arroyo UDF. + pub fn deserialize_msgpack(buffer: &[u8]) -> Result> { + let wire: KllSketchData = rmp_serde::from_slice(buffer) + .map_err(|e| format!("Failed to deserialize KllSketchData from MessagePack: {e}"))?; + + let backend = if use_sketchlib_for_kll() { + KllBackend::Sketchlib(sketchlib_kll_from_bytes(&wire.sketch_bytes)?) + } else { + KllBackend::Legacy( + KllDoubleSketch::deserialize(&wire.sketch_bytes) + .map_err(|e| format!("Failed to deserialize KLL sketch: {e}"))?, + ) + }; + + Ok(Self { k: wire.k, backend }) + } + + /// Merge from references without cloning. + pub fn merge_refs( + sketches: &[&Self], + ) -> Result> { + if sketches.is_empty() { + return Err("No sketches to merge".into()); + } + let k = sketches[0].k; + for s in sketches { + if s.k != k { + return Err("Cannot merge KllSketch with different k values".into()); + } + } + let mut merged = Self::new(k); + match &mut merged.backend { + KllBackend::Legacy(merged_legacy) => { + for s in sketches { + if let KllBackend::Legacy(s_legacy) = &s.backend { + merged_legacy.merge(s_legacy); + } else { + return Err("Cannot merge Legacy with Sketchlib KLL".into()); + } + } + } + KllBackend::Sketchlib(merged_sketchlib) => { + for s in sketches { + if let KllBackend::Sketchlib(s_sketchlib) = &s.backend { + sketchlib_kll_merge(merged_sketchlib, s_sketchlib); + } else { + return Err("Cannot merge Sketchlib with Legacy KLL".into()); + } + } + } + } + Ok(merged) + } + + /// Deserialize from a raw datasketches byte buffer (legacy Flink/FlinkSketch format). + pub fn from_dsrs_bytes(bytes: &[u8], k: u16) -> Result> { + let sketch = KllDoubleSketch::deserialize(bytes) + .map_err(|e| format!("Failed to deserialize KLL sketch from dsrs bytes: {e}"))?; + Ok(Self { + k, + backend: KllBackend::Legacy(sketch), + }) + } + + /// One-shot aggregation for the Arroyo UDAF call pattern. + pub fn aggregate_kll(k: u16, values: &[f64]) -> Option> { + if values.is_empty() { + return None; + } + let mut sketch = Self::new(k); + for &value in values { + sketch.update(value); + } + Some(sketch.serialize_msgpack()) + } +} + +// Manual trait implementations since the C++ and sketchlib types don't provide Clone +impl Clone for KllSketch { + fn clone(&self) -> Self { + let backend = match &self.backend { + KllBackend::Legacy(sketch) => { + let new_sketch = if sketch.get_n() == 0 { + KllDoubleSketch::with_k(self.k) + } else { + let bytes = sketch.serialize(); + KllDoubleSketch::deserialize(bytes.as_ref()).unwrap() + }; + KllBackend::Legacy(new_sketch) + } + KllBackend::Sketchlib(s) => { + let bytes = bytes_from_sketchlib_kll(s); + KllBackend::Sketchlib(sketchlib_kll_from_bytes(&bytes).unwrap()) + } + }; + Self { k: self.k, backend } + } +} + +impl std::fmt::Debug for KllSketch { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("KllSketch") + .field("k", &self.k) + .field("sketch_n", &self.count()) + .finish() + } +} + +// TODO: verify this +// Thread safety: The C++ library is not thread-safe by default, but since we're using it +// in a single-threaded context per accumulator instance and only sharing read-only operations, +// this should be safe. The actual sketch data is immutable once created. +unsafe impl Send for KllSketch {} +unsafe impl Sync for KllSketch {} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_kll_creation() { + let kll = KllSketch::new(200); + assert_eq!(kll.count(), 0); + assert_eq!(kll.k, 200); + } + + #[test] + fn test_kll_update() { + let mut kll = KllSketch::new(200); + kll.update(10.0); + kll.update(20.0); + kll.update(15.0); + assert_eq!(kll.count(), 3); + } + + #[test] + fn test_kll_quantile() { + let mut kll = KllSketch::new(200); + for i in 1..=10 { + kll.update(i as f64); + } + assert_eq!(kll.get_quantile(0.0), 1.0); + assert_eq!(kll.get_quantile(1.0), 10.0); + let median = kll.get_quantile(0.5); + assert!( + (5.0..=6.0).contains(&median), + "median should be between 5 and 6; got {median}" + ); + } + + #[test] + fn test_kll_merge() { + let mut kll1 = KllSketch::new(200); + let mut kll2 = KllSketch::new(200); + + for i in 1..=5 { + kll1.update(i as f64); + } + for i in 6..=10 { + kll2.update(i as f64); + } + + let merged = KllSketch::merge(vec![kll1, kll2]).unwrap(); + assert_eq!(merged.count(), 10); + assert_eq!(merged.get_quantile(0.0), 1.0); + assert_eq!(merged.get_quantile(1.0), 10.0); + } + + #[test] + fn test_msgpack_round_trip() { + let mut kll = KllSketch::new(200); + for i in 1..=5 { + kll.update(i as f64); + } + + let bytes = kll.serialize_msgpack(); + let deserialized = KllSketch::deserialize_msgpack(&bytes).unwrap(); + + assert_eq!(deserialized.k, 200); + assert_eq!(deserialized.count(), 5); + assert_eq!(deserialized.get_quantile(0.0), 1.0); + assert_eq!(deserialized.get_quantile(1.0), 5.0); + } + + #[test] + fn test_aggregate_kll() { + let values = [1.0, 2.0, 3.0, 4.0, 5.0]; + let bytes = KllSketch::aggregate_kll(200, &values).unwrap(); + let kll = KllSketch::deserialize_msgpack(&bytes).unwrap(); + assert_eq!(kll.count(), 5); + assert_eq!(kll.get_quantile(0.0), 1.0); + assert_eq!(kll.get_quantile(1.0), 5.0); + } + + #[test] + fn test_aggregate_kll_empty() { + assert!(KllSketch::aggregate_kll(200, &[]).is_none()); + } +} diff --git a/asap-common/sketch-core/src/kll_sketchlib.rs b/asap-common/sketch-core/src/kll_sketchlib.rs new file mode 100644 index 0000000..bdbee0e --- /dev/null +++ b/asap-common/sketch-core/src/kll_sketchlib.rs @@ -0,0 +1,36 @@ +use asap_sketchlib::KLL; + +/// Concrete KLL type from asap_sketchlib when sketchlib backend is enabled. +pub type SketchlibKll = KLL; + +/// Creates a fresh sketchlib KLL sketch with the requested accuracy parameter `k`. +pub fn new_sketchlib_kll(k: u16) -> SketchlibKll { + KLL::init_kll(k as i32) +} + +/// Updates a sketchlib KLL with one numeric observation. +pub fn sketchlib_kll_update(inner: &mut SketchlibKll, value: f64) { + // KLL accepts only numeric inputs. We intentionally ignore the error here because `value` + // is always numeric. + inner.update(&value); +} + +/// Queries a sketchlib KLL for the value at the requested quantile. +pub fn sketchlib_kll_quantile(inner: &SketchlibKll, q: f64) -> f64 { + inner.quantile(q) +} + +/// Merges `src` into `dst`. +pub fn sketchlib_kll_merge(dst: &mut SketchlibKll, src: &SketchlibKll) { + dst.merge(src); +} + +/// Serializes a sketchlib KLL into MessagePack bytes. +pub fn bytes_from_sketchlib_kll(inner: &SketchlibKll) -> Vec { + inner.serialize_to_bytes().unwrap() +} + +/// Deserializes a sketchlib KLL from MessagePack bytes. +pub fn sketchlib_kll_from_bytes(bytes: &[u8]) -> Result> { + Ok(KLL::deserialize_from_bytes(bytes)?) +} diff --git a/asap-common/sketch-core/src/lib.rs b/asap-common/sketch-core/src/lib.rs new file mode 100644 index 0000000..3ddd32b --- /dev/null +++ b/asap-common/sketch-core/src/lib.rs @@ -0,0 +1,16 @@ +#[cfg(test)] +#[ctor::ctor] +fn init_sketch_legacy_for_tests() { + crate::config::force_legacy_mode_for_tests(); +} + +pub mod config; +pub mod count_min; +pub mod count_min_sketchlib; +pub mod count_min_with_heap; +pub mod count_min_with_heap_sketchlib; +pub mod delta_set_aggregator; +pub mod hydra_kll; +pub mod kll; +pub mod kll_sketchlib; +pub mod set_aggregator; diff --git a/asap-common/sketch-core/src/set_aggregator.rs b/asap-common/sketch-core/src/set_aggregator.rs new file mode 100644 index 0000000..c745f28 --- /dev/null +++ b/asap-common/sketch-core/src/set_aggregator.rs @@ -0,0 +1,152 @@ +// Adapted from QueryEngineRust/src/precompute_operators/set_aggregator_accumulator.rs +// Changes: +// - Renamed SetAggregatorAccumulator -> SetAggregator +// - values field is now HashSet instead of HashSet +// - add_key(&str) instead of add_key(KeyByLabelValues) +// - serialize_msgpack / deserialize_msgpack use StringSet { values: HashSet } +// wire format matching the Arroyo setaggregator_ UDF exactly (same as DeltaResult pattern) +// - merge_accumulators -> merge +// - Removed: deserialize_from_json, deserialize_from_bytes, deserialize_from_bytes_arroyo +// (QE-specific / buggy legacy formats stay in QE) +// - Removed: AggregateCore, SerializableToSink, MergeableAccumulator, MultipleSubpopulationAggregate impls +// - Removed: with_added (QE-specific constructor) + +use serde::{Deserialize, Serialize}; +use std::collections::HashSet; + +/// Set aggregator for tracking a set of unique string keys. +/// Wire format: StringSet { values: HashSet } in MessagePack — matches Arroyo setaggregator_ UDF. +#[derive(Debug, Clone)] +pub struct SetAggregator { + pub values: HashSet, +} + +impl SetAggregator { + pub fn new() -> Self { + Self { + values: HashSet::new(), + } + } + + pub fn insert(&mut self, key: &str) { + self.values.insert(key.to_string()); + } + + pub fn merge( + accumulators: Vec, + ) -> Result> { + if accumulators.is_empty() { + return Err("No accumulators to merge".into()); + } + + let mut merged = SetAggregator::new(); + for accumulator in accumulators { + merged.values.extend(accumulator.values); + } + + Ok(merged) + } + + /// Serialize to MessagePack — matches the Arroyo setaggregator_ UDF wire format exactly: + /// StringSet { values: HashSet } as a msgpack map. + pub fn serialize_msgpack(&self) -> Vec { + #[derive(Serialize)] + struct StringSet<'a> { + values: &'a HashSet, + } + let wrapper = StringSet { + values: &self.values, + }; + let mut buf = Vec::new(); + rmp_serde::encode::write(&mut buf, &wrapper).unwrap(); + buf + } + + /// Deserialize from MessagePack produced by the Arroyo setaggregator_ UDF. + pub fn deserialize_msgpack(buffer: &[u8]) -> Result> { + #[derive(Deserialize)] + struct StringSet { + values: HashSet, + } + let wrapper: StringSet = rmp_serde::from_slice(buffer) + .map_err(|e| format!("Failed to deserialize SetAggregator from MessagePack: {e}"))?; + Ok(Self { + values: wrapper.values, + }) + } +} + +impl Default for SetAggregator { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_creation() { + let sa = SetAggregator::new(); + assert!(sa.values.is_empty()); + } + + #[test] + fn test_insert() { + let mut sa = SetAggregator::new(); + sa.insert("web"); + sa.insert("api"); + sa.insert("web"); // duplicate + assert_eq!(sa.values.len(), 2); + assert!(sa.values.contains("web")); + assert!(sa.values.contains("api")); + } + + #[test] + fn test_merge() { + let mut sa1 = SetAggregator::new(); + let mut sa2 = SetAggregator::new(); + + sa1.insert("web"); + sa1.insert("api"); + sa2.insert("api"); // duplicate + sa2.insert("db"); + + let merged = SetAggregator::merge(vec![sa1, sa2]).unwrap(); + assert_eq!(merged.values.len(), 3); + assert!(merged.values.contains("web")); + assert!(merged.values.contains("api")); + assert!(merged.values.contains("db")); + } + + #[test] + fn test_msgpack_round_trip() { + let mut sa = SetAggregator::new(); + sa.insert("web"); + sa.insert("api"); + + let bytes = sa.serialize_msgpack(); + let deserialized = SetAggregator::deserialize_msgpack(&bytes).unwrap(); + + assert_eq!(deserialized.values.len(), 2); + assert!(deserialized.values.contains("web")); + assert!(deserialized.values.contains("api")); + } + + #[test] + fn test_msgpack_matches_arroyo_format() { + // Verify wire format is StringSet { values: [...] } not a plain array. + // Arroyo's setaggregator_.rs serializes StringSet { values: HashSet }. + #[derive(Deserialize)] + struct StringSet { + values: HashSet, + } + let mut sa = SetAggregator::new(); + sa.insert("a"); + let bytes = sa.serialize_msgpack(); + let decoded: StringSet = + rmp_serde::from_slice(&bytes).expect("should decode as StringSet { values: ... }"); + assert!(decoded.values.contains("a")); + } +} diff --git a/asap-query-engine/Cargo.toml b/asap-query-engine/Cargo.toml index 135758f..9b73901 100644 --- a/asap-query-engine/Cargo.toml +++ b/asap-query-engine/Cargo.toml @@ -5,6 +5,7 @@ edition.workspace = true [dependencies] # Internal crates (workspace) +sketch-core.workspace = true promql_utilities.workspace = true sql_utilities.workspace = true asap_types.workspace = true @@ -38,6 +39,7 @@ urlencoding = "2.1" flate2 = "1.0" async-trait = "0.1" xxhash-rust = { version = "0.8", features = ["xxh32", "xxh64"] } +dsrs = { git = "https://github.com/ProjectASAP/datasketches-rs", rev = "d748ec75c80fff21f7b24897244dd1c895df2e9a" } base64 = "0.21" hex = "0.4" sqlparser = "0.59.0" @@ -58,7 +60,7 @@ tracing-appender = "0.2" arc-swap = "1" csv = "1" elastic_dsl_utilities.workspace = true -asap_sketchlib = { git = "https://github.com/ProjectASAP/asap_sketchlib", branch = "refactor/adopt-sketch-core-modules" } +asap_sketchlib = { git = "https://github.com/ProjectASAP/asap_sketchlib" } [[bin]] name = "precompute_engine" @@ -77,6 +79,7 @@ name = "e2e_quickstart_resource_test" path = "src/bin/e2e_quickstart_resource_test.rs" [dev-dependencies] +ctor = "0.2" tempfile = "3.20.0" criterion = { version = "0.5", features = ["html_reports"] } @@ -91,3 +94,4 @@ default = [] lock_profiling = [] # Enable extra debugging output extra_debugging = [] +sketchlib-tests = [] diff --git a/asap-query-engine/Dockerfile b/asap-query-engine/Dockerfile index 2401dc4..f54cdce 100644 --- a/asap-query-engine/Dockerfile +++ b/asap-query-engine/Dockerfile @@ -16,6 +16,9 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ # Copy the asap-common directory COPY asap-common ./asap-common +# Copy path dependencies of asap-query-engine +COPY asap-common/sketch-core ./asap-common/sketch-core + COPY Cargo.toml ./ COPY Cargo.lock ./ COPY asap-query-engine/Cargo.toml ./asap-query-engine/ diff --git a/asap-query-engine/src/lib.rs b/asap-query-engine/src/lib.rs index f6b8a4e..b0d40dd 100644 --- a/asap-query-engine/src/lib.rs +++ b/asap-query-engine/src/lib.rs @@ -1,3 +1,16 @@ +#[cfg(test)] +#[ctor::ctor] +fn init_sketch_backend_for_tests() { + #[cfg(feature = "sketchlib-tests")] + let _ = sketch_core::config::configure( + sketch_core::config::ImplMode::Sketchlib, + sketch_core::config::ImplMode::Legacy, + sketch_core::config::ImplMode::Sketchlib, + ); + #[cfg(not(feature = "sketchlib-tests"))] + sketch_core::config::force_legacy_mode_for_tests(); +} + pub mod data_model; pub mod drivers; pub mod engines; diff --git a/asap-query-engine/src/main.rs b/asap-query-engine/src/main.rs index a96ec4b..6c601f0 100644 --- a/asap-query-engine/src/main.rs +++ b/asap-query-engine/src/main.rs @@ -5,6 +5,8 @@ use std::sync::{Arc, RwLock}; use tokio::signal; use tracing::{debug, error, info, warn}; +use sketch_core::config::{self, ImplMode}; + use asap_types::streaming_config::StreamingConfig; use query_engine_rust::data_model::enums::{ CleanupPolicy, InputFormat, LockStrategy, StreamingEngine, @@ -113,6 +115,18 @@ struct Args { #[arg(long)] promsketch_config: Option, + /// Backend implementation for Count-Min Sketch (legacy | sketchlib) + #[arg(long, value_enum, default_value_t = config::DEFAULT_CMS_IMPL)] + sketch_cms_impl: ImplMode, + + /// Backend implementation for KLL Sketch (legacy | sketchlib) + #[arg(long, value_enum, default_value_t = config::DEFAULT_KLL_IMPL)] + sketch_kll_impl: ImplMode, + + /// Backend implementation for Count-Min-With-Heap (legacy | sketchlib) + #[arg(long, value_enum, default_value_t = config::DEFAULT_CMWH_IMPL)] + sketch_cmwh_impl: ImplMode, + /// Enable OTLP metrics ingest (gRPC + HTTP) #[arg(long)] enable_otel_ingest: bool, @@ -192,6 +206,14 @@ struct Args { async fn main() -> Result<()> { let args = Args::parse(); + // Configure sketch-core backends before any sketch operations. + config::configure( + args.sketch_cms_impl, + args.sketch_kll_impl, + args.sketch_cmwh_impl, + ) + .expect("sketch backend already initialised"); + // Create output directory fs::create_dir_all(&args.output_dir)?; diff --git a/asap-query-engine/src/precompute_engine/worker.rs b/asap-query-engine/src/precompute_engine/worker.rs index 02c330b..1c37723 100644 --- a/asap-query-engine/src/precompute_engine/worker.rs +++ b/asap-query-engine/src/precompute_engine/worker.rs @@ -776,8 +776,8 @@ mod tests { use crate::precompute_operators::datasketches_kll_accumulator::DatasketchesKLLAccumulator; use crate::precompute_operators::multiple_sum_accumulator::MultipleSumAccumulator; use crate::precompute_operators::sum_accumulator::SumAccumulator; - use asap_sketchlib::sketches::kll::KllSketch; use asap_types::enums::{AggregationType, WindowType}; + use sketch_core::kll::KllSketch; fn make_agg_config( id: u64, diff --git a/asap-query-engine/src/precompute_operators/count_min_sketch_accumulator.rs b/asap-query-engine/src/precompute_operators/count_min_sketch_accumulator.rs index ed6ae31..fe3ec33 100644 --- a/asap-query-engine/src/precompute_operators/count_min_sketch_accumulator.rs +++ b/asap-query-engine/src/precompute_operators/count_min_sketch_accumulator.rs @@ -2,14 +2,14 @@ use crate::data_model::{ AggregateCore, AggregationType, KeyByLabelValues, MergeableAccumulator, MultipleSubpopulationAggregate, SerializableToSink, }; -use asap_sketchlib::sketches::countmin::CountMinSketch; use serde_json::Value; +use sketch_core::count_min::CountMinSketch; use std::collections::HashMap; use promql_utilities::query_logics::enums::Statistic; -/// Count-Min Sketch accumulator — wraps asap_sketchlib::sketches::CountMinSketch. -/// Core struct, update/merge/serde logic live in `asap_sketchlib::sketches`. +/// Count-Min Sketch accumulator — wraps sketch_core::CountMinSketch. +/// Core struct, update/merge/serde logic live in sketch-core. /// This file retains QE-specific trait impls, legacy deserializers, and JSON output. #[derive(Debug, Clone)] pub struct CountMinSketchAccumulator { @@ -29,7 +29,7 @@ impl CountMinSketchAccumulator { } pub fn query_key(&self, key: &KeyByLabelValues) -> f64 { - self.inner.estimate(&key.to_semicolon_str()) + self.inner.query_key(&key.to_semicolon_str()) } pub fn deserialize_from_json(data: &Value) -> Result> { @@ -64,8 +64,7 @@ impl CountMinSketchAccumulator { buffer: &[u8], ) -> Result> { Ok(Self { - inner: CountMinSketch::deserialize_msgpack(buffer) - .map_err(|e| -> Box { e.to_string().into() })?, + inner: CountMinSketch::deserialize_msgpack(buffer)?, }) } @@ -137,10 +136,10 @@ impl CountMinSketchAccumulator { } // Check dimensions are consistent - let rows = cms_accumulators[0].inner.rows(); - let cols = cms_accumulators[0].inner.cols(); + let row_num = cms_accumulators[0].inner.row_num; + let col_num = cms_accumulators[0].inner.col_num; for acc in &cms_accumulators { - if acc.inner.rows() != rows || acc.inner.cols() != cols { + if acc.inner.row_num != row_num || acc.inner.col_num != col_num { return Err( "Cannot merge CountMinSketch accumulators with different dimensions".into(), ); @@ -159,14 +158,14 @@ impl CountMinSketchAccumulator { impl SerializableToSink for CountMinSketchAccumulator { fn serialize_to_json(&self) -> Value { serde_json::json!({ - "row_num": self.inner.rows(), - "col_num": self.inner.cols(), + "row_num": self.inner.row_num, + "col_num": self.inner.col_num, "sketch": self.inner.sketch() }) } fn serialize_to_bytes(&self) -> Vec { - self.inner.serialize_msgpack().unwrap_or_default() + self.inner.serialize_msgpack() } } @@ -250,12 +249,11 @@ impl MergeableAccumulator for CountMinSketchAccumulat if accumulators.is_empty() { return Err("No accumulators to merge".into()); } - let mut iter = accumulators.into_iter(); - let mut merged = iter.next().unwrap(); - for acc in iter { - merged.inner.merge(&acc.inner)?; - } - Ok(merged) + let inners: Vec = accumulators.into_iter().map(|acc| acc.inner).collect(); + let merged_inner = CountMinSketch::merge(inners)?; + Ok(Self { + inner: merged_inner, + }) } } @@ -266,8 +264,8 @@ mod tests { #[test] fn test_count_min_sketch_creation() { let cms = CountMinSketchAccumulator::new(4, 1000); - assert_eq!(cms.inner.rows(), 4); - assert_eq!(cms.inner.cols(), 1000); + assert_eq!(cms.inner.row_num, 4); + assert_eq!(cms.inner.col_num, 1000); let sketch = cms.inner.sketch(); assert_eq!(sketch.len(), 4); assert_eq!(sketch[0].len(), 1000); @@ -346,8 +344,8 @@ mod tests { let deserialized = CountMinSketchAccumulator::deserialize_from_bytes_arroyo(&bytes).unwrap(); - assert_eq!(deserialized.inner.rows(), 2); - assert_eq!(deserialized.inner.cols(), 3); + assert_eq!(deserialized.inner.row_num, 2); + assert_eq!(deserialized.inner.col_num, 3); let deser_sketch = deserialized.inner.sketch(); assert_eq!(deser_sketch[0][1], 42.0); assert_eq!(deser_sketch[1][2], 100.0); diff --git a/asap-query-engine/src/precompute_operators/count_min_sketch_with_heap_accumulator.rs b/asap-query-engine/src/precompute_operators/count_min_sketch_with_heap_accumulator.rs index 3be5651..76ece28 100644 --- a/asap-query-engine/src/precompute_operators/count_min_sketch_with_heap_accumulator.rs +++ b/asap-query-engine/src/precompute_operators/count_min_sketch_with_heap_accumulator.rs @@ -2,22 +2,25 @@ use crate::data_model::{ AggregateCore, AggregationType, KeyByLabelValues, MergeableAccumulator, MultipleSubpopulationAggregate, SerializableToSink, }; -use asap_sketchlib::sketches::cms_heap::{CmsHeapItem, CountMinSketchWithHeap}; use serde_json::Value; +use sketch_core::count_min_with_heap::{CountMinSketchWithHeap, HeapItem}; use std::collections::HashMap; use promql_utilities::query_logics::enums::Statistic; -/// Count-Min Sketch with Heap accumulator — wraps `asap_sketchlib::sketches::CountMinSketchWithHeap`. -/// Core struct, update/merge/serde logic live in `asap_sketchlib::sketches::cms_heap`. +/// Count-Min Sketch with Heap accumulator — wraps sketch_core::CountMinSketchWithHeap. +/// Core struct, update/merge/serde logic live in sketch-core. /// This file retains QE-specific trait impls, legacy deserializers, and JSON output. +/// +/// NOTE (bug, do not fix): QueryEngineRust uses xxhash-rust::xxh32; the Arroyo template uses +/// twox-hash::XxHash32. Bucket assignments differ. Tracked separately. #[derive(Debug, Clone)] pub struct CountMinSketchWithHeapAccumulator { pub inner: CountMinSketchWithHeap, } // Re-export HeapItem so existing code using CountMinSketchWithHeapAccumulator::HeapItem still works. -pub use asap_sketchlib::sketches::cms_heap::CmsHeapItem as HeapItemReexport; +pub use sketch_core::count_min_with_heap::HeapItem as HeapItemReexport; impl CountMinSketchWithHeapAccumulator { pub fn new(row_num: usize, col_num: usize, heap_size: usize) -> Self { @@ -28,7 +31,7 @@ impl CountMinSketchWithHeapAccumulator { pub fn query_key(&self, key: &KeyByLabelValues) -> f64 { let key_string = key.labels.join(";"); - self.inner.estimate(&key_string) + self.inner.query_key(&key_string) } /// This function seems will never be used anymore. Keep it for possible future use. @@ -71,7 +74,7 @@ impl CountMinSketchWithHeapAccumulator { let value = item["value"] .as_f64() .ok_or("Missing or invalid 'value' in heap item")?; - topk_heap.push(CmsHeapItem { key, value }); + topk_heap.push(HeapItem { key, value }); } Ok(Self { @@ -85,8 +88,7 @@ impl CountMinSketchWithHeapAccumulator { buffer: &[u8], ) -> Result> { Ok(Self { - inner: CountMinSketchWithHeap::deserialize_msgpack(buffer) - .map_err(|e| -> Box { e.to_string().into() })?, + inner: CountMinSketchWithHeap::deserialize_msgpack(buffer)?, }) } @@ -122,8 +124,8 @@ impl SerializableToSink for CountMinSketchWithHeapAccumulator { .collect(); serde_json::json!({ - "row_num": self.inner.rows(), - "col_num": self.inner.cols(), + "row_num": self.inner.row_num, + "col_num": self.inner.col_num, "heap_size": self.inner.heap_size, "sketch": self.inner.sketch_matrix(), "topk_heap": heap_items @@ -131,7 +133,7 @@ impl SerializableToSink for CountMinSketchWithHeapAccumulator { } fn serialize_to_bytes(&self) -> Vec { - self.inner.serialize_msgpack().unwrap_or_default() + self.inner.serialize_msgpack() } } @@ -213,12 +215,12 @@ impl MergeableAccumulator for CountMinSketchW if accumulators.is_empty() { return Err("No accumulators to merge".into()); } - let mut iter = accumulators.into_iter(); - let mut merged = iter.next().unwrap(); - for acc in iter { - merged.inner.merge(&acc.inner)?; - } - Ok(merged) + let inners: Vec = + accumulators.into_iter().map(|acc| acc.inner).collect(); + let merged_inner = CountMinSketchWithHeap::merge(inners)?; + Ok(Self { + inner: merged_inner, + }) } } @@ -229,8 +231,8 @@ mod tests { #[test] fn test_count_min_sketch_with_heap_creation() { let cms = CountMinSketchWithHeapAccumulator::new(4, 1000, 20); - assert_eq!(cms.inner.rows(), 4); - assert_eq!(cms.inner.cols(), 1000); + assert_eq!(cms.inner.row_num, 4); + assert_eq!(cms.inner.col_num, 1000); assert_eq!(cms.inner.heap_size, 20); assert_eq!(cms.inner.topk_heap_items().len(), 0); } @@ -253,11 +255,11 @@ mod tests { vec![0.0, 20.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], ]; let heap1 = vec![ - CmsHeapItem { + HeapItem { key: "key1".to_string(), value: 100.0, }, - CmsHeapItem { + HeapItem { key: "key2".to_string(), value: 50.0, }, @@ -267,11 +269,11 @@ mod tests { vec![0.0, 15.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], ]; let heap2 = vec![ - CmsHeapItem { + HeapItem { key: "key3".to_string(), value: 75.0, }, - CmsHeapItem { + HeapItem { key: "key1".to_string(), value: 80.0, }, @@ -299,8 +301,8 @@ mod tests { let result = CountMinSketchWithHeapAccumulator::merge_accumulators(vec![cms.clone()]); assert!(result.is_ok()); let merged = result.unwrap(); - assert_eq!(merged.inner.rows(), cms.inner.rows()); - assert_eq!(merged.inner.cols(), cms.inner.cols()); + assert_eq!(merged.inner.row_num, cms.inner.row_num); + assert_eq!(merged.inner.col_num, cms.inner.col_num); assert_eq!(merged.inner.heap_size, cms.inner.heap_size); } @@ -310,14 +312,17 @@ mod tests { let cms2 = CountMinSketchWithHeapAccumulator::new(3, 10, 5); let result = CountMinSketchWithHeapAccumulator::merge_accumulators(vec![cms1, cms2]); assert!(result.is_err()); - assert!(result.unwrap_err().to_string().contains("dimension")); + assert!(result + .unwrap_err() + .to_string() + .contains("different dimensions")); } #[test] fn test_count_min_sketch_with_heap_serialization() { // Use from_legacy_matrix for a controlled state that round-trips correctly with both backends. let sketch = vec![vec![0.0, 42.0, 0.0], vec![0.0, 0.0, 100.0]]; - let topk_heap = vec![CmsHeapItem { + let topk_heap = vec![HeapItem { key: "test_key".to_string(), value: 99.0, }]; @@ -329,8 +334,8 @@ mod tests { let deserialized = CountMinSketchWithHeapAccumulator::deserialize_from_bytes_arroyo(&bytes).unwrap(); - assert_eq!(deserialized.inner.rows(), 2); - assert_eq!(deserialized.inner.cols(), 3); + assert_eq!(deserialized.inner.row_num, 2); + assert_eq!(deserialized.inner.col_num, 3); assert_eq!(deserialized.inner.heap_size, 5); assert_eq!(deserialized.inner.sketch_matrix()[0][1], 42.0); // [1][2] may be 100 (legacy, no hash collision) or 199 (100+99 when test_key hashes there) diff --git a/asap-query-engine/src/precompute_operators/datasketches_kll_accumulator.rs b/asap-query-engine/src/precompute_operators/datasketches_kll_accumulator.rs index 33e085d..e5c15e3 100644 --- a/asap-query-engine/src/precompute_operators/datasketches_kll_accumulator.rs +++ b/asap-query-engine/src/precompute_operators/datasketches_kll_accumulator.rs @@ -2,9 +2,9 @@ use crate::data_model::{ AggregateCore, AggregationType, MergeableAccumulator, SerializableToSink, SingleSubpopulationAggregate, }; -use asap_sketchlib::sketches::kll::KllSketch; use base64::{engine::general_purpose, Engine as _}; use serde_json::Value; +use sketch_core::kll::KllSketch; use std::collections::HashMap; #[cfg(feature = "extra_debugging")] use std::time::Instant; @@ -12,9 +12,9 @@ use tracing::debug; use promql_utilities::query_logics::enums::Statistic; -/// KLL sketch accumulator — wraps asap_sketchlib::sketches::KllSketch. -/// Core struct, update/merge/serde logic live in `asap_sketchlib::sketches`. -/// This file retains QE-specific trait impls and JSON output. +/// KLL sketch accumulator — wraps sketch_core::KllSketch. +/// Core struct, update/merge/serde logic live in sketch-core. +/// This file retains QE-specific trait impls, legacy deserializers, and JSON output. pub struct DatasketchesKLLAccumulator { pub inner: KllSketch, } @@ -31,19 +31,42 @@ impl DatasketchesKLLAccumulator { } pub fn get_quantile(&self, quantile: f64) -> f64 { - self.inner.quantile(quantile) + self.inner.get_quantile(quantile) + } + + pub fn deserialize_from_json(data: &Value) -> Result> { + // Mirror Python implementation: expects {"sketch": base64_encoded_string} + let sketch_b64 = data["sketch"] + .as_str() + .ok_or("Missing or invalid 'sketch' field")?; + + let sketch_bytes = general_purpose::STANDARD + .decode(sketch_b64) + .map_err(|e| format!("Failed to decode base64 sketch data: {e}"))?; + + // TODO: remove this hardcoding once FlinkSketch serializes k in its output + Ok(Self { + inner: KllSketch::from_dsrs_bytes(&sketch_bytes, 200)?, + }) + } + + pub fn deserialize_from_bytes(buffer: &[u8]) -> Result> { + // Mirror Python implementation: deserialize sketch directly from bytes + // TODO: remove this hardcoding once FlinkSketch serializes k in its output + Ok(Self { + inner: KllSketch::from_dsrs_bytes(buffer, 200)?, + }) } pub fn deserialize_from_bytes_arroyo( buffer: &[u8], ) -> Result> { debug!( - "Deserializing DatasketchesKLLAccumulator from MessagePack buffer of size {}", + "Deserializing DatasketchesKLLAccumulator from Arroyo MessagePack buffer of size {}", buffer.len() ); Ok(Self { - inner: KllSketch::deserialize_msgpack(buffer) - .map_err(|e| -> Box { e.to_string().into() })?, + inner: KllSketch::deserialize_msgpack(buffer)?, }) } @@ -113,7 +136,7 @@ impl SerializableToSink for DatasketchesKLLAccumulator { } fn serialize_to_bytes(&self) -> Vec { - self.inner.serialize_msgpack().unwrap_or_default() + self.inner.serialize_msgpack() } } @@ -232,12 +255,11 @@ impl MergeableAccumulator for DatasketchesKLLAccumul if accumulators.is_empty() { return Err("No accumulators to merge".into()); } - let mut iter = accumulators.into_iter(); - let mut merged = iter.next().unwrap(); - for acc in iter { - merged.inner.merge(&acc.inner)?; - } - Ok(merged) + let inners: Vec = accumulators.into_iter().map(|acc| acc.inner).collect(); + let merged_inner = KllSketch::merge(inners)?; + Ok(Self { + inner: merged_inner, + }) } } diff --git a/asap-query-engine/src/precompute_operators/delta_set_aggregator_accumulator.rs b/asap-query-engine/src/precompute_operators/delta_set_aggregator_accumulator.rs index f323426..e8b1b1b 100644 --- a/asap-query-engine/src/precompute_operators/delta_set_aggregator_accumulator.rs +++ b/asap-query-engine/src/precompute_operators/delta_set_aggregator_accumulator.rs @@ -2,15 +2,15 @@ use crate::data_model::{ AggregateCore, AggregationType, KeyByLabelValues, MergeableAccumulator, MultipleSubpopulationAggregate, SerializableToSink, }; -use asap_sketchlib::sketches::delta_set_aggregator::{deserialize_msgpack, serialize_msgpack}; use serde_json::Value; +use sketch_core::delta_set_aggregator::{deserialize_msgpack, serialize_msgpack}; use std::collections::{HashMap, HashSet}; use promql_utilities::query_logics::enums::Statistic; /// Accumulator that tracks sets of added and removed keys. /// Used for delta aggregation to track changes in cardinality. -/// Wire format (DeltaResult) and msgpack serde live in `asap_sketchlib::sketches`. +/// Wire format (DeltaResult) and msgpack serde live in sketch-core. #[derive(Debug, Clone)] pub struct DeltaSetAggregatorAccumulator { pub added: HashSet, @@ -153,8 +153,7 @@ impl DeltaSetAggregatorAccumulator { buffer: &[u8], ) -> Result> { // Delegate to sketch-core canonical DeltaResult msgpack format - let delta = deserialize_msgpack(buffer) - .map_err(|e| -> Box { e.to_string().into() })?; + let delta = deserialize_msgpack(buffer)?; let mut added = HashSet::new(); for item in &delta.added { @@ -203,7 +202,7 @@ impl SerializableToSink for DeltaSetAggregatorAccumulator { .iter() .map(|key| key.to_semicolon_str()) .collect(); - serialize_msgpack(&added, &removed).unwrap_or_default() + serialize_msgpack(&added, &removed) } } diff --git a/asap-query-engine/src/precompute_operators/hydra_kll_accumulator.rs b/asap-query-engine/src/precompute_operators/hydra_kll_accumulator.rs index f33012d..0b2e924 100644 --- a/asap-query-engine/src/precompute_operators/hydra_kll_accumulator.rs +++ b/asap-query-engine/src/precompute_operators/hydra_kll_accumulator.rs @@ -5,14 +5,14 @@ use crate::{ }, KeyByLabelValues, }; -use asap_sketchlib::sketches::hydra_kll::HydraKllSketch; use base64::{engine::general_purpose, Engine as _}; +use sketch_core::hydra_kll::HydraKllSketch; use std::collections::HashMap; use promql_utilities::query_logics::enums::Statistic; -/// HydraKLL sketch accumulator — wraps asap_sketchlib::sketches::HydraKllSketch. -/// Core struct, update/merge/serde logic live in `asap_sketchlib::sketches`. +/// HydraKLL sketch accumulator — wraps sketch_core::HydraKllSketch. +/// Core struct, update/merge/serde logic live in sketch-core. /// This file retains QE-specific trait impls and JSON output. #[derive(Debug, Clone)] pub struct HydraKllSketchAccumulator { @@ -38,26 +38,25 @@ impl HydraKllSketchAccumulator { buffer: &[u8], ) -> Result> { Ok(Self { - inner: HydraKllSketch::deserialize_msgpack(buffer) - .map_err(|e| -> Box { e.to_string().into() })?, + inner: HydraKllSketch::deserialize_msgpack(buffer)?, }) } pub fn query_key(&self, key: &KeyByLabelValues, quantile: f64) -> f64 { - self.inner.quantile(&key.to_semicolon_str(), quantile) + self.inner.query(&key.to_semicolon_str(), quantile) } } impl SerializableToSink for HydraKllSketchAccumulator { fn serialize_to_json(&self) -> serde_json::Value { // Mirror Python implementation: {"sketch": base64_encoded_string} - let sketch_bytes = self.inner.serialize_msgpack().unwrap_or_default(); + let sketch_bytes = self.inner.serialize_msgpack(); let sketch_b64 = general_purpose::STANDARD.encode(&sketch_bytes); serde_json::json!({ "sketch": sketch_b64 }) } fn serialize_to_bytes(&self) -> Vec { - self.inner.serialize_msgpack().unwrap_or_default() + self.inner.serialize_msgpack() } } @@ -68,12 +67,11 @@ impl MergeableAccumulator for HydraKllSketchAccumulat if accumulators.is_empty() { return Err("No accumulators to merge".into()); } - let mut iter = accumulators.into_iter(); - let mut merged = iter.next().unwrap(); - for acc in iter { - merged.inner.merge(&acc.inner)?; - } - Ok(merged) + let inners: Vec = accumulators.into_iter().map(|acc| acc.inner).collect(); + let merged_inner = HydraKllSketch::merge(inners)?; + Ok(Self { + inner: merged_inner, + }) } } diff --git a/asap-query-engine/src/precompute_operators/set_aggregator_accumulator.rs b/asap-query-engine/src/precompute_operators/set_aggregator_accumulator.rs index 45b74d5..4ec46c5 100644 --- a/asap-query-engine/src/precompute_operators/set_aggregator_accumulator.rs +++ b/asap-query-engine/src/precompute_operators/set_aggregator_accumulator.rs @@ -2,14 +2,14 @@ use crate::data_model::{ AggregateCore, AggregationType, KeyByLabelValues, MergeableAccumulator, MultipleSubpopulationAggregate, SerializableToSink, }; -use asap_sketchlib::sketches::set_aggregator::SetAggregator; use serde_json::Value; +use sketch_core::set_aggregator::SetAggregator; use std::collections::{HashMap, HashSet}; use promql_utilities::query_logics::enums::Statistic; -/// Set aggregator accumulator — wraps asap_sketchlib::sketches::SetAggregator. -/// Core struct, merge/serde logic live in `asap_sketchlib::sketches`. +/// Set aggregator accumulator — wraps sketch_core::SetAggregator. +/// Core struct, merge/serde logic live in sketch-core. /// This file retains QE-specific trait impls, KeyByLabelValues conversion, /// and legacy deserializers. #[derive(Debug, Clone)] @@ -92,8 +92,7 @@ impl SetAggregatorAccumulator { pub fn deserialize_from_bytes_arroyo( buffer: &[u8], ) -> Result> { - let sa = SetAggregator::deserialize_msgpack(buffer) - .map_err(|e| -> Box { e.to_string().into() })?; + let sa = SetAggregator::deserialize_msgpack(buffer)?; let added = sa .values .into_iter() @@ -107,9 +106,9 @@ impl SetAggregatorAccumulator { pub fn serialize_to_bytes_arroyo(&self) -> Vec { let mut sa = SetAggregator::new(); for key in &self.added { - sa.update(&key.to_semicolon_str()); + sa.insert(&key.to_semicolon_str()); } - sa.serialize_msgpack().unwrap_or_default() + sa.serialize_msgpack() } } diff --git a/asap-query-engine/tests/e2e_precompute_equivalence.rs b/asap-query-engine/tests/e2e_precompute_equivalence.rs index b8c6953..804884c 100644 --- a/asap-query-engine/tests/e2e_precompute_equivalence.rs +++ b/asap-query-engine/tests/e2e_precompute_equivalence.rs @@ -1,18 +1,18 @@ //! End-to-end integration tests: precompute engine output equivalence -//! with the wire-format sketch encoding. +//! with ArroYo sketch format. //! //! Each test: //! 1. Starts a PrecomputeEngine backed by a CapturingOutputSink //! 2. Sends Prometheus remote write samples via HTTP (Snappy-compressed protobuf) //! 3. Advances the watermark past the window boundary to close it -//! 4. Drains captured outputs and verifies equivalence with wire-format accumulators +//! 4. Drains captured outputs and verifies equivalence with ArroYo-format accumulators -use asap_sketchlib::sketches::kll::KllSketch; use asap_types::aggregation_config::AggregationConfig; use asap_types::enums::{AggregationType, WindowType}; use flate2::{write::GzEncoder, Compression}; use prost::Message; use serde_json::json; +use sketch_core::kll::KllSketch; use std::collections::HashMap; use std::io::Write; use std::sync::Arc; @@ -162,10 +162,10 @@ fn gzip_hex(bytes: &[u8]) -> String { hex::encode(encoder.finish().unwrap()) } -// ─── test 1: DatasketchesKLL output matches wire-format KLL ───────────────── +// ─── test 1: DatasketchesKLL output matches ArroYo KLL ────────────────────── /// Full e2e: send KLL samples through the HTTP ingest → PrecomputeEngine stack, -/// then verify the emitted DatasketchesKLLAccumulator matches what the wire-format +/// then verify the emitted DatasketchesKLLAccumulator matches what ArroYo's /// KllSketch::aggregate_kll would produce for the same values. #[tokio::test] async fn e2e_kll_output_matches_arroyo() { @@ -240,7 +240,7 @@ async fn e2e_kll_output_matches_arroyo() { .downcast_ref::() .expect("captured accumulator should be DatasketchesKLLAccumulator"); - // Build the wire-format equivalent and deserialize it + // Build the ArroYo-format equivalent and deserialize it let arroyo_bytes = KllSketch::aggregate_kll(k, &values).expect("KllSketch::aggregate_kll failed"); let arroyo_json = json!({ @@ -252,11 +252,11 @@ async fn e2e_kll_output_matches_arroyo() { let streaming_config_for_deser = StreamingConfig::new(agg_map); let (_arroyo_output, arroyo_acc_box) = PrecomputedOutput::deserialize_from_json_arroyo(&arroyo_json, &streaming_config_for_deser) - .expect("wire-format KLL deserialization failed"); + .expect("ArroYo KLL deserialization failed"); let arroyo_acc = arroyo_acc_box .as_any() .downcast_ref::() - .expect("wire-format payload should deserialize to DatasketchesKLLAccumulator"); + .expect("ArroYo payload should deserialize to DatasketchesKLLAccumulator"); // Window metadata assert_eq!(handcrafted_output.aggregation_id, agg_id); @@ -282,11 +282,11 @@ async fn e2e_kll_output_matches_arroyo() { } } -// ─── test 2: MultipleSum output matches wire-format MultipleSum ───────────── +// ─── test 2: MultipleSum output matches ArroYo MultipleSum ────────────────── /// Full e2e: send MultipleSum samples (grouped by "host") through the HTTP /// ingest → PrecomputeEngine stack, then verify the emitted -/// MultipleSumAccumulator matches the wire-format MessagePack-encoded sums map. +/// MultipleSumAccumulator matches the ArroYo MessagePack-encoded sums map. #[tokio::test] async fn e2e_multiple_sum_output_matches_arroyo() { let port = 19401u16; @@ -352,7 +352,7 @@ async fn e2e_multiple_sum_output_matches_arroyo() { .downcast_ref::() .expect("captured accumulator should be MultipleSumAccumulator"); - // Build the wire-format equivalent and deserialize it + // Build the ArroYo-format equivalent and deserialize it let mut expected_sums: HashMap = HashMap::new(); expected_sums.insert("A".to_string(), 6.0); let arroyo_bytes = rmp_serde::to_vec(&expected_sums).expect("msgpack encoding failed"); @@ -365,11 +365,11 @@ async fn e2e_multiple_sum_output_matches_arroyo() { let streaming_config_for_deser = StreamingConfig::new(agg_map); let (_arroyo_output, arroyo_acc_box) = PrecomputedOutput::deserialize_from_json_arroyo(&arroyo_json, &streaming_config_for_deser) - .expect("wire-format MultipleSum deserialization failed"); + .expect("ArroYo MultipleSum deserialization failed"); let arroyo_acc = arroyo_acc_box .as_any() .downcast_ref::() - .expect("wire-format payload should deserialize to MultipleSumAccumulator"); + .expect("ArroYo payload should deserialize to MultipleSumAccumulator"); // Window metadata assert_eq!(handcrafted_output.aggregation_id, agg_id); diff --git a/asap-query-engine/tests/test_both_backends.rs b/asap-query-engine/tests/test_both_backends.rs new file mode 100644 index 0000000..5643756 --- /dev/null +++ b/asap-query-engine/tests/test_both_backends.rs @@ -0,0 +1,30 @@ +//! Integration test that runs the library test suite with the sketchlib backend. +//! +//! When you run `cargo test -p query_engine_rust` (without --features sketchlib-tests), +//! the lib tests run with the legacy backend. This test spawns a second run with the +//! sketchlib backend so both modes are exercised in one `cargo test` invocation. +//! +//! This test is only compiled when sketchlib-tests is NOT enabled, to avoid recursion. + +#[cfg(not(feature = "sketchlib-tests"))] +#[test] +fn test_sketchlib_backend() { + use std::process::Command; + + let status = Command::new(env!("CARGO")) + .args([ + "test", + "-p", + "query_engine_rust", + "--lib", + "--features", + "sketchlib-tests", + ]) + .status() + .expect("failed to spawn cargo test"); + + assert!( + status.success(), + "sketchlib backend tests failed (run `cargo test -p query_engine_rust --lib --features sketchlib-tests` for details)" + ); +}