Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport 43311 to 22.8 #50017

Merged
merged 6 commits into from
May 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 54 additions & 6 deletions src/Columns/ColumnFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <Common/assert_cast.h>
#include <IO/WriteHelpers.h>
#include <Functions/IFunction.h>
#include <DataTypes/DataTypeLowCardinality.h>


namespace ProfileEvents
Expand All @@ -23,8 +24,18 @@ namespace ErrorCodes
extern const int LOGICAL_ERROR;
}

ColumnFunction::ColumnFunction(size_t size, FunctionBasePtr function_, const ColumnsWithTypeAndName & columns_to_capture, bool is_short_circuit_argument_, bool is_function_compiled_)
: elements_size(size), function(function_), is_short_circuit_argument(is_short_circuit_argument_), is_function_compiled(is_function_compiled_)
ColumnFunction::ColumnFunction(
size_t size,
FunctionBasePtr function_,
const ColumnsWithTypeAndName & columns_to_capture,
bool is_short_circuit_argument_,
bool is_function_compiled_,
bool recursively_convert_result_to_full_column_if_low_cardinality_)
: elements_size(size)
, function(function_)
, is_short_circuit_argument(is_short_circuit_argument_)
, recursively_convert_result_to_full_column_if_low_cardinality(recursively_convert_result_to_full_column_if_low_cardinality_)
, is_function_compiled(is_function_compiled_)
{
appendArguments(columns_to_capture);
}
Expand Down Expand Up @@ -113,7 +124,13 @@ ColumnPtr ColumnFunction::filter(const Filter & filt, ssize_t result_size_hint)
else
filtered_size = capture.front().column->size();

return ColumnFunction::create(filtered_size, function, capture, is_short_circuit_argument, is_function_compiled);
return ColumnFunction::create(
filtered_size,
function,
capture,
is_short_circuit_argument,
is_function_compiled,
recursively_convert_result_to_full_column_if_low_cardinality);
}

void ColumnFunction::expand(const Filter & mask, bool inverted)
Expand All @@ -135,7 +152,13 @@ ColumnPtr ColumnFunction::permute(const Permutation & perm, size_t limit) const
for (auto & column : capture)
column.column = column.column->permute(perm, limit);

return ColumnFunction::create(limit, function, capture, is_short_circuit_argument, is_function_compiled);
return ColumnFunction::create(
limit,
function,
capture,
is_short_circuit_argument,
is_function_compiled,
recursively_convert_result_to_full_column_if_low_cardinality);
}

ColumnPtr ColumnFunction::index(const IColumn & indexes, size_t limit) const
Expand All @@ -144,7 +167,13 @@ ColumnPtr ColumnFunction::index(const IColumn & indexes, size_t limit) const
for (auto & column : capture)
column.column = column.column->index(indexes, limit);

return ColumnFunction::create(limit, function, capture, is_short_circuit_argument, is_function_compiled);
return ColumnFunction::create(
limit,
function,
capture,
is_short_circuit_argument,
is_function_compiled,
recursively_convert_result_to_full_column_if_low_cardinality);
}

std::vector<MutableColumnPtr> ColumnFunction::scatter(IColumn::ColumnIndex num_columns,
Expand Down Expand Up @@ -173,7 +202,13 @@ std::vector<MutableColumnPtr> ColumnFunction::scatter(IColumn::ColumnIndex num_c
{
auto & capture = captures[part];
size_t capture_size = capture.empty() ? counts[part] : capture.front().column->size();
columns.emplace_back(ColumnFunction::create(capture_size, function, std::move(capture), is_short_circuit_argument));
columns.emplace_back(ColumnFunction::create(
capture_size,
function,
std::move(capture),
is_short_circuit_argument,
is_function_compiled,
recursively_convert_result_to_full_column_if_low_cardinality));
}

return columns;
Expand Down Expand Up @@ -237,6 +272,9 @@ void ColumnFunction::appendArgument(const ColumnWithTypeAndName & column)

DataTypePtr ColumnFunction::getResultType() const
{
if (recursively_convert_result_to_full_column_if_low_cardinality)
return recursiveRemoveLowCardinality(function->getResultType());

return function->getResultType();
}

Expand Down Expand Up @@ -270,9 +308,19 @@ ColumnWithTypeAndName ColumnFunction::reduce() const
ProfileEvents::increment(ProfileEvents::CompiledFunctionExecute);

res.column = function->execute(columns, res.type, elements_size);
if (recursively_convert_result_to_full_column_if_low_cardinality)
{
res.column = recursiveRemoveLowCardinality(res.column);
res.type = recursiveRemoveLowCardinality(res.type);
}
return res;
}

ColumnPtr ColumnFunction::recursivelyConvertResultToFullColumnIfLowCardinality() const
{
return ColumnFunction::create(elements_size, function, captured_columns, is_short_circuit_argument, is_function_compiled, true);
}

const ColumnFunction * checkAndGetShortCircuitArgument(const ColumnPtr & column)
{
const ColumnFunction * column_function;
Expand Down
13 changes: 12 additions & 1 deletion src/Columns/ColumnFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ class ColumnFunction final : public COWHelper<IColumn, ColumnFunction>
FunctionBasePtr function_,
const ColumnsWithTypeAndName & columns_to_capture,
bool is_short_circuit_argument_ = false,
bool is_function_compiled_ = false);
bool is_function_compiled_ = false,
bool recursively_convert_result_to_full_column_if_low_cardinality_ = false);

