Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix aggregate projections with normalized states #54480

Merged
merged 3 commits into from Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Columns/ColumnAggregateFunction.cpp
Expand Up @@ -73,7 +73,7 @@ ColumnAggregateFunction::ColumnAggregateFunction(const AggregateFunctionPtr & fu

}

void ColumnAggregateFunction::set(const AggregateFunctionPtr & func_, size_t version_)
void ColumnAggregateFunction::set(const AggregateFunctionPtr & func_, std::optional<size_t> version_)
{
func = func_;
version = version_;
Expand Down
2 changes: 1 addition & 1 deletion src/Columns/ColumnAggregateFunction.h
Expand Up @@ -103,7 +103,7 @@ class ColumnAggregateFunction final : public COWHelper<IColumn, ColumnAggregateF
public:
~ColumnAggregateFunction() override;

void set(const AggregateFunctionPtr & func_, size_t version_);
void set(const AggregateFunctionPtr & func_, std::optional<size_t> version_ = std::nullopt);

AggregateFunctionPtr getAggregateFunction() { return func; }
AggregateFunctionPtr getAggregateFunction() const { return func; }
Expand Down
48 changes: 24 additions & 24 deletions src/DataTypes/DataTypeAggregateFunction.cpp
Expand Up @@ -117,43 +117,43 @@ Field DataTypeAggregateFunction::getDefault() const
return field;
}


bool DataTypeAggregateFunction::equals(const IDataType & rhs) const
bool DataTypeAggregateFunction::strictEquals(const DataTypePtr & lhs_state_type, const DataTypePtr & rhs_state_type)
{
if (typeid(rhs) != typeid(*this))
return false;
const auto * lhs_state = typeid_cast<const DataTypeAggregateFunction *>(lhs_state_type.get());
const auto * rhs_state = typeid_cast<const DataTypeAggregateFunction *>(rhs_state_type.get());

auto lhs_state_type = function->getNormalizedStateType();
auto rhs_state_type = typeid_cast<const DataTypeAggregateFunction &>(rhs).function->getNormalizedStateType();

if (typeid(lhs_state_type.get()) != typeid(rhs_state_type.get()))
if (!lhs_state || !rhs_state)
return false;

if (const auto * lhs_state = typeid_cast<const DataTypeAggregateFunction *>(lhs_state_type.get()))
{
const auto & rhs_state = typeid_cast<const DataTypeAggregateFunction &>(*rhs_state_type);
if (lhs_state->function->getName() != rhs_state->function->getName())
return false;

if (lhs_state->function->getName() != rhs_state.function->getName())
return false;
if (lhs_state->parameters.size() != rhs_state->parameters.size())
return false;

if (lhs_state->parameters.size() != rhs_state.parameters.size())
for (size_t i = 0; i < lhs_state->parameters.size(); ++i)
if (lhs_state->parameters[i] != rhs_state->parameters[i])
return false;

for (size_t i = 0; i < lhs_state->parameters.size(); ++i)
if (lhs_state->parameters[i] != rhs_state.parameters[i])
return false;
if (lhs_state->argument_types.size() != rhs_state->argument_types.size())
return false;

if (lhs_state->argument_types.size() != rhs_state.argument_types.size())
for (size_t i = 0; i < lhs_state->argument_types.size(); ++i)
if (!lhs_state->argument_types[i]->equals(*rhs_state->argument_types[i]))
return false;

for (size_t i = 0; i < lhs_state->argument_types.size(); ++i)
if (!lhs_state->argument_types[i]->equals(*rhs_state.argument_types[i]))
return false;
return true;
}

return true;
}
bool DataTypeAggregateFunction::equals(const IDataType & rhs) const
{
if (typeid(rhs) != typeid(*this))
return false;

auto lhs_state_type = function->getNormalizedStateType();
auto rhs_state_type = typeid_cast<const DataTypeAggregateFunction &>(rhs).function->getNormalizedStateType();

return lhs_state_type->equals(*rhs_state_type);
return strictEquals(lhs_state_type, rhs_state_type);
}


Expand Down
1 change: 1 addition & 0 deletions src/DataTypes/DataTypeAggregateFunction.h
Expand Up @@ -60,6 +60,7 @@ class DataTypeAggregateFunction final : public IDataType

Field getDefault() const override;

static bool strictEquals(const DataTypePtr & lhs_state_type, const DataTypePtr & rhs_state_type);
bool equals(const IDataType & rhs) const override;

