Skip to content

Commit

Permalink
Optimize equality with is null check in JOIN ON section
Browse files Browse the repository at this point in the history
  • Loading branch information
vdimir committed Nov 9, 2023
1 parent 0437b57 commit b726455
Show file tree
Hide file tree
Showing 4 changed files with 250 additions and 0 deletions.
187 changes: 187 additions & 0 deletions src/Analyzer/Passes/LogicalExpressionOptimizerPass.cpp
Expand Up @@ -5,11 +5,17 @@
#include <Analyzer/InDepthQueryTreeVisitor.h>
#include <Analyzer/FunctionNode.h>
#include <Analyzer/ConstantNode.h>
#include <Analyzer/JoinNode.h>
#include <Analyzer/HashUtils.h>

namespace DB
{

namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
}

class LogicalExpressionOptimizerVisitor : public InDepthQueryTreeVisitorWithContext<LogicalExpressionOptimizerVisitor>
{
public:
Expand All @@ -21,6 +27,15 @@ class LogicalExpressionOptimizerVisitor : public InDepthQueryTreeVisitorWithCont

void enterImpl(QueryTreeNodePtr & node)
{
if (auto * join_node = node->as<JoinNode>())
{
join_stack.push_back(join_node);
return;
}

if (!join_stack.empty() && join_stack.back()->getJoinExpression().get() == node.get())
is_inside_on_section = true;

auto * function_node = node->as<FunctionNode>();

if (!function_node)
Expand All @@ -29,6 +44,10 @@ class LogicalExpressionOptimizerVisitor : public InDepthQueryTreeVisitorWithCont
if (function_node->getFunctionName() == "or")
{
tryReplaceOrEqualsChainWithIn(node);

/// Operator <=> is not supported outside of JOIN ON section
if (is_inside_on_section)
tryOptimizeIsNotDistinctOrIsNull(node);
return;
}

Expand All @@ -38,6 +57,20 @@ class LogicalExpressionOptimizerVisitor : public InDepthQueryTreeVisitorWithCont
return;
}
}

void leaveImpl(QueryTreeNodePtr & node)
{
if (!join_stack.empty() && join_stack.back()->getJoinExpression().get() == node.get())
is_inside_on_section = false;

if (auto * join_node = node->as<JoinNode>())
{
assert(join_stack.back() == join_node);
join_stack.pop_back();
return;
}
}

private:
void tryReplaceAndEqualsChainsWithConstant(QueryTreeNodePtr & node)
{
Expand Down Expand Up @@ -231,6 +264,160 @@ class LogicalExpressionOptimizerVisitor : public InDepthQueryTreeVisitorWithCont
function_node.getArguments().getNodes() = std::move(or_operands);
function_node.resolveAsFunction(or_function_resolver);
}

