Skip to content

Commit

Permalink
Merge pull request #62074 from kitaisreal/analyzer-support-recursive-cte
Browse files Browse the repository at this point in the history
Analyzer support recursive CTEs
  • Loading branch information
alexey-milovidov committed Apr 29, 2024
2 parents eaf3b91 + c570316 commit dffcc51
Show file tree
Hide file tree
Showing 57 changed files with 3,519 additions and 138 deletions.
110 changes: 80 additions & 30 deletions src/Analyzer/Passes/QueryAnalysisPass.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
#include <Analyzer/Passes/QueryAnalysisPass.h>

#include <boost/algorithm/string.hpp>

#include <Common/checkStackSize.h>
#include <Common/NamePrompter.h>
#include <Common/ProfileEvents.h>
#include <Analyzer/FunctionSecretArgumentsFinderTreeNode.h>

#include <IO/WriteBuffer.h>
#include <IO/WriteHelpers.h>
Expand Down Expand Up @@ -81,8 +82,8 @@
#include <Analyzer/QueryTreeBuilder.h>
#include <Analyzer/IQueryTreeNode.h>
#include <Analyzer/Identifier.h>

#include <boost/algorithm/string.hpp>
#include <Analyzer/FunctionSecretArgumentsFinderTreeNode.h>
#include <Analyzer/RecursiveCTE.h>

