Skip to content

Commit

Permalink
[fix](Nereids) could not run query with repeat node in cte (#26330)
Browse files Browse the repository at this point in the history
ExpressionDeepCopier not process VirtualReference, so we generate inline
plan with mistake.
  • Loading branch information
morrySnow committed Nov 3, 2023
1 parent 9243de1 commit a89477e
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,15 @@
import org.apache.doris.nereids.trees.expressions.ScalarSubquery;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
import org.apache.doris.nereids.trees.expressions.VirtualSlotReference;
import org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingScalarFunction;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.plans.algebra.Repeat.GroupingSetShapes;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;

import com.google.common.base.Function;

import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -75,6 +81,30 @@ public Expression visitSlotReference(SlotReference slotReference, DeepCopierCont
}
}

@Override
public Expression visitVirtualReference(VirtualSlotReference virtualSlotReference, DeepCopierContext context) {
Map<ExprId, ExprId> exprIdReplaceMap = context.exprIdReplaceMap;
ExprId newExprId;
if (exprIdReplaceMap.containsKey(virtualSlotReference.getExprId())) {
newExprId = exprIdReplaceMap.get(virtualSlotReference.getExprId());
} else {
newExprId = StatementScopeIdGenerator.newExprId();
}
// according to VirtualReference generating logic in Repeat.java
// generateVirtualGroupingIdSlot and generateVirtualSlotByFunction
Optional<GroupingScalarFunction> newOriginExpression = virtualSlotReference.getOriginExpression()
.map(func -> (GroupingScalarFunction) func.accept(this, context));
Function<GroupingSetShapes, List<Long>> newFunction = newOriginExpression
.<Function<GroupingSetShapes, List<Long>>>map(f -> f::computeVirtualSlotValue)
.orElseGet(() -> GroupingSetShapes::computeVirtualGroupingIdValue);
VirtualSlotReference newOne = new VirtualSlotReference(newExprId,
virtualSlotReference.getName(), virtualSlotReference.getDataType(),
virtualSlotReference.nullable(), virtualSlotReference.getQualifier(),
newOriginExpression, newFunction);
exprIdReplaceMap.put(virtualSlotReference.getExprId(), newOne.getExprId());
return newOne;
}

@Override
public Expression visitExistsSubquery(Exists exists, DeepCopierContext context) {
LogicalPlan logicalPlan = LogicalPlanDeepCopier.INSTANCE.deepCopy(exists.getQueryPlan(), context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,13 @@ default List<Expression> getGroupByExpressions() {

static VirtualSlotReference generateVirtualGroupingIdSlot() {
return new VirtualSlotReference(COL_GROUPING_ID, BigIntType.INSTANCE, Optional.empty(),
shapes -> shapes.computeVirtualGroupingIdValue());
GroupingSetShapes::computeVirtualGroupingIdValue);
}

static VirtualSlotReference generateVirtualSlotByFunction(GroupingScalarFunction function) {
return new VirtualSlotReference(
generateVirtualSlotName(function), function.getDataType(), Optional.of(function),
shapes -> function.computeVirtualSlotValue(shapes));
function::computeVirtualSlotValue);
}

/**
Expand Down Expand Up @@ -175,7 +175,7 @@ default List<Set<Integer>> getGroupingSetsIndexesInOutput() {
if (index == null) {
throw new AnalysisException("Can not find grouping set expression in output: " + expression);
}
if (groupingSetsIndex.contains(index)) {
if (groupingSetIndex.contains(index)) {
throw new AnalysisException("expression duplicate in grouping set: " + expression);
}
groupingSetIndex.add(index);
Expand Down Expand Up @@ -228,14 +228,6 @@ public GroupingSetShapes(Set<Expression> flattenGroupingSetExpression, List<Grou
this.shapes = ImmutableList.copyOf(shapes);
}

public GroupingSetShape getGroupingSetShape(int index) {
return shapes.get(index);
}

public Expression getExpression(int index) {
return flattenGroupingSetExpression.get(index);
}

// compute a long value that backend need to fill to the GROUPING_ID slot
public List<Long> computeVirtualGroupingIdValue() {
return shapes.stream()
Expand Down
11 changes: 11 additions & 0 deletions regression-test/data/nereids_syntax_p0/cte.out
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,17 @@ ASIA 1
15
29

-- !cte_with_repeat --
\N \N 1
\N 1 1
\N 2 1
\N 6 1
1309892 \N 0
1309892 1 0
1309892 2 0
1310179 \N 0
1310179 6 0

-- !test --
1 2023-08-25 00:00:00 10 10
1 2023-08-25 01:00:00 20 30
Expand Down
7 changes: 6 additions & 1 deletion regression-test/suites/nereids_syntax_p0/cte.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ suite("cte") {
ORDER BY dd.s_suppkey;
"""

sql "set experimental_enable_pipeline_engine=true"
sql "set enable_pipeline_engine=true"

qt_cte14 """
SELECT abs(dd.s_suppkey)
Expand Down Expand Up @@ -308,6 +308,11 @@ suite("cte") {

sql "WITH cte_0 AS ( SELECT 1 AS a ) SELECT * from cte_0 t1 LIMIT 10 UNION SELECT * from cte_0 t1 LIMIT 10"

qt_cte_with_repeat """
with cte_0 as (select lo_orderkey, lo_linenumber, grouping_id(lo_orderkey) as id from lineorder group by cube(lo_orderkey, lo_linenumber))
select * from cte_0 order by lo_orderkey, lo_linenumber, id
"""

qt_test """
SELECT * FROM (
WITH temptable as (
Expand Down

0 comments on commit a89477e

Please sign in to comment.