diff --git a/docs/en/sql-reference/functions/array-functions.md b/docs/en/sql-reference/functions/array-functions.md index 6e460a64bcf9..40bfb65e4e8d 100644 --- a/docs/en/sql-reference/functions/array-functions.md +++ b/docs/en/sql-reference/functions/array-functions.md @@ -1081,6 +1081,10 @@ Result: └─────────────────────────────────────────────────────────────┘ ``` +**See also** + +- [arrayFold](#arrayFold) + ## arrayReduceInRanges Applies an aggregate function to array elements in given ranges and returns an array containing the result corresponding to each range. The function will return the same result as multiple `arrayReduce(agg_func, arraySlice(arr1, index, length), ...)`. @@ -1138,7 +1142,7 @@ arrayFold(lambda_function, arr1, arr2, ..., accumulator) Query: ``` sql -SELECT arrayFold( x,acc -> acc + x*2, [1, 2, 3, 4], toInt64(3)) AS res; +SELECT arrayFold( acc,x -> acc + x*2, [1, 2, 3, 4], toInt64(3)) AS res; ``` Result: @@ -1152,7 +1156,7 @@ Result: **Example with the Fibonacci sequence** ```sql -SELECT arrayFold( x, acc -> (acc.2, acc.2 + acc.1), range(number), (1::Int64, 0::Int64)).1 AS fibonacci +SELECT arrayFold( acc,x -> (acc.2, acc.2 + acc.1), range(number), (1::Int64, 0::Int64)).1 AS fibonacci FROM numbers(1,10); ┌─fibonacci─┐ @@ -1169,6 +1173,9 @@ FROM numbers(1,10); └───────────┘ ``` +**See also** + +- [arrayReduce](#arrayReduce) ## arrayReverse(arr) diff --git a/src/Functions/array/arrayFold.cpp b/src/Functions/array/arrayFold.cpp index 94ed5d59ca9b..b5b650e72897 100644 --- a/src/Functions/array/arrayFold.cpp +++ b/src/Functions/array/arrayFold.cpp @@ -30,37 +30,37 @@ class ArrayFold : public IFunction void getLambdaArgumentTypes(DataTypes & arguments) const override { if (arguments.size() < 3) - throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires as arguments a lambda function, at least one array and an accumulator argument", getName()); + throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires as arguments a lambda function, at least one array and an accumulator", getName()); - DataTypes nested_types(arguments.size() - 1); - for (size_t i = 0; i < nested_types.size() - 1; ++i) + DataTypes accumulator_and_array_types(arguments.size() - 1); + accumulator_and_array_types[0] = arguments.back(); + for (size_t i = 1; i < accumulator_and_array_types.size(); ++i) { - const auto * array_type = checkAndGetDataType(&*arguments[i + 1]); + const auto * array_type = checkAndGetDataType(&*arguments[i]); if (!array_type) - throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument {} of function {} must be array, found {} instead", i + 2, getName(), arguments[i + 1]->getName()); - nested_types[i] = recursiveRemoveLowCardinality(array_type->getNestedType()); + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument {} of function {} must be of type Array, found {} instead", i + 1, getName(), arguments[i]->getName()); + accumulator_and_array_types[i] = recursiveRemoveLowCardinality(array_type->getNestedType()); } - nested_types[nested_types.size() - 1] = arguments[arguments.size() - 1]; - const auto * function_type = checkAndGetDataType(arguments[0].get()); - if (!function_type || function_type->getArgumentTypes().size() != nested_types.size()) - throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for this overload of {} must be a function with {} arguments, found {} instead.", - getName(), nested_types.size(), arguments[0]->getName()); + const auto * lambda_function_type = checkAndGetDataType(arguments[0].get()); + if (!lambda_function_type || lambda_function_type->getArgumentTypes().size() != accumulator_and_array_types.size()) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument of function {} must be a lambda function with {} arguments, found {} instead.", + getName(), accumulator_and_array_types.size(), arguments[0]->getName()); - arguments[0] = std::make_shared(nested_types); + arguments[0] = std::make_shared(accumulator_and_array_types); } DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override { - if (arguments.size() < 2) - throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires at least 2 arguments, passed: {}.", getName(), arguments.size()); + if (arguments.size() < 3) + throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires as arguments a lambda function, at least one array and an accumulator", getName()); - const auto * data_type_function = checkAndGetDataType(arguments[0].type.get()); - if (!data_type_function) + const auto * lambda_function_type = checkAndGetDataType(arguments[0].type.get()); + if (!lambda_function_type) throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for function {} must be a function", getName()); auto accumulator_type = arguments.back().type; - auto lambda_type = data_type_function->getReturnType(); + auto lambda_type = lambda_function_type->getReturnType(); if (!accumulator_type->equals(*lambda_type)) throw Exception(ErrorCodes::TYPE_MISMATCH, "Return type of lambda function must be the same as the accumulator type, inferred return type of lambda: {}, inferred type of accumulator: {}", @@ -71,12 +71,12 @@ class ArrayFold : public IFunction ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override { - const auto & lambda_with_type_and_name = arguments[0]; + const auto & lambda_function_with_type_and_name = arguments[0]; - if (!lambda_with_type_and_name.column) + if (!lambda_function_with_type_and_name.column) throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for function {} must be a function", getName()); - const auto * lambda_function = typeid_cast(lambda_with_type_and_name.column.get()); + const auto * lambda_function = typeid_cast(lambda_function_with_type_and_name.column.get()); if (!lambda_function) throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for function {} must be a function", getName()); @@ -85,6 +85,7 @@ class ArrayFold : public IFunction const ColumnArray * column_first_array = nullptr; ColumnsWithTypeAndName arrays; arrays.reserve(arguments.size() - 1); + /// Validate input types and get input array columns in convenient form for (size_t i = 1; i < arguments.size() - 1; ++i) { @@ -131,8 +132,7 @@ class ArrayFold : public IFunction if (rows_count == 0) return arguments.back().column->convertToFullColumnIfConst()->cloneEmpty(); - ColumnPtr current_column; - current_column = arguments.back().column->convertToFullColumnIfConst(); + ColumnPtr current_column = arguments.back().column->convertToFullColumnIfConst(); MutableColumnPtr result_data = arguments.back().column->convertToFullColumnIfConst()->cloneEmpty(); size_t max_array_size = 0; @@ -198,9 +198,9 @@ class ArrayFold : public IFunction auto res_lambda = lambda_function->cloneResized(prev[1]->size()); auto * res_lambda_ptr = typeid_cast(res_lambda.get()); + res_lambda_ptr->appendArguments(std::vector({ColumnWithTypeAndName(std::move(prev[1]), arguments.back().type, arguments.back().name)})); for (size_t i = 0; i < array_count; i++) res_lambda_ptr->appendArguments(std::vector({ColumnWithTypeAndName(std::move(data_arrays[i][ind]), arrays[i].type, arrays[i].name)})); - res_lambda_ptr->appendArguments(std::vector({ColumnWithTypeAndName(std::move(prev[1]), arguments.back().type, arguments.back().name)})); current_column = IColumn::mutate(res_lambda_ptr->reduce().column); prev_size = current_column->size(); diff --git a/tests/performance/array_fold.xml b/tests/performance/array_fold.xml index fae8bd164a72..32bd45beb1ec 100644 --- a/tests/performance/array_fold.xml +++ b/tests/performance/array_fold.xml @@ -1,5 +1,5 @@ - SELECT arrayFold((x, acc) -> acc + x, range(number % 100), toUInt64(0)) from numbers(100000) Format Null - SELECT arrayFold((x, acc) -> acc + 1, range(number % 100), toUInt64(0)) from numbers(100000) Format Null - SELECT arrayFold((x, acc) -> acc + x, range(number), toUInt64(0)) from numbers(10000) Format Null + SELECT arrayFold((acc, x) -> acc + x, range(number % 100), toUInt64(0)) from numbers(100000) Format Null + SELECT arrayFold((acc, x) -> acc + 1, range(number % 100), toUInt64(0)) from numbers(100000) Format Null + SELECT arrayFold((acc, x) -> acc + x, range(number), toUInt64(0)) from numbers(10000) Format Null diff --git a/tests/queries/0_stateless/02718_array_fold.sql b/tests/queries/0_stateless/02718_array_fold.sql index 7f20602a3715..0486a5ce2e36 100644 --- a/tests/queries/0_stateless/02718_array_fold.sql +++ b/tests/queries/0_stateless/02718_array_fold.sql @@ -1,23 +1,24 @@ SELECT 'Negative tests'; SELECT arrayFold(); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH } SELECT arrayFold(1); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH } -SELECT arrayFold(1, toUInt64(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } -SELECT arrayFold( x,acc -> x, emptyArrayString(), toInt8(0)); -- { serverError TYPE_MISMATCH } -SELECT arrayFold( x,acc -> x, 'not an array', toUInt8(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } -SELECT arrayFold( x,y,acc -> x, [0, 1], 'not an array', toUInt8(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } -SELECT arrayFold( x,acc -> x, [0, 1], [2, 3], toUInt8(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } -SELECT arrayFold( x,y,acc -> x, [0, 1], [2, 3, 4], toUInt8(0)); -- { serverError SIZES_OF_ARRAYS_DONT_MATCH } +SELECT arrayFold(1, toUInt64(0)); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH } +SELECT arrayFold(1, emptyArrayUInt64(), toUInt64(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } +SELECT arrayFold( acc,x -> x, emptyArrayString(), toInt8(0)); -- { serverError TYPE_MISMATCH } +SELECT arrayFold( acc,x -> x, 'not an array', toUInt8(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } +SELECT arrayFold( acc,x,y -> x, [0, 1], 'not an array', toUInt8(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } +SELECT arrayFold( acc,x -> x, [0, 1], [2, 3], toUInt8(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } +SELECT arrayFold( acc,x,y -> x, [0, 1], [2, 3, 4], toUInt8(0)); -- { serverError SIZES_OF_ARRAYS_DONT_MATCH } SELECT 'Const arrays'; -SELECT arrayFold( x,acc -> acc+x*2, [1, 2, 3, 4], toInt64(3)); -SELECT arrayFold( x,acc -> acc+x*2, emptyArrayInt64(), toInt64(3)); -SELECT arrayFold( x,y,acc -> acc+x*2+y*3, [1, 2, 3, 4], [5, 6, 7, 8], toInt64(3)); -SELECT arrayFold( x,acc -> arrayPushBack(acc, x), [1, 2, 3, 4], emptyArrayInt64()); -SELECT arrayFold( x,acc -> arrayPushFront(acc, x), [1, 2, 3, 4], emptyArrayInt64()); -SELECT arrayFold( x,acc -> (arrayPushFront(acc.1, x),arrayPushBack(acc.2, x)), [1, 2, 3, 4], (emptyArrayInt64(), emptyArrayInt64())); -SELECT arrayFold( x,acc -> x%2 ? (arrayPushBack(acc.1, x), acc.2): (acc.1, arrayPushBack(acc.2, x)), [1, 2, 3, 4, 5, 6], (emptyArrayInt64(), emptyArrayInt64())); +SELECT arrayFold( acc,x -> acc+x*2, [1, 2, 3, 4], toInt64(3)); +SELECT arrayFold( acc,x -> acc+x*2, emptyArrayInt64(), toInt64(3)); +SELECT arrayFold( acc,x,y -> acc+x*2+y*3, [1, 2, 3, 4], [5, 6, 7, 8], toInt64(3)); +SELECT arrayFold( acc,x -> arrayPushBack(acc, x), [1, 2, 3, 4], emptyArrayInt64()); +SELECT arrayFold( acc,x -> arrayPushFront(acc, x), [1, 2, 3, 4], emptyArrayInt64()); +SELECT arrayFold( acc,x -> (arrayPushFront(acc.1, x),arrayPushBack(acc.2, x)), [1, 2, 3, 4], (emptyArrayInt64(), emptyArrayInt64())); +SELECT arrayFold( acc,x -> x%2 ? (arrayPushBack(acc.1, x), acc.2): (acc.1, arrayPushBack(acc.2, x)), [1, 2, 3, 4, 5, 6], (emptyArrayInt64(), emptyArrayInt64())); SELECT 'Non-const arrays'; -SELECT arrayFold( x,acc -> acc+x, range(number), number) FROM system.numbers LIMIT 5; -SELECT arrayFold( x,acc -> arrayPushFront(acc,x), range(number), emptyArrayUInt64()) FROM system.numbers LIMIT 5; -SELECT arrayFold( x,acc -> x%2 ? arrayPushFront(acc,x) : arrayPushBack(acc,x), range(number), emptyArrayUInt64()) FROM system.numbers LIMIT 5; +SELECT arrayFold( acc,x -> acc+x, range(number), number) FROM system.numbers LIMIT 5; +SELECT arrayFold( acc,x -> arrayPushFront(acc,x), range(number), emptyArrayUInt64()) FROM system.numbers LIMIT 5; +SELECT arrayFold( acc,x -> x%2 ? arrayPushFront(acc,x) : arrayPushBack(acc,x), range(number), emptyArrayUInt64()) FROM system.numbers LIMIT 5;