Skip to content

Commit

Permalink
adjusted python wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderSaydakov committed Jan 26, 2022
1 parent 59b2924 commit a9c8d07
Showing 1 changed file with 21 additions and 8 deletions.
29 changes: 21 additions & 8 deletions python/src/kll_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,24 @@ double kll_sketch_get_rank(const kll_sketch<T>& sk, const T& item, bool inclusiv
return sk.template get_rank<false>(item);
}

template<typename T>
T kll_sketch_get_quantile(const kll_sketch<T>& sk,
double rank,
bool inclusive) {
if (inclusive)
return T(sk.template get_quantile<true>(rank));
else
return T(sk.template get_quantile<false>(rank));
}

template<typename T>
py::list kll_sketch_get_quantiles(const kll_sketch<T>& sk,
std::vector<double>& fractions) {
std::vector<double>& fractions,
bool inclusive) {
size_t nQuantiles = fractions.size();
auto result = sk.get_quantiles(&fractions[0], nQuantiles);
auto result = inclusive ?
sk.template get_quantiles<true>(fractions.data(), nQuantiles)
: sk.template get_quantiles<false>(fractions.data(), nQuantiles);

// returning as std::vector<> would copy values to a list anyway
py::list list(nQuantiles);
Expand All @@ -79,8 +92,8 @@ py::list kll_sketch_get_pmf(const kll_sketch<T>& sk,
bool inclusive) {
size_t nPoints = split_points.size();
auto result = inclusive ?
sk.template get_PMF<true>(&split_points[0], nPoints)
: sk.template get_PMF<false>(&split_points[0], nPoints);
sk.template get_PMF<true>(split_points.data(), nPoints)
: sk.template get_PMF<false>(split_points.data(), nPoints);

py::list list(nPoints + 1);
for (size_t i = 0; i <= nPoints; ++i) {
Expand All @@ -96,8 +109,8 @@ py::list kll_sketch_get_cdf(const kll_sketch<T>& sk,
bool inclusive) {
size_t nPoints = split_points.size();
auto result = inclusive ?
sk.template get_CDF<true>(&split_points[0], nPoints)
: sk.template get_CDF<false>(&split_points[0], nPoints);
sk.template get_CDF<true>(split_points.data(), nPoints)
: sk.template get_CDF<false>(split_points.data(), nPoints);

py::list list(nPoints + 1);
for (size_t i = 0; i <= nPoints; ++i) {
Expand Down Expand Up @@ -156,7 +169,7 @@ void bind_kll_sketch(py::module &m, const char* name) {
"Returns the minimum value from the stream. If empty, kll_floats_sketch retursn nan; kll_ints_sketch throws a RuntimeError")
.def("get_max_value", &kll_sketch<T>::get_max_value,
"Returns the maximum value from the stream. If empty, kll_floats_sketch retursn nan; kll_ints_sketch throws a RuntimeError")
.def("get_quantile", &kll_sketch<T>::get_quantile, py::arg("fraction"),
.def("get_quantile", &dspy::kll_sketch_get_quantile<T>, py::arg("fraction"), py::arg("inclusive")=false,
"Returns an approximation to the value of the data item "
"that would be preceded by the given fraction of a hypothetical sorted "
"version of the input stream so far.\n"
Expand All @@ -165,7 +178,7 @@ void bind_kll_sketch(py::module &m, const char* name) {
"sketch. Instead use get_quantiles(), which pays the overhead only once.\n"
"For kll_floats_sketch: if the sketch is empty this returns nan. "
"For kll_ints_sketch: if the sketch is empty this throws a RuntimeError.")
.def("get_quantiles", &dspy::kll_sketch_get_quantiles<T>, py::arg("fractions"),
.def("get_quantiles", &dspy::kll_sketch_get_quantiles<T>, py::arg("fractions"), py::arg("inclusive")=false,
"This is a more efficient multiple-query version of get_quantile().\n"
"This returns an array that could have been generated by using get_quantile() for each "
"fractional rank separately, but would be very inefficient. "
Expand Down

0 comments on commit a9c8d07

Please sign in to comment.