Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Normalization of ASTSelectWithUnionQuery #21246

Merged
merged 5 commits into from Mar 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
117 changes: 0 additions & 117 deletions src/Interpreters/InterpreterSelectWithUnionQuery.cpp
Expand Up @@ -24,110 +24,8 @@ namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
extern const int UNION_ALL_RESULT_STRUCTURES_MISMATCH;
extern const int EXPECTED_ALL_OR_DISTINCT;
}

struct CustomizeASTSelectWithUnionQueryNormalize
{
using TypeToVisit = ASTSelectWithUnionQuery;

const UnionMode & union_default_mode;

static void getSelectsFromUnionListNode(ASTPtr & ast_select, ASTs & selects)
{
if (auto * inner_union = ast_select->as<ASTSelectWithUnionQuery>())
{
for (auto & child : inner_union->list_of_selects->children)
getSelectsFromUnionListNode(child, selects);

return;
}

selects.push_back(std::move(ast_select));
}

void visit(ASTSelectWithUnionQuery & ast, ASTPtr &) const
{
auto & union_modes = ast.list_of_modes;
ASTs selects;
auto & select_list = ast.list_of_selects->children;

int i;
for (i = union_modes.size() - 1; i >= 0; --i)
{
/// Rewrite UNION Mode
if (union_modes[i] == ASTSelectWithUnionQuery::Mode::Unspecified)
{
if (union_default_mode == UnionMode::ALL)
union_modes[i] = ASTSelectWithUnionQuery::Mode::ALL;
else if (union_default_mode == UnionMode::DISTINCT)
union_modes[i] = ASTSelectWithUnionQuery::Mode::DISTINCT;
else
throw Exception(
"Expected ALL or DISTINCT in SelectWithUnion query, because setting (union_default_mode) is empty",
DB::ErrorCodes::EXPECTED_ALL_OR_DISTINCT);
}

if (union_modes[i] == ASTSelectWithUnionQuery::Mode::ALL)
{
if (auto * inner_union = select_list[i + 1]->as<ASTSelectWithUnionQuery>())
{
/// Inner_union is an UNION ALL list, just lift up
for (auto child = inner_union->list_of_selects->children.rbegin();
child != inner_union->list_of_selects->children.rend();
++child)
selects.push_back(std::move(*child));
}
else
selects.push_back(std::move(select_list[i + 1]));
}
/// flatten all left nodes and current node to a UNION DISTINCT list
else if (union_modes[i] == ASTSelectWithUnionQuery::Mode::DISTINCT)
{
auto distinct_list = std::make_shared<ASTSelectWithUnionQuery>();
distinct_list->list_of_selects = std::make_shared<ASTExpressionList>();
distinct_list->children.push_back(distinct_list->list_of_selects);

for (int j = 0; j <= i + 1; ++j)
{
getSelectsFromUnionListNode(select_list[j], distinct_list->list_of_selects->children);
}

distinct_list->union_mode = ASTSelectWithUnionQuery::Mode::DISTINCT;
distinct_list->is_normalized = true;
selects.push_back(std::move(distinct_list));
break;
}
}

/// No UNION DISTINCT or only one child in select_list
if (i == -1)
{
if (auto * inner_union = select_list[0]->as<ASTSelectWithUnionQuery>())
{
/// Inner_union is an UNION ALL list, just lift it up
for (auto child = inner_union->list_of_selects->children.rbegin(); child != inner_union->list_of_selects->children.rend();
++child)
selects.push_back(std::move(*child));
}
else
selects.push_back(std::move(select_list[0]));
}

// reverse children list
std::reverse(selects.begin(), selects.end());

ast.is_normalized = true;
ast.union_mode = ASTSelectWithUnionQuery::Mode::ALL;

ast.list_of_selects->children = std::move(selects);
}
};

/// We need normalize children first, so we should visit AST tree bottom up
using CustomizeASTSelectWithUnionQueryNormalizeVisitor
= InDepthNodeVisitor<OneTypeMatcher<CustomizeASTSelectWithUnionQueryNormalize>, false>;