bool isParametric() const override { return true; }
Expand Down
50 changes: 43 additions & 7 deletions src/Functions/FunctionsConversion.h
Expand Up @@ -33,6 +33,7 @@
#include <Columns/ColumnString.h>
#include <Columns/ColumnFixedString.h>
#include <Columns/ColumnConst.h>
#include <Columns/ColumnAggregateFunction.h>
#include <Columns/ColumnArray.h>
#include <Columns/ColumnNullable.h>
#include <Columns/ColumnTuple.h>
Expand Down Expand Up @@ -3188,14 +3189,40 @@ class FunctionCast final : public FunctionCastBase
{
return &ConvertImplGenericFromString<ColumnString>::execute;
}
else
else if (const auto * agg_type = checkAndGetDataType<DataTypeAggregateFunction>(from_type_untyped.get()))
{
if (cast_type == CastType::accurateOrNull)
return createToNullableColumnWrapper();
else
throw Exception(ErrorCodes::CANNOT_CONVERT_TYPE, "Conversion from {} to {} is not supported",
from_type_untyped->getName(), to_type->getName());
if (agg_type->getFunction()->haveSameStateRepresentation(*to_type->getFunction()))
{
return [function = to_type->getFunction()](
ColumnsWithTypeAndName & arguments,
const DataTypePtr & /* result_type */,
const ColumnNullable * /* nullable_source */,
size_t /*input_rows_count*/) -> ColumnPtr
{
const auto & argument_column = arguments.front();
const auto * col_agg = checkAndGetColumn<ColumnAggregateFunction>(argument_column.column.get());
if (col_agg)
{
auto new_col_agg = ColumnAggregateFunction::create(*col_agg);
new_col_agg->set(function);
return new_col_agg;
}
else
{
throw Exception(
ErrorCodes::LOGICAL_ERROR,
"Illegal column {} for function CAST AS AggregateFunction",
argument_column.column->getName());
}
};
}
}

if (cast_type == CastType::accurateOrNull)
return createToNullableColumnWrapper();
else
throw Exception(ErrorCodes::CANNOT_CONVERT_TYPE, "Conversion from {} to {} is not supported",
from_type_untyped->getName(), to_type->getName());
}

WrapperType createArrayWrapper(const DataTypePtr & from_type_untyped, const DataTypeArray & to_type) const
Expand Down Expand Up @@ -3976,7 +4003,16 @@ class FunctionCast final : public FunctionCastBase
safe_convert_custom_types = to_type->getCustomName() && from_type_custom_name->getName() == to_type->getCustomName()->getName();

if (from_type->equals(*to_type) && safe_convert_custom_types)
return createIdentityWrapper(from_type);
{
/// We can only use identity conversion for DataTypeAggregateFunction when they are strictly equivalent.
if (typeid_cast<const DataTypeAggregateFunction *>(from_type.get()))
{
if (DataTypeAggregateFunction::strictEquals(from_type, to_type))
return createIdentityWrapper(from_type);
}
else
return createIdentityWrapper(from_type);
}
else if (WhichDataType(from_type).isNothing())
return createNothingWrapper(to_type.get());

Expand Down
Expand Up @@ -143,12 +143,12 @@ std::optional<AggregateFunctionMatches> matchAggregateFunctions(
argument_types.clear();
const auto & candidate = info.aggregates[idx];

/// Note: this check is a bit strict.
/// We check that aggregate function names, argument types and parameters are equal.
/// In some cases it's possible only to check that states are equal,
/// e.g. for quantile(0.3)(...) and quantile(0.5)(...).
/// But also functions sum(...) and sumIf(...) will have equal states,
/// and we can't replace one to another from projection.
///
/// Note we already checked that aggregate function names are equal,
/// so that functions sum(...) and sumIf(...) with equal states will
/// not match.
if (!candidate.function->getStateType()->equals(*aggregate.function->getStateType()))
{
// LOG_TRACE(&Poco::Logger::get("optimizeUseProjections"), "Cannot match agg func {} vs {} by state {} vs {}",
Expand Down Expand Up @@ -249,12 +249,24 @@ static void appendAggregateFunctions(

auto & input = inputs[match.description];
if (!input)
input = &proj_dag.addInput(match.description->column_name, std::move(type));
input = &proj_dag.addInput(match.description->column_name, type);

const auto * node = input;

if (node->result_name != aggregate.column_name)
node = &proj_dag.addAlias(*node, aggregate.column_name);
{
if (DataTypeAggregateFunction::strictEquals(type, node->result_type))
{
node = &proj_dag.addAlias(*node, aggregate.column_name);
}
else
{
/// Cast to aggregate types specified in query if it's not
/// strictly the same as the one specified in projection. This
/// is required to generate correct results during finalization.
node = &proj_dag.addCast(*node, type, aggregate.column_name);
}
}

proj_dag_outputs.push_back(node);
}
Expand Down
@@ -0,0 +1,2 @@
3
950 990 500 2000
@@ -0,0 +1,31 @@
DROP TABLE IF EXISTS r;

select finalizeAggregation(cast(quantileState(0)(arrayJoin([1,2,3])) as AggregateFunction(quantile(1), UInt8)));

CREATE TABLE r (
x String,
a LowCardinality(String),
q AggregateFunction(quantilesTiming(0.5, 0.95, 0.99), Int64),
s Int64,
PROJECTION p
(SELECT a, quantilesTimingMerge(0.5, 0.95, 0.99)(q), sum(s) GROUP BY a)
) Engine=SummingMergeTree order by (x, a);

insert into r
select number%100 x,
'x' a,
quantilesTimingState(0.5, 0.95, 0.99)(number::Int64) q,
sum(1) s
from numbers(1000)
group by x,a;

SELECT
ifNotFinite(quantilesTimingMerge(0.95)(q)[1],0) as d1,
ifNotFinite(quantilesTimingMerge(0.99)(q)[1],0) as d2,
ifNotFinite(quantilesTimingMerge(0.50)(q)[1],0) as d3,
sum(s)
FROM cluster('test_cluster_two_shards', currentDatabase(), r)
WHERE a = 'x'
settings prefer_localhost_replica=0;

DROP TABLE r;