Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite equality with is null check in JOIN ON section #56538

Merged
merged 6 commits into from Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
212 changes: 212 additions & 0 deletions src/Analyzer/Passes/LogicalExpressionOptimizerPass.cpp
Expand Up @@ -5,11 +5,211 @@
#include <Analyzer/InDepthQueryTreeVisitor.h>
#include <Analyzer/FunctionNode.h>
#include <Analyzer/ConstantNode.h>
#include <Analyzer/JoinNode.h>
#include <Analyzer/HashUtils.h>
#include <Analyzer/Utils.h>

namespace DB
{

namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
}

/// Visitor that optimizes logical expressions _only_ in JOIN ON section
class JoinOnLogicalExpressionOptimizerVisitor : public InDepthQueryTreeVisitorWithContext<JoinOnLogicalExpressionOptimizerVisitor>
{
public:
using Base = InDepthQueryTreeVisitorWithContext<JoinOnLogicalExpressionOptimizerVisitor>;

explicit JoinOnLogicalExpressionOptimizerVisitor(ContextPtr context)
: Base(std::move(context))
{}

void enterImpl(QueryTreeNodePtr & node)
{
auto * function_node = node->as<FunctionNode>();

if (!function_node)
return;

if (function_node->getFunctionName() == "or")
{
bool is_argument_type_changed = tryOptimizeIsNotDistinctOrIsNull(node, getContext());
if (is_argument_type_changed)
need_rerun_resolve = true;
return;
}
}

void leaveImpl(QueryTreeNodePtr & node)
{
if (!need_rerun_resolve)
return;

if (auto * function_node = node->as<FunctionNode>())
rerunFunctionResolve(function_node, getContext());
}

private:
bool need_rerun_resolve = false;

/// Returns true if type of some operand is changed and parent function needs to be re-resolved
static bool tryOptimizeIsNotDistinctOrIsNull(QueryTreeNodePtr & node, const ContextPtr & context)
{
auto & function_node = node->as<FunctionNode &>();
assert(function_node.getFunctionName() == "or");
vdimir marked this conversation as resolved.
Show resolved Hide resolved

QueryTreeNodes or_operands;
vdimir marked this conversation as resolved.
Show resolved Hide resolved

/// Indices of `equals` or `isNotDistinctFrom` functions in the vector above
std::vector<size_t> equals_functions_indices;

/** Map from `isNull` argument to indices of operands that contains that `isNull` functions
* `a = b OR (a IS NULL AND b IS NULL) OR (a IS NULL AND c IS NULL)`
* will be mapped to
* {
* a => [(a IS NULL AND b IS NULL), (a IS NULL AND c IS NULL)]
* b => [(a IS NULL AND b IS NULL)]
* c => [(a IS NULL AND c IS NULL)]
* }
* Then for each a <=> b we can find all operands that contains both a IS NULL and b IS NULL
*/
QueryTreeNodePtrWithHashMap<std::vector<size_t>> is_null_argument_to_indices;

for (const auto & argument : function_node.getArguments())
{
or_operands.push_back(argument);

auto * argument_function = argument->as<FunctionNode>();
if (!argument_function)
continue;

const auto & func_name = argument_function->getFunctionName();
if (func_name == "equals" || func_name == "isNotDistinctFrom")
equals_functions_indices.push_back(or_operands.size() - 1);

if (func_name == "and")
vdimir marked this conversation as resolved.
Show resolved Hide resolved
{
for (const auto & and_argument : argument_function->getArguments().getNodes())
{
auto * and_argument_function = and_argument->as<FunctionNode>();
if (and_argument_function && and_argument_function->getFunctionName() == "isNull")
{
const auto & is_null_argument = and_argument_function->getArguments().getNodes()[0];
is_null_argument_to_indices[is_null_argument].push_back(or_operands.size() - 1);
}
}
}
}

/// OR operands that are changed to and needs to be re-resolved
std::unordered_set<size_t> arguments_to_reresolve;

for (size_t equals_function_idx : equals_functions_indices)
{
auto * equals_function = or_operands[equals_function_idx]->as<FunctionNode>();

/// For a <=> b we are looking for expressions containing both `a IS NULL` and `b IS NULL` combined with AND
const auto & argument_nodes = equals_function->getArguments().getNodes();
const auto & lhs_is_null_parents = is_null_argument_to_indices[argument_nodes[0]];
const auto & rhs_is_null_parents = is_null_argument_to_indices[argument_nodes[1]];
std::unordered_set<size_t> operands_to_optimize;
std::set_intersection(lhs_is_null_parents.begin(), lhs_is_null_parents.end(),
rhs_is_null_parents.begin(), rhs_is_null_parents.end(),
std::inserter(operands_to_optimize, operands_to_optimize.begin()));

/// If we have `a = b OR (a IS NULL AND b IS NULL)` we can optimize it to `a <=> b`
if (!operands_to_optimize.empty() && equals_function->getFunctionName() == "equals")
arguments_to_reresolve.insert(equals_function_idx);

for (size_t to_optimize_idx : operands_to_optimize)
{
/// We are looking for operand `a IS NULL AND b IS NULL AND ...`
auto * operand_to_optimize = or_operands[to_optimize_idx]->as<FunctionNode>();

/// Remove `a IS NULL` and `b IS NULL` arguments from AND
QueryTreeNodes new_arguments;
for (const auto & and_argument : operand_to_optimize->getArguments().getNodes())
{
bool to_eliminate = false;

const auto * and_argument_function = and_argument->as<FunctionNode>();
if (and_argument_function && and_argument_function->getFunctionName() == "isNull")
{
const auto & is_null_argument = and_argument_function->getArguments().getNodes()[0];
to_eliminate = (is_null_argument->isEqual(*argument_nodes[0]) || is_null_argument->isEqual(*argument_nodes[1]));
}

if (to_eliminate)
arguments_to_reresolve.insert(to_optimize_idx);
else
new_arguments.emplace_back(and_argument);
}
/// If less than two arguments left, we will remove or replace the whole AND below
operand_to_optimize->getArguments().getNodes() = std::move(new_arguments);
}
}

if (arguments_to_reresolve.empty())
/// Nothing have been changed
return false;

auto and_function_resolver = FunctionFactory::instance().get("and", context);
auto strict_equals_function_resolver = FunctionFactory::instance().get("isNotDistinctFrom", context);

bool need_reresolve = false;
QueryTreeNodes new_or_operands;
for (size_t i = 0; i < or_operands.size(); ++i)
{
if (arguments_to_reresolve.contains(i))
{
auto * function = or_operands[i]->as<FunctionNode>();
if (function->getFunctionName() == "equals")
{
/// We should replace `a = b` with `a <=> b` because we removed checks for IS NULL
need_reresolve = need_reresolve || function->getResultType()->isNullable();
vdimir marked this conversation as resolved.
Show resolved Hide resolved
function->resolveAsFunction(strict_equals_function_resolver);
new_or_operands.emplace_back(std::move(or_operands[i]));
}
else if (function->getFunctionName() == "and")
{
const auto & and_arguments = function->getArguments().getNodes();
if (and_arguments.size() > 1)
{
function->resolveAsFunction(and_function_resolver);
new_or_operands.emplace_back(std::move(or_operands[i]));
}
else if (and_arguments.size() == 1)
{
/// Replace AND with a single argument with the argument itself
new_or_operands.emplace_back(and_arguments[0]);
}
}
else
throw Exception(ErrorCodes::LOGICAL_ERROR, "Unexpected function name: '{}'", function->getFunctionName());
}
else
{
new_or_operands.emplace_back(std::move(or_operands[i]));
}
}

if (new_or_operands.size() == 1)
{
node = std::move(new_or_operands[0]);
return need_reresolve;
}

/// Rebuild OR function
auto or_function_resolver = FunctionFactory::instance().get("or", context);
function_node.getArguments().getNodes() = std::move(new_or_operands);
function_node.resolveAsFunction(or_function_resolver);
return need_reresolve;
}
};

