Skip to content

Commit

Permalink
Merge pull request #55948 from rschu1ze/arrayFold-switch-args
Browse files Browse the repository at this point in the history
arrayFold: Switch accumulator and array arguments
  • Loading branch information
rschu1ze committed Oct 24, 2023
2 parents 2619eec + 68c3f41 commit 0955107
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 44 deletions.
11 changes: 9 additions & 2 deletions docs/en/sql-reference/functions/array-functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -1081,6 +1081,10 @@ Result:
└─────────────────────────────────────────────────────────────┘
```

**See also**

- [arrayFold](#arrayFold)

## arrayReduceInRanges

Applies an aggregate function to array elements in given ranges and returns an array containing the result corresponding to each range. The function will return the same result as multiple `arrayReduce(agg_func, arraySlice(arr1, index, length), ...)`.
Expand Down Expand Up @@ -1138,7 +1142,7 @@ arrayFold(lambda_function, arr1, arr2, ..., accumulator)
Query:

``` sql
SELECT arrayFold( x,acc -> acc + x*2, [1, 2, 3, 4], toInt64(3)) AS res;
SELECT arrayFold( acc,x -> acc + x*2, [1, 2, 3, 4], toInt64(3)) AS res;
```

Result:
Expand All @@ -1152,7 +1156,7 @@ Result:
**Example with the Fibonacci sequence**

```sql
SELECT arrayFold( x, acc -> (acc.2, acc.2 + acc.1), range(number), (1::Int64, 0::Int64)).1 AS fibonacci
SELECT arrayFold( acc,x -> (acc.2, acc.2 + acc.1), range(number), (1::Int64, 0::Int64)).1 AS fibonacci
FROM numbers(1,10);