void tryOptimizeIsNotDistinctOrIsNull(QueryTreeNodePtr & node)
{
auto & function_node = node->as<FunctionNode &>();
assert(function_node.getFunctionName() == "or");

QueryTreeNodes or_operands;

/// Indices of `equals` or `isNotDistinctFrom` functions in the vector above
std::vector<size_t> equals_functions_indices;

/** Map from `isNull` argument to indices of operands that contains that `isNull` functions
* `a = b OR (a IS NULL AND b IS NULL) OR (a IS NULL AND c IS NULL)`
* will be mapped to
* {
* a => [(a IS NULL AND b IS NULL), (a IS NULL AND c IS NULL)]
* b => [(a IS NULL AND b IS NULL)]
* c => [(a IS NULL AND c IS NULL)]
* }
* Then for each a <=> b we can find all operands that contains both a IS NULL and b IS NULL
*/
QueryTreeNodePtrWithHashMap<std::vector<size_t>> is_null_argument_to_indices;

for (const auto & argument : function_node.getArguments())
{
or_operands.push_back(argument);

auto * argument_function = argument->as<FunctionNode>();
if (!argument_function)
continue;

const auto & func_name = argument_function->getFunctionName();
if (func_name == "equals" || func_name == "isNotDistinctFrom")
equals_functions_indices.push_back(or_operands.size() - 1);

if (func_name == "and")
{
for (const auto & and_argument : argument_function->getArguments().getNodes())
{
auto * and_argument_function = and_argument->as<FunctionNode>();
if (and_argument_function && and_argument_function->getFunctionName() == "isNull")
{
const auto & is_null_argument = and_argument_function->getArguments().getNodes()[0];
is_null_argument_to_indices[is_null_argument].push_back(or_operands.size() - 1);
}
}
}
}

/// OR operands that are changed to and needs to be re-resolved
std::unordered_set<size_t> arguments_to_reresolve;

for (size_t equals_function_idx : equals_functions_indices)
{
auto * equals_function = or_operands[equals_function_idx]->as<FunctionNode>();

/// For a <=> b we are looking for expressions containing both `a IS NULL` and `b IS NULL` combined with AND
const auto & argument_nodes = equals_function->getArguments().getNodes();
const auto & lhs_is_null_parents = is_null_argument_to_indices[argument_nodes[0]];
const auto & rhs_is_null_parents = is_null_argument_to_indices[argument_nodes[1]];
std::unordered_set<size_t> operands_to_optimize;
std::set_intersection(lhs_is_null_parents.begin(), lhs_is_null_parents.end(),
rhs_is_null_parents.begin(), rhs_is_null_parents.end(),
std::inserter(operands_to_optimize, operands_to_optimize.begin()));

/// If we have `a = b OR (a IS NULL AND b IS NULL)` we can optimize it to `a <=> b`
if (!operands_to_optimize.empty() && equals_function->getFunctionName() == "equals")
arguments_to_reresolve.insert(equals_function_idx);

for (size_t to_optimize_idx : operands_to_optimize)
{
/// We are looking for operand `a IS NULL AND b IS NULL AND ...`
auto * operand_to_optimize = or_operands[to_optimize_idx]->as<FunctionNode>();

/// Remove `a IS NULL` and `b IS NULL` arguments from AND
QueryTreeNodes new_arguments;
for (const auto & and_argument : operand_to_optimize->getArguments().getNodes())
{
bool to_eliminate = false;

const auto * and_argument_function = and_argument->as<FunctionNode>();
if (and_argument_function && and_argument_function->getFunctionName() == "isNull")
{
const auto & is_null_argument = and_argument_function->getArguments().getNodes()[0];
to_eliminate = (is_null_argument->isEqual(*argument_nodes[0]) || is_null_argument->isEqual(*argument_nodes[1]));
}

if (to_eliminate)
arguments_to_reresolve.insert(to_optimize_idx);
else
new_arguments.emplace_back(and_argument);
}
/// If less than two arguments left, we will remove or replace the whole AND below
operand_to_optimize->getArguments().getNodes() = std::move(new_arguments);
}
}


if (arguments_to_reresolve.empty())
/// Nothing have been changed
return;

auto and_function_resolver = FunctionFactory::instance().get("and", getContext());
auto strict_equals_function_resolver = FunctionFactory::instance().get("isNotDistinctFrom", getContext());
QueryTreeNodes new_or_operands;
for (size_t i = 0; i < or_operands.size(); ++i)
{
if (arguments_to_reresolve.contains(i))
{
auto * function = or_operands[i]->as<FunctionNode>();
if (function->getFunctionName() == "equals")
{
/// Because we removed checks for IS NULL, we should replace `a = b` with `a <=> b`
function->resolveAsFunction(strict_equals_function_resolver);
new_or_operands.emplace_back(std::move(or_operands[i]));
}
else if (function->getFunctionName() == "and")
{
const auto & and_arguments = function->getArguments().getNodes();
if (and_arguments.size() > 1)
{
function->resolveAsFunction(and_function_resolver);
new_or_operands.emplace_back(std::move(or_operands[i]));
}
else if (and_arguments.size() == 1)
{
/// Replace AND with a single argument with the argument itself
new_or_operands.emplace_back(std::move(and_arguments[0]));
}
}
else
throw Exception(ErrorCodes::LOGICAL_ERROR, "Unexpected function name: '{}'", function->getFunctionName());
}
else
{
new_or_operands.emplace_back(std::move(or_operands[i]));
}
}

if (new_or_operands.size() == 1)
{
node = std::move(new_or_operands[0]);
return;
}

/// Rebuild OR function
auto or_function_resolver = FunctionFactory::instance().get("or", getContext());
function_node.getArguments().getNodes() = std::move(new_or_operands);
function_node.resolveAsFunction(or_function_resolver);
}

private:
bool is_inside_on_section = false;
std::deque<const JoinNode *> join_stack;
};

