Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New analyzer pass to optimize in single value #61564

Merged
merged 12 commits into from Mar 25, 2024
73 changes: 73 additions & 0 deletions src/Analyzer/Passes/ConvertInToEqualPass.cpp
@@ -0,0 +1,73 @@
#include <Analyzer/ColumnNode.h>
#include <Analyzer/ConstantNode.h>
#include <Analyzer/FunctionNode.h>
#include <Analyzer/InDepthQueryTreeVisitor.h>
#include <Analyzer/Passes/ConvertInToEqualPass.h>
#include <Functions/equals.h>
#include <Functions/notEquals.h>

namespace DB
{

class ConvertInToEqualPassVisitor : public InDepthQueryTreeVisitorWithContext<ConvertInToEqualPassVisitor>
{
public:
using Base = InDepthQueryTreeVisitorWithContext<ConvertInToEqualPassVisitor>;
using Base::Base;

void enterImpl(QueryTreeNodePtr & node)
{
static const std::unordered_map<String, String> MAPPING = {
{"in", "equals"},
{"notIn", "notEquals"}
};
auto * func_node = node->as<FunctionNode>();
if (!func_node
|| !MAPPING.contains(func_node->getFunctionName())
|| func_node->getArguments().getNodes().size() != 2)
return ;
auto args = func_node->getArguments().getNodes();
auto * column_node = args[0]->as<ColumnNode>();
auto * constant_node = args[1]->as<ConstantNode>();
if (!column_node || !constant_node)
return ;
// IN multiple values is not supported
if (constant_node->getValue().getType() == Field::Types::Which::Tuple
|| constant_node->getValue().getType() == Field::Types::Which::Array)
return ;
// x IN null not equivalent to x = null
if (constant_node->getValue().isNull())
return ;
auto result_func_name = MAPPING.at(func_node->getFunctionName());
auto equal = std::make_shared<FunctionNode>(result_func_name);
QueryTreeNodes arguments{column_node->clone(), constant_node->clone()};
equal->getArguments().getNodes() = std::move(arguments);
FunctionOverloadResolverPtr resolver;
bool decimal_check_overflow = getContext()->getSettingsRef().decimal_check_overflow;
if (result_func_name == "equals")
{
resolver = createInternalFunctionEqualOverloadResolver(decimal_check_overflow);
}
else
{
resolver = createInternalFunctionNotEqualOverloadResolver(decimal_check_overflow);
}
try
{
equal->resolveAsFunction(resolver);
}
catch (...)
{
// When function resolver fails, we should not replace the function node
return;
}
node = equal;
}
};

void ConvertInToEqualPass::run(QueryTreeNodePtr & query_tree_node, ContextPtr context)
{
ConvertInToEqualPassVisitor visitor(std::move(context));
visitor.visit(query_tree_node);
}
}
27 changes: 27 additions & 0 deletions src/Analyzer/Passes/ConvertInToEqualPass.h
@@ -0,0 +1,27 @@
#pragma once

#include <Analyzer/IQueryTreePass.h>

namespace DB
{
/** Optimize `in` to `equals` if possible.
* 1. convert in single value to equal
* Example: SELECT * from test where x IN (1);
* Result: SELECT * from test where x = 1;
*
* 2. convert not in single value to notEqual
* Example: SELECT * from test where x NOT IN (1);
* Result: SELECT * from test where x != 1;
*
* If value is null or tuple, do not convert.
*/
class ConvertInToEqualPass final : public IQueryTreePass
{
public:
String getName() override { return "ConvertInToEqualPass"; }

String getDescription() override { return "Convert in to equal"; }

void run(QueryTreeNodePtr & query_tree_node, ContextPtr context) override;
};
}
2 changes: 2 additions & 0 deletions src/Analyzer/QueryTreePassManager.cpp
Expand Up @@ -28,6 +28,7 @@
#include <Analyzer/Passes/MultiIfToIfPass.h>
#include <Analyzer/Passes/IfConstantConditionPass.h>
#include <Analyzer/Passes/IfChainToMultiIfPass.h>
#include <Analyzer/Passes/ConvertInToEqualPass.h>
#include <Analyzer/Passes/OrderByTupleEliminationPass.h>
#include <Analyzer/Passes/NormalizeCountVariantsPass.h>
#include <Analyzer/Passes/AggregateFunctionsArithmericOperationsPass.h>
Expand Down Expand Up @@ -263,6 +264,7 @@ void addQueryTreePasses(QueryTreePassManager & manager, bool only_analyze)
manager.addPass(std::make_unique<SumIfToCountIfPass>());
manager.addPass(std::make_unique<RewriteArrayExistsToHasPass>());
manager.addPass(std::make_unique<NormalizeCountVariantsPass>());
manager.addPass(std::make_unique<ConvertInToEqualPass>());

/// should before AggregateFunctionsArithmericOperationsPass
manager.addPass(std::make_unique<AggregateFunctionOfGroupByKeysPass>());
Expand Down
2 changes: 2 additions & 0 deletions src/Functions/CMakeLists.txt
Expand Up @@ -14,6 +14,8 @@ extract_into_parent_list(clickhouse_functions_sources dbms_sources
multiMatchAny.cpp
checkHyperscanRegexp.cpp
array/has.cpp
equals.cpp
notEquals.cpp
CastOverloadResolver.cpp
)
extract_into_parent_list(clickhouse_functions_headers dbms_headers
Expand Down
5 changes: 5 additions & 0 deletions src/Functions/equals.cpp
Expand Up @@ -13,6 +13,11 @@ REGISTER_FUNCTION(Equals)
factory.registerFunction<FunctionEquals>();
}

