From d8c377ca8d52845c2ded8add592e48bf361bd12e Mon Sep 17 00:00:00 2001 From: Jimmy Lu Date: Wed, 28 Sep 2022 09:39:37 -0700 Subject: [PATCH] Change the intermediate type of approx_percentile to `ROW` (#18386) Summary: X-link: https://github.com/prestodb/presto/pull/18386 Pull Request resolved: https://github.com/facebookincubator/velox/pull/2621 Before this change, intermediate aggregation node for `approx_percentile` in worker only sees a type signature of `VARBINARY -> VARBINARY`, thus not be able to figure out the actual value type and cannot perform the merge properly when the value type is not `DOUBLE`. This fix changes the intermediate result type to `ROW` so that we can get the value type from it and instantiate accumulator with proper type. Fix https://github.com/facebookincubator/velox/issues/2430 Reviewed By: mbasmanova Differential Revision: D39733152 fbshipit-source-id: f7476a9a4a2e92d7fdc976f388a07692787cada8 --- velox/docs/develop/aggregate-functions.rst | 15 + velox/functions/lib/KllSketch.h | 44 +-- .../aggregates/ApproxPercentileAggregate.cpp | 329 +++++++++++------- .../aggregates/tests/AggregationTestBase.cpp | 12 +- .../aggregates/tests/ApproxPercentileTest.cpp | 70 ++-- 5 files changed, 274 insertions(+), 196 deletions(-) diff --git a/velox/docs/develop/aggregate-functions.rst b/velox/docs/develop/aggregate-functions.rst index 1840960d9690..303bc5e6f060 100644 --- a/velox/docs/develop/aggregate-functions.rst +++ b/velox/docs/develop/aggregate-functions.rst @@ -580,3 +580,18 @@ To confirm that aggregate function works end to end as part of query, update tes .. code-block:: java assertQuery("SELECT orderkey, array_agg(linenumber) FROM lineitem GROUP BY 1"); + +Overwrite Intermediate Type in Presto +------------------------------------- + +Sometimes we need to change the intermediate type of aggregation function in +Presto, due to the differences in implementation or in the type information +worker node receives. This is done in Presto class +``com.facebook.presto.metadata.BuiltInTypeAndFunctionNamespaceManager``. When +``FeaturesConfig.isUseAlternativeFunctionSignatures()`` is enabled, we can +register a different set of function signatures used specifically by Velox. An +example of how to create such alternative function signatures from scratch can +be found in +``com.facebook.presto.operator.aggregation.AlternativeApproxPercentile``. An +example pull request can be found at +https://github.com/prestodb/presto/pull/18386. diff --git a/velox/functions/lib/KllSketch.h b/velox/functions/lib/KllSketch.h index 1d319c43cfb3..e9e6e2b90c75 100644 --- a/velox/functions/lib/KllSketch.h +++ b/velox/functions/lib/KllSketch.h @@ -140,26 +140,6 @@ struct KllSketch { /// Get frequencies of items being tracked. The result is sorted by item. std::vector> getFrequencies() const; - private: - KllSketch(const Allocator&, uint32_t seed); - void doInsert(T); - uint32_t insertPosition(); - int findLevelToCompact() const; - void addEmptyTopLevelToCompletelyFullSketch(); - void shiftItems(uint32_t delta); - - uint8_t numLevels() const { - return levels_.size() - 1; - } - - uint32_t getNumRetained() const { - return levels_.back() - levels_[0]; - } - - uint32_t safeLevelSize(uint8_t level) const { - return level < numLevels() ? levels_[level + 1] - levels_[level] : 0; - } - struct View { uint32_t k; size_t n; @@ -179,10 +159,32 @@ struct KllSketch { void deserialize(const char* FOLLY_NONNULL); }; + void mergeViews(const folly::Range& views); + View toView() const; + + private: + KllSketch(const Allocator&, uint32_t seed); + void doInsert(T); + uint32_t insertPosition(); + int findLevelToCompact() const; + void addEmptyTopLevelToCompletelyFullSketch(); + void shiftItems(uint32_t delta); + + uint8_t numLevels() const { + return levels_.size() - 1; + } + + uint32_t getNumRetained() const { + return levels_.back() - levels_[0]; + } + + uint32_t safeLevelSize(uint8_t level) const { + return level < numLevels() ? levels_[level + 1] - levels_[level] : 0; + } + static KllSketch fromView(const View&, const Allocator&, uint32_t seed); - void mergeViews(const folly::Range& views); using AllocU32 = typename std::allocator_traits< Allocator>::template rebind_alloc; diff --git a/velox/functions/prestosql/aggregates/ApproxPercentileAggregate.cpp b/velox/functions/prestosql/aggregates/ApproxPercentileAggregate.cpp index ef6c6458553b..85b6673c5d12 100644 --- a/velox/functions/prestosql/aggregates/ApproxPercentileAggregate.cpp +++ b/velox/functions/prestosql/aggregates/ApproxPercentileAggregate.cpp @@ -66,12 +66,12 @@ struct KllSketchAccumulator { } } - void append(const char* deserializedSketch) { - sketch_.mergeDeserialized(deserializedSketch); + void append(const typename KllSketch::View& view) { + sketch_.mergeViews(folly::Range(&view, 1)); } - void append(const std::vector& sketches) { - sketch_.mergeDeserialized(folly::Range(sketches.begin(), sketches.end())); + void append(const std::vector::View>& views) { + sketch_.mergeViews(views); } void finalize() { @@ -81,7 +81,7 @@ struct KllSketchAccumulator { sketch_.compact(); } - const KllSketch& getSketch() { + const KllSketch& getSketch() const { return sketch_; } @@ -104,11 +104,18 @@ struct KllSketchAccumulator { } }; -// The following variations are possible: -// x, percentile -// x, weight, percentile -// x, percentile, accuracy -// x, weight, percentile, accuracy +enum IntermediateTypeChildIndex { + kPercentiles = 0, + kPercentilesIsArray = 1, + kAccuracy = 2, + kK = 3, + kN = 4, + kMinValue = 5, + kMaxValue = 6, + kItems = 7, + kLevels = 8, +}; + template class ApproxPercentileAggregate : public exec::Aggregate { public: @@ -198,15 +205,103 @@ class ApproxPercentileAggregate : public exec::Aggregate { void extractAccumulators(char** groups, int32_t numGroups, VectorPtr* result) override { VELOX_CHECK(result); - auto flatResult = (*result)->asFlatVector(); + auto rowResult = (*result)->as(); + VELOX_CHECK(rowResult); + auto pool = rowResult->pool(); - extract( - groups, + if (percentiles_) { + auto& values = percentiles_->values; + auto size = values.size(); + auto elements = + BaseVector::create>(DOUBLE(), size, pool); + std::copy(values.begin(), values.end(), elements->mutableRawValues()); + auto array = std::make_shared( + pool, + ARRAY(DOUBLE()), + nullptr, + 1, + AlignedBuffer::allocate(1, pool, 0), + AlignedBuffer::allocate(1, pool, size), + std::move(elements)); + rowResult->childAt(kPercentiles) = + BaseVector::wrapInConstant(numGroups, 0, std::move(array)); + rowResult->childAt(kPercentilesIsArray) = + std::make_shared>( + pool, numGroups, false, bool(percentiles_->isArray)); + } else { + rowResult->childAt(kPercentiles) = BaseVector::wrapInConstant( + numGroups, + 0, + std::make_shared( + pool, + ARRAY(DOUBLE()), + AlignedBuffer::allocate(1, pool, bits::kNull), + 1, + AlignedBuffer::allocate(1, pool, 0), + AlignedBuffer::allocate(1, pool, 0), + nullptr)); + rowResult->childAt(kPercentilesIsArray) = + std::make_shared>(pool, numGroups, true, false); + } + rowResult->childAt(kAccuracy) = std::make_shared>( + pool, numGroups, - flatResult, - [&](const KllSketch& digest, - FlatVector* result, - vector_size_t index) { serializeDigest(digest, result, index); }); + accuracy_ == kMissingNormalizedValue, + double(accuracy_)); + auto k = rowResult->childAt(kK)->asFlatVector(); + auto n = rowResult->childAt(kN)->asFlatVector(); + auto minValue = rowResult->childAt(kMinValue)->asFlatVector(); + auto maxValue = rowResult->childAt(kMaxValue)->asFlatVector(); + auto items = rowResult->childAt(kItems)->as(); + auto levels = rowResult->childAt(kLevels)->as(); + + rowResult->resize(numGroups); + k->resize(numGroups); + n->resize(numGroups); + minValue->resize(numGroups); + maxValue->resize(numGroups); + items->resize(numGroups); + levels->resize(numGroups); + + auto itemsElements = items->elements()->asFlatVector(); + auto levelsElements = levels->elements()->asFlatVector(); + size_t itemsCount = 0; + vector_size_t levelsCount = 0; + for (int i = 0; i < numGroups; ++i) { + auto accumulator = value>(groups[i]); + auto v = accumulator->getSketch().toView(); + itemsCount += v.items.size(); + levelsCount += v.levels.size(); + } + VELOX_CHECK_LE(itemsCount, std::numeric_limits::max()); + itemsElements->resetNulls(); + itemsElements->resize(itemsCount); + levelsElements->resetNulls(); + levelsElements->resize(levelsCount); + + auto rawItems = itemsElements->mutableRawValues(); + auto rawLevels = levelsElements->mutableRawValues(); + itemsCount = 0; + levelsCount = 0; + for (int i = 0; i < numGroups; ++i) { + auto accumulator = value>(groups[i]); + auto v = accumulator->getSketch().toView(); + if (v.n == 0) { + rowResult->setNull(i, true); + } else { + rowResult->setNull(i, false); + k->set(i, v.k); + n->set(i, v.n); + minValue->set(i, v.minValue); + maxValue->set(i, v.maxValue); + std::copy(v.items.begin(), v.items.end(), rawItems + itemsCount); + items->setOffsetAndSize(i, itemsCount, v.items.size()); + itemsCount += v.items.size(); + std::copy(v.levels.begin(), v.levels.end(), rawLevels + levelsCount); + levels->setOffsetAndSize(i, levelsCount, v.levels.size()); + levelsCount += v.levels.size(); + } + } } void addRawInput( @@ -256,20 +351,7 @@ class ApproxPercentileAggregate : public exec::Aggregate { const SelectivityVector& rows, const std::vector& args, bool /*mayPushdown*/) override { - decodedDigest_.decode(*args[0], rows, true); - - rows.applyToSelected([&](auto row) { - if (decodedDigest_.isNullAt(row)) { - return; - } - auto tracker = trackRowSize(groups[row]); - auto accumulator = value>(groups[row]); - auto digest = getDeserializedDigest(row); - if (accuracy_ != kMissingNormalizedValue) { - accumulator->setAccuracy(accuracy_); - } - accumulator->append(digest); - }); + addIntermediate(groups, rows, args); } void addSingleGroupRawInput( @@ -318,26 +400,7 @@ class ApproxPercentileAggregate : public exec::Aggregate { const SelectivityVector& rows, const std::vector& args, bool /*mayPushdown*/) override { - decodedDigest_.decode(*args[0], rows, true); - auto accumulator = value>(group); - - auto tracker = trackRowSize(group); - std::vector digests; - digests.reserve(rows.end()); - - rows.applyToSelected([&](auto row) { - if (decodedDigest_.isNullAt(row)) { - return; - } - digests.push_back(getDeserializedDigest(row)); - }); - - if (!digests.empty()) { - if (accuracy_ != kMissingNormalizedValue) { - accumulator->setAccuracy(accuracy_); - } - accumulator->append(digests); - } + addIntermediate(group, rows, args); } private: @@ -470,66 +533,94 @@ class ApproxPercentileAggregate : public exec::Aggregate { KllSketchAccumulator* initRawAccumulator(char* group) { auto accumulator = value>(group); - if (hasAccuracy_) { + if (accuracy_ != kMissingNormalizedValue) { accumulator->setAccuracy(accuracy_); } return accumulator; } - void serializeDigest( - const KllSketch& digest, - FlatVector* result, - vector_size_t index) { - auto size = sizeof(int32_t) + sizeof(double) * percentiles_->values.size() + - sizeof accuracy_ + digest.serializedByteSize(); - Buffer* buffer = result->getBufferWithSpace(size); - char* data = buffer->asMutable() + buffer->size(); - common::OutputByteStream stream(data); - if (percentiles_) { - stream.appendOne( - percentiles_->isArray ? percentiles_->values.size() : -1); - for (double p : percentiles_->values) { - stream.appendOne(p); - } - } else { - stream.appendOne(0); - } - stream.appendOne(accuracy_); - digest.serialize(data + stream.offset()); - buffer->setSize(buffer->size() + size); - result->setNoCopy(index, StringView(data, size)); - } - - const char* getDeserializedDigest(vector_size_t row) { - auto data = decodedDigest_.valueAt(row); - common::InputByteStream stream(data.data()); - auto percentileCount = stream.read(); - bool percentileIsArray; - if (percentileCount == -1) { - percentileIsArray = false; - percentileCount = 1; - } else { - percentileIsArray = true; - } - VELOX_DCHECK_GE(percentileCount, 0); - if (percentileCount > 0) { - checkSetPercentile( - percentileIsArray, - stream.read(percentileCount), - percentileCount); + template + void addIntermediate( + std::conditional_t group, + const SelectivityVector& rows, + const std::vector& args) { + VELOX_CHECK_EQ(args.size(), 1); + DecodedVector decoded(*args[0], rows); + auto rowVec = decoded.base()->as(); + VELOX_CHECK(rowVec); + DecodedVector percentiles(*rowVec->childAt(kPercentiles), rows); + auto percentileIsArray = + rowVec->childAt(kPercentilesIsArray)->asUnchecked>(); + auto accuracy = + rowVec->childAt(kAccuracy)->asUnchecked>(); + auto k = rowVec->childAt(kK)->asUnchecked>(); + auto n = rowVec->childAt(kN)->asUnchecked>(); + auto minValue = rowVec->childAt(kMinValue)->asUnchecked>(); + auto maxValue = rowVec->childAt(kMaxValue)->asUnchecked>(); + auto items = rowVec->childAt(kItems)->asUnchecked(); + auto levels = rowVec->childAt(kLevels)->asUnchecked(); + + auto rawItems = items->elements()->asFlatVector()->rawValues(); + auto rawLevels = + levels->elements()->asFlatVector()->rawValues(); + KllSketchAccumulator* accumulator = nullptr; + std::vector::View> views; + if constexpr (kSingleGroup) { + views.reserve(rows.end()); } - if (auto accuracy = stream.read(); - accuracy != kMissingNormalizedValue) { - checkSetAccuracy(accuracy); + rows.applyToSelected([&](auto row) { + if (decoded.isNullAt(row)) { + return; + } + int i = decoded.index(row); + if (percentileIsArray->isNullAt(i)) { + return; + } + if (!accumulator) { + int j = percentiles.index(i); + auto percentilesBase = percentiles.base()->asUnchecked(); + auto rawPercentiles = + percentilesBase->elements()->asFlatVector()->rawValues(); + checkSetPercentile( + percentileIsArray->valueAt(i), + rawPercentiles + percentilesBase->offsetAt(j), + percentilesBase->sizeAt(j)); + if (!accuracy->isNullAt(i)) { + checkSetAccuracy(accuracy->valueAt(i)); + } + } + if constexpr (kSingleGroup) { + if (!accumulator) { + accumulator = initRawAccumulator(group); + } + } else { + accumulator = initRawAccumulator(group[row]); + } + typename KllSketch::View v{ + .k = static_cast(k->valueAt(i)), + .n = static_cast(n->valueAt(i)), + .minValue = minValue->valueAt(i), + .maxValue = maxValue->valueAt(i), + .items = + {rawItems + items->offsetAt(i), + static_cast(items->sizeAt(i))}, + .levels = + {rawLevels + levels->offsetAt(i), + static_cast(levels->sizeAt(i))}, + }; + if constexpr (kSingleGroup) { + views.push_back(v); + } else { + auto tracker = trackRowSize(group[row]); + accumulator->append(v); + } + }); + if constexpr (kSingleGroup) { + if (!views.empty()) { + auto tracker = trackRowSize(group); + accumulator->append(views); + } } - // If 'data' is inline, this function will return a local - // address. Assert data is not inline. - VELOX_DCHECK(!data.isInline()); - // Some compilers cannot deduce that the StringView cannot be inline from - // the assert above. Suppress warning. - VELOX_SUPPRESS_RETURN_LOCAL_ADDR_WARNING - return data.data() + stream.offset(); - VELOX_UNSUPPRESS_RETURN_LOCAL_ADDR_WARNING } struct Percentiles { @@ -564,29 +655,32 @@ void addSignatures( const std::string& returnType, std::vector>& signatures) { + auto intermediateType = fmt::format( + "row(array(double), boolean, double, integer, bigint, {0}, {0}, array({0}), array(integer))", + inputType); signatures.push_back(exec::AggregateFunctionSignatureBuilder() .returnType(returnType) - .intermediateType("varbinary") + .intermediateType(intermediateType) .argumentType(inputType) .argumentType(percentileType) .build()); signatures.push_back(exec::AggregateFunctionSignatureBuilder() .returnType(returnType) - .intermediateType("varbinary") + .intermediateType(intermediateType) .argumentType(inputType) .argumentType("bigint") .argumentType(percentileType) .build()); signatures.push_back(exec::AggregateFunctionSignatureBuilder() .returnType(returnType) - .intermediateType("varbinary") + .intermediateType(intermediateType) .argumentType(inputType) .argumentType(percentileType) .argumentType("double") .build()); signatures.push_back(exec::AggregateFunctionSignatureBuilder() .returnType(returnType) - .intermediateType("varbinary") + .intermediateType(intermediateType) .argumentType(inputType) .argumentType("bigint") .argumentType(percentileType) @@ -645,24 +739,19 @@ bool registerApproxPercentile(const std::string& name) { VELOX_USER_CHECK_EQ( argTypes.size(), 1, - "The type of partial result for {} must be VARBINARY", + "The type of partial result for {} must be ROW", name); - VELOX_USER_CHECK_GE( + VELOX_USER_CHECK_EQ( argTypes[0]->kind(), - TypeKind::VARBINARY, - "The type of partial result for {} must be VARBINARY", + TypeKind::ROW, + "The type of partial result for {} must be ROW", name); } - if (!isRawInput && exec::isPartialOutput(step)) { - // FIXME: This is not working for non-double type, issue #2430 to - // track this. - return std::make_unique>( - false, false, VARBINARY()); - } - TypePtr type; - if (isRawInput) { + if (!isRawInput && exec::isPartialOutput(step)) { + type = argTypes[0]->asRow().childAt(kMinValue); + } else if (isRawInput) { type = argTypes[0]; } else if (resultType->isArray()) { type = resultType->as().elementType(); diff --git a/velox/functions/prestosql/aggregates/tests/AggregationTestBase.cpp b/velox/functions/prestosql/aggregates/tests/AggregationTestBase.cpp index 03403de2f92a..2cbe844642d2 100644 --- a/velox/functions/prestosql/aggregates/tests/AggregationTestBase.cpp +++ b/velox/functions/prestosql/aggregates/tests/AggregationTestBase.cpp @@ -125,7 +125,7 @@ void AggregationTestBase::testAggregations( assertResults) { { SCOPED_TRACE("Run partial + final"); - PlanBuilder builder; + PlanBuilder builder(pool()); makeSource(builder); builder.partialAggregation(groupingKeys, aggregates).finalAggregation(); if (!postAggregationProjections.empty()) { @@ -138,7 +138,7 @@ void AggregationTestBase::testAggregations( { SCOPED_TRACE("Run single"); - PlanBuilder builder; + PlanBuilder builder(pool()); makeSource(builder); builder.singleAggregation(groupingKeys, aggregates); if (!postAggregationProjections.empty()) { @@ -151,7 +151,7 @@ void AggregationTestBase::testAggregations( if (!groupingKeys.empty() && allowInputShuffle_) { SCOPED_TRACE("Run partial + final with spilling"); - PlanBuilder builder; + PlanBuilder builder(pool()); makeSource(builder); // Spilling needs at least 2 batches of input. Use round-robin @@ -187,7 +187,7 @@ void AggregationTestBase::testAggregations( { SCOPED_TRACE("Run partial + intermediate + final"); - PlanBuilder builder; + PlanBuilder builder(pool()); makeSource(builder); builder.partialAggregation(groupingKeys, aggregates) @@ -204,7 +204,7 @@ void AggregationTestBase::testAggregations( if (!groupingKeys.empty()) { SCOPED_TRACE("Run partial + local exchange + final"); - PlanBuilder builder; + PlanBuilder builder(pool()); makeSource(builder); builder.partialAggregation(groupingKeys, aggregates) @@ -222,7 +222,7 @@ void AggregationTestBase::testAggregations( { SCOPED_TRACE( "Run partial + local exchange + intermediate + local exchange + final"); - PlanBuilder builder; + PlanBuilder builder(pool()); makeSource(builder); builder.partialAggregation(groupingKeys, aggregates); diff --git a/velox/functions/prestosql/aggregates/tests/ApproxPercentileTest.cpp b/velox/functions/prestosql/aggregates/tests/ApproxPercentileTest.cpp index 808a7fddd1a9..bb112b426174 100644 --- a/velox/functions/prestosql/aggregates/tests/ApproxPercentileTest.cpp +++ b/velox/functions/prestosql/aggregates/tests/ApproxPercentileTest.cpp @@ -58,9 +58,9 @@ class ApproxPercentileTest : public AggregationTestBase { void SetUp() override { AggregationTestBase::SetUp(); random::setSeed(0); + allowInputShuffle(); } - // TODO: Use `testAggregations` once issue #2430 is fixed. template void testGlobalAgg( const VectorPtr& values, @@ -73,39 +73,20 @@ class ApproxPercentileTest : public AggregationTestBase { weights != nullptr, percentile, accuracy)); - auto call = functionCall(false, weights.get(), percentile, accuracy, -1); auto rows = weights ? makeRowVector({values, weights}) : makeRowVector({values}); - auto op = - PlanBuilder().values({rows}).singleAggregation({}, {call}).planNode(); - { - SCOPED_TRACE("single_agg=false"); - assertQuery(op, fmt::format("SELECT {}", expectedResult)); - } - op = PlanBuilder() - .values({rows}) - .partialAggregation({}, {call}) - .finalAggregation() - .planNode(); - { - SCOPED_TRACE("single_agg=true"); - assertQuery(op, fmt::format("SELECT {}", expectedResult)); - } - call = functionCall(false, weights.get(), percentile, accuracy, 3); - op = PlanBuilder(pool()) - .values({rows}) - .partialAggregation({}, {call}) - .finalAggregation() - .planNode(); - { - SCOPED_TRACE("Percentile array"); - auto expected = makeRowVector( - {makeArrayVector({std::vector(3, expectedResult)})}); - assertQuery(op, expected); - } + testAggregations( + {rows}, + {}, + {functionCall(false, weights.get(), percentile, accuracy, -1)}, + fmt::format("SELECT {}", expectedResult)); + testAggregations( + {rows}, + {}, + {functionCall(false, weights.get(), percentile, accuracy, 3)}, + fmt::format("SELECT ARRAY[{0},{0},{0}]", expectedResult)); } - // TODO: Use `testAggregations` once issue #2430 is fixed. void testGroupByAgg( const VectorPtr& keys, const VectorPtr& values, @@ -113,26 +94,13 @@ class ApproxPercentileTest : public AggregationTestBase { double percentile, double accuracy, const RowVectorPtr& expectedResult) { - auto call = functionCall(true, weights.get(), percentile, accuracy, -1); auto rows = weights ? makeRowVector({keys, values, weights}) : makeRowVector({keys, values}); - auto op = PlanBuilder() - .values({rows}) - .singleAggregation({"c0"}, {call}) - .planNode(); - assertQuery(op, expectedResult); - op = PlanBuilder() - .values({rows}) - .partialAggregation({"c0"}, {call}) - .finalAggregation() - .planNode(); - assertQuery(op, expectedResult); - call = functionCall(true, weights.get(), percentile, accuracy, 3); - op = PlanBuilder(pool()) - .values({rows}) - .partialAggregation({"c0"}, {call}) - .finalAggregation() - .planNode(); + testAggregations( + {rows}, + {"c0"}, + {functionCall(true, weights.get(), percentile, accuracy, -1)}, + {expectedResult}); { SCOPED_TRACE("Percentile array"); auto resultValues = expectedResult->childAt(1); @@ -159,7 +127,11 @@ class ApproxPercentileTest : public AggregationTestBase { offsets, sizes, elements)}); - assertQuery(op, expected); + testAggregations( + {rows}, + {"c0"}, + {functionCall(true, weights.get(), percentile, accuracy, 3)}, + {expected}); } } };