diff --git a/velox/docs/develop/expression-evaluation.rst b/velox/docs/develop/expression-evaluation.rst index a34eacac0606..4b5075ce2f36 100644 --- a/velox/docs/develop/expression-evaluation.rst +++ b/velox/docs/develop/expression-evaluation.rst @@ -79,6 +79,12 @@ try Handles errors generated by the input expression by returning nulls for the corresponding rows. +coalesce + COALESCE expression. Takes multiple input expressions of the same type. + + Returns the first non-null value in the argument list. Like an IF or SWITCH + expression, arguments are only evaluated if necessary. + When evaluating AND and OR expressions, the engine adaptively reorders the conjuncts to evaluate the cheapest most decisive conjuncts first. E.g. the AND expression chooses to evaluate the cheapest conjunct that returns FALSE most diff --git a/velox/expression/CMakeLists.txt b/velox/expression/CMakeLists.txt index 7f9a5029441a..3009a3e67d31 100644 --- a/velox/expression/CMakeLists.txt +++ b/velox/expression/CMakeLists.txt @@ -20,6 +20,7 @@ target_link_libraries(velox_expression_functions velox_common_base) add_library( velox_expression CastExpr.cpp + CoalesceExpr.cpp ControlExpr.cpp EvalCtx.cpp Expr.cpp diff --git a/velox/expression/CoalesceExpr.cpp b/velox/expression/CoalesceExpr.cpp new file mode 100644 index 000000000000..6561ae222e82 --- /dev/null +++ b/velox/expression/CoalesceExpr.cpp @@ -0,0 +1,66 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/expression/CoalesceExpr.h" +#include "velox/expression/VarSetter.h" + +namespace facebook::velox::exec { + +CoalesceExpr::CoalesceExpr(TypePtr type, std::vector&& inputs) + : SpecialForm(std::move(type), std::move(inputs), kCoalesce) { + for (auto i = 1; i < inputs_.size(); i++) { + VELOX_USER_CHECK_EQ( + inputs_[0]->type()->kind(), + inputs_[i]->type()->kind(), + "Inputs to coalesce must have the same type"); + } +} + +void CoalesceExpr::evalSpecialForm( + const SelectivityVector& rows, + EvalCtx* context, + VectorPtr* result) { + // Make sure to include current expression in the error message in case of an + // exception. + ExceptionContextSetter exceptionContext( + {[](auto* expr) { return static_cast(expr)->toString(); }, this}); + + // Null positions to populate. + exec::LocalSelectivityVector activeRowsHolder(context, rows.end()); + auto activeRows = activeRowsHolder.get(); + *activeRows = rows; + + // Fix finalSelection at "rows" unless already fixed. + VarSetter finalSelection( + context->mutableFinalSelection(), &rows, context->isFinalSelection()); + VarSetter isFinalSelection(context->mutableIsFinalSelection(), false); + + for (int i = 0; i < inputs_.size(); i++) { + inputs_[i]->eval(*activeRows, context, result); + + const uint64_t* rawNulls = (*result)->flatRawNulls(*activeRows); + if (!rawNulls) { + // No nulls left. + return; + } + + activeRows->deselectNonNulls(rawNulls, 0, activeRows->end()); + if (!activeRows->hasSelections()) { + // No nulls left. + return; + } + } +} +} // namespace facebook::velox::exec diff --git a/velox/expression/CoalesceExpr.h b/velox/expression/CoalesceExpr.h new file mode 100644 index 000000000000..181963cafd9b --- /dev/null +++ b/velox/expression/CoalesceExpr.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/expression/ControlExpr.h" + +namespace facebook::velox::exec { + +const char* const kCoalesce = "coalesce"; + +class CoalesceExpr : public SpecialForm { + public: + CoalesceExpr(TypePtr type, std::vector&& inputs); + + void evalSpecialForm( + const SelectivityVector& rows, + EvalCtx* context, + VectorPtr* result) override; + + bool propagatesNulls() const override { + return false; + } +}; +} // namespace facebook::velox::exec diff --git a/velox/expression/ExprCompiler.cpp b/velox/expression/ExprCompiler.cpp index 4f716ee562b0..31e3b4221cda 100644 --- a/velox/expression/ExprCompiler.cpp +++ b/velox/expression/ExprCompiler.cpp @@ -15,8 +15,8 @@ */ #include "velox/expression/ExprCompiler.h" -#include "velox/core/SimpleFunctionMetadata.h" #include "velox/expression/CastExpr.h" +#include "velox/expression/CoalesceExpr.h" #include "velox/expression/ControlExpr.h" #include "velox/expression/Expr.h" #include "velox/expression/SimpleFunctionRegistry.h" @@ -193,6 +193,9 @@ ExprPtr getSpecialForm( VELOX_CHECK_EQ(compiledChildren.size(), 1); return std::make_shared(type, std::move(compiledChildren[0])); } + if (name == kCoalesce) { + return std::make_shared(type, std::move(compiledChildren)); + } return nullptr; } diff --git a/velox/expression/tests/CMakeLists.txt b/velox/expression/tests/CMakeLists.txt index 9d94b5496781..14d91799e51b 100644 --- a/velox/expression/tests/CMakeLists.txt +++ b/velox/expression/tests/CMakeLists.txt @@ -17,6 +17,7 @@ add_executable( ExprTest.cpp ExprStatsTest.cpp CastExprTest.cpp + CoalesceTest.cpp MapWriterTest.cpp ArrayWriterTest.cpp RowWriterTest.cpp diff --git a/velox/functions/prestosql/tests/CoalesceTest.cpp b/velox/expression/tests/CoalesceTest.cpp similarity index 66% rename from velox/functions/prestosql/tests/CoalesceTest.cpp rename to velox/expression/tests/CoalesceTest.cpp index a2f232733cb4..8c32645dce2f 100644 --- a/velox/functions/prestosql/tests/CoalesceTest.cpp +++ b/velox/expression/tests/CoalesceTest.cpp @@ -39,23 +39,36 @@ TEST_F(CoalesceTest, basic) { auto row = makeRowVector({first, second, third}); auto result = evaluate>("coalesce(c0, c1, c2)", row); - for (int i = 0; i < size; ++i) { - EXPECT_EQ(result->valueAt(i), i * pow(10, i % 3)) << "at " << i; - } + auto expectedResult = makeFlatVector( + size, [](auto row) { return row * pow(10, row % 3); }); + assertEqualVectors(expectedResult, result); + + // Verify that input expressions are evaluated only on rows that are still + // null after evaluating all the preceding inputs and not evaluated at all if + // there are no nulls remaining. + + // The last expression 'c1 / 0' should not be evaluated. + result = evaluate>( + "coalesce(c0, c1, c2, cast(c1 / 0 as integer))", row); + assertEqualVectors(expectedResult, result); + + // The second expression 'c1 / (c1 % 3)' should not be evaluated on rows where + // c1 % 3 is zero. + result = evaluate>( + "coalesce(c0, cast(c1 / (c1 % 3) as integer), c2)", row); + assertEqualVectors(expectedResult, result); result = evaluate>("coalesce(c2, c1, c0)", row); - for (int i = 0; i < size; ++i) { - EXPECT_EQ(result->valueAt(i), i * 100) << "at " << i; - } + expectedResult = + makeFlatVector(size, [](auto row) { return row * 100; }); + assertEqualVectors(expectedResult, result); result = evaluate>("coalesce(c0, c1)", row); - for (int i = 0; i < size; ++i) { - if (i % 3 == 2) { - EXPECT_TRUE(result->isNullAt(i)); - } else { - EXPECT_EQ(result->valueAt(i), i * pow(10, i % 3)) << "at " << i; - } - } + expectedResult = makeFlatVector( + size, + [](auto row) { return row * pow(10, row % 3); }, + [](auto row) { return row % 3 == 2; }); + assertEqualVectors(expectedResult, result); } TEST_F(CoalesceTest, strings) { diff --git a/velox/expression/tests/TryExprTest.cpp b/velox/expression/tests/TryExprTest.cpp index 016d4234ad0c..5d93e1eff7f4 100644 --- a/velox/expression/tests/TryExprTest.cpp +++ b/velox/expression/tests/TryExprTest.cpp @@ -82,14 +82,14 @@ TEST_F(TryExprTest, nestedTryChildErrors) { // throw. auto flatVector = makeFlatVector( 5, [&](auto row) { return row % 2 == 0 ? "1" : "a"; }); - auto result = evaluate>( - "try(coalesce(try(cast(c0 as integer)), 3))", + auto result = evaluate>( + "try(coalesce(try(cast(c0 as integer)), cast(3 as integer)))", makeRowVector({flatVector})); assertEqualVectors( // Every other row throws an exception, which should get caught and // coalesced to 3. - makeFlatVector({1, 3, 1, 3, 1}), + makeFlatVector({1, 3, 1, 3, 1}), result); } @@ -104,12 +104,12 @@ TEST_F(TryExprTest, nestedTryParentErrors) { size, [&](auto row) { return row % 3 == 0 ? "a" : "1"; }); auto col1 = makeFlatVector( size, [&](auto row) { return row % 3 == 1 ? "a" : "1"; }); - auto result = evaluate>( - "try(cast(c1 as integer) + coalesce(try(cast(c0 as integer)), 3))", + auto result = evaluate>( + "try(cast(c1 as integer) + coalesce(try(cast(c0 as integer)), cast(3 as integer)))", makeRowVector({col0, col1})); assertEqualVectors( - makeFlatVector( + makeFlatVector( size, [&](auto row) { if (row % 3 == 0) { diff --git a/velox/functions/prestosql/CMakeLists.txt b/velox/functions/prestosql/CMakeLists.txt index 2b92254d92dc..9ec732f5d93c 100644 --- a/velox/functions/prestosql/CMakeLists.txt +++ b/velox/functions/prestosql/CMakeLists.txt @@ -24,7 +24,6 @@ add_library( ArrayIntersectExcept.cpp ArrayMinMax.cpp ArrayPosition.cpp - Coalesce.cpp ElementAt.cpp FilterFunctions.cpp FromUnixTime.cpp diff --git a/velox/functions/prestosql/Coalesce.cpp b/velox/functions/prestosql/Coalesce.cpp deleted file mode 100644 index daca34db18c1..000000000000 --- a/velox/functions/prestosql/Coalesce.cpp +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "velox/expression/EvalCtx.h" -#include "velox/expression/VectorFunction.h" - -namespace facebook::velox::functions { -namespace { -class CoalesceFunction : public exec::VectorFunction { - public: - bool isDefaultNullBehavior() const override { - return false; - } - - void apply( - const SelectivityVector& rows, - std::vector& args, - const TypePtr& /* outputType */, - exec::EvalCtx* context, - VectorPtr* result) const override { - BaseVector::ensureWritable(rows, args[0]->type(), args[0]->pool(), result); - - // null positions to populate - exec::LocalSelectivityVector activeRowsHolder(context, rows.end()); - auto activeRows = activeRowsHolder.get(); - *activeRows = rows; - - // positions to be copied from the next argument - exec::LocalSelectivityVector copyRowsHolder(context, rows.end()); - auto copyRows = copyRowsHolder.get(); - for (int i = 0; i < args.size(); i++) { - auto& arg = args[i]; - const uint64_t* rawNulls = arg->flatRawNulls(*activeRows); - if (!rawNulls) { - (*result)->copy(arg.get(), *activeRows, nullptr); - return; // no nulls left - } - - if (i == 0) { - // initialize result by copying all rows from the first argument - (*result)->copy(arg.get(), *activeRows, nullptr); - } else { - *copyRows = *activeRows; - copyRows->deselectNulls(rawNulls, 0, activeRows->end()); - if (copyRows->hasSelections()) { - (*result)->copy(arg.get(), *copyRows, nullptr); - } else { - continue; - } - } - - activeRows->deselectNonNulls(rawNulls, 0, activeRows->end()); - if (!activeRows->hasSelections()) { - // no nulls left - return; - } - } - } - - static std::vector> signatures() { - // T... -> T - return {exec::FunctionSignatureBuilder() - .typeVariable("T") - .returnType("T") - .argumentType("T") - .variableArity() - .build()}; - } -}; -} // namespace - -VELOX_DECLARE_VECTOR_FUNCTION( - udf_coalesce, - CoalesceFunction::signatures(), - std::make_unique()); -} // namespace facebook::velox::functions diff --git a/velox/functions/prestosql/registration/GeneralFunctionsRegistration.cpp b/velox/functions/prestosql/registration/GeneralFunctionsRegistration.cpp index f81328e65619..f6e3c3734c54 100644 --- a/velox/functions/prestosql/registration/GeneralFunctionsRegistration.cpp +++ b/velox/functions/prestosql/registration/GeneralFunctionsRegistration.cpp @@ -23,7 +23,6 @@ void registerGeneralFunctions() { VELOX_REGISTER_VECTOR_FUNCTION(udf_subscript, "subscript"); VELOX_REGISTER_VECTOR_FUNCTION(udf_transform, "transform"); VELOX_REGISTER_VECTOR_FUNCTION(udf_reduce, "reduce"); - VELOX_REGISTER_VECTOR_FUNCTION(udf_coalesce, "coalesce"); VELOX_REGISTER_VECTOR_FUNCTION(udf_is_null, "is_null"); VELOX_REGISTER_VECTOR_FUNCTION(udf_in, "in"); VELOX_REGISTER_VECTOR_FUNCTION(udf_array_filter, "filter"); diff --git a/velox/functions/prestosql/tests/CMakeLists.txt b/velox/functions/prestosql/tests/CMakeLists.txt index a9072cc86822..2ad1596b8ec1 100644 --- a/velox/functions/prestosql/tests/CMakeLists.txt +++ b/velox/functions/prestosql/tests/CMakeLists.txt @@ -34,7 +34,6 @@ add_executable( BitwiseTest.cpp CardinalityTest.cpp CeilFloorTest.cpp - CoalesceTest.cpp ComparisonsTest.cpp DateTimeFunctionsTest.cpp ElementAtTest.cpp diff --git a/velox/functions/sparksql/Register.cpp b/velox/functions/sparksql/Register.cpp index 91d35792dd9e..ae447d70b1e4 100644 --- a/velox/functions/sparksql/Register.cpp +++ b/velox/functions/sparksql/Register.cpp @@ -55,7 +55,6 @@ static void workAroundRegistrationMacro(const std::string& prefix) { VELOX_REGISTER_VECTOR_FUNCTION(udf_replace, prefix + "replace"); VELOX_REGISTER_VECTOR_FUNCTION(udf_upper, prefix + "upper"); // Logical. - VELOX_REGISTER_VECTOR_FUNCTION(udf_coalesce, prefix + "coalesce"); VELOX_REGISTER_VECTOR_FUNCTION(udf_is_null, prefix + "isnull"); VELOX_REGISTER_VECTOR_FUNCTION(udf_is_not_null, prefix + "isnotnull"); VELOX_REGISTER_VECTOR_FUNCTION(udf_not, prefix + "not"); diff --git a/velox/parse/TypeResolver.cpp b/velox/parse/TypeResolver.cpp index c14500fa2897..bc15e9681f13 100644 --- a/velox/parse/TypeResolver.cpp +++ b/velox/parse/TypeResolver.cpp @@ -49,7 +49,8 @@ std::shared_ptr resolveType( return BOOLEAN(); } - if (expr->getFunctionName() == "try") { + if (expr->getFunctionName() == "try" || + expr->getFunctionName() == "coalesce") { VELOX_CHECK(!inputs.empty()); return inputs.front()->type(); } diff --git a/velox/vector/FlatVector.h b/velox/vector/FlatVector.h index 964a15abe543..3dd12847e189 100644 --- a/velox/vector/FlatVector.h +++ b/velox/vector/FlatVector.h @@ -202,7 +202,7 @@ class FlatVector final : public SimpleVector { T* mutableRawValues() { if (!values_ || !values_->unique()) { BufferPtr newValues = - AlignedBuffer::allocate(BaseVector::length_, values_->pool()); + AlignedBuffer::allocate(BaseVector::length_, BaseVector::pool()); if (values_) { // This codepath is not yet enabled for OPAQUE types (asMutable will // fail below)