Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,9 @@ class DecomposeGroupingSetsRule extends RelOptRule(
val newAggCalls = aggCallsWithIndexes.collect {
case (aggCall, idx) if !groupIdExprs.contains(idx) =>
val newArgList = aggCall.getArgList.map(a => duplicateFieldMap.getOrElse(a, a)).toList
val newFilterArg = duplicateFieldMap.getOrDefault(aggCall.filterArg, aggCall.filterArg)
aggCall.adaptTo(
relBuilder.peek(), newArgList, aggCall.filterArg, agg.getGroupCount, newGroupCount)
relBuilder.peek(), newArgList, newFilterArg, agg.getGroupCount, newGroupCount)
}

// create simple aggregate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,17 @@ object ExpandUtil {
val commonGroupSet = groupSets.asList().reduce((g1, g2) => g1.intersect(g2)).asList()
val duplicateFieldIndexes = aggCalls.zipWithIndex.flatMap {
case (aggCall, idx) =>
// filterArg should also be considered here.
val allArgList = new util.ArrayList[Integer](aggCall.getArgList)
if (aggCall.filterArg > -1) {
allArgList.add(aggCall.filterArg)
}
if (groupIdExprs.contains(idx)) {
List.empty[Integer]
} else if (commonGroupSet.containsAll(aggCall.getArgList)) {
} else if (commonGroupSet.containsAll(allArgList)) {
List.empty[Integer]
} else {
aggCall.getArgList.diff(commonGroupSet)
allArgList.diff(commonGroupSet)
}
}.intersect(groupSet.asList()).sorted.toArray[Integer]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -667,4 +667,32 @@ Calc(select=[a, CAST(EXPR$1) AS EXPR$1, EXPR$2, EXPR$3])
]]>
</Resource>
</TestCase>
<TestCase name="testDistinctAggWithDuplicateFilterField">
<Resource name="sql">
<![CDATA[SELECT a, COUNT(c) FILTER (WHERE b > 1),
COUNT(DISTINCT d) FILTER (WHERE b > 1) FROM MyTable2 GROUP BY a]]>
</Resource>
<Resource name="planBefore">
<![CDATA[
LogicalAggregate(group=[{0}], EXPR$1=[COUNT($1) FILTER $2], EXPR$2=[COUNT(DISTINCT $3) FILTER $2])
+- LogicalProject(a=[$0], c=[$2], $f2=[IS TRUE(>($1, 1))], d=[$3])
+- LogicalTableScan(table=[[default_catalog, default_database, MyTable2, source: [TestTableSource(a, b, c, d, e)]]])
]]>
</Resource>
<Resource name="planAfter">
<![CDATA[
Calc(select=[a, CAST(EXPR$1) AS EXPR$1, EXPR$2])
+- HashAggregate(isMerge=[true], groupBy=[a], select=[a, Final_MIN(min$0) AS EXPR$1, Final_COUNT(count$1) AS EXPR$2])
+- Exchange(distribution=[hash[a]])
+- LocalHashAggregate(groupBy=[a], select=[a, Partial_MIN(EXPR$1) FILTER $g_3 AS min$0, Partial_COUNT(d) FILTER $g_0 AS count$1])
+- Calc(select=[a, d, EXPR$1, AND(=(CASE(=($e, 0:BIGINT), 0:BIGINT, 3:BIGINT), 0), IS TRUE(CAST($f2))) AS $g_0, =(CASE(=($e, 0:BIGINT), 0:BIGINT, 3:BIGINT), 3) AS $g_3])
+- HashAggregate(isMerge=[true], groupBy=[a, $f2, d, $e], select=[a, $f2, d, $e, Final_COUNT(count$0) AS EXPR$1])
+- Exchange(distribution=[hash[a, $f2, d, $e]])
+- LocalHashAggregate(groupBy=[a, $f2, d, $e], select=[a, $f2, d, $e, Partial_COUNT(c) FILTER $f2_0 AS count$0])
+- Expand(projects=[a, c, $f2, d, $e, $f2_0], projects=[{a, c, $f2, d, 0 AS $e, $f2 AS $f2_0}, {a, c, null AS $f2, null AS d, 3 AS $e, $f2 AS $f2_0}])
+- Calc(select=[a, c, IS TRUE(>(b, 1)) AS $f2, d])
+- LegacyTableSourceScan(table=[[default_catalog, default_database, MyTable2, source: [TestTableSource(a, b, c, d, e)]]], fields=[a, b, c, d, e])
]]>
</Resource>
</TestCase>
</Root>
Original file line number Diff line number Diff line change
Expand Up @@ -610,4 +610,28 @@ FlinkLogicalCalc(select=[a, CAST(EXPR$1) AS EXPR$1, EXPR$2, EXPR$3])
]]>
</Resource>
</TestCase>
<TestCase name="testDistinctAggWithDuplicateFilterField">
<Resource name="sql">
<![CDATA[SELECT a, COUNT(c) FILTER (WHERE b > 1),
COUNT(DISTINCT d) FILTER (WHERE b > 1) FROM MyTable2 GROUP BY a]]>
</Resource>
<Resource name="planBefore">
<![CDATA[
LogicalAggregate(group=[{0}], EXPR$1=[COUNT($1) FILTER $2], EXPR$2=[COUNT(DISTINCT $3) FILTER $2])
+- LogicalProject(a=[$0], c=[$2], $f2=[IS TRUE(>($1, 1))], d=[$3])
+- LogicalTableScan(table=[[default_catalog, default_database, MyTable2, source: [TestTableSource(a, b, c, d, e)]]])
]]>
</Resource>
<Resource name="planAfter">
<![CDATA[
FlinkLogicalCalc(select=[a, CAST(EXPR$1) AS EXPR$1, EXPR$2])
+- FlinkLogicalAggregate(group=[{0}], EXPR$1=[MIN($2) FILTER $4], EXPR$2=[COUNT($1) FILTER $3])
+- FlinkLogicalCalc(select=[a, d, EXPR$1, AND(=(CASE(=($e, 0:BIGINT), 0:BIGINT, 3:BIGINT), 0), IS TRUE(CAST($f2))) AS $g_0, =(CASE(=($e, 0:BIGINT), 0:BIGINT, 3:BIGINT), 3) AS $g_3])
+- FlinkLogicalAggregate(group=[{0, 2, 3, 4}], EXPR$1=[COUNT($1) FILTER $5])
+- FlinkLogicalExpand(projects=[a, c, $f2, d, $e, $f2_0])
+- FlinkLogicalCalc(select=[a, c, IS TRUE(>(b, 1)) AS $f2, d])
+- FlinkLogicalLegacyTableSourceScan(table=[[default_catalog, default_database, MyTable2, source: [TestTableSource(a, b, c, d, e)]]], fields=[a, b, c, d, e])
]]>
</Resource>
</TestCase>
</Root>
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,13 @@ abstract class DistinctAggregateTestBase extends TableTestBase {
util.verifyPlan(sqlQuery)
}

@Test
def testDistinctAggWithDuplicateFilterField(): Unit = {
val sqlQuery = "SELECT a, COUNT(c) FILTER (WHERE b > 1),\n" +
"COUNT(DISTINCT d) FILTER (WHERE b > 1) FROM MyTable2 GROUP BY a"
util.verifyPlan(sqlQuery)
}

@Test(expected = classOf[RuntimeException])
def testTooManyDistinctAggOnDifferentColumn(): Unit = {
// max group count must be less than 64
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,9 @@ abstract class AggregateITCaseBase(testName: String) extends BatchTestBase {
checkResult(
sql,
Seq(
row(1,0,1,4),
row(2,0,0,7),
row(3,0,0,3)
row(1,4,1,4),
row(2,7,0,7),
row(3,3,0,3)
)
)
}
Expand Down