Skip to content

Commit

Permalink
Rerun resolve in JoinOnLogicalExpressionOptimizerVisitor
Browse files Browse the repository at this point in the history
  • Loading branch information
vdimir committed Nov 16, 2023
1 parent bdeb04f commit 6ad0e90
Showing 1 changed file with 27 additions and 8 deletions.
35 changes: 27 additions & 8 deletions src/Analyzer/Passes/LogicalExpressionOptimizerPass.cpp
Expand Up @@ -7,6 +7,7 @@
#include <Analyzer/ConstantNode.h>
#include <Analyzer/JoinNode.h>
#include <Analyzer/HashUtils.h>
#include <Analyzer/Utils.h>

namespace DB
{
Expand Down Expand Up @@ -35,13 +36,27 @@ class JoinOnLogicalExpressionOptimizerVisitor : public InDepthQueryTreeVisitorWi

if (function_node->getFunctionName() == "or")
{
tryOptimizeIsNotDistinctOrIsNull(node);
bool is_argument_type_changed = tryOptimizeIsNotDistinctOrIsNull(node, getContext());
if (is_argument_type_changed)
need_rerun_resolve = true;
return;
}
}

void leaveImpl(QueryTreeNodePtr & node)
{
if (!need_rerun_resolve)
return;

if (auto * function_node = node->as<FunctionNode>())
rerunFunctionResolve(function_node, getContext());
}

private:
void tryOptimizeIsNotDistinctOrIsNull(QueryTreeNodePtr & node)
bool need_rerun_resolve = false;

/// Returns true if type of some operand is changed and parent function needs to be re-resolved
static bool tryOptimizeIsNotDistinctOrIsNull(QueryTreeNodePtr & node, const ContextPtr & context)
{
auto & function_node = node->as<FunctionNode &>();
assert(function_node.getFunctionName() == "or");
Expand Down Expand Up @@ -139,10 +154,12 @@ class JoinOnLogicalExpressionOptimizerVisitor : public InDepthQueryTreeVisitorWi

if (arguments_to_reresolve.empty())
/// Nothing have been changed
return;
return false;

auto and_function_resolver = FunctionFactory::instance().get("and", getContext());
auto strict_equals_function_resolver = FunctionFactory::instance().get("isNotDistinctFrom", getContext());
auto and_function_resolver = FunctionFactory::instance().get("and", context);
auto strict_equals_function_resolver = FunctionFactory::instance().get("isNotDistinctFrom", context);

bool need_reresolve = false;
QueryTreeNodes new_or_operands;
for (size_t i = 0; i < or_operands.size(); ++i)
{
Expand All @@ -151,7 +168,8 @@ class JoinOnLogicalExpressionOptimizerVisitor : public InDepthQueryTreeVisitorWi
auto * function = or_operands[i]->as<FunctionNode>();
if (function->getFunctionName() == "equals")
{
/// Because we removed checks for IS NULL, we should replace `a = b` with `a <=> b`
/// We should replace `a = b` with `a <=> b` because we removed checks for IS NULL
need_reresolve = need_reresolve || function->getResultType()->isNullable();
function->resolveAsFunction(strict_equals_function_resolver);
new_or_operands.emplace_back(std::move(or_operands[i]));
}
Expand Down Expand Up @@ -181,13 +199,14 @@ class JoinOnLogicalExpressionOptimizerVisitor : public InDepthQueryTreeVisitorWi
if (new_or_operands.size() == 1)
{
node = std::move(new_or_operands[0]);
return;
return need_reresolve;
}

/// Rebuild OR function
auto or_function_resolver = FunctionFactory::instance().get("or", getContext());
auto or_function_resolver = FunctionFactory::instance().get("or", context);
function_node.getArguments().getNodes() = std::move(new_or_operands);
function_node.resolveAsFunction(or_function_resolver);
return need_reresolve;
}
};

Expand Down

0 comments on commit 6ad0e90

Please sign in to comment.