Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,14 @@ private Plan inferNewPredicate(Plan plan, Set<Expression> expressions) {
Set<Expression> predicates = new LinkedHashSet<>();
Set<Slot> planOutputs = plan.getOutputSet();
for (Expression expr : expressions) {
if (expr.containsUniqueFunction()) {
// Non-deterministic expressions (e.g. rand(), uuid()) must not be cloned into
// subtrees that did not already evaluate them. Otherwise, callers that perform
// slot substitution (e.g. SetOp visitors below) would introduce a fresh
// per-row evaluation of the unique function on a sibling branch, changing
// query semantics (see EXCEPT/INTERSECT regression cases).
continue;
}
Set<Slot> slots = expr.getInputSlots();
if (!slots.isEmpty() && planOutputs.containsAll(slots)) {
predicates.add(expr);
Expand All @@ -242,6 +250,11 @@ private Plan inferNewPredicateRemoveUselessIsNull(Plan plan, Set<Expression> exp
Set<Expression> predicates = new LinkedHashSet<>();
Set<Slot> planOutputs = plan.getOutputSet();
for (Expression expr : expressions) {
if (expr.containsUniqueFunction()) {
// See inferNewPredicate for rationale: never clone non-deterministic
// predicates into a subtree that did not already evaluate them.
continue;
}
Set<Slot> slots = expr.getInputSlots();
if (slots.isEmpty() || !planOutputs.containsAll(slots)) {
continue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,13 @@ public Expression visit(Expression expression, ReplacerContext ctx) {
if (input.isEmpty() || expression instanceof Slot) {
return expression;
}
// A mixed expression like `t1.a + rand() > t2.b` has inputSlots={t1.a}; if we alias
// it into a child Project, rand()'s evaluation granularity changes from "per join
// pair" to "per row of that child", which silently changes results. Keep such
// expressions inline in otherJoinConjuncts.
if (expression.containsUniqueFunction()) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This guard stops PROJECT_OTHER_JOIN_CONDITION from pushing the expression into a child project, but the same semantic change still happens later in AddProjectForUniqueFunction.JoinRewrite. The updated regression-test/data/nereids_rules_p0/unique_function/add_project_for_unique_function.out still shows t1.id + t2.id + random(1, 100) between 10 and 20 becoming a PhysicalProject[random(...) AS ...] under the NLJ, so random() is evaluated once per left row instead of once per join pair. Please guard that later join rewrite too (or otherwise preserve per-pair evaluation), and add a regression that exercises the BETWEEN / duplicated-UniqueFunction path.`

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replied in the main PR comment: this is a pre-existing trade-off in AddProjectForUniqueFunction.JoinRewrite caused by BETWEEN expansion into two independent rand() calls, and is intentionally out of scope for this PR. Please see the main comment for the full reasoning.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This guard is still too local to preserve join-pair semantics. ProjectOtherJoinConditionForNestedLoopJoin now stops aliasing the whole mixed subtree here, but later AddProjectForUniqueFunction.JoinRewrite still scans otherJoinConjuncts and hoists repeated UniqueFunctions into the left child project. The updated add_project_for_unique_function.out already shows that for t1.id + t2.id + random() between 10 and 20: $_random_7_$ is materialized on the left side, so the same random draw is reused for every right row of a given left row. That is still different from evaluating the ON predicate once per join pair.

Can we also block the later join-side materialization for unique functions that live inside mixed-side otherJoinConjuncts (or recurse here and only skip aliasing the exact unique-function subtree)? Otherwise this PR still leaves wrong-results cases in NLJ.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replied in the main PR comment: this is a pre-existing trade-off in AddProjectForUniqueFunction.JoinRewrite caused by BETWEEN expansion into two independent rand() calls, and is intentionally out of scope for this PR. Please see the main comment for the full reasoning.

return expression;
}
if (ctx.leftSlots.containsAll(input)) {
Alias alias = ctx.aliasMap.computeIfAbsent(expression, o -> new Alias(o));
ctx.leftAlias.add(alias);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.EmptyRelation;
import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation;
Expand All @@ -44,6 +45,7 @@

import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand All @@ -61,11 +63,47 @@ public Rule build() {
.when(s -> s.arity() > 0
|| (s instanceof LogicalUnion && !((LogicalUnion) s).getConstantExprsList().isEmpty())))
.thenApply(ctx -> {
LogicalFilter<LogicalSetOperation> filter = ctx.root;
LogicalSetOperation setOperation = filter.child();
LogicalFilter<LogicalSetOperation> origFilter = ctx.root;
LogicalSetOperation setOperation = origFilter.child();

// Pushing a conjunct that contains a UniqueFunction (rand/uuid/random_bytes/...)
// into each branch changes semantics for every set-op except UNION ALL.
// - UNION ALL: each branch row = exactly one output row (1:1), so evaluating
// rand() once per branch row still matches the per-output-row semantic.
// - UNION DISTINCT / INTERSECT / EXCEPT: the set-op semantics depend on the
// full branch row sets before dedup/intersect/except. Sampling rows in each
// branch independently changes which rows participate (e.g. INTERSECT becomes
// "half of A intersect half of B" instead of "half of (A intersect B)").
boolean canPushUniqueFn = setOperation instanceof LogicalUnion
&& setOperation.getQualifier() == Qualifier.ALL;
Set<Expression> pushableConjuncts;
Set<Expression> keptAboveConjuncts;
if (canPushUniqueFn) {
pushableConjuncts = origFilter.getConjuncts();
keptAboveConjuncts = ImmutableSet.of();
} else {
pushableConjuncts = new LinkedHashSet<>();
Set<Expression> kept = new LinkedHashSet<>();
for (Expression c : origFilter.getConjuncts()) {
if (c.containsUniqueFunction()) {
kept.add(c);
} else {
pushableConjuncts.add(c);
}
}
keptAboveConjuncts = kept;
if (pushableConjuncts.isEmpty()) {
return null;
}
}
LogicalFilter<LogicalSetOperation> filter = pushableConjuncts == origFilter.getConjuncts()
? origFilter
: new LogicalFilter<>(ImmutableSet.copyOf(pushableConjuncts), setOperation);

List<Plan> newChildren = new ArrayList<>();
List<List<SlotReference>> newRegularChildrenOutputs = Lists.newArrayList();
CascadesContext cascadesContext = ctx.cascadesContext;
Plan rewritten;
if (setOperation instanceof LogicalUnion) {
List<List<NamedExpression>> constantExprs = ((LogicalUnion) setOperation).getConstantExprsList();
StatementContext statementContext = ctx.statementContext;
Expand All @@ -85,7 +123,7 @@ public Rule build() {

List<NamedExpression> setOutputs = setOperation.getOutputs();
if (newChildren.isEmpty() && newConstantExprs.isEmpty()) {
return new LogicalEmptyRelation(
rewritten = new LogicalEmptyRelation(
statementContext.getNextRelationId(), setOutputs
);
} else if (newChildren.isEmpty() && newConstantExprs.size() == 1) {
Expand All @@ -104,27 +142,32 @@ public Rule build() {
}
newOneRowRelationOutput.add(oneRowRelationOutput);
}
return new LogicalOneRowRelation(
rewritten = new LogicalOneRowRelation(
ctx.statementContext.getNextRelationId(), newOneRowRelationOutput.build()
);
}
} else {
Builder<List<SlotReference>> newChildrenOutput
= ImmutableList.builderWithExpectedSize(newChildren.size());
for (Plan newChild : newChildren) {
newChildrenOutput.add((List) newChild.getOutput());
}

Builder<List<SlotReference>> newChildrenOutput
= ImmutableList.builderWithExpectedSize(newChildren.size());
for (Plan newChild : newChildren) {
newChildrenOutput.add((List) newChild.getOutput());
rewritten = ((LogicalUnion) setOperation).withChildrenAndConstExprsList(
newChildren, newRegularChildrenOutputs, newConstantExprs);
}

return ((LogicalUnion) setOperation).withChildrenAndConstExprsList(
newChildren, newRegularChildrenOutputs, newConstantExprs);
} else {
addFiltersToNewChildren(setOperation, filter, setOperation.children(),
setOperation.getRegularChildrenOutputs(),
cascadesContext, newChildren, newRegularChildrenOutputs, null,
(rowIndex, columnIndex) -> setOperation.getRegularChildOutput(rowIndex).get(columnIndex),
Function.identity());
rewritten = setOperation.withChildren(newChildren);
}

addFiltersToNewChildren(setOperation, filter, setOperation.children(),
setOperation.getRegularChildrenOutputs(),
cascadesContext, newChildren, newRegularChildrenOutputs, null,
(rowIndex, columnIndex) -> setOperation.getRegularChildOutput(rowIndex).get(columnIndex),
Function.identity());
return setOperation.withChildren(newChildren);
if (keptAboveConjuncts.isEmpty()) {
return rewritten;
}
return new LogicalFilter<>(ImmutableSet.copyOf(keptAboveConjuncts), rewritten);
}).toRule(RuleType.PUSH_DOWN_FILTER_THROUGH_SET_OPERATION);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,13 +179,13 @@ PhysicalResultSink
-- !join_1 --
PhysicalResultSink
--PhysicalProject[t1.id, t1.msg, t2.id, t2.msg]
----NestedLoopJoin[INNER_JOIN](((cast(id as BIGINT) + cast(id as BIGINT)) + $_random_9_$) >= 10)(((cast(id as BIGINT) + cast(id as BIGINT)) + $_random_9_$) <= 20)
------PhysicalProject[cast(id as BIGINT) AS `cast(id as BIGINT)`, random(1, 100) AS `$_random_9_$`, t1.id, t1.msg]
--------filter(($_random_10_$ <= 10) and ($_random_10_$ >= 1))
----------PhysicalProject[random(1, 100) AS `$_random_10_$`, t1.id, t1.msg]
----NestedLoopJoin[INNER_JOIN](((cast(id as BIGINT) + cast(id as BIGINT)) + $_random_7_$) >= 10)(((cast(id as BIGINT) + cast(id as BIGINT)) + $_random_7_$) <= 20)
------PhysicalProject[random(1, 100) AS `$_random_7_$`, t1.id, t1.msg]
--------filter(($_random_8_$ <= 10) and ($_random_8_$ >= 1))
----------PhysicalProject[random(1, 100) AS `$_random_8_$`, t1.id, t1.msg]
------------PhysicalOlapScan[t1]
------PhysicalProject[cast(id as BIGINT) AS `cast(id as BIGINT)`, t2.id, t2.msg]
--------filter(((cast(id as BIGINT) * $_random_11_$) <= 200) and ((cast(id as BIGINT) * $_random_11_$) >= 100))
----------PhysicalProject[random(1, 100) AS `$_random_11_$`, t2.id, t2.msg]
------PhysicalProject[t2.id, t2.msg]
--------filter(((cast(id as BIGINT) * $_random_9_$) <= 200) and ((cast(id as BIGINT) * $_random_9_$) >= 100))
----------PhysicalProject[random(1, 100) AS `$_random_9_$`, t2.id, t2.msg]
------------PhysicalOlapScan[t2]

Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !except_with_rand --
PhysicalResultSink
--PhysicalExcept
----PhysicalProject[t1.id]
------filter(((cast(id as DOUBLE) + random()) > 5.0))
--------PhysicalOlapScan[t1]
----PhysicalProject[t2.id]
------PhysicalOlapScan[t2]

-- !intersect_with_rand --
PhysicalResultSink
--PhysicalIntersect
----PhysicalProject[t1.id]
------filter(((cast(id as DOUBLE) + random()) > 5.0))
--------PhysicalOlapScan[t1]
----PhysicalProject[t2.id]
------PhysicalOlapScan[t2]

-- !except_with_uuid --
PhysicalResultSink
--PhysicalExcept
----PhysicalProject[t1.id]
------filter(((uuid_to_int(uuid()) + cast(id as LARGEINT)) > 5))
--------PhysicalOlapScan[t1]
----PhysicalProject[t2.id]
------PhysicalOlapScan[t2]

-- !intersect_with_uuid --
PhysicalResultSink
--PhysicalIntersect
----PhysicalProject[t1.id]
------filter(((uuid_to_int(uuid()) + cast(id as LARGEINT)) > 5))
--------PhysicalOlapScan[t1]
----PhysicalProject[t2.id]
------PhysicalOlapScan[t2]

-- !except_deterministic --
PhysicalResultSink
--PhysicalExcept
----PhysicalProject[t1.id]
------filter((t1.id > 4))
--------PhysicalOlapScan[t1]
----PhysicalProject[t2.id]
------filter((t2.id > 4))
--------PhysicalOlapScan[t2]

Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !left_join_rand_one_side --
PhysicalResultSink
--NestedLoopJoin[LEFT_OUTER_JOIN]((cast(id as DOUBLE) + random()) < cast(id as DOUBLE))
----PhysicalProject[t1.id]
------PhysicalOlapScan[t1]
----PhysicalProject[t2.id]
------PhysicalOlapScan[t2]

-- !cross_rand_one_side --
PhysicalResultSink
--NestedLoopJoin[INNER_JOIN]((cast(id as DOUBLE) + random()) > cast(id as DOUBLE))
----PhysicalProject[t1.id]
------PhysicalOlapScan[t1]
----PhysicalProject[t2.id]
------PhysicalOlapScan[t2]

-- !cross_rand_both_sides --
PhysicalResultSink
--NestedLoopJoin[INNER_JOIN]((cast(id as DOUBLE) + random()) > (cast(id as DOUBLE) + random()))
----PhysicalProject[t1.id]
------PhysicalOlapScan[t1]
----PhysicalProject[t2.id]
------PhysicalOlapScan[t2]

Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !union_distinct_keep_rand --
PhysicalResultSink
--filter((random() > 0.1))
----hashAgg[GLOBAL]
------hashAgg[LOCAL]
--------PhysicalUnion
----------PhysicalProject[t1.id]
------------PhysicalOlapScan[t1]
----------PhysicalProject[t2.id]
------------PhysicalOlapScan[t2]

-- !intersect_keep_rand --
PhysicalResultSink
--filter((random() > 0.1))
----PhysicalIntersect
------PhysicalProject[t1.id]
--------PhysicalOlapScan[t1]
------PhysicalProject[t2.id]
--------PhysicalOlapScan[t2]

-- !except_keep_rand --
PhysicalResultSink
--filter((random() > 0.1))
----PhysicalExcept
------PhysicalProject[t1.id]
--------PhysicalOlapScan[t1]
------PhysicalProject[t2.id]
--------PhysicalOlapScan[t2]

-- !union_all_push_rand --
PhysicalResultSink
--PhysicalUnion
----PhysicalProject[t1.id]
------filter((random() > 0.1))
--------PhysicalOlapScan[t1]
----PhysicalProject[t2.id]
------filter((random() > 0.1))
--------PhysicalOlapScan[t2]

-- !union_distinct_split --
PhysicalResultSink
--filter((random() > 0.1))
----PhysicalLimit[GLOBAL]
------PhysicalUnion
--------PhysicalProject[t1.id]
----------filter((t1.id = 1))
------------PhysicalOlapScan[t1]
--------PhysicalProject[t2.id]
----------filter((t2.id = 1))
------------PhysicalOlapScan[t2]

-- !intersect_split --
PhysicalResultSink
--filter((random() > 0.1))
----PhysicalIntersect
------PhysicalProject[t1.id]
--------filter((t1.id = 1))
----------PhysicalOlapScan[t1]
------PhysicalProject[t2.id]
--------filter((t2.id = 1))
----------PhysicalOlapScan[t2]

-- !except_split --
PhysicalResultSink
--filter((random() > 0.1))
----PhysicalExcept
------PhysicalProject[t1.id]
--------filter((t1.id = 1))
----------PhysicalOlapScan[t1]
------PhysicalProject[t2.id]
--------filter((t2.id = 1))
----------PhysicalOlapScan[t2]

Loading
Loading