Skip to content

Commit

Permalink
Merge pull request #50017 from ClickHouse/backport/22.8/43311
Browse files Browse the repository at this point in the history
Backport 43311 to 22.8
  • Loading branch information
Avogar committed May 22, 2023
2 parents cfc05c7 + adea01c commit 9305607
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 8 deletions.
60 changes: 54 additions & 6 deletions src/Columns/ColumnFunction.cpp
Expand Up @@ -6,6 +6,7 @@
#include <Common/assert_cast.h>
#include <IO/WriteHelpers.h>
#include <Functions/IFunction.h>
#include <DataTypes/DataTypeLowCardinality.h>


namespace ProfileEvents
Expand All @@ -23,8 +24,18 @@ namespace ErrorCodes
extern const int LOGICAL_ERROR;
}

ColumnFunction::ColumnFunction(size_t size, FunctionBasePtr function_, const ColumnsWithTypeAndName & columns_to_capture, bool is_short_circuit_argument_, bool is_function_compiled_)
: elements_size(size), function(function_), is_short_circuit_argument(is_short_circuit_argument_), is_function_compiled(is_function_compiled_)
ColumnFunction::ColumnFunction(
size_t size,
FunctionBasePtr function_,
const ColumnsWithTypeAndName & columns_to_capture,
bool is_short_circuit_argument_,
bool is_function_compiled_,
bool recursively_convert_result_to_full_column_if_low_cardinality_)
: elements_size(size)
, function(function_)
, is_short_circuit_argument(is_short_circuit_argument_)
, recursively_convert_result_to_full_column_if_low_cardinality(recursively_convert_result_to_full_column_if_low_cardinality_)
, is_function_compiled(is_function_compiled_)
{
appendArguments(columns_to_capture);
}
Expand Down Expand Up @@ -113,7 +124,13 @@ ColumnPtr ColumnFunction::filter(const Filter & filt, ssize_t result_size_hint)
else
filtered_size = capture.front().column->size();

return ColumnFunction::create(filtered_size, function, capture, is_short_circuit_argument, is_function_compiled);
return ColumnFunction::create(
filtered_size,
function,
capture,
is_short_circuit_argument,
is_function_compiled,
recursively_convert_result_to_full_column_if_low_cardinality);
}

void ColumnFunction::expand(const Filter & mask, bool inverted)
Expand All @@ -135,7 +152,13 @@ ColumnPtr ColumnFunction::permute(const Permutation & perm, size_t limit) const
for (auto & column : capture)
column.column = column.column->permute(perm, limit);

return ColumnFunction::create(limit, function, capture, is_short_circuit_argument, is_function_compiled);
return ColumnFunction::create(
limit,
function,
capture,
is_short_circuit_argument,
is_function_compiled,
recursively_convert_result_to_full_column_if_low_cardinality);
}

ColumnPtr ColumnFunction::index(const IColumn & indexes, size_t limit) const
Expand All @@ -144,7 +167,13 @@ ColumnPtr ColumnFunction::index(const IColumn & indexes, size_t limit) const
for (auto & column : capture)
column.column = column.column->index(indexes, limit);

return ColumnFunction::create(limit, function, capture, is_short_circuit_argument, is_function_compiled);
return ColumnFunction::create(
limit,
function,
capture,
is_short_circuit_argument,
is_function_compiled,
recursively_convert_result_to_full_column_if_low_cardinality);
}

std::vector<MutableColumnPtr> ColumnFunction::scatter(IColumn::ColumnIndex num_columns,
Expand Down Expand Up @@ -173,7 +202,13 @@ std::vector<MutableColumnPtr> ColumnFunction::scatter(IColumn::ColumnIndex num_c
{
auto & capture = captures[part];
size_t capture_size = capture.empty() ? counts[part] : capture.front().column->size();
columns.emplace_back(ColumnFunction::create(capture_size, function, std::move(capture), is_short_circuit_argument));
columns.emplace_back(ColumnFunction::create(
capture_size,
function,
std::move(capture),
is_short_circuit_argument,
is_function_compiled,
recursively_convert_result_to_full_column_if_low_cardinality));
}

return columns;
Expand Down Expand Up @@ -237,6 +272,9 @@ void ColumnFunction::appendArgument(const ColumnWithTypeAndName & column)

DataTypePtr ColumnFunction::getResultType() const
{
if (recursively_convert_result_to_full_column_if_low_cardinality)
return recursiveRemoveLowCardinality(function->getResultType());

return function->getResultType();
}

Expand Down Expand Up @@ -270,9 +308,19 @@ ColumnWithTypeAndName ColumnFunction::reduce() const
ProfileEvents::increment(ProfileEvents::CompiledFunctionExecute);

res.column = function->execute(columns, res.type, elements_size);
if (recursively_convert_result_to_full_column_if_low_cardinality)
{
res.column = recursiveRemoveLowCardinality(res.column);
res.type = recursiveRemoveLowCardinality(res.type);
}
return res;
}

ColumnPtr ColumnFunction::recursivelyConvertResultToFullColumnIfLowCardinality() const
{
return ColumnFunction::create(elements_size, function, captured_columns, is_short_circuit_argument, is_function_compiled, true);
}

const ColumnFunction * checkAndGetShortCircuitArgument(const ColumnPtr & column)
{
const ColumnFunction * column_function;
Expand Down
13 changes: 12 additions & 1 deletion src/Columns/ColumnFunction.h
Expand Up @@ -29,7 +29,8 @@ class ColumnFunction final : public COWHelper<IColumn, ColumnFunction>
FunctionBasePtr function_,
const ColumnsWithTypeAndName & columns_to_capture,
bool is_short_circuit_argument_ = false,
bool is_function_compiled_ = false);
bool is_function_compiled_ = false,
bool recursively_convert_result_to_full_column_if_low_cardinality_ = false);