┌─fibonacci─┐
Expand All @@ -1169,6 +1173,9 @@ FROM numbers(1,10);
└───────────┘
```

**See also**

- [arrayReduce](#arrayReduce)

## arrayReverse(arr)

Expand Down
46 changes: 23 additions & 23 deletions src/Functions/array/arrayFold.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,37 +30,37 @@ class ArrayFold : public IFunction
void getLambdaArgumentTypes(DataTypes & arguments) const override
{
if (arguments.size() < 3)
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires as arguments a lambda function, at least one array and an accumulator argument", getName());
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires as arguments a lambda function, at least one array and an accumulator", getName());

DataTypes nested_types(arguments.size() - 1);
for (size_t i = 0; i < nested_types.size() - 1; ++i)
DataTypes accumulator_and_array_types(arguments.size() - 1);
accumulator_and_array_types[0] = arguments.back();
for (size_t i = 1; i < accumulator_and_array_types.size(); ++i)
{
const auto * array_type = checkAndGetDataType<DataTypeArray>(&*arguments[i + 1]);
const auto * array_type = checkAndGetDataType<DataTypeArray>(&*arguments[i]);
if (!array_type)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument {} of function {} must be array, found {} instead", i + 2, getName(), arguments[i + 1]->getName());
nested_types[i] = recursiveRemoveLowCardinality(array_type->getNestedType());
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument {} of function {} must be of type Array, found {} instead", i + 1, getName(), arguments[i]->getName());
accumulator_and_array_types[i] = recursiveRemoveLowCardinality(array_type->getNestedType());
}
nested_types[nested_types.size() - 1] = arguments[arguments.size() - 1];

const auto * function_type = checkAndGetDataType<DataTypeFunction>(arguments[0].get());
if (!function_type || function_type->getArgumentTypes().size() != nested_types.size())
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for this overload of {} must be a function with {} arguments, found {} instead.",
getName(), nested_types.size(), arguments[0]->getName());
const auto * lambda_function_type = checkAndGetDataType<DataTypeFunction>(arguments[0].get());
if (!lambda_function_type || lambda_function_type->getArgumentTypes().size() != accumulator_and_array_types.size())
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument of function {} must be a lambda function with {} arguments, found {} instead.",
getName(), accumulator_and_array_types.size(), arguments[0]->getName());

arguments[0] = std::make_shared<DataTypeFunction>(nested_types);
arguments[0] = std::make_shared<DataTypeFunction>(accumulator_and_array_types);
}

DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
{
if (arguments.size() < 2)
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires at least 2 arguments, passed: {}.", getName(), arguments.size());
if (arguments.size() < 3)
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires as arguments a lambda function, at least one array and an accumulator", getName());

const auto * data_type_function = checkAndGetDataType<DataTypeFunction>(arguments[0].type.get());
if (!data_type_function)
const auto * lambda_function_type = checkAndGetDataType<DataTypeFunction>(arguments[0].type.get());
if (!lambda_function_type)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for function {} must be a function", getName());

auto accumulator_type = arguments.back().type;
auto lambda_type = data_type_function->getReturnType();
auto lambda_type = lambda_function_type->getReturnType();
if (!accumulator_type->equals(*lambda_type))
throw Exception(ErrorCodes::TYPE_MISMATCH,
"Return type of lambda function must be the same as the accumulator type, inferred return type of lambda: {}, inferred type of accumulator: {}",
Expand All @@ -71,12 +71,12 @@ class ArrayFold : public IFunction

ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
{
const auto & lambda_with_type_and_name = arguments[0];
const auto & lambda_function_with_type_and_name = arguments[0];

if (!lambda_with_type_and_name.column)
if (!lambda_function_with_type_and_name.column)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for function {} must be a function", getName());

const auto * lambda_function = typeid_cast<const ColumnFunction *>(lambda_with_type_and_name.column.get());
const auto * lambda_function = typeid_cast<const ColumnFunction *>(lambda_function_with_type_and_name.column.get());
if (!lambda_function)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for function {} must be a function", getName());

Expand All @@ -85,6 +85,7 @@ class ArrayFold : public IFunction
const ColumnArray * column_first_array = nullptr;
ColumnsWithTypeAndName arrays;
arrays.reserve(arguments.size() - 1);

/// Validate input types and get input array columns in convenient form
for (size_t i = 1; i < arguments.size() - 1; ++i)
{
Expand Down Expand Up @@ -131,8 +132,7 @@ class ArrayFold : public IFunction
if (rows_count == 0)
return arguments.back().column->convertToFullColumnIfConst()->cloneEmpty();

ColumnPtr current_column;
current_column = arguments.back().column->convertToFullColumnIfConst();
ColumnPtr current_column = arguments.back().column->convertToFullColumnIfConst();
MutableColumnPtr result_data = arguments.back().column->convertToFullColumnIfConst()->cloneEmpty();

size_t max_array_size = 0;
Expand Down Expand Up @@ -198,9 +198,9 @@ class ArrayFold : public IFunction
auto res_lambda = lambda_function->cloneResized(prev[1]->size());
auto * res_lambda_ptr = typeid_cast<ColumnFunction *>(res_lambda.get());

res_lambda_ptr->appendArguments(std::vector({ColumnWithTypeAndName(std::move(prev[1]), arguments.back().type, arguments.back().name)}));
for (size_t i = 0; i < array_count; i++)
res_lambda_ptr->appendArguments(std::vector({ColumnWithTypeAndName(std::move(data_arrays[i][ind]), arrays[i].type, arrays[i].name)}));
res_lambda_ptr->appendArguments(std::vector({ColumnWithTypeAndName(std::move(prev[1]), arguments.back().type, arguments.back().name)}));

current_column = IColumn::mutate(res_lambda_ptr->reduce().column);
prev_size = current_column->size();
Expand Down
6 changes: 3 additions & 3 deletions tests/performance/array_fold.xml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
<test>
<query>SELECT arrayFold((x, acc) -> acc + x, range(number % 100), toUInt64(0)) from numbers(100000) Format Null</query>
<query>SELECT arrayFold((x, acc) -> acc + 1, range(number % 100), toUInt64(0)) from numbers(100000) Format Null</query>
<query>SELECT arrayFold((x, acc) -> acc + x, range(number), toUInt64(0)) from numbers(10000) Format Null</query>
<query>SELECT arrayFold((acc, x) -> acc + x, range(number % 100), toUInt64(0)) from numbers(100000) Format Null</query>
<query>SELECT arrayFold((acc, x) -> acc + 1, range(number % 100), toUInt64(0)) from numbers(100000) Format Null</query>
<query>SELECT arrayFold((acc, x) -> acc + x, range(number), toUInt64(0)) from numbers(10000) Format Null</query>
</test>
33 changes: 17 additions & 16 deletions tests/queries/0_stateless/02718_array_fold.sql
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
SELECT 'Negative tests';
SELECT arrayFold(); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH }
SELECT arrayFold(1); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH }
SELECT arrayFold(1, toUInt64(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
SELECT arrayFold( x,acc -> x, emptyArrayString(), toInt8(0)); -- { serverError TYPE_MISMATCH }
SELECT arrayFold( x,acc -> x, 'not an array', toUInt8(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
SELECT arrayFold( x,y,acc -> x, [0, 1], 'not an array', toUInt8(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
SELECT arrayFold( x,acc -> x, [0, 1], [2, 3], toUInt8(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
SELECT arrayFold( x,y,acc -> x, [0, 1], [2, 3, 4], toUInt8(0)); -- { serverError SIZES_OF_ARRAYS_DONT_MATCH }
SELECT arrayFold(1, toUInt64(0)); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH }
SELECT arrayFold(1, emptyArrayUInt64(), toUInt64(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
SELECT arrayFold( acc,x -> x, emptyArrayString(), toInt8(0)); -- { serverError TYPE_MISMATCH }
SELECT arrayFold( acc,x -> x, 'not an array', toUInt8(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
SELECT arrayFold( acc,x,y -> x, [0, 1], 'not an array', toUInt8(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
SELECT arrayFold( acc,x -> x, [0, 1], [2, 3], toUInt8(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
SELECT arrayFold( acc,x,y -> x, [0, 1], [2, 3, 4], toUInt8(0)); -- { serverError SIZES_OF_ARRAYS_DONT_MATCH }

SELECT 'Const arrays';
SELECT arrayFold( x,acc -> acc+x*2, [1, 2, 3, 4], toInt64(3));
SELECT arrayFold( x,acc -> acc+x*2, emptyArrayInt64(), toInt64(3));
SELECT arrayFold( x,y,acc -> acc+x*2+y*3, [1, 2, 3, 4], [5, 6, 7, 8], toInt64(3));
SELECT arrayFold( x,acc -> arrayPushBack(acc, x), [1, 2, 3, 4], emptyArrayInt64());
SELECT arrayFold( x,acc -> arrayPushFront(acc, x), [1, 2, 3, 4], emptyArrayInt64());
SELECT arrayFold( x,acc -> (arrayPushFront(acc.1, x),arrayPushBack(acc.2, x)), [1, 2, 3, 4], (emptyArrayInt64(), emptyArrayInt64()));
SELECT arrayFold( x,acc -> x%2 ? (arrayPushBack(acc.1, x), acc.2): (acc.1, arrayPushBack(acc.2, x)), [1, 2, 3, 4, 5, 6], (emptyArrayInt64(), emptyArrayInt64()));
SELECT arrayFold( acc,x -> acc+x*2, [1, 2, 3, 4], toInt64(3));
SELECT arrayFold( acc,x -> acc+x*2, emptyArrayInt64(), toInt64(3));
SELECT arrayFold( acc,x,y -> acc+x*2+y*3, [1, 2, 3, 4], [5, 6, 7, 8], toInt64(3));
SELECT arrayFold( acc,x -> arrayPushBack(acc, x), [1, 2, 3, 4], emptyArrayInt64());
SELECT arrayFold( acc,x -> arrayPushFront(acc, x), [1, 2, 3, 4], emptyArrayInt64());
SELECT arrayFold( acc,x -> (arrayPushFront(acc.1, x),arrayPushBack(acc.2, x)), [1, 2, 3, 4], (emptyArrayInt64(), emptyArrayInt64()));
SELECT arrayFold( acc,x -> x%2 ? (arrayPushBack(acc.1, x), acc.2): (acc.1, arrayPushBack(acc.2, x)), [1, 2, 3, 4, 5, 6], (emptyArrayInt64(), emptyArrayInt64()));

SELECT 'Non-const arrays';
SELECT arrayFold( x,acc -> acc+x, range(number), number) FROM system.numbers LIMIT 5;
SELECT arrayFold( x,acc -> arrayPushFront(acc,x), range(number), emptyArrayUInt64()) FROM system.numbers LIMIT 5;
SELECT arrayFold( x,acc -> x%2 ? arrayPushFront(acc,x) : arrayPushBack(acc,x), range(number), emptyArrayUInt64()) FROM system.numbers LIMIT 5;
SELECT arrayFold( acc,x -> acc+x, range(number), number) FROM system.numbers LIMIT 5;
SELECT arrayFold( acc,x -> arrayPushFront(acc,x), range(number), emptyArrayUInt64()) FROM system.numbers LIMIT 5;
SELECT arrayFold( acc,x -> x%2 ? arrayPushFront(acc,x) : arrayPushBack(acc,x), range(number), emptyArrayUInt64()) FROM system.numbers LIMIT 5;

0 comments on commit 0955107

Please sign in to comment.