Skip to content

Commit

Permalink
Backport #54480 to 23.3: Fix aggregate projections with normalized st…
Browse files Browse the repository at this point in the history
…ates
  • Loading branch information
robot-clickhouse committed Sep 20, 2023
1 parent 3390d54 commit fecd683
Show file tree
Hide file tree
Showing 8 changed files with 121 additions and 39 deletions.
2 changes: 1 addition & 1 deletion src/Columns/ColumnAggregateFunction.cpp
Expand Up @@ -73,7 +73,7 @@ ColumnAggregateFunction::ColumnAggregateFunction(const AggregateFunctionPtr & fu

}

void ColumnAggregateFunction::set(const AggregateFunctionPtr & func_, size_t version_)
void ColumnAggregateFunction::set(const AggregateFunctionPtr & func_, std::optional<size_t> version_)
{
func = func_;
version = version_;
Expand Down
2 changes: 1 addition & 1 deletion src/Columns/ColumnAggregateFunction.h
Expand Up @@ -103,7 +103,7 @@ class ColumnAggregateFunction final : public COWHelper<IColumn, ColumnAggregateF
public:
~ColumnAggregateFunction() override;

void set(const AggregateFunctionPtr & func_, size_t version_);
void set(const AggregateFunctionPtr & func_, std::optional<size_t> version_ = std::nullopt);

AggregateFunctionPtr getAggregateFunction() { return func; }
AggregateFunctionPtr getAggregateFunction() const { return func; }
Expand Down
48 changes: 24 additions & 24 deletions src/DataTypes/DataTypeAggregateFunction.cpp
Expand Up @@ -117,43 +117,43 @@ Field DataTypeAggregateFunction::getDefault() const
return field;
}


bool DataTypeAggregateFunction::equals(const IDataType & rhs) const
bool DataTypeAggregateFunction::strictEquals(const DataTypePtr & lhs_state_type, const DataTypePtr & rhs_state_type)
{
if (typeid(rhs) != typeid(*this))
return false;
const auto * lhs_state = typeid_cast<const DataTypeAggregateFunction *>(lhs_state_type.get());
const auto * rhs_state = typeid_cast<const DataTypeAggregateFunction *>(rhs_state_type.get());

auto lhs_state_type = function->getNormalizedStateType();
auto rhs_state_type = typeid_cast<const DataTypeAggregateFunction &>(rhs).function->getNormalizedStateType();

if (typeid(lhs_state_type.get()) != typeid(rhs_state_type.get()))
if (!lhs_state || !rhs_state)
return false;

if (const auto * lhs_state = typeid_cast<const DataTypeAggregateFunction *>(lhs_state_type.get()))
{
const auto & rhs_state = typeid_cast<const DataTypeAggregateFunction &>(*rhs_state_type);
if (lhs_state->function->getName() != rhs_state->function->getName())
return false;

if (lhs_state->function->getName() != rhs_state.function->getName())
return false;
if (lhs_state->parameters.size() != rhs_state->parameters.size())
return false;

if (lhs_state->parameters.size() != rhs_state.parameters.size())
for (size_t i = 0; i < lhs_state->parameters.size(); ++i)
if (lhs_state->parameters[i] != rhs_state->parameters[i])
return false;

for (size_t i = 0; i < lhs_state->parameters.size(); ++i)
if (lhs_state->parameters[i] != rhs_state.parameters[i])
return false;
if (lhs_state->argument_types.size() != rhs_state->argument_types.size())
return false;

if (lhs_state->argument_types.size() != rhs_state.argument_types.size())
for (size_t i = 0; i < lhs_state->argument_types.size(); ++i)
if (!lhs_state->argument_types[i]->equals(*rhs_state->argument_types[i]))
return false;

for (size_t i = 0; i < lhs_state->argument_types.size(); ++i)
if (!lhs_state->argument_types[i]->equals(*rhs_state.argument_types[i]))
return false;
return true;
}

return true;
}
bool DataTypeAggregateFunction::equals(const IDataType & rhs) const
{
if (typeid(rhs) != typeid(*this))
return false;

auto lhs_state_type = function->getNormalizedStateType();
auto rhs_state_type = typeid_cast<const DataTypeAggregateFunction &>(rhs).function->getNormalizedStateType();

return lhs_state_type->equals(*rhs_state_type);
return strictEquals(lhs_state_type, rhs_state_type);
}


Expand Down
1 change: 1 addition & 0 deletions src/DataTypes/DataTypeAggregateFunction.h
Expand Up @@ -59,6 +59,7 @@ class DataTypeAggregateFunction final : public IDataType

Field getDefault() const override;

static bool strictEquals(const DataTypePtr & lhs_state_type, const DataTypePtr & rhs_state_type);
bool equals(const IDataType & rhs) const override;