namespace ProfileEvents
{
Expand Down Expand Up @@ -740,7 +741,7 @@ struct IdentifierResolveScope
/// Identifier lookup to result
std::unordered_map<IdentifierLookup, IdentifierResolveState, IdentifierLookupHash> identifier_lookup_to_resolve_state;

/// Lambda argument can be expression like constant, column, or it can be function
/// Argument can be expression like constant, column, function or table expression
std::unordered_map<std::string, QueryTreeNodePtr> expression_argument_name_to_node;

/// Alias name to query expression node
Expand Down Expand Up @@ -1464,7 +1465,8 @@ class QueryAnalyzer
/// Lambdas that are currently in resolve process
std::unordered_set<IQueryTreeNode *> lambdas_in_resolve_process;

std::unordered_set<std::string_view> cte_in_resolve_process;
/// CTEs that are currently in resolve process
std::unordered_set<std::string_view> ctes_in_resolve_process;

/// Function name to user defined lambda map
std::unordered_map<std::string, QueryTreeNodePtr> function_name_to_user_defined_lambda;
Expand Down Expand Up @@ -2148,9 +2150,9 @@ void QueryAnalyzer::evaluateScalarSubqueryIfNeeded(QueryTreeNodePtr & node, Iden
else
{
/** Make unique column names for tuple.
*
* Example: SELECT (SELECT 2 AS x, x)
*/
*
* Example: SELECT (SELECT 2 AS x, x)
*/
makeUniqueColumnNamesInBlock(block);

scalar_block.insert({
Expand Down Expand Up @@ -3981,21 +3983,20 @@ IdentifierResolveResult QueryAnalyzer::tryResolveIdentifierInParentScopes(const
auto * union_node = resolved_identifier->as<UnionNode>();

bool is_cte = (subquery_node && subquery_node->isCTE()) || (union_node && union_node->isCTE());
bool is_table_from_expression_arguments = lookup_result.resolve_place == IdentifierResolvePlace::EXPRESSION_ARGUMENTS &&
resolved_identifier->getNodeType() == QueryTreeNodeType::TABLE;
bool is_valid_table_expression = is_cte || is_table_from_expression_arguments;

/** From parent scopes we can resolve table identifiers only as CTE.
* Example: SELECT (SELECT 1 FROM a) FROM test_table AS a;
*
* During child scope table identifier resolve a, table node test_table with alias a from parent scope
* is invalid.
*/
if (identifier_lookup.isTableExpressionLookup() && !is_cte)
if (identifier_lookup.isTableExpressionLookup() && !is_valid_table_expression)
continue;

if (is_cte)
{
return lookup_result;
}
else if (resolved_identifier->as<ConstantNode>())
if (is_valid_table_expression || resolved_identifier->as<ConstantNode>())
{
return lookup_result;
}
Expand Down Expand Up @@ -4071,13 +4072,9 @@ IdentifierResolveResult QueryAnalyzer::tryResolveIdentifier(const IdentifierLook

if (it->second.resolve_result.isResolved() &&
scope.use_identifier_lookup_to_result_cache &&
!scope.non_cached_identifier_lookups_during_expression_resolve.contains(identifier_lookup))
{
if (!it->second.resolve_result.isResolvedFromCTEs() || !cte_in_resolve_process.contains(identifier_lookup.identifier.getFullName()))
{
return it->second.resolve_result;
}
}
!scope.non_cached_identifier_lookups_during_expression_resolve.contains(identifier_lookup) &&
(!it->second.resolve_result.isResolvedFromCTEs() || !ctes_in_resolve_process.contains(identifier_lookup.identifier.getFullName())))
return it->second.resolve_result;
}
else
{
Expand Down Expand Up @@ -4150,7 +4147,7 @@ IdentifierResolveResult QueryAnalyzer::tryResolveIdentifier(const IdentifierLook
/// To accomplish this behaviour it's not allowed to resolve identifiers to
/// CTE that is being resolved.
if (cte_query_node_it != scope.cte_name_to_query_node.end()
&& !cte_in_resolve_process.contains(full_name))
&& !ctes_in_resolve_process.contains(full_name))
{
resolve_result.resolved_identifier = cte_query_node_it->second;
resolve_result.resolve_place = IdentifierResolvePlace::CTE;
Expand Down Expand Up @@ -6296,14 +6293,14 @@ ProjectionNames QueryAnalyzer::resolveExpressionNode(QueryTreeNodePtr & node, Id
///
/// In this example argument of function `in` is being resolve here. If CTE `test1` is not forbidden,
/// `test1` is resolved to CTE (not to the table) in `initializeQueryJoinTreeNode` function.
cte_in_resolve_process.insert(cte_name);
ctes_in_resolve_process.insert(cte_name);

if (subquery_node)
resolveQuery(resolved_identifier_node, subquery_scope);
else
resolveUnion(resolved_identifier_node, subquery_scope);

cte_in_resolve_process.erase(cte_name);
ctes_in_resolve_process.erase(cte_name);
}
}
}
Expand Down Expand Up @@ -7874,7 +7871,7 @@ void QueryAnalyzer::resolveQuery(const QueryTreeNodePtr & query_node, Identifier
auto & query_node_typed = query_node->as<QueryNode &>();

if (query_node_typed.isCTE())
cte_in_resolve_process.insert(query_node_typed.getCTEName());
ctes_in_resolve_process.insert(query_node_typed.getCTEName());

bool is_rollup_or_cube = query_node_typed.isGroupByWithRollup() || query_node_typed.isGroupByWithCube();

Expand Down Expand Up @@ -7956,7 +7953,6 @@ void QueryAnalyzer::resolveQuery(const QueryTreeNodePtr & query_node, Identifier
auto * union_node = node->as<UnionNode>();

bool subquery_is_cte = (subquery_node && subquery_node->isCTE()) || (union_node && union_node->isCTE());

if (!subquery_is_cte)
continue;

Expand Down Expand Up @@ -8213,21 +8209,64 @@ void QueryAnalyzer::resolveQuery(const QueryTreeNodePtr & query_node, Identifier
query_node_typed.resolveProjectionColumns(std::move(projection_columns));

if (query_node_typed.isCTE())
cte_in_resolve_process.erase(query_node_typed.getCTEName());
ctes_in_resolve_process.erase(query_node_typed.getCTEName());
}

void QueryAnalyzer::resolveUnion(const QueryTreeNodePtr & union_node, IdentifierResolveScope & scope)
{
auto & union_node_typed = union_node->as<UnionNode &>();

if (union_node_typed.isCTE())
cte_in_resolve_process.insert(union_node_typed.getCTEName());
ctes_in_resolve_process.insert(union_node_typed.getCTEName());

auto & queries_nodes = union_node_typed.getQueries().getNodes();

for (auto & query_node : queries_nodes)
std::optional<RecursiveCTETable> recursive_cte_table;
TableNodePtr recursive_cte_table_node;

if (union_node_typed.isCTE() && union_node_typed.isRecursiveCTE())
{
auto & non_recursive_query = queries_nodes[0];
bool non_recursive_query_is_query_node = non_recursive_query->getNodeType() == QueryTreeNodeType::QUERY;
auto & non_recursive_query_mutable_context = non_recursive_query_is_query_node ? non_recursive_query->as<QueryNode &>().getMutableContext()
: non_recursive_query->as<UnionNode &>().getMutableContext();

IdentifierResolveScope non_recursive_subquery_scope(non_recursive_query, &scope /*parent_scope*/);
non_recursive_subquery_scope.subquery_depth = scope.subquery_depth + 1;

if (non_recursive_query_is_query_node)
resolveQuery(non_recursive_query, non_recursive_subquery_scope);
else
resolveUnion(non_recursive_query, non_recursive_subquery_scope);

auto temporary_table_columns = non_recursive_query_is_query_node
? non_recursive_query->as<QueryNode &>().getProjectionColumns()
: non_recursive_query->as<UnionNode &>().computeProjectionColumns();

auto temporary_table_holder = std::make_shared<TemporaryTableHolder>(
non_recursive_query_mutable_context,
ColumnsDescription{NamesAndTypesList{temporary_table_columns.begin(), temporary_table_columns.end()}},
ConstraintsDescription{},
nullptr /*query*/,
true /*create_for_global_subquery*/);
auto temporary_table_storage = temporary_table_holder->getTable();

recursive_cte_table_node = std::make_shared<TableNode>(temporary_table_storage, non_recursive_query_mutable_context);
recursive_cte_table_node->setTemporaryTableName(union_node_typed.getCTEName());

recursive_cte_table.emplace(std::move(temporary_table_holder), std::move(temporary_table_storage), std::move(temporary_table_columns));
}

size_t queries_nodes_size = queries_nodes.size();
for (size_t i = recursive_cte_table.has_value(); i < queries_nodes_size; ++i)
{
auto & query_node = queries_nodes[i];

IdentifierResolveScope subquery_scope(query_node, &scope /*parent_scope*/);

if (recursive_cte_table_node)
subquery_scope.expression_argument_name_to_node[union_node_typed.getCTEName()] = recursive_cte_table_node;

auto query_node_type = query_node->getNodeType();

if (query_node_type == QueryTreeNodeType::QUERY)
Expand All @@ -8247,8 +8286,19 @@ void QueryAnalyzer::resolveUnion(const QueryTreeNodePtr & union_node, Identifier
}
}

if (recursive_cte_table && isStorageUsedInTree(recursive_cte_table->storage, union_node.get()))
{
if (union_node_typed.getUnionMode() != SelectUnionMode::UNION_ALL)
throw Exception(ErrorCodes::UNSUPPORTED_METHOD,
"Recursive CTE subquery {} with {} union mode is unsupported, only UNION ALL union mode is supported",
union_node_typed.formatASTForErrorMessage(),
toString(union_node_typed.getUnionMode()));

union_node_typed.setRecursiveCTETable(std::move(*recursive_cte_table));
}

if (union_node_typed.isCTE())
cte_in_resolve_process.erase(union_node_typed.getCTEName());
ctes_in_resolve_process.erase(union_node_typed.getCTEName());
}

}
Expand Down
69 changes: 56 additions & 13 deletions src/Analyzer/QueryNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@

#include <Parsers/ASTExpressionList.h>
#include <Parsers/ASTTablesInSelectQuery.h>
#include <Parsers/ASTWithElement.h>
#include <Parsers/ASTSubquery.h>
#include <Parsers/ASTSelectQuery.h>
#include <Parsers/ASTSelectWithUnionQuery.h>
#include <Parsers/ASTSetQuery.h>

#include <Analyzer/Utils.h>
#include <Analyzer/UnionNode.h>

namespace DB
{
Expand Down Expand Up @@ -107,6 +109,9 @@ void QueryNode::dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, s
if (is_cte)
buffer << ", is_cte: " << is_cte;

if (is_recursive_with)
buffer << ", is_recursive_with: " << is_recursive_with;

if (is_distinct)
buffer << ", is_distinct: " << is_distinct;

Expand Down Expand Up @@ -259,6 +264,7 @@ bool QueryNode::isEqualImpl(const IQueryTreeNode & rhs, CompareOptions) const

return is_subquery == rhs_typed.is_subquery &&
is_cte == rhs_typed.is_cte &&
is_recursive_with == rhs_typed.is_recursive_with &&
is_distinct == rhs_typed.is_distinct &&
is_limit_with_ties == rhs_typed.is_limit_with_ties &&
is_group_by_with_totals == rhs_typed.is_group_by_with_totals &&
Expand Down Expand Up @@ -291,6 +297,7 @@ void QueryNode::updateTreeHashImpl(HashState & state, CompareOptions) const
state.update(projection_column_type_name);
}

state.update(is_recursive_with);
state.update(is_distinct);
state.update(is_limit_with_ties);
state.update(is_group_by_with_totals);
Expand All @@ -317,26 +324,28 @@ QueryTreeNodePtr QueryNode::cloneImpl() const
{
auto result_query_node = std::make_shared<QueryNode>(context);

result_query_node->is_subquery = is_subquery;
result_query_node->is_cte = is_cte;
result_query_node->is_distinct = is_distinct;
result_query_node->is_limit_with_ties = is_limit_with_ties;
result_query_node->is_group_by_with_totals = is_group_by_with_totals;
result_query_node->is_group_by_with_rollup = is_group_by_with_rollup;
result_query_node->is_group_by_with_cube = is_group_by_with_cube;
result_query_node->is_subquery = is_subquery;
result_query_node->is_cte = is_cte;
result_query_node->is_recursive_with = is_recursive_with;
result_query_node->is_distinct = is_distinct;
result_query_node->is_limit_with_ties = is_limit_with_ties;
result_query_node->is_group_by_with_totals = is_group_by_with_totals;
result_query_node->is_group_by_with_rollup = is_group_by_with_rollup;
result_query_node->is_group_by_with_cube = is_group_by_with_cube;
result_query_node->is_group_by_with_grouping_sets = is_group_by_with_grouping_sets;
result_query_node->is_group_by_all = is_group_by_all;
result_query_node->is_order_by_all = is_order_by_all;
result_query_node->cte_name = cte_name;
result_query_node->projection_columns = projection_columns;
result_query_node->settings_changes = settings_changes;
result_query_node->is_group_by_all = is_group_by_all;
result_query_node->is_order_by_all = is_order_by_all;
result_query_node->cte_name = cte_name;
result_query_node->projection_columns = projection_columns;
result_query_node->settings_changes = settings_changes;

return result_query_node;
}

ASTPtr QueryNode::toASTImpl(const ConvertToASTOptions & options) const
{
auto select_query = std::make_shared<ASTSelectQuery>();
select_query->recursive_with = is_recursive_with;
select_query->distinct = is_distinct;
select_query->limit_with_ties = is_limit_with_ties;
select_query->group_by_with_totals = is_group_by_with_totals;
Expand All @@ -347,7 +356,41 @@ ASTPtr QueryNode::toASTImpl(const ConvertToASTOptions & options) const
select_query->order_by_all = is_order_by_all;

if (hasWith())
select_query->setExpression(ASTSelectQuery::Expression::WITH, getWith().toAST(options));
{
const auto & with = getWith();
auto expression_list_ast = std::make_shared<ASTExpressionList>();
expression_list_ast->children.reserve(with.getNodes().size());

for (const auto & with_node : with)
{
auto with_node_ast = with_node->toAST(options);
expression_list_ast->children.push_back(with_node_ast);

const auto * with_query_node = with_node->as<QueryNode>();
const auto * with_union_node = with_node->as<UnionNode>();
if (!with_query_node && !with_union_node)
continue;

bool is_with_node_cte = with_query_node ? with_query_node->isCTE() : with_union_node->isCTE();
if (!is_with_node_cte)
continue;

const auto & with_node_cte_name = with_query_node ? with_query_node->cte_name : with_union_node->getCTEName();

auto * with_node_ast_subquery = with_node_ast->as<ASTSubquery>();
if (with_node_ast_subquery)
with_node_ast_subquery->cte_name = "";

auto with_element_ast = std::make_shared<ASTWithElement>();
with_element_ast->name = with_node_cte_name;
with_element_ast->subquery = std::move(with_node_ast);
with_element_ast->children.push_back(with_element_ast->subquery);

expression_list_ast->children.back() = std::move(with_element_ast);
}

select_query->setExpression(ASTSelectQuery::Expression::WITH, std::move(expression_list_ast));
}

auto projection_ast = getProjection().toAST(options);
auto & projection_expression_list_ast = projection_ast->as<ASTExpressionList &>();
Expand Down
13 changes: 13 additions & 0 deletions src/Analyzer/QueryNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,18 @@ class QueryNode final : public IQueryTreeNode
cte_name = std::move(cte_name_value);
}

/// Returns true if query node has RECURSIVE WITH, false otherwise
bool isRecursiveWith() const
{
return is_recursive_with;
}

/// Set query node RECURSIVE WITH value
void setIsRecursiveWith(bool is_recursive_with_value)
{
is_recursive_with = is_recursive_with_value;
}

/// Returns true if query node has DISTINCT, false otherwise
bool isDistinct() const
{
Expand Down Expand Up @@ -618,6 +630,7 @@ class QueryNode final : public IQueryTreeNode
private:
bool is_subquery = false;
bool is_cte = false;
bool is_recursive_with = false;
bool is_distinct = false;
bool is_limit_with_ties = false;
bool is_group_by_with_totals = false;
Expand Down

0 comments on commit dffcc51

Please sign in to comment.