diff --git a/src/Analyzer/Passes/ConvertInToEqualPass.cpp b/src/Analyzer/Passes/ConvertInToEqualPass.cpp new file mode 100644 index 000000000000..66a37fea5bdb --- /dev/null +++ b/src/Analyzer/Passes/ConvertInToEqualPass.cpp @@ -0,0 +1,73 @@ +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ + +class ConvertInToEqualPassVisitor : public InDepthQueryTreeVisitorWithContext +{ +public: + using Base = InDepthQueryTreeVisitorWithContext; + using Base::Base; + + void enterImpl(QueryTreeNodePtr & node) + { + static const std::unordered_map MAPPING = { + {"in", "equals"}, + {"notIn", "notEquals"} + }; + auto * func_node = node->as(); + if (!func_node + || !MAPPING.contains(func_node->getFunctionName()) + || func_node->getArguments().getNodes().size() != 2) + return ; + auto args = func_node->getArguments().getNodes(); + auto * column_node = args[0]->as(); + auto * constant_node = args[1]->as(); + if (!column_node || !constant_node) + return ; + // IN multiple values is not supported + if (constant_node->getValue().getType() == Field::Types::Which::Tuple + || constant_node->getValue().getType() == Field::Types::Which::Array) + return ; + // x IN null not equivalent to x = null + if (constant_node->getValue().isNull()) + return ; + auto result_func_name = MAPPING.at(func_node->getFunctionName()); + auto equal = std::make_shared(result_func_name); + QueryTreeNodes arguments{column_node->clone(), constant_node->clone()}; + equal->getArguments().getNodes() = std::move(arguments); + FunctionOverloadResolverPtr resolver; + bool decimal_check_overflow = getContext()->getSettingsRef().decimal_check_overflow; + if (result_func_name == "equals") + { + resolver = createInternalFunctionEqualOverloadResolver(decimal_check_overflow); + } + else + { + resolver = createInternalFunctionNotEqualOverloadResolver(decimal_check_overflow); + } + try + { + equal->resolveAsFunction(resolver); + } + catch (...) + { + // When function resolver fails, we should not replace the function node + return; + } + node = equal; + } +}; + +void ConvertInToEqualPass::run(QueryTreeNodePtr & query_tree_node, ContextPtr context) +{ + ConvertInToEqualPassVisitor visitor(std::move(context)); + visitor.visit(query_tree_node); +} +} diff --git a/src/Analyzer/Passes/ConvertInToEqualPass.h b/src/Analyzer/Passes/ConvertInToEqualPass.h new file mode 100644 index 000000000000..bd4f8607c88a --- /dev/null +++ b/src/Analyzer/Passes/ConvertInToEqualPass.h @@ -0,0 +1,27 @@ +#pragma once + +#include + +namespace DB +{ +/** Optimize `in` to `equals` if possible. + * 1. convert in single value to equal + * Example: SELECT * from test where x IN (1); + * Result: SELECT * from test where x = 1; + * + * 2. convert not in single value to notEqual + * Example: SELECT * from test where x NOT IN (1); + * Result: SELECT * from test where x != 1; + * + * If value is null or tuple, do not convert. + */ +class ConvertInToEqualPass final : public IQueryTreePass +{ +public: + String getName() override { return "ConvertInToEqualPass"; } + + String getDescription() override { return "Convert in to equal"; } + + void run(QueryTreeNodePtr & query_tree_node, ContextPtr context) override; +}; +} diff --git a/src/Analyzer/QueryTreePassManager.cpp b/src/Analyzer/QueryTreePassManager.cpp index 9c07884a4642..14eb179680c6 100644 --- a/src/Analyzer/QueryTreePassManager.cpp +++ b/src/Analyzer/QueryTreePassManager.cpp @@ -28,6 +28,7 @@ #include #include #include +#include #include #include #include @@ -263,6 +264,7 @@ void addQueryTreePasses(QueryTreePassManager & manager, bool only_analyze) manager.addPass(std::make_unique()); manager.addPass(std::make_unique()); manager.addPass(std::make_unique()); + manager.addPass(std::make_unique()); /// should before AggregateFunctionsArithmericOperationsPass manager.addPass(std::make_unique()); diff --git a/src/Functions/CMakeLists.txt b/src/Functions/CMakeLists.txt index 733ae25274e2..d5eb12f3deef 100644 --- a/src/Functions/CMakeLists.txt +++ b/src/Functions/CMakeLists.txt @@ -14,6 +14,8 @@ extract_into_parent_list(clickhouse_functions_sources dbms_sources multiMatchAny.cpp checkHyperscanRegexp.cpp array/has.cpp + equals.cpp + notEquals.cpp CastOverloadResolver.cpp ) extract_into_parent_list(clickhouse_functions_headers dbms_headers diff --git a/src/Functions/equals.cpp b/src/Functions/equals.cpp index 5c59daf05372..512abaa6fc70 100644 --- a/src/Functions/equals.cpp +++ b/src/Functions/equals.cpp @@ -13,6 +13,11 @@ REGISTER_FUNCTION(Equals) factory.registerFunction(); } +FunctionOverloadResolverPtr createInternalFunctionEqualOverloadResolver(bool decimal_check_overflow) +{ + return std::make_unique(std::make_shared(decimal_check_overflow)); +} + template <> ColumnPtr FunctionComparison::executeTupleImpl( const ColumnsWithTypeAndName & x, const ColumnsWithTypeAndName & y, size_t tuple_size, size_t input_rows_count) const diff --git a/src/Functions/equals.h b/src/Functions/equals.h new file mode 100644 index 000000000000..855cba4db3e1 --- /dev/null +++ b/src/Functions/equals.h @@ -0,0 +1,11 @@ +#pragma once +#include + +namespace DB +{ + +class IFunctionOverloadResolver; +using FunctionOverloadResolverPtr = std::shared_ptr; + +FunctionOverloadResolverPtr createInternalFunctionEqualOverloadResolver(bool decimal_check_overflow); +} diff --git a/src/Functions/notEquals.cpp b/src/Functions/notEquals.cpp index 3a63db467117..744a0997d95e 100644 --- a/src/Functions/notEquals.cpp +++ b/src/Functions/notEquals.cpp @@ -12,6 +12,11 @@ REGISTER_FUNCTION(NotEquals) factory.registerFunction(); } +FunctionOverloadResolverPtr createInternalFunctionNotEqualOverloadResolver(bool decimal_check_overflow) +{ + return std::make_unique(std::make_shared(decimal_check_overflow)); +} + template <> ColumnPtr FunctionComparison::executeTupleImpl( const ColumnsWithTypeAndName & x, const ColumnsWithTypeAndName & y, size_t tuple_size, size_t input_rows_count) const diff --git a/src/Functions/notEquals.h b/src/Functions/notEquals.h new file mode 100644 index 000000000000..961889d68d7c --- /dev/null +++ b/src/Functions/notEquals.h @@ -0,0 +1,11 @@ +#pragma once +#include + +namespace DB +{ + +class IFunctionOverloadResolver; +using FunctionOverloadResolverPtr = std::shared_ptr; + +FunctionOverloadResolverPtr createInternalFunctionNotEqualOverloadResolver(bool decimal_check_overflow); +} diff --git a/src/Storages/tests/gtest_transform_query_for_external_database.cpp b/src/Storages/tests/gtest_transform_query_for_external_database.cpp index 7e2d393c3d11..6490498d717a 100644 --- a/src/Storages/tests/gtest_transform_query_for_external_database.cpp +++ b/src/Storages/tests/gtest_transform_query_for_external_database.cpp @@ -306,7 +306,8 @@ TEST(TransformQueryForExternalDatabase, Aliases) check(state, 1, {"field"}, "SELECT field AS value, field AS display FROM table WHERE field NOT IN ('') AND display LIKE '%test%'", - R"(SELECT "field" FROM "test"."table" WHERE ("field" NOT IN ('')) AND ("field" LIKE '%test%'))"); + R"(SELECT "field" FROM "test"."table" WHERE ("field" NOT IN ('')) AND ("field" LIKE '%test%'))", + R"(SELECT "field" FROM "test"."table" WHERE ("field" != '') AND ("field" LIKE '%test%'))"); } TEST(TransformQueryForExternalDatabase, ForeignColumnInWhere) @@ -408,5 +409,6 @@ TEST(TransformQueryForExternalDatabase, Analyzer) check(state, 1, {"column", "apply_id", "apply_type", "apply_status", "create_time", "field", "value", "a", "b", "foo"}, "SELECT * FROM table WHERE (column) IN (1)", - R"(SELECT "column", "apply_id", "apply_type", "apply_status", "create_time", "field", "value", "a", "b", "foo" FROM "test"."table" WHERE "column" IN (1))"); + R"(SELECT "column", "apply_id", "apply_type", "apply_status", "create_time", "field", "value", "a", "b", "foo" FROM "test"."table" WHERE "column" IN (1))", + R"(SELECT "column", "apply_id", "apply_type", "apply_status", "create_time", "field", "value", "a", "b", "foo" FROM "test"."table" WHERE "column" = 1)"); } diff --git a/tests/performance/function_in.xml b/tests/performance/function_in.xml new file mode 100644 index 000000000000..af4f8737ba78 --- /dev/null +++ b/tests/performance/function_in.xml @@ -0,0 +1,28 @@ + + + 8 + 1 + + + + CREATE TABLE t_nullable + ( + key_string1 Nullable(String), + key_string2 Nullable(String), + key_string3 Nullable(String), + key_int64_1 Nullable(Int64), + key_int64_2 Nullable(Int64), + key_int64_3 Nullable(Int64), + key_int64_4 Nullable(Int64), + key_int64_5 Nullable(Int64), + m1 Int64, + m2 Int64 + ) + ENGINE = Memory + + insert into t_nullable select ['aaaaaa','bbaaaa','ccaaaa','ddaaaa'][number % 101 + 1], ['aa','bb','cc','dd'][number % 100 + 1], ['aa','bb','cc','dd'][number % 102 + 1], number%10+1, number%10+2, number%10+3, number%10+4,number%10+5, number%6000+1, number%5000+2 from numbers_mt(30000000) + select * from t_nullable where key_string1 in ('aaaaaa') format Null SETTINGS allow_experimental_analyzer=1 + select * from t_nullable where key_string2 in ('3') format Null SETTINGS allow_experimental_analyzer=1 + drop table if exists t_nullable + + diff --git a/tests/queries/0_stateless/03013_optimize_in_to_equal.reference b/tests/queries/0_stateless/03013_optimize_in_to_equal.reference new file mode 100644 index 000000000000..93ac91bd9570 --- /dev/null +++ b/tests/queries/0_stateless/03013_optimize_in_to_equal.reference @@ -0,0 +1,188 @@ +a 1 +------------------- +0 +0 +0 +------------------- +QUERY id: 0 + PROJECTION COLUMNS + x String + y Int32 + PROJECTION + LIST id: 1, nodes: 2 + COLUMN id: 2, column_name: x, result_type: String, source_id: 3 + COLUMN id: 4, column_name: y, result_type: Int32, source_id: 3 + JOIN TREE + TABLE id: 3, alias: __table1, table_name: default.test + WHERE + FUNCTION id: 5, function_name: equals, function_type: ordinary, result_type: UInt8 + ARGUMENTS + LIST id: 6, nodes: 2 + COLUMN id: 7, column_name: x, result_type: String, source_id: 3 + CONSTANT id: 8, constant_value: \'a\', constant_value_type: String +------------------- +QUERY id: 0 + PROJECTION COLUMNS + x String + y Int32 + PROJECTION + LIST id: 1, nodes: 2 + COLUMN id: 2, column_name: x, result_type: String, source_id: 3 + COLUMN id: 4, column_name: y, result_type: Int32, source_id: 3 + JOIN TREE + TABLE id: 3, alias: __table1, table_name: default.test + WHERE + FUNCTION id: 5, function_name: equals, function_type: ordinary, result_type: UInt8 + ARGUMENTS + LIST id: 6, nodes: 2 + COLUMN id: 7, column_name: x, result_type: String, source_id: 3 + CONSTANT id: 8, constant_value: \'A\', constant_value_type: String + EXPRESSION + FUNCTION id: 9, function_name: upper, function_type: ordinary, result_type: String + ARGUMENTS + LIST id: 10, nodes: 1 + CONSTANT id: 11, constant_value: \'a\', constant_value_type: String +------------------- +QUERY id: 0 + PROJECTION COLUMNS + x String + y Int32 + PROJECTION + LIST id: 1, nodes: 2 + COLUMN id: 2, column_name: x, result_type: String, source_id: 3 + COLUMN id: 4, column_name: y, result_type: Int32, source_id: 3 + JOIN TREE + TABLE id: 3, alias: __table1, table_name: default.test + WHERE + FUNCTION id: 5, function_name: in, function_type: ordinary, result_type: UInt8 + ARGUMENTS + LIST id: 6, nodes: 2 + COLUMN id: 7, column_name: x, result_type: String, source_id: 3 + CONSTANT id: 8, constant_value: Tuple_(\'a\', \'b\'), constant_value_type: Tuple(String, String) +------------------- +QUERY id: 0 + PROJECTION COLUMNS + x String + y Int32 + PROJECTION + LIST id: 1, nodes: 2 + COLUMN id: 2, column_name: x, result_type: String, source_id: 3 + COLUMN id: 4, column_name: y, result_type: Int32, source_id: 3 + JOIN TREE + TABLE id: 3, alias: __table1, table_name: default.test + WHERE + FUNCTION id: 5, function_name: in, function_type: ordinary, result_type: UInt8 + ARGUMENTS + LIST id: 6, nodes: 2 + COLUMN id: 7, column_name: x, result_type: String, source_id: 3 + CONSTANT id: 8, constant_value: Array_[\'a\', \'b\'], constant_value_type: Array(String) +------------------- +b 2 +c 3 +------------------- +QUERY id: 0 + PROJECTION COLUMNS + x String + y Int32 + PROJECTION + LIST id: 1, nodes: 2 + COLUMN id: 2, column_name: x, result_type: String, source_id: 3 + COLUMN id: 4, column_name: y, result_type: Int32, source_id: 3 + JOIN TREE + TABLE id: 3, alias: __table1, table_name: default.test + WHERE + FUNCTION id: 5, function_name: notEquals, function_type: ordinary, result_type: UInt8 + ARGUMENTS + LIST id: 6, nodes: 2 + COLUMN id: 7, column_name: x, result_type: String, source_id: 3 + CONSTANT id: 8, constant_value: \'a\', constant_value_type: String +------------------- +QUERY id: 0 + PROJECTION COLUMNS + x String + y Int32 + PROJECTION + LIST id: 1, nodes: 2 + COLUMN id: 2, column_name: x, result_type: String, source_id: 3 + COLUMN id: 4, column_name: y, result_type: Int32, source_id: 3 + JOIN TREE + TABLE id: 3, alias: __table1, table_name: default.test + WHERE + FUNCTION id: 5, function_name: notEquals, function_type: ordinary, result_type: UInt8 + ARGUMENTS + LIST id: 6, nodes: 2 + COLUMN id: 7, column_name: x, result_type: String, source_id: 3 + CONSTANT id: 8, constant_value: \'A\', constant_value_type: String + EXPRESSION + FUNCTION id: 9, function_name: upper, function_type: ordinary, result_type: String + ARGUMENTS + LIST id: 10, nodes: 1 + CONSTANT id: 11, constant_value: \'a\', constant_value_type: String +------------------- +QUERY id: 0 + PROJECTION COLUMNS + x String + y Int32 + PROJECTION + LIST id: 1, nodes: 2 + COLUMN id: 2, column_name: x, result_type: String, source_id: 3 + COLUMN id: 4, column_name: y, result_type: Int32, source_id: 3 + JOIN TREE + TABLE id: 3, alias: __table1, table_name: default.test + WHERE + FUNCTION id: 5, function_name: notIn, function_type: ordinary, result_type: UInt8 + ARGUMENTS + LIST id: 6, nodes: 2 + COLUMN id: 7, column_name: x, result_type: String, source_id: 3 + CONSTANT id: 8, constant_value: Tuple_(\'a\', \'b\'), constant_value_type: Tuple(String, String) +------------------- +QUERY id: 0 + PROJECTION COLUMNS + x String + y Int32 + PROJECTION + LIST id: 1, nodes: 2 + COLUMN id: 2, column_name: x, result_type: String, source_id: 3 + COLUMN id: 4, column_name: y, result_type: Int32, source_id: 3 + JOIN TREE + TABLE id: 3, alias: __table1, table_name: default.test + WHERE + FUNCTION id: 5, function_name: notIn, function_type: ordinary, result_type: UInt8 + ARGUMENTS + LIST id: 6, nodes: 2 + COLUMN id: 7, column_name: x, result_type: String, source_id: 3 + CONSTANT id: 8, constant_value: Array_[\'a\', \'b\'], constant_value_type: Array(String) +------------------- +QUERY id: 0 + PROJECTION COLUMNS + x String + y Int32 + PROJECTION + LIST id: 1, nodes: 2 + COLUMN id: 2, column_name: x, result_type: String, source_id: 3 + COLUMN id: 4, column_name: y, result_type: Int32, source_id: 3 + JOIN TREE + TABLE id: 3, alias: __table1, table_name: default.test + WHERE + FUNCTION id: 5, function_name: notIn, function_type: ordinary, result_type: UInt8 + ARGUMENTS + LIST id: 6, nodes: 2 + COLUMN id: 7, column_name: x, result_type: String, source_id: 3 + CONSTANT id: 8, constant_value: NULL, constant_value_type: Nullable(Nothing) +------------------- +QUERY id: 0 + PROJECTION COLUMNS + x String + y Int32 + PROJECTION + LIST id: 1, nodes: 2 + COLUMN id: 2, column_name: x, result_type: String, source_id: 3 + COLUMN id: 4, column_name: y, result_type: Int32, source_id: 3 + JOIN TREE + TABLE id: 3, alias: __table1, table_name: default.test + WHERE + FUNCTION id: 5, function_name: in, function_type: ordinary, result_type: UInt8 + ARGUMENTS + LIST id: 6, nodes: 2 + COLUMN id: 7, column_name: x, result_type: String, source_id: 3 + CONSTANT id: 8, constant_value: NULL, constant_value_type: Nullable(Nothing) diff --git a/tests/queries/0_stateless/03013_optimize_in_to_equal.sql b/tests/queries/0_stateless/03013_optimize_in_to_equal.sql new file mode 100644 index 000000000000..ba6eb5d4f5f9 --- /dev/null +++ b/tests/queries/0_stateless/03013_optimize_in_to_equal.sql @@ -0,0 +1,29 @@ +DROP TABLE IF EXISTS test; +CREATE TABLE test (x String, y Int32) ENGINE = MergeTree() ORDER BY x; +SET allow_experimental_analyzer = 1; +INSERT INTO test VALUES ('a', 1), ('b', 2), ('c', 3); +select * from test where x in ('a'); +select '-------------------'; +select x in Null from test; +select '-------------------'; +explain query tree select * from test where x in ('a'); +select '-------------------'; +explain query tree select * from test where x in (upper('a')); +select '-------------------'; +explain query tree select * from test where x in ('a','b'); +select '-------------------'; +explain query tree select * from test where x in ['a','b']; +select '-------------------'; +select * from test where x not in ('a'); +select '-------------------'; +explain query tree select * from test where x not in ('a'); +select '-------------------'; +explain query tree select * from test where x not in (upper('a')); +select '-------------------'; +explain query tree select * from test where x not in ('a','b'); +select '-------------------'; +explain query tree select * from test where x not in ['a','b']; +select '-------------------'; +explain query tree select * from test where x not in (NULL); +select '-------------------'; +explain query tree select * from test where x in (NULL);