New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added function arrayFold()
#49794
Added function arrayFold()
#49794
Changes from 22 commits
39a505e
213ac1e
67e28ae
0a6d08f
1a8846c
74ab98a
bd3c084
23dec23
65cdae8
b2e4f3b
2c9635c
03f5465
0334edf
c19a13e
c213ee1
6c31772
28589d8
a32cbfa
f09a221
2e6a48f
9e91e19
cb29764
c645d5f
2848548
07e0cc1
878e36d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,226 @@ | ||
#include "FunctionArrayMapped.h" | ||
#include <Functions/FunctionFactory.h> | ||
#include <Common/Exception.h> | ||
|
||
namespace DB | ||
{ | ||
|
||
namespace ErrorCodes | ||
{ | ||
extern const int ILLEGAL_COLUMN; | ||
extern const int ILLEGAL_TYPE_OF_ARGUMENT; | ||
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; | ||
extern const int SIZES_OF_ARRAYS_DONT_MATCH; | ||
extern const int TYPE_MISMATCH; | ||
} | ||
|
||
/** arrayFold(x1,...,xn,accum -> expression, array1,...,arrayn, init_accum) - apply the expression to each element of the array (or set of parallel arrays). | ||
*/ | ||
class ArrayFold : public IFunction | ||
{ | ||
public: | ||
static constexpr auto name = "arrayFold"; | ||
static FunctionPtr create(ContextPtr) { return std::make_shared<ArrayFold>(); } | ||
|
||
bool isVariadic() const override { return true; } | ||
size_t getNumberOfArguments() const override { return 0; } | ||
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; } | ||
|
||
void getLambdaArgumentTypes(DataTypes & arguments) const override | ||
{ | ||
if (arguments.size() < 3) | ||
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} needs lambda function, at least one array argument and one accumulator argument.", getName()); | ||
DataTypes nested_types(arguments.size() - 1); | ||
for (size_t i = 0; i < nested_types.size() - 1; ++i) | ||
{ | ||
const DataTypeArray * array_type = checkAndGetDataType<DataTypeArray>(&*arguments[i + 1]); | ||
if (!array_type) | ||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument {} of function {} must be array. Found {} instead.", toString(i + 2), getName(), arguments[i + 1]->getName()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ( |
||
nested_types[i] = recursiveRemoveLowCardinality(array_type->getNestedType()); | ||
} | ||
nested_types[nested_types.size() - 1] = arguments[arguments.size() - 1]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
const DataTypeFunction * function_type = checkAndGetDataType<DataTypeFunction>(arguments[0].get()); | ||
if (!function_type || function_type->getArgumentTypes().size() != nested_types.size()) | ||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for this overload of {} must be a function with {} arguments. Found {} instead.", | ||
getName(), toString(nested_types.size()), arguments[0]->getName()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (toString() isn't necessary) |
||
|
||
arguments[0] = std::make_shared<DataTypeFunction>(nested_types); | ||
} | ||
|
||
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override | ||
{ | ||
if (arguments.size() < 2) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't we need at least three arguments? |
||
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} needs at least 2 arguments; passed {}.", getName(), toString(arguments.size())); | ||
const auto * data_type_function = checkAndGetDataType<DataTypeFunction>(arguments[0].type.get()); | ||
if (!data_type_function) | ||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for function {} must be a function.", getName()); | ||
|
||
auto const accumulator_type = arguments.back().type; | ||
auto const lambda_type = data_type_function->getReturnType(); | ||
if (! accumulator_type->equals(*lambda_type)) | ||
throw Exception(ErrorCodes::TYPE_MISMATCH, "Return type of lambda function must be the same as the accumulator type. " | ||
"Inferred type of lambda {}, inferred type of accumulator {}.", lambda_type->getName(), accumulator_type->getName()); | ||
|
||
return DataTypePtr(accumulator_type); | ||
} | ||
|
||
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override | ||
{ | ||
const auto & column_with_type_and_name = arguments[0]; | ||
rschu1ze marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if (!column_with_type_and_name.column) | ||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for function {} must be a function.", getName()); | ||
|
||
const auto * column_function = typeid_cast<const ColumnFunction *>(column_with_type_and_name.column.get()); | ||
rschu1ze marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if (!column_function) | ||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for function {} must be a function.", getName()); | ||
|
||
ColumnPtr offsets_column; | ||
ColumnPtr column_first_array_ptr; | ||
const ColumnArray * column_first_array = nullptr; | ||
ColumnsWithTypeAndName arrays; | ||
arrays.reserve(arguments.size() - 1); | ||
|
||
for (size_t i = 1; i < arguments.size() - 1; ++i) | ||
{ | ||
const auto & array_with_type_and_name = arguments[i]; | ||
ColumnPtr column_array_ptr = array_with_type_and_name.column; | ||
const auto * column_array = checkAndGetColumn<ColumnArray>(column_array_ptr.get()); | ||
const DataTypePtr & array_type_ptr = array_with_type_and_name.type; | ||
rschu1ze marked this conversation as resolved.
Show resolved
Hide resolved
|
||
const auto * array_type = checkAndGetDataType<DataTypeArray>(array_type_ptr.get()); | ||
if (!column_array) | ||
{ | ||
const ColumnConst * column_const_array = checkAndGetColumnConst<ColumnArray>(column_array_ptr.get()); | ||
if (!column_const_array) | ||
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Expected array column, found {}", column_array_ptr->getName()); | ||
column_array_ptr = recursiveRemoveLowCardinality(column_const_array->convertToFullColumn()); | ||
rschu1ze marked this conversation as resolved.
Show resolved
Hide resolved
|
||
column_array = checkAndGetColumn<ColumnArray>(column_array_ptr.get()); | ||
} | ||
if (!array_type) | ||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Expected array type, found {}", array_type_ptr->getName()); | ||
if (!offsets_column) | ||
{ | ||
rschu1ze marked this conversation as resolved.
Show resolved
Hide resolved
|
||
offsets_column = column_array->getOffsetsPtr(); | ||
} | ||
else | ||
{ | ||
/// The first condition is optimization: do not compare data if the pointers are equal. | ||
if (column_array->getOffsetsPtr() != offsets_column | ||
&& column_array->getOffsets() != typeid_cast<const ColumnArray::ColumnOffsets &>(*offsets_column).getData()) | ||
throw Exception(ErrorCodes::SIZES_OF_ARRAYS_DONT_MATCH, "Arrays passed to {} must have equal size", getName()); | ||
} | ||
if (i == 1) | ||
{ | ||
column_first_array_ptr = column_array_ptr; | ||
column_first_array = column_array; | ||
} | ||
arrays.emplace_back(ColumnWithTypeAndName(column_array->getDataPtr(), | ||
recursiveRemoveLowCardinality(array_type->getNestedType()), | ||
array_with_type_and_name.name)); | ||
} | ||
|
||
ssize_t rows_count = input_rows_count; | ||
rschu1ze marked this conversation as resolved.
Show resolved
Hide resolved
|
||
ssize_t data_row_count = arrays[0].column->size(); | ||
auto array_count = arrays.size(); | ||
rschu1ze marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if (rows_count == 0) | ||
return arguments.back().column->convertToFullColumnIfConst()->cloneEmpty(); | ||
|
||
ColumnPtr current_column; | ||
current_column = arguments.back().column->convertToFullColumnIfConst(); | ||
MutableColumnPtr result_data = arguments.back().column->convertToFullColumnIfConst()->cloneEmpty(); | ||
|
||
size_t max_array_size = 0; | ||
const auto & offsets = column_first_array->getOffsets(); | ||
|
||
//get columns of Nth array elements | ||
IColumn::Selector selector(data_row_count); | ||
size_t cur_ind = 0; | ||
ssize_t cur_arr = 0; | ||
|
||
if (data_row_count) | ||
while (offsets[cur_arr] == 0) | ||
++cur_arr; | ||
|
||
for (ssize_t i = 0; i < data_row_count; ++i) | ||
{ | ||
selector[i] = cur_ind++; | ||
rschu1ze marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if (cur_ind > max_array_size) | ||
max_array_size = cur_ind; | ||
while (cur_arr < rows_count && cur_ind >= offsets[cur_arr] - offsets[cur_arr - 1]) | ||
{ | ||
++cur_arr; | ||
cur_ind = 0; | ||
} | ||
} | ||
|
||
std::vector<MutableColumns> data_arrays; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As written elsewhere, below code is a bit hard to grasp. Let's add comments. |
||
data_arrays.resize(array_count); | ||
|
||
if (max_array_size > 0) | ||
for (size_t i = 0; i < array_count; ++i) | ||
data_arrays[i] = arrays[i].column->scatter(max_array_size, selector); | ||
|
||
size_t prev_size = rows_count; | ||
IColumn::Permutation inverse_permutation(rows_count); | ||
size_t inverse_permutation_count = 0; | ||
|
||
for (size_t ind = 0; ind < max_array_size; ++ind) | ||
{ | ||
IColumn::Selector prev_selector(prev_size); | ||
size_t prev_ind = 0; | ||
for (ssize_t irow = 0; irow < rows_count; ++irow) | ||
{ | ||
if (offsets[irow] - offsets[irow - 1] > ind) | ||
{ | ||
prev_selector[prev_ind++] = 1; | ||
} | ||
else if (offsets[irow] - offsets[irow - 1] == ind) | ||
{ | ||
inverse_permutation[inverse_permutation_count++] = irow; | ||
prev_selector[prev_ind++] = 0; | ||
} | ||
} | ||
auto prev = current_column->scatter(2, prev_selector); | ||
|
||
result_data->insertRangeFrom(*(prev[0]), 0, prev[0]->size()); | ||
|
||
auto res_lambda = column_function->cloneResized(prev[1]->size()); | ||
auto * res_lambda_ptr = typeid_cast<ColumnFunction *>(res_lambda.get()); | ||
|
||
for (size_t i = 0; i < array_count; i++) | ||
res_lambda_ptr->appendArguments(std::vector({ColumnWithTypeAndName(std::move(data_arrays[i][ind]), arrays[i].type, arrays[i].name)})); | ||
res_lambda_ptr->appendArguments(std::vector({ColumnWithTypeAndName(std::move(prev[1]), arguments.back().type, arguments.back().name)})); | ||
|
||
current_column = IColumn::mutate(res_lambda_ptr->reduce().column); | ||
prev_size = current_column->size(); | ||
} | ||
|
||
result_data->insertRangeFrom(*current_column, 0, current_column->size()); | ||
for (ssize_t irow = 0; irow < rows_count; ++irow) | ||
if (offsets[irow] - offsets[irow - 1] == max_array_size) | ||
inverse_permutation[inverse_permutation_count++] = irow; | ||
|
||
IColumn::Permutation perm(rows_count); | ||
for (ssize_t i = 0; i < rows_count; i++) | ||
perm[inverse_permutation[i]] = i; | ||
return result_data->permute(perm, 0); | ||
} | ||
|
||
private: | ||
String getName() const override | ||
{ | ||
return name; | ||
} | ||
}; | ||
|
||
REGISTER_FUNCTION(ArrayFold) | ||
{ | ||
factory.registerFunction<ArrayFold>(FunctionDocumentation{.description=R"( | ||
Function arrayFold(x1,...,xn,accum -> expression, array1,...,arrayn, init_accum) applies lambda function to a number of same sized array columns | ||
and collects result in accumulator. Accumulator can be either constant or column. | ||
)", .examples{{"sum", "SELECT arrayFold(x,acc -> acc + x, [1,2,3,4], toInt64(1));", "11"}}, .categories{"Array"}}); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
<test> | ||
<query>SELECT arrayFold((x, acc) -> acc + x, range(number % 100), toUInt64(0)) from numbers(100000) Format Null</query> | ||
<query>SELECT arrayFold((x, acc) -> acc + 1, range(number % 100), toUInt64(0)) from numbers(100000) Format Null</query> | ||
<query>SELECT arrayFold((x, acc) -> acc + x, range(number), toUInt64(0)) from numbers(10000) Format Null</query> | ||
</test> |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
23 | ||
3 | ||
101 | ||
269 | ||
[1,2,3,4] | ||
[4,3,2,1] | ||
([4,3,2,1],[1,2,3,4]) | ||
([1,3,5],[2,4,6]) | ||
0 | ||
0 | ||
1 | ||
3 | ||
6 | ||
10 | ||
0 | ||
1 | ||
3 | ||
6 | ||
10 | ||
15 | ||
[] | ||
[0] | ||
[1,0] | ||
[2,1,0] | ||
[3,2,1,0] | ||
[4,3,2,1,0] | ||
[] | ||
[0] | ||
[1,0] | ||
[1,0,2] | ||
[3,1,0,2] | ||
[3,1,0,2,4] | ||
[(0,0)] | ||
[(0,1),(0,0)] | ||
[(1,2),(0,1),(0,0)] | ||
[(2,3),(1,2),(0,1),(0,0)] | ||
[(3,4),(2,3),(1,2),(0,1),(0,0)] | ||
[(4,5),(3,4),(2,3),(1,2),(0,1),(0,0)] | ||
[] | ||
['0'] | ||
['0','1'] | ||
['0','1','2'] | ||
['0','1','2','3'] | ||
['0','1','2','3','4'] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
SELECT arrayFold(x,acc -> acc + x * 2, [1,2,3,4], toInt64(3)); | ||
rschu1ze marked this conversation as resolved.
Show resolved
Hide resolved
|
||
SELECT arrayFold(x,acc -> acc + x * 2, emptyArrayInt64(), toInt64(3)); | ||
SELECT arrayFold(x,y,acc -> acc + x * 2 + y * 3, [1,2,3,4], [5,6,7,8], toInt64(3)); | ||
rschu1ze marked this conversation as resolved.
Show resolved
Hide resolved
|
||
SELECT arrayFold(x,y,z,acc -> acc + x * 2 + y * 3 + z * 4, [1,2,3,4], [5,6,7,8], [9,10,11,12], toInt64(3)); | ||
SELECT arrayFold(x,acc -> arrayPushBack(acc,x), [1,2,3,4], emptyArrayInt64()); | ||
SELECT arrayFold(x,acc -> arrayPushFront(acc,x), [1,2,3,4], emptyArrayInt64()); | ||
SELECT arrayFold(x,acc -> (arrayPushFront(acc.1,x), arrayPushBack(acc.2,x)), [1,2,3,4], (emptyArrayInt64(), emptyArrayInt64())); | ||
SELECT arrayFold(x,acc -> x % 2 ? (arrayPushBack(acc.1,x), acc.2): (acc.1, arrayPushBack(acc.2,x)), [1,2,3,4,5,6], (emptyArrayInt64(), emptyArrayInt64())); | ||
|
||
SELECT arrayFold(x,acc -> acc+x, range(number), toInt64(0)) FROM system.numbers LIMIT 6; | ||
SELECT arrayFold(x,acc -> acc+x, range(number), number) FROM system.numbers LIMIT 6; | ||
SELECT arrayFold(x,acc -> arrayPushFront(acc, x), range(number), emptyArrayUInt64()) FROM system.numbers LIMIT 6; | ||
SELECT arrayFold(x,acc -> x % 2 ? arrayPushFront(acc, x) : arrayPushBack(acc, x), range(number), emptyArrayUInt64()) FROM system.numbers LIMIT 6; | ||
SELECT arrayFold(x,acc -> arrayPushFront(acc, (x, x+1)), range(number), [(toUInt64(0),toUInt64(0))]) FROM system.numbers LIMIT 6; | ||
SELECT arrayFold(x, acc -> concat(acc, arrayMap(z -> toString(x), [number])) , range(number), CAST([] as Array(String))) FROM system.numbers LIMIT 6; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function
getLambdaArgumentTypes
isn't implemented often in the codebase. Suggest to add some comments what the elementsarguments
represent. I think they are simply the function arguments (lambda function, array arguments, initial accumulator) but a naive reader could think they are the lambda's arguments.