Skip to content

Commit

Permalink
DataTypeAggregateFunction::strictEquals
Browse files Browse the repository at this point in the history
  • Loading branch information
amosbird committed Sep 11, 2023
1 parent 9e56cff commit 667426f
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 30 deletions.
48 changes: 24 additions & 24 deletions src/DataTypes/DataTypeAggregateFunction.cpp
Expand Up @@ -117,43 +117,43 @@ Field DataTypeAggregateFunction::getDefault() const
return field;
}


bool DataTypeAggregateFunction::equals(const IDataType & rhs) const
bool DataTypeAggregateFunction::strictEquals(const DataTypePtr & lhs_state_type, const DataTypePtr & rhs_state_type)
{
if (typeid(rhs) != typeid(*this))
return false;
const auto * lhs_state = typeid_cast<const DataTypeAggregateFunction *>(lhs_state_type.get());
const auto * rhs_state = typeid_cast<const DataTypeAggregateFunction *>(rhs_state_type.get());

auto lhs_state_type = function->getNormalizedStateType();
auto rhs_state_type = typeid_cast<const DataTypeAggregateFunction &>(rhs).function->getNormalizedStateType();

if (typeid(lhs_state_type.get()) != typeid(rhs_state_type.get()))
if (!lhs_state || !rhs_state)
return false;

if (const auto * lhs_state = typeid_cast<const DataTypeAggregateFunction *>(lhs_state_type.get()))
{
const auto & rhs_state = typeid_cast<const DataTypeAggregateFunction &>(*rhs_state_type);
if (lhs_state->function->getName() != rhs_state->function->getName())
return false;

if (lhs_state->function->getName() != rhs_state.function->getName())
return false;
if (lhs_state->parameters.size() != rhs_state->parameters.size())
return false;

if (lhs_state->parameters.size() != rhs_state.parameters.size())
for (size_t i = 0; i < lhs_state->parameters.size(); ++i)
if (lhs_state->parameters[i] != rhs_state->parameters[i])
return false;

for (size_t i = 0; i < lhs_state->parameters.size(); ++i)
if (lhs_state->parameters[i] != rhs_state.parameters[i])
return false;
if (lhs_state->argument_types.size() != rhs_state->argument_types.size())
return false;

if (lhs_state->argument_types.size() != rhs_state.argument_types.size())
for (size_t i = 0; i < lhs_state->argument_types.size(); ++i)
if (!lhs_state->argument_types[i]->equals(*rhs_state->argument_types[i]))
return false;

for (size_t i = 0; i < lhs_state->argument_types.size(); ++i)
if (!lhs_state->argument_types[i]->equals(*rhs_state.argument_types[i]))
return false;
return true;
}

return true;
}
bool DataTypeAggregateFunction::equals(const IDataType & rhs) const
{
if (typeid(rhs) != typeid(*this))
return false;

auto lhs_state_type = function->getNormalizedStateType();
auto rhs_state_type = typeid_cast<const DataTypeAggregateFunction &>(rhs).function->getNormalizedStateType();

return lhs_state_type->equals(*rhs_state_type);
return strictEquals(lhs_state_type, rhs_state_type);
}


Expand Down
1 change: 1 addition & 0 deletions src/DataTypes/DataTypeAggregateFunction.h
Expand Up @@ -60,6 +60,7 @@ class DataTypeAggregateFunction final : public IDataType

Field getDefault() const override;

static bool strictEquals(const DataTypePtr & lhs_state_type, const DataTypePtr & rhs_state_type);
bool equals(const IDataType & rhs) const override;

bool isParametric() const override { return true; }
Expand Down
13 changes: 11 additions & 2 deletions src/Functions/FunctionsConversion.h
Expand Up @@ -3193,7 +3193,7 @@ class FunctionCast final : public FunctionCastBase
{
if (agg_type->getFunction()->haveSameStateRepresentation(*to_type->getFunction()))
{
return [function = agg_type->getFunction()](
return [function = to_type->getFunction()](
ColumnsWithTypeAndName & arguments,
const DataTypePtr & /* result_type */,
const ColumnNullable * /* nullable_source */,
Expand Down Expand Up @@ -4003,7 +4003,16 @@ class FunctionCast final : public FunctionCastBase
safe_convert_custom_types = to_type->getCustomName() && from_type_custom_name->getName() == to_type->getCustomName()->getName();

if (from_type->equals(*to_type) && safe_convert_custom_types)
return createIdentityWrapper(from_type);
{
/// We can only use identity conversion for DataTypeAggregateFunction when they are strictly equivalent.
if (typeid_cast<const DataTypeAggregateFunction *>(from_type.get()))
{
if (DataTypeAggregateFunction::strictEquals(from_type, to_type))
return createIdentityWrapper(from_type);
}
else
return createIdentityWrapper(from_type);
}
else if (WhichDataType(from_type).isNothing())
return createNothingWrapper(to_type.get());

Expand Down
Expand Up @@ -255,10 +255,17 @@ static void appendAggregateFunctions(

if (node->result_name != aggregate.column_name)
{
/// Always cast to aggregate types specified in query, because input
/// columns from projection might have the same state but different
/// type, which can generate wrong results during finalization.
node = &proj_dag.addCast(*node, type, aggregate.column_name);
if (DataTypeAggregateFunction::strictEquals(type, node->result_type))
{
node = &proj_dag.addAlias(*node, aggregate.column_name);
}
else
{
/// Cast to aggregate types specified in query if it's not
/// strictly the same as the one specified in projection. This
/// is required to generate correct results during finalization.
node = &proj_dag.addCast(*node, type, aggregate.column_name);
}
}

proj_dag_outputs.push_back(node);
Expand Down
@@ -1 +1,2 @@
3
950 990 500 2000
@@ -1,5 +1,7 @@
DROP TABLE IF EXISTS r;

select finalizeAggregation(cast(quantileState(0)(arrayJoin([1,2,3])) as AggregateFunction(quantile(1), UInt8)));

CREATE TABLE r (
x String,
a LowCardinality(String),
Expand Down

0 comments on commit 667426f

Please sign in to comment.