FunctionOverloadResolverPtr createInternalFunctionEqualOverloadResolver(bool decimal_check_overflow)
{
return std::make_unique<FunctionToOverloadResolverAdaptor>(std::make_shared<FunctionEquals>(decimal_check_overflow));
}

template <>
ColumnPtr FunctionComparison<EqualsOp, NameEquals>::executeTupleImpl(
const ColumnsWithTypeAndName & x, const ColumnsWithTypeAndName & y, size_t tuple_size, size_t input_rows_count) const
Expand Down
11 changes: 11 additions & 0 deletions src/Functions/equals.h
@@ -0,0 +1,11 @@
#pragma once
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the purpose of creating a .h files for equals and notEquals? I think we can avoid this behaviour.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will also use it for other subsequent optimizations, which can be better reused through .h

#include <memory>

namespace DB
{

class IFunctionOverloadResolver;
using FunctionOverloadResolverPtr = std::shared_ptr<IFunctionOverloadResolver>;

FunctionOverloadResolverPtr createInternalFunctionEqualOverloadResolver(bool decimal_check_overflow);
}
5 changes: 5 additions & 0 deletions src/Functions/notEquals.cpp
Expand Up @@ -12,6 +12,11 @@ REGISTER_FUNCTION(NotEquals)
factory.registerFunction<FunctionNotEquals>();
}

FunctionOverloadResolverPtr createInternalFunctionNotEqualOverloadResolver(bool decimal_check_overflow)
{
return std::make_unique<FunctionToOverloadResolverAdaptor>(std::make_shared<FunctionNotEquals>(decimal_check_overflow));
}

template <>
ColumnPtr FunctionComparison<NotEqualsOp, NameNotEquals>::executeTupleImpl(
const ColumnsWithTypeAndName & x, const ColumnsWithTypeAndName & y, size_t tuple_size, size_t input_rows_count) const
Expand Down
11 changes: 11 additions & 0 deletions src/Functions/notEquals.h
@@ -0,0 +1,11 @@
#pragma once
#include <memory>

namespace DB
{

class IFunctionOverloadResolver;
using FunctionOverloadResolverPtr = std::shared_ptr<IFunctionOverloadResolver>;

FunctionOverloadResolverPtr createInternalFunctionNotEqualOverloadResolver(bool decimal_check_overflow);
}
Expand Up @@ -306,7 +306,8 @@ TEST(TransformQueryForExternalDatabase, Aliases)

check(state, 1, {"field"},
"SELECT field AS value, field AS display FROM table WHERE field NOT IN ('') AND display LIKE '%test%'",
R"(SELECT "field" FROM "test"."table" WHERE ("field" NOT IN ('')) AND ("field" LIKE '%test%'))");
R"(SELECT "field" FROM "test"."table" WHERE ("field" NOT IN ('')) AND ("field" LIKE '%test%'))",
R"(SELECT "field" FROM "test"."table" WHERE ("field" != '') AND ("field" LIKE '%test%'))");
}

TEST(TransformQueryForExternalDatabase, ForeignColumnInWhere)
Expand Down Expand Up @@ -408,5 +409,6 @@ TEST(TransformQueryForExternalDatabase, Analyzer)

check(state, 1, {"column", "apply_id", "apply_type", "apply_status", "create_time", "field", "value", "a", "b", "foo"},
"SELECT * FROM table WHERE (column) IN (1)",
R"(SELECT "column", "apply_id", "apply_type", "apply_status", "create_time", "field", "value", "a", "b", "foo" FROM "test"."table" WHERE "column" IN (1))");
R"(SELECT "column", "apply_id", "apply_type", "apply_status", "create_time", "field", "value", "a", "b", "foo" FROM "test"."table" WHERE "column" IN (1))",
R"(SELECT "column", "apply_id", "apply_type", "apply_status", "create_time", "field", "value", "a", "b", "foo" FROM "test"."table" WHERE "column" = 1)");
}
28 changes: 28 additions & 0 deletions tests/performance/function_in.xml
@@ -0,0 +1,28 @@
<test>
<settings>
<max_insert_threads>8</max_insert_threads>
<max_threads>1</max_threads>
</settings>

<create_query>
CREATE TABLE t_nullable
(
key_string1 Nullable(String),
key_string2 Nullable(String),
key_string3 Nullable(String),
key_int64_1 Nullable(Int64),
key_int64_2 Nullable(Int64),
key_int64_3 Nullable(Int64),
key_int64_4 Nullable(Int64),
key_int64_5 Nullable(Int64),
m1 Int64,
m2 Int64
)
ENGINE = Memory
</create_query>
<fill_query>insert into t_nullable select ['aaaaaa','bbaaaa','ccaaaa','ddaaaa'][number % 101 + 1], ['aa','bb','cc','dd'][number % 100 + 1], ['aa','bb','cc','dd'][number % 102 + 1], number%10+1, number%10+2, number%10+3, number%10+4,number%10+5, number%6000+1, number%5000+2 from numbers_mt(30000000)</fill_query>
<query>select * from t_nullable where key_string1 in ('aaaaaa') format Null SETTINGS allow_experimental_analyzer=1</query>
<query>select * from t_nullable where key_string2 in ('3') format Null SETTINGS allow_experimental_analyzer=1</query>
<drop_query>drop table if exists t_nullable</drop_query>

</test>