bool isParametric() const override { return true; }
Expand Down
50 changes: 43 additions & 7 deletions src/Functions/FunctionsConversion.h
Expand Up @@ -33,6 +33,7 @@
#include <Columns/ColumnString.h>
#include <Columns/ColumnFixedString.h>
#include <Columns/ColumnConst.h>
#include <Columns/ColumnAggregateFunction.h>
#include <Columns/ColumnArray.h>
#include <Columns/ColumnNullable.h>
#include <Columns/ColumnTuple.h>
Expand Down Expand Up @@ -3080,14 +3081,40 @@ class FunctionCast final : public FunctionCastBase
{
return &ConvertImplGenericFromString<ColumnString>::execute;
}
else
else if (const auto * agg_type = checkAndGetDataType<DataTypeAggregateFunction>(from_type_untyped.get()))
{
if (cast_type == CastType::accurateOrNull)
return createToNullableColumnWrapper();
else
throw Exception(ErrorCodes::CANNOT_CONVERT_TYPE, "Conversion from {} to {} is not supported",
from_type_untyped->getName(), to_type->getName());
if (agg_type->getFunction()->haveSameStateRepresentation(*to_type->getFunction()))
{
return [function = to_type->getFunction()](
ColumnsWithTypeAndName & arguments,
const DataTypePtr & /* result_type */,
const ColumnNullable * /* nullable_source */,
size_t /*input_rows_count*/) -> ColumnPtr
{
const auto & argument_column = arguments.front();
const auto * col_agg = checkAndGetColumn<ColumnAggregateFunction>(argument_column.column.get());
if (col_agg)
{
auto new_col_agg = ColumnAggregateFunction::create(*col_agg);
new_col_agg->set(function);
return new_col_agg;
}
else
{
throw Exception(
ErrorCodes::LOGICAL_ERROR,
"Illegal column {} for function CAST AS AggregateFunction",
argument_column.column->getName());
}
};
}
}

if (cast_type == CastType::accurateOrNull)
return createToNullableColumnWrapper();
else
throw Exception(ErrorCodes::CANNOT_CONVERT_TYPE, "Conversion from {} to {} is not supported",
from_type_untyped->getName(), to_type->getName());
}

WrapperType createArrayWrapper(const DataTypePtr & from_type_untyped, const DataTypeArray & to_type) const
Expand Down Expand Up @@ -3854,7 +3881,16 @@ class FunctionCast final : public FunctionCastBase
safe_convert_custom_types = to_type->getCustomName() && from_type_custom_name->getName() == to_type->getCustomName()->getName();

if (from_type->equals(*to_type) && safe_convert_custom_types)
return createIdentityWrapper(from_type);
{
/// We can only use identity conversion for DataTypeAggregateFunction when they are strictly equivalent.
if (typeid_cast<const DataTypeAggregateFunction *>(from_type.get()))
{
if (DataTypeAggregateFunction::strictEquals(from_type, to_type))
return createIdentityWrapper(from_type);
}
else
return createIdentityWrapper(from_type);
}
else if (WhichDataType(from_type).isNothing())
return createNothingWrapper(to_type.get());

Expand Down
Expand Up @@ -155,12 +155,12 @@ std::optional<AggregateFunctionMatches> matchAggregateFunctions(
argument_types.clear();
const auto & candidate = info.aggregates[idx];

/// Note: this check is a bit strict.
/// We check that aggregate function names, argument types and parameters are equal.
/// In some cases it's possible only to check that states are equal,
/// e.g. for quantile(0.3)(...) and quantile(0.5)(...).
/// But also functions sum(...) and sumIf(...) will have equal states,
/// and we can't replace one to another from projection.
///
/// Note we already checked that aggregate function names are equal,
/// so that functions sum(...) and sumIf(...) with equal states will
/// not match.
if (!candidate.function->getStateType()->equals(*aggregate.function->getStateType()))
{
// LOG_TRACE(&Poco::Logger::get("optimizeUseProjections"), "Cannot match agg func {} vs {} by state {} vs {}",
Expand Down Expand Up @@ -267,12 +267,24 @@ static void appendAggregateFunctions(

auto & input = inputs[match.description];
if (!input)
input = &proj_dag.addInput(match.description->column_name, std::move(type));
input = &proj_dag.addInput(match.description->column_name, type);

const auto * node = input;

if (node->result_name != aggregate.column_name)
node = &proj_dag.addAlias(*node, aggregate.column_name);
{
if (DataTypeAggregateFunction::strictEquals(type, node->result_type))
{
node = &proj_dag.addAlias(*node, aggregate.column_name);
}
else
{
/// Cast to aggregate types specified in query if it's not
/// strictly the same as the one specified in projection. This
/// is required to generate correct results during finalization.
node = &proj_dag.addCast(*node, type, aggregate.column_name);
}
}

proj_dag_outputs.push_back(node);
}
Expand Down
@@ -0,0 +1,2 @@
3
950 990 500 2000
@@ -0,0 +1,31 @@
DROP TABLE IF EXISTS r;

select finalizeAggregation(cast(quantileState(0)(arrayJoin([1,2,3])) as AggregateFunction(quantile(1), UInt8)));

CREATE TABLE r (
x String,
a LowCardinality(String),
q AggregateFunction(quantilesTiming(0.5, 0.95, 0.99), Int64),
s Int64,
PROJECTION p
(SELECT a, quantilesTimingMerge(0.5, 0.95, 0.99)(q), sum(s) GROUP BY a)
) Engine=SummingMergeTree order by (x, a);

insert into r
select number%100 x,
'x' a,
quantilesTimingState(0.5, 0.95, 0.99)(number::Int64) q,
sum(1) s
from numbers(1000)
group by x,a;

SELECT
ifNotFinite(quantilesTimingMerge(0.95)(q)[1],0) as d1,
ifNotFinite(quantilesTimingMerge(0.99)(q)[1],0) as d2,
ifNotFinite(quantilesTimingMerge(0.50)(q)[1],0) as d3,
sum(s)
FROM cluster('test_cluster_two_shards', currentDatabase(), r)
WHERE a = 'x'
settings prefer_localhost_replica=0;

DROP TABLE r;

0 comments on commit fecd683

Please sign in to comment.