public:
const char * getFamilyName() const override { return "Function"; }
Expand Down Expand Up @@ -177,6 +178,9 @@ class ColumnFunction final : public COWHelper<IColumn, ColumnFunction>

DataTypePtr getResultType() const;

/// Create copy of this column, but with recursively_convert_result_to_full_column_if_low_cardinality = true
ColumnPtr recursivelyConvertResultToFullColumnIfLowCardinality() const;

private:
size_t elements_size;
FunctionBasePtr function;
Expand All @@ -188,6 +192,13 @@ class ColumnFunction final : public COWHelper<IColumn, ColumnFunction>
/// See ExpressionActions.cpp for details.
bool is_short_circuit_argument;

/// Special flag for lazy executed argument for short-circuit function.
/// If true, call recursiveRemoveLowCardinality on the result column
/// when function will be executed.
/// It's used when short-circuit function uses default implementation
/// for low cardinality arguments.
bool recursively_convert_result_to_full_column_if_low_cardinality = false;

/// Determine if passed function is compiled. Used for profiling.
bool is_function_compiled;

Expand Down
12 changes: 12 additions & 0 deletions src/DataTypes/DataTypeLowCardinalityHelpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <Columns/ColumnTuple.h>
#include <Columns/ColumnMap.h>
#include <Columns/ColumnLowCardinality.h>
#include <Columns/ColumnFunction.h>

#include <DataTypes/DataTypeLowCardinality.h>
#include <DataTypes/DataTypeArray.h>
Expand Down Expand Up @@ -95,6 +96,17 @@ ColumnPtr recursiveRemoveLowCardinality(const ColumnPtr & column)
return ColumnMap::create(nested_no_lc);
}

/// Special case when column is a lazy argument of short circuit function.
/// We should call recursiveRemoveLowCardinality on the result column
/// when function will be executed.
if (const auto * column_function = typeid_cast<const ColumnFunction *>(column.get()))
{
if (!column_function->isShortCircuitArgument())
return column;

return column_function->recursivelyConvertResultToFullColumnIfLowCardinality();
}

if (const auto * column_low_cardinality = typeid_cast<const ColumnLowCardinality *>(column.get()))
return column_low_cardinality->convertToFullColumn();

Expand Down
2 changes: 1 addition & 1 deletion src/Functions/IFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ class IFunctionOverloadResolver
*/
virtual bool useDefaultImplementationForSparseColumns() const { return true; }

// /// If it isn't, will convert all ColumnLowCardinality arguments to full columns.
/// If it isn't, will convert all ColumnLowCardinality arguments to full columns.
virtual bool canBeExecutedOnLowCardinalityDictionary() const { return true; }

private:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
if with one LC argument
b
a
b
b
a
b
a
if with LC and NULL arguments
\N
a
\N
\N
a
\N
a
if with two LC arguments
b
a
b
b
a
a
a
\N
1
1
1
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
set short_circuit_function_evaluation='force_enable';

select 'if with one LC argument';
select if(0, toLowCardinality('a'), 'b');
select if(1, toLowCardinality('a'), 'b');
select if(materialize(0), materialize(toLowCardinality('a')), materialize('b'));
select if(number % 2, toLowCardinality('a'), 'b') from numbers(2);
select if(number % 2, materialize(toLowCardinality('a')), materialize('b')) from numbers(2);

select 'if with LC and NULL arguments';
select if(0, toLowCardinality('a'), NULL);
select if(1, toLowCardinality('a'), NULL);
select if(materialize(0), materialize(toLowCardinality('a')), NULL);
select if(number % 2, toLowCardinality('a'), NULL) from numbers(2);
select if(number % 2, materialize(toLowCardinality('a')), NULL) from numbers(2);

select 'if with two LC arguments';
select if(0, toLowCardinality('a'), toLowCardinality('b'));
select if(1, toLowCardinality('a'), toLowCardinality('b'));
select if(materialize(0), materialize(toLowCardinality('a')), materialize(toLowCardinality('b')));
select if(number % 2, toLowCardinality('a'), toLowCardinality('b')) from numbers(2);
select if(number % 2, materialize(toLowCardinality('a')), materialize(toLowCardinality('a'))) from numbers(2);

select if(number % 2, toLowCardinality(number), NULL) from numbers(2);
select if(number % 2, toLowCardinality(number), toLowCardinality(number + 1)) from numbers(2);

Loading