void LogicalExpressionOptimizerPass::run(QueryTreeNodePtr query_tree_node, ContextPtr context)
Expand Down
11 changes: 11 additions & 0 deletions src/Analyzer/Passes/LogicalExpressionOptimizerPass.h
Expand Up @@ -67,6 +67,17 @@ namespace DB
* FROM TABLE
* WHERE a = 1 AND b = 'test';
* -------------------------------
*
* 5. Remove unnecessary IS NULL checks in JOIN ON clause
* - equality check with explicit IS NULL check replaced with <=> operator
* -------------------------------
* SELECT * FROM t1 JOIN t2 ON a = b OR (a IS NULL AND b IS NULL)
* SELECT * FROM t1 JOIN t2 ON a <=> b OR (a IS NULL AND b IS NULL)
*
* will be transformed into
*
* SELECT * FROM t1 JOIN t2 ON a <=> b
* -------------------------------
*/

class LogicalExpressionOptimizerPass final : public IQueryTreePass
Expand Down
@@ -0,0 +1,25 @@
-- { echoOn }
SELECT * FROM t1 JOIN t2 ON (t1.x <=> t2.x OR (t1.x IS NULL AND t2.x IS NULL)) ORDER BY t1.x NULLS LAST;
2 2 2 2
3 3 3 33
\N \N \N \N
SELECT * FROM t1 JOIN t2 ON (t1.x <=> t2.x OR t1.x IS NULL AND t1.y <=> t2.y AND t2.x IS NULL) ORDER BY t1.x NULLS LAST;
1 42 4 42
2 2 2 2
3 3 3 33
\N \N \N \N
SELECT * FROM t1 JOIN t2 ON (t1.x = t2.x OR t1.x IS NULL AND t2.x IS NULL) AND t1.y <=> t2.y ORDER BY t1.x NULLS LAST;
2 2 2 2
\N \N \N \N
SELECT * FROM t1 JOIN t2 ON (t1.x <=> t2.x OR t1.y <=> t2.y OR (t1.x IS NULL AND t1.y IS NULL AND t2.x IS NULL AND t2.y IS NULL)) ORDER BY t1.x NULLS LAST;
1 42 4 42
2 2 2 2
3 3 3 33
\N \N \N \N
SELECT * FROM t1 JOIN t2 ON (t1.x <=> t2.x OR (t1.x IS NULL AND t2.x IS NULL)) AND (t1.y == t2.y OR (t1.y IS NULL AND t2.y IS NULL)) AND COALESCE(t1.x, 0) != 2 ORDER BY t1.x NULLS LAST;
\N \N \N \N
SELECT x = y OR (x IS NULL AND y IS NULL) FROM t1 ORDER BY x NULLS LAST;
0
1
1
1
27 changes: 27 additions & 0 deletions tests/queries/0_stateless/02911_join_on_nullsafe_optimization.sql
@@ -0,0 +1,27 @@
DROP TABLE IF EXISTS t1;
DROP TABLE IF EXISTS t2;

CREATE TABLE t1 (x Nullable(Int64), y Nullable(UInt64)) ENGINE = TinyLog;
CREATE TABLE t2 (x Nullable(Int64), y Nullable(UInt64)) ENGINE = TinyLog;

INSERT INTO t1 VALUES (1,42), (2,2), (3,3), (NULL,NULL);
INSERT INTO t2 VALUES (NULL,NULL), (2,2), (3,33), (4,42);

SET allow_experimental_analyzer = 1;

-- { echoOn }
SELECT * FROM t1 JOIN t2 ON (t1.x <=> t2.x OR (t1.x IS NULL AND t2.x IS NULL)) ORDER BY t1.x NULLS LAST;

SELECT * FROM t1 JOIN t2 ON (t1.x <=> t2.x OR t1.x IS NULL AND t1.y <=> t2.y AND t2.x IS NULL) ORDER BY t1.x NULLS LAST;

SELECT * FROM t1 JOIN t2 ON (t1.x = t2.x OR t1.x IS NULL AND t2.x IS NULL) AND t1.y <=> t2.y ORDER BY t1.x NULLS LAST;

SELECT * FROM t1 JOIN t2 ON (t1.x <=> t2.x OR t1.y <=> t2.y OR (t1.x IS NULL AND t1.y IS NULL AND t2.x IS NULL AND t2.y IS NULL)) ORDER BY t1.x NULLS LAST;

SELECT * FROM t1 JOIN t2 ON (t1.x <=> t2.x OR (t1.x IS NULL AND t2.x IS NULL)) AND (t1.y == t2.y OR (t1.y IS NULL AND t2.y IS NULL)) AND COALESCE(t1.x, 0) != 2 ORDER BY t1.x NULLS LAST;

SELECT x = y OR (x IS NULL AND y IS NULL) FROM t1 ORDER BY x NULLS LAST;
-- { echoOff }

DROP TABLE IF EXISTS t1;
DROP TABLE IF EXISTS t2;

0 comments on commit b726455

Please sign in to comment.