public:
const char * getFamilyName() const override { return "Function"; }
Expand Down Expand Up @@ -177,6 +178,9 @@ class ColumnFunction final : public COWHelper<IColumn, ColumnFunction>

DataTypePtr getResultType() const;

/// Create copy of this column, but with recursively_convert_result_to_full_column_if_low_cardinality = true
ColumnPtr recursivelyConvertResultToFullColumnIfLowCardinality() const;

private:
size_t elements_size;
FunctionBasePtr function;
Expand All @@ -188,6 +192,13 @@ class ColumnFunction final : public COWHelper<IColumn, ColumnFunction>
/// See ExpressionActions.cpp for details.
bool is_short_circuit_argument;

/// Special flag for lazy executed argument for short-circuit function.
/// If true, call recursiveRemoveLowCardinality on the result column
/// when function will be executed.
/// It's used when short-circuit function uses default implementation
/// for low cardinality arguments.
bool recursively_convert_result_to_full_column_if_low_cardinality = false;

/// Determine if passed function is compiled. Used for profiling.
bool is_function_compiled;

Expand Down
12 changes: 12 additions & 0 deletions src/DataTypes/DataTypeLowCardinalityHelpers.cpp
Expand Up @@ -3,6 +3,7 @@
#include <Columns/ColumnTuple.h>
#include <Columns/ColumnMap.h>
#include <Columns/ColumnLowCardinality.h>
#include <Columns/ColumnFunction.h>

#include <DataTypes/DataTypeLowCardinality.h>
#include <DataTypes/DataTypeArray.h>
Expand Down Expand Up @@ -95,6 +96,17 @@ ColumnPtr recursiveRemoveLowCardinality(const ColumnPtr & column)
return ColumnMap::create(nested_no_lc);
}

/// Special case when column is a lazy argument of short circuit function.
/// We should call recursiveRemoveLowCardinality on the result column
/// when function will be executed.
if (const auto * column_function = typeid_cast<const ColumnFunction *>(column.get()))
{
if (!column_function->isShortCircuitArgument())
return column;

return column_function->recursivelyConvertResultToFullColumnIfLowCardinality();
}

if (const auto * column_low_cardinality = typeid_cast<const ColumnLowCardinality *>(column.get()))
return column_low_cardinality->convertToFullColumn();

Expand Down
2 changes: 1 addition & 1 deletion src/Functions/IFunction.h
Expand Up @@ -381,7 +381,7 @@ class IFunctionOverloadResolver
*/
virtual bool useDefaultImplementationForSparseColumns() const { return true; }

// /// If it isn't, will convert all ColumnLowCardinality arguments to full columns.
/// If it isn't, will convert all ColumnLowCardinality arguments to full columns.
virtual bool canBeExecutedOnLowCardinalityDictionary() const { return true; }

private:
Expand Down
@@ -0,0 +1,28 @@
if with one LC argument
b
a
b
b
a
b
a
if with LC and NULL arguments
\N
a
\N
\N
a
\N
a
if with two LC arguments
b
a
b
b
a
a
a
\N
1
1
1
@@ -0,0 +1,26 @@
set short_circuit_function_evaluation='force_enable';

select 'if with one LC argument';
select if(0, toLowCardinality('a'), 'b');
select if(1, toLowCardinality('a'), 'b');
select if(materialize(0), materialize(toLowCardinality('a')), materialize('b'));
select if(number % 2, toLowCardinality('a'), 'b') from numbers(2);
select if(number % 2, materialize(toLowCardinality('a')), materialize('b')) from numbers(2);

select 'if with LC and NULL arguments';
select if(0, toLowCardinality('a'), NULL);
select if(1, toLowCardinality('a'), NULL);
select if(materialize(0), materialize(toLowCardinality('a')), NULL);
select if(number % 2, toLowCardinality('a'), NULL) from numbers(2);
select if(number % 2, materialize(toLowCardinality('a')), NULL) from numbers(2);

select 'if with two LC arguments';
select if(0, toLowCardinality('a'), toLowCardinality('b'));
select if(1, toLowCardinality('a'), toLowCardinality('b'));
select if(materialize(0), materialize(toLowCardinality('a')), materialize(toLowCardinality('b')));
select if(number % 2, toLowCardinality('a'), toLowCardinality('b')) from numbers(2);
select if(number % 2, materialize(toLowCardinality('a')), materialize(toLowCardinality('a'))) from numbers(2);

select if(number % 2, toLowCardinality(number), NULL) from numbers(2);
select if(number % 2, toLowCardinality(number), toLowCardinality(number + 1)) from numbers(2);

0 comments on commit 9305607

Please sign in to comment.