class LogicalExpressionOptimizerVisitor : public InDepthQueryTreeVisitorWithContext<LogicalExpressionOptimizerVisitor>
{
public:
Expand All @@ -21,6 +221,17 @@ class LogicalExpressionOptimizerVisitor : public InDepthQueryTreeVisitorWithCont

void enterImpl(QueryTreeNodePtr & node)
{
if (auto * join_node = node->as<JoinNode>())
{
/// Operator <=> is not supported outside of JOIN ON section
if (join_node->hasJoinExpression())
{
JoinOnLogicalExpressionOptimizerVisitor join_on_visitor(getContext());
join_on_visitor.visit(join_node->getJoinExpression());
}
return;
}

auto * function_node = node->as<FunctionNode>();

if (!function_node)
Expand All @@ -38,6 +249,7 @@ class LogicalExpressionOptimizerVisitor : public InDepthQueryTreeVisitorWithCont
return;
}
}

private:
void tryReplaceAndEqualsChainsWithConstant(QueryTreeNodePtr & node)
{
Expand Down
11 changes: 11 additions & 0 deletions src/Analyzer/Passes/LogicalExpressionOptimizerPass.h
Expand Up @@ -67,6 +67,17 @@ namespace DB
* FROM TABLE
* WHERE a = 1 AND b = 'test';
* -------------------------------
*
* 5. Remove unnecessary IS NULL checks in JOIN ON clause
* - equality check with explicit IS NULL check replaced with <=> operator
* -------------------------------
* SELECT * FROM t1 JOIN t2 ON a = b OR (a IS NULL AND b IS NULL)
* SELECT * FROM t1 JOIN t2 ON a <=> b OR (a IS NULL AND b IS NULL)
*
* will be transformed into
*
* SELECT * FROM t1 JOIN t2 ON a <=> b
* -------------------------------
*/

class LogicalExpressionOptimizerPass final : public IQueryTreePass
Expand Down
@@ -0,0 +1,25 @@
-- { echoOn }
SELECT * FROM t1 JOIN t2 ON (t1.x <=> t2.x OR (t1.x IS NULL AND t2.x IS NULL)) ORDER BY t1.x NULLS LAST;
2 2 2 2
3 3 3 33
\N \N \N \N
SELECT * FROM t1 JOIN t2 ON (t1.x <=> t2.x OR t1.x IS NULL AND t1.y <=> t2.y AND t2.x IS NULL) ORDER BY t1.x NULLS LAST;
1 42 4 42
2 2 2 2
3 3 3 33
\N \N \N \N
SELECT * FROM t1 JOIN t2 ON (t1.x = t2.x OR t1.x IS NULL AND t2.x IS NULL) AND t1.y <=> t2.y ORDER BY t1.x NULLS LAST;
2 2 2 2
\N \N \N \N
SELECT * FROM t1 JOIN t2 ON (t1.x <=> t2.x OR t1.y <=> t2.y OR (t1.x IS NULL AND t1.y IS NULL AND t2.x IS NULL AND t2.y IS NULL)) ORDER BY t1.x NULLS LAST;
1 42 4 42
2 2 2 2
3 3 3 33
\N \N \N \N
SELECT * FROM t1 JOIN t2 ON (t1.x <=> t2.x OR (t1.x IS NULL AND t2.x IS NULL)) AND (t1.y == t2.y OR (t1.y IS NULL AND t2.y IS NULL)) AND COALESCE(t1.x, 0) != 2 ORDER BY t1.x NULLS LAST;
\N \N \N \N
SELECT x = y OR (x IS NULL AND y IS NULL) FROM t1 ORDER BY x NULLS LAST;
0
1
1
1
27 changes: 27 additions & 0 deletions tests/queries/0_stateless/02911_join_on_nullsafe_optimization.sql
@@ -0,0 +1,27 @@
DROP TABLE IF EXISTS t1;
DROP TABLE IF EXISTS t2;

CREATE TABLE t1 (x Nullable(Int64), y Nullable(UInt64)) ENGINE = TinyLog;
CREATE TABLE t2 (x Nullable(Int64), y Nullable(UInt64)) ENGINE = TinyLog;

INSERT INTO t1 VALUES (1,42), (2,2), (3,3), (NULL,NULL);
INSERT INTO t2 VALUES (NULL,NULL), (2,2), (3,33), (4,42);

SET allow_experimental_analyzer = 1;

-- { echoOn }
SELECT * FROM t1 JOIN t2 ON (t1.x <=> t2.x OR (t1.x IS NULL AND t2.x IS NULL)) ORDER BY t1.x NULLS LAST;

SELECT * FROM t1 JOIN t2 ON (t1.x <=> t2.x OR t1.x IS NULL AND t1.y <=> t2.y AND t2.x IS NULL) ORDER BY t1.x NULLS LAST;

SELECT * FROM t1 JOIN t2 ON (t1.x = t2.x OR t1.x IS NULL AND t2.x IS NULL) AND t1.y <=> t2.y ORDER BY t1.x NULLS LAST;

SELECT * FROM t1 JOIN t2 ON (t1.x <=> t2.x OR t1.y <=> t2.y OR (t1.x IS NULL AND t1.y IS NULL AND t2.x IS NULL AND t2.y IS NULL)) ORDER BY t1.x NULLS LAST;

SELECT * FROM t1 JOIN t2 ON (t1.x <=> t2.x OR (t1.x IS NULL AND t2.x IS NULL)) AND (t1.y == t2.y OR (t1.y IS NULL AND t2.y IS NULL)) AND COALESCE(t1.x, 0) != 2 ORDER BY t1.x NULLS LAST;

SELECT x = y OR (x IS NULL AND y IS NULL) FROM t1 ORDER BY x NULLS LAST;
-- { echoOff }

DROP TABLE IF EXISTS t1;
DROP TABLE IF EXISTS t2;