InterpreterSelectWithUnionQuery::InterpreterSelectWithUnionQuery(
const ASTPtr & query_ptr_, const Context & context_, const SelectQueryOptions & options_, const Names & required_result_column_names)
: IInterpreterUnionOrSelectQuery(query_ptr_, context_, options_)
Expand All @@ -138,21 +36,6 @@ InterpreterSelectWithUnionQuery::InterpreterSelectWithUnionQuery(
if (options.subquery_depth == 0 && (settings.limit > 0 || settings.offset > 0))
settings_limit_offset_needed = true;

/// Normalize AST Tree
if (!ast->is_normalized)
{
CustomizeASTSelectWithUnionQueryNormalizeVisitor::Data union_default_mode{settings.union_default_mode};
CustomizeASTSelectWithUnionQueryNormalizeVisitor(union_default_mode).visit(query_ptr);

/// After normalization, if it only has one ASTSelectWithUnionQuery child,
/// we can lift it up, this can reduce one unnecessary recursion later.
if (ast->list_of_selects->children.size() == 1 && ast->list_of_selects->children.at(0)->as<ASTSelectWithUnionQuery>())
{
query_ptr = std::move(ast->list_of_selects->children.at(0));
ast = query_ptr->as<ASTSelectWithUnionQuery>();
}
}

size_t num_children = ast->list_of_selects->children.size();
if (!num_children)
throw Exception("Logical error: no children in ASTSelectWithUnionQuery", ErrorCodes::LOGICAL_ERROR);
Expand Down
116 changes: 116 additions & 0 deletions src/Interpreters/NormalizeSelectWithUnionQueryVisitor.cpp
@@ -0,0 +1,116 @@
#include <Interpreters/NormalizeSelectWithUnionQueryVisitor.h>
#include <Parsers/ASTExpressionList.h>
#include <Common/typeid_cast.h>

namespace DB
{

namespace ErrorCodes
{
extern const int EXPECTED_ALL_OR_DISTINCT;
}

void NormalizeSelectWithUnionQueryMatcher::getSelectsFromUnionListNode(ASTPtr & ast_select, ASTs & selects)
{
if (auto * inner_union = ast_select->as<ASTSelectWithUnionQuery>())
{
for (auto & child : inner_union->list_of_selects->children)
getSelectsFromUnionListNode(child, selects);

return;
}

selects.push_back(ast_select);
}

void NormalizeSelectWithUnionQueryMatcher::visit(ASTPtr & ast, Data & data)
{
if (auto * select_union = ast->as<ASTSelectWithUnionQuery>())
visit(*select_union, data);
}

void NormalizeSelectWithUnionQueryMatcher::visit(ASTSelectWithUnionQuery & ast, Data & data)
{
auto & union_modes = ast.list_of_modes;
ASTs selects;
auto & select_list = ast.list_of_selects->children;

int i;
for (i = union_modes.size() - 1; i >= 0; --i)
{
/// Rewrite UNION Mode
if (union_modes[i] == ASTSelectWithUnionQuery::Mode::Unspecified)
{
if (data.union_default_mode == UnionMode::ALL)
union_modes[i] = ASTSelectWithUnionQuery::Mode::ALL;
else if (data.union_default_mode == UnionMode::DISTINCT)
union_modes[i] = ASTSelectWithUnionQuery::Mode::DISTINCT;
else
throw Exception(
"Expected ALL or DISTINCT in SelectWithUnion query, because setting (union_default_mode) is empty",
DB::ErrorCodes::EXPECTED_ALL_OR_DISTINCT);
}

if (union_modes[i] == ASTSelectWithUnionQuery::Mode::ALL)
{
if (auto * inner_union = select_list[i + 1]->as<ASTSelectWithUnionQuery>();
inner_union && inner_union->union_mode == ASTSelectWithUnionQuery::Mode::ALL)
{
/// Inner_union is an UNION ALL list, just lift up
for (auto child = inner_union->list_of_selects->children.rbegin(); child != inner_union->list_of_selects->children.rend();
++child)
selects.push_back(*child);
}
else
selects.push_back(select_list[i + 1]);
}
/// flatten all left nodes and current node to a UNION DISTINCT list
else if (union_modes[i] == ASTSelectWithUnionQuery::Mode::DISTINCT)
{
auto distinct_list = std::make_shared<ASTSelectWithUnionQuery>();
distinct_list->list_of_selects = std::make_shared<ASTExpressionList>();
distinct_list->children.push_back(distinct_list->list_of_selects);

for (int j = 0; j <= i + 1; ++j)
{
getSelectsFromUnionListNode(select_list[j], distinct_list->list_of_selects->children);
}

distinct_list->union_mode = ASTSelectWithUnionQuery::Mode::DISTINCT;
distinct_list->is_normalized = true;
selects.push_back(std::move(distinct_list));
break;
}
}

/// No UNION DISTINCT or only one child in select_list
if (i == -1)
{
if (auto * inner_union = select_list[0]->as<ASTSelectWithUnionQuery>();
inner_union && inner_union->union_mode == ASTSelectWithUnionQuery::Mode::ALL)
{
/// Inner_union is an UNION ALL list, just lift it up
for (auto child = inner_union->list_of_selects->children.rbegin(); child != inner_union->list_of_selects->children.rend();
++child)
selects.push_back(*child);
}
else
selects.push_back(select_list[0]);
}

/// Just one union type child, lift it up
if (selects.size() == 1 && selects[0]->as<ASTSelectWithUnionQuery>())
{
ast = *(selects[0]->as<ASTSelectWithUnionQuery>());
return;
}

// reverse children list
std::reverse(selects.begin(), selects.end());

ast.is_normalized = true;
ast.union_mode = ASTSelectWithUnionQuery::Mode::ALL;

ast.list_of_selects->children = std::move(selects);
}
}
34 changes: 34 additions & 0 deletions src/Interpreters/NormalizeSelectWithUnionQueryVisitor.h
@@ -0,0 +1,34 @@
#pragma once

#include <unordered_set>

#include <Parsers/IAST.h>
#include <Interpreters/InDepthNodeVisitor.h>

#include <Core/Settings.h>
#include <Parsers/ASTSelectWithUnionQuery.h>

namespace DB
{

class ASTFunction;

class NormalizeSelectWithUnionQueryMatcher
{
public:
struct Data
{
const UnionMode & union_default_mode;
};

static void getSelectsFromUnionListNode(ASTPtr & ast_select, ASTs & selects);

static void visit(ASTPtr & ast, Data &);
static void visit(ASTSelectWithUnionQuery &, Data &);
static bool needChildVisit(const ASTPtr &, const ASTPtr &) { return true; }
};

/// We need normalize children first, so we should visit AST tree bottom up
using NormalizeSelectWithUnionQueryVisitor
= InDepthNodeVisitor<NormalizeSelectWithUnionQueryMatcher, false>;
}
14 changes: 9 additions & 5 deletions src/Interpreters/executeQuery.cpp
Expand Up @@ -39,16 +39,17 @@
#include <Storages/StorageInput.h>

#include <Access/EnabledQuota.h>
#include <Interpreters/ApplyWithGlobalVisitor.h>
#include <Interpreters/Context.h>
#include <Interpreters/InterpreterFactory.h>
#include <Interpreters/ProcessList.h>
#include <Interpreters/InterpreterSetQuery.h>
#include <Interpreters/NormalizeSelectWithUnionQueryVisitor.h>
#include <Interpreters/OpenTelemetrySpanLog.h>
#include <Interpreters/ProcessList.h>
#include <Interpreters/QueryLog.h>
#include <Interpreters/InterpreterSetQuery.h>
#include <Interpreters/ApplyWithGlobalVisitor.h>
#include <Interpreters/ReplaceQueryParameterVisitor.h>
#include <Interpreters/SelectQueryOptions.h>
#include <Interpreters/executeQuery.h>
#include <Interpreters/Context.h>
#include <Common/ProfileEvents.h>

#include <Common/SensitiveDataMasker.h>
Expand Down Expand Up @@ -472,9 +473,12 @@ static std::tuple<ASTPtr, BlockIO> executeQueryImpl(
if (settings.enable_global_with_statement)
{
ApplyWithGlobalVisitor().visit(ast);
query = serializeAST(*ast);
}

/// Normalize SelectWithUnionQuery
NormalizeSelectWithUnionQueryVisitor::Data data{context.getSettingsRef().union_default_mode};
NormalizeSelectWithUnionQueryVisitor{data}.visit(ast);

/// Check the limits.
checkASTSizeLimits(*ast, settings);

Expand Down
1 change: 1 addition & 0 deletions src/Interpreters/ya.make
Expand Up @@ -111,6 +111,7 @@ SRCS(
MetricLog.cpp
MutationsInterpreter.cpp
MySQL/InterpretersMySQLDDLQuery.cpp
NormalizeSelectWithUnionQueryVisitor.cpp
NullableUtils.cpp
OpenTelemetrySpanLog.cpp
OptimizeIfChains.cpp
Expand Down
@@ -0,0 +1,66 @@
SELECT 1
UNION ALL
SELECT 1
UNION ALL
SELECT 1
UNION ALL
SELECT 1
UNION ALL
SELECT 1

SELECT 1
UNION ALL
(
SELECT 1
UNION DISTINCT
SELECT 1
UNION DISTINCT
SELECT 1
)
UNION ALL
SELECT 1

SELECT x
FROM
(
SELECT 1 AS x
UNION ALL
(
SELECT 1
UNION DISTINCT
SELECT 1
UNION DISTINCT
SELECT 1
)
UNION ALL
SELECT 1
)

SELECT x
FROM
(
SELECT 1 AS x
UNION ALL
SELECT 1
UNION ALL
SELECT 1
)

SELECT 1
UNION DISTINCT
SELECT 1
UNION DISTINCT
SELECT 1

SELECT 1


(
SELECT 1
UNION DISTINCT
SELECT 1
UNION DISTINCT
SELECT 1
)
UNION ALL
SELECT 1