Skip to content

Commit

Permalink
Merge pull request ClickHouse#61564 from liuneng1994/optimize_in_sing…
Browse files Browse the repository at this point in the history
…le_value

New analyzer pass to optimize in single value
  • Loading branch information
yariks5s committed Mar 25, 2024
2 parents a642f4d + 97bd6ec commit 20a45b4
Show file tree
Hide file tree
Showing 12 changed files with 385 additions and 2 deletions.
73 changes: 73 additions & 0 deletions src/Analyzer/Passes/ConvertInToEqualPass.cpp
@@ -0,0 +1,73 @@
#include <Analyzer/ColumnNode.h>
#include <Analyzer/ConstantNode.h>
#include <Analyzer/FunctionNode.h>
#include <Analyzer/InDepthQueryTreeVisitor.h>
#include <Analyzer/Passes/ConvertInToEqualPass.h>
#include <Functions/equals.h>
#include <Functions/notEquals.h>

namespace DB
{

class ConvertInToEqualPassVisitor : public InDepthQueryTreeVisitorWithContext<ConvertInToEqualPassVisitor>
{
public:
using Base = InDepthQueryTreeVisitorWithContext<ConvertInToEqualPassVisitor>;
using Base::Base;

void enterImpl(QueryTreeNodePtr & node)
{
static const std::unordered_map<String, String> MAPPING = {
{"in", "equals"},
{"notIn", "notEquals"}
};
auto * func_node = node->as<FunctionNode>();
if (!func_node
|| !MAPPING.contains(func_node->getFunctionName())
|| func_node->getArguments().getNodes().size() != 2)
return ;
auto args = func_node->getArguments().getNodes();
auto * column_node = args[0]->as<ColumnNode>();
auto * constant_node = args[1]->as<ConstantNode>();
if (!column_node || !constant_node)
return ;
// IN multiple values is not supported
if (constant_node->getValue().getType() == Field::Types::Which::Tuple
|| constant_node->getValue().getType() == Field::Types::Which::Array)
return ;
// x IN null not equivalent to x = null
if (constant_node->getValue().isNull())
return ;
auto result_func_name = MAPPING.at(func_node->getFunctionName());
auto equal = std::make_shared<FunctionNode>(result_func_name);
QueryTreeNodes arguments{column_node->clone(), constant_node->clone()};
equal->getArguments().getNodes() = std::move(arguments);
FunctionOverloadResolverPtr resolver;
bool decimal_check_overflow = getContext()->getSettingsRef().decimal_check_overflow;
if (result_func_name == "equals")
{
resolver = createInternalFunctionEqualOverloadResolver(decimal_check_overflow);
}
else
{
resolver = createInternalFunctionNotEqualOverloadResolver(decimal_check_overflow);
}
try
{
equal->resolveAsFunction(resolver);
}
catch (...)
{
// When function resolver fails, we should not replace the function node
return;
}
node = equal;
}
};

void ConvertInToEqualPass::run(QueryTreeNodePtr & query_tree_node, ContextPtr context)
{
ConvertInToEqualPassVisitor visitor(std::move(context));
visitor.visit(query_tree_node);
}
}
27 changes: 27 additions & 0 deletions src/Analyzer/Passes/ConvertInToEqualPass.h
@@ -0,0 +1,27 @@
#pragma once

#include <Analyzer/IQueryTreePass.h>

namespace DB
{
/** Optimize `in` to `equals` if possible.
* 1. convert in single value to equal
* Example: SELECT * from test where x IN (1);
* Result: SELECT * from test where x = 1;
*
* 2. convert not in single value to notEqual
* Example: SELECT * from test where x NOT IN (1);
* Result: SELECT * from test where x != 1;
*
* If value is null or tuple, do not convert.
*/
class ConvertInToEqualPass final : public IQueryTreePass
{
public:
String getName() override { return "ConvertInToEqualPass"; }

String getDescription() override { return "Convert in to equal"; }

void run(QueryTreeNodePtr & query_tree_node, ContextPtr context) override;
};
}
2 changes: 2 additions & 0 deletions src/Analyzer/QueryTreePassManager.cpp
Expand Up @@ -28,6 +28,7 @@
#include <Analyzer/Passes/MultiIfToIfPass.h>
#include <Analyzer/Passes/IfConstantConditionPass.h>
#include <Analyzer/Passes/IfChainToMultiIfPass.h>
#include <Analyzer/Passes/ConvertInToEqualPass.h>
#include <Analyzer/Passes/OrderByTupleEliminationPass.h>
#include <Analyzer/Passes/NormalizeCountVariantsPass.h>
#include <Analyzer/Passes/AggregateFunctionsArithmericOperationsPass.h>
Expand Down Expand Up @@ -263,6 +264,7 @@ void addQueryTreePasses(QueryTreePassManager & manager, bool only_analyze)
manager.addPass(std::make_unique<SumIfToCountIfPass>());
manager.addPass(std::make_unique<RewriteArrayExistsToHasPass>());
manager.addPass(std::make_unique<NormalizeCountVariantsPass>());
manager.addPass(std::make_unique<ConvertInToEqualPass>());

/// should before AggregateFunctionsArithmericOperationsPass
manager.addPass(std::make_unique<AggregateFunctionOfGroupByKeysPass>());
Expand Down
2 changes: 2 additions & 0 deletions src/Functions/CMakeLists.txt
Expand Up @@ -14,6 +14,8 @@ extract_into_parent_list(clickhouse_functions_sources dbms_sources
multiMatchAny.cpp
checkHyperscanRegexp.cpp
array/has.cpp
equals.cpp
notEquals.cpp
CastOverloadResolver.cpp
)
extract_into_parent_list(clickhouse_functions_headers dbms_headers
Expand Down
5 changes: 5 additions & 0 deletions src/Functions/equals.cpp
Expand Up @@ -13,6 +13,11 @@ REGISTER_FUNCTION(Equals)
factory.registerFunction<FunctionEquals>();
}

FunctionOverloadResolverPtr createInternalFunctionEqualOverloadResolver(bool decimal_check_overflow)
{
return std::make_unique<FunctionToOverloadResolverAdaptor>(std::make_shared<FunctionEquals>(decimal_check_overflow));
}

template <>
ColumnPtr FunctionComparison<EqualsOp, NameEquals>::executeTupleImpl(
const ColumnsWithTypeAndName & x, const ColumnsWithTypeAndName & y, size_t tuple_size, size_t input_rows_count) const
Expand Down
11 changes: 11 additions & 0 deletions src/Functions/equals.h
@@ -0,0 +1,11 @@
#pragma once
#include <memory>

namespace DB
{

class IFunctionOverloadResolver;
using FunctionOverloadResolverPtr = std::shared_ptr<IFunctionOverloadResolver>;

FunctionOverloadResolverPtr createInternalFunctionEqualOverloadResolver(bool decimal_check_overflow);
}
5 changes: 5 additions & 0 deletions src/Functions/notEquals.cpp
Expand Up @@ -12,6 +12,11 @@ REGISTER_FUNCTION(NotEquals)
factory.registerFunction<FunctionNotEquals>();
}

FunctionOverloadResolverPtr createInternalFunctionNotEqualOverloadResolver(bool decimal_check_overflow)
{
return std::make_unique<FunctionToOverloadResolverAdaptor>(std::make_shared<FunctionNotEquals>(decimal_check_overflow));
}

template <>
ColumnPtr FunctionComparison<NotEqualsOp, NameNotEquals>::executeTupleImpl(
const ColumnsWithTypeAndName & x, const ColumnsWithTypeAndName & y, size_t tuple_size, size_t input_rows_count) const
Expand Down
11 changes: 11 additions & 0 deletions src/Functions/notEquals.h
@@ -0,0 +1,11 @@
#pragma once
#include <memory>

namespace DB
{

class IFunctionOverloadResolver;
using FunctionOverloadResolverPtr = std::shared_ptr<IFunctionOverloadResolver>;

FunctionOverloadResolverPtr createInternalFunctionNotEqualOverloadResolver(bool decimal_check_overflow);
}
Expand Up @@ -306,7 +306,8 @@ TEST(TransformQueryForExternalDatabase, Aliases)

check(state, 1, {"field"},
"SELECT field AS value, field AS display FROM table WHERE field NOT IN ('') AND display LIKE '%test%'",
R"(SELECT "field" FROM "test"."table" WHERE ("field" NOT IN ('')) AND ("field" LIKE '%test%'))");
R"(SELECT "field" FROM "test"."table" WHERE ("field" NOT IN ('')) AND ("field" LIKE '%test%'))",
R"(SELECT "field" FROM "test"."table" WHERE ("field" != '') AND ("field" LIKE '%test%'))");
}

TEST(TransformQueryForExternalDatabase, ForeignColumnInWhere)
Expand Down Expand Up @@ -408,5 +409,6 @@ TEST(TransformQueryForExternalDatabase, Analyzer)

check(state, 1, {"column", "apply_id", "apply_type", "apply_status", "create_time", "field", "value", "a", "b", "foo"},
"SELECT * FROM table WHERE (column) IN (1)",
R"(SELECT "column", "apply_id", "apply_type", "apply_status", "create_time", "field", "value", "a", "b", "foo" FROM "test"."table" WHERE "column" IN (1))");
R"(SELECT "column", "apply_id", "apply_type", "apply_status", "create_time", "field", "value", "a", "b", "foo" FROM "test"."table" WHERE "column" IN (1))",
R"(SELECT "column", "apply_id", "apply_type", "apply_status", "create_time", "field", "value", "a", "b", "foo" FROM "test"."table" WHERE "column" = 1)");
}
28 changes: 28 additions & 0 deletions tests/performance/function_in.xml
@@ -0,0 +1,28 @@
<test>
<settings>
<max_insert_threads>8</max_insert_threads>
<max_threads>1</max_threads>
</settings>

<create_query>
CREATE TABLE t_nullable
(
key_string1 Nullable(String),
key_string2 Nullable(String),
key_string3 Nullable(String),
key_int64_1 Nullable(Int64),
key_int64_2 Nullable(Int64),
key_int64_3 Nullable(Int64),
key_int64_4 Nullable(Int64),
key_int64_5 Nullable(Int64),
m1 Int64,
m2 Int64
)
ENGINE = Memory
</create_query>
<fill_query>insert into t_nullable select ['aaaaaa','bbaaaa','ccaaaa','ddaaaa'][number % 101 + 1], ['aa','bb','cc','dd'][number % 100 + 1], ['aa','bb','cc','dd'][number % 102 + 1], number%10+1, number%10+2, number%10+3, number%10+4,number%10+5, number%6000+1, number%5000+2 from numbers_mt(30000000)</fill_query>
<query>select * from t_nullable where key_string1 in ('aaaaaa') format Null SETTINGS allow_experimental_analyzer=1</query>
<query>select * from t_nullable where key_string2 in ('3') format Null SETTINGS allow_experimental_analyzer=1</query>
<drop_query>drop table if exists t_nullable</drop_query>

</test>

0 comments on commit 20a45b4

Please sign in to comment.