Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ protected Expression orderKeyEqualsExpression() {
equalTo(lasValue, operand(i)));
}
Optional<Expression> ret = Arrays.stream(orderKeyEquals).reduce(ExpressionBuilder::and);
return ret.orElseGet(() -> literal(true));
return ret.orElseGet(() -> literal(false));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, I just wander why this value should be changed ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xuyangzhong The default value of true will cause the sort value to be unchanged, which I don't think conforms to the semantics of sorting.

}

protected Expression generateInitLiteral(LogicalType orderType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -532,18 +532,23 @@ class AggFunctionFactory(
}

private def createRankAggFunction(argTypes: Array[LogicalType]): UserDefinedFunction = {
val argTypes = orderKeyIndexes.map(inputRowType.getChildren.get(_))
new RankAggFunction(argTypes)
new RankAggFunction(getArgTypesOrEmpty())
}

private def createDenseRankAggFunction(argTypes: Array[LogicalType]): UserDefinedFunction = {
val argTypes = orderKeyIndexes.map(inputRowType.getChildren.get(_))
new DenseRankAggFunction(argTypes)
new DenseRankAggFunction(getArgTypesOrEmpty())
}

private def createPercentRankAggFunction(argTypes: Array[LogicalType]): UserDefinedFunction = {
val argTypes = orderKeyIndexes.map(inputRowType.getChildren.get(_))
new PercentRankAggFunction(argTypes)
new PercentRankAggFunction(getArgTypesOrEmpty())
}

private def getArgTypesOrEmpty(): Array[LogicalType] = {
if (orderKeyIndexes != null) {
orderKeyIndexes.map(inputRowType.getChildren.get(_))
} else {
Array[LogicalType]()
}
}

private def createNTILEAggFUnction(argTypes: Array[LogicalType]): UserDefinedFunction = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,50 @@ OverAggregate(partitionBy=[c], window#0=[COUNT(*) AS w0$o0 RANG BETWEEN UNBOUNDE
]]>
</Resource>
</TestCase>
<TestCase name="testDenseRankOnOrder">
<Resource name="sql">
<![CDATA[SELECT a, DENSE_RANK() OVER (PARTITION BY a ORDER BY proctime) FROM MyTableWithProctime]]>
</Resource>
<Resource name="ast">
<![CDATA[
LogicalProject(a=[$0], EXPR$1=[DENSE_RANK() OVER (PARTITION BY $0 ORDER BY $3 NULLS FIRST)])
+- LogicalTableScan(table=[[default_catalog, default_database, MyTableWithProctime, source: [TestTableSource(a, b, c, proctime)]]])
]]>
</Resource>
<Resource name="optimized exec plan">
<![CDATA[
Calc(select=[a, w0$o0 AS $1])
+- OverAggregate(partitionBy=[a], orderBy=[proctime ASC], window#0=[DENSE_RANK(*) AS w0$o0 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, proctime, w0$o0])
+- Exchange(distribution=[forward])
+- Sort(orderBy=[a ASC, proctime ASC])
+- Exchange(distribution=[hash[a]])
+- Calc(select=[a, proctime])
+- LegacyTableSourceScan(table=[[default_catalog, default_database, MyTableWithProctime, source: [TestTableSource(a, b, c, proctime)]]], fields=[a, b, c, proctime])
]]>
</Resource>
</TestCase>
<TestCase name="testRankOnOver">
<Resource name="sql">
<![CDATA[SELECT a, RANK() OVER (PARTITION BY a ORDER BY proctime) FROM MyTableWithProctime]]>
</Resource>
<Resource name="ast">
<![CDATA[
LogicalProject(a=[$0], EXPR$1=[RANK() OVER (PARTITION BY $0 ORDER BY $3 NULLS FIRST)])
+- LogicalTableScan(table=[[default_catalog, default_database, MyTableWithProctime, source: [TestTableSource(a, b, c, proctime)]]])
]]>
</Resource>
<Resource name="optimized exec plan">
<![CDATA[
Calc(select=[a, w0$o0 AS $1])
+- OverAggregate(partitionBy=[a], orderBy=[proctime ASC], window#0=[RANK(*) AS w0$o0 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, proctime, w0$o0])
+- Exchange(distribution=[forward])
+- Sort(orderBy=[a ASC, proctime ASC])
+- Exchange(distribution=[hash[a]])
+- Calc(select=[a, proctime])
+- LegacyTableSourceScan(table=[[default_catalog, default_database, MyTableWithProctime, source: [TestTableSource(a, b, c, proctime)]]], fields=[a, b, c, proctime])
]]>
</Resource>
</TestCase>
<TestCase name="testOverWindowWithoutPartitionBy">
<Resource name="sql">
<![CDATA[SELECT c, SUM(a) OVER (ORDER BY b) FROM MyTable]]>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class OverAggregateTest extends TableTestBase {

private val util = batchTestUtil()
util.addTableSource[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
util.addTableSource[(Int, Long, String, Long)]("MyTableWithProctime", 'a, 'b, 'c, 'proctime)

@Test
def testOverWindowWithoutPartitionByOrderBy(): Unit = {
Expand All @@ -47,6 +48,18 @@ class OverAggregateTest extends TableTestBase {
util.verifyExecPlan("SELECT c, SUM(a) OVER (ORDER BY b) FROM MyTable")
}

@Test
def testDenseRankOnOrder(): Unit = {
Copy link
Contributor

@xuyangzhong xuyangzhong May 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The plan tests in batch mode appear not to be able to capture the modifications introduced by this PR, and I apologize for mentioning the unit test for the plan earlier.

I've attempted this case, and it seems only possible to test it within an integration test in stream mode. I propose we can add another testRankOnOver test specifically for testing the RANK() function.

Additionally, there seems to be a potential risk of encountering a NPE in AggFunctionFactory#createPercentRankAggFunction (even though PercentRank is not currently supported in streaming mode, the code suggests such a risk). Maybe we'd better address this as well?

// ----------
// Broadening the discussion a bit, perhaps we should consider fundamentally preventing the orderKeyIndexes, which isn't marked as Nullable, from being null, although this would require substantial changes 🤔. I agree that we can just avoid huge changes as demonstrated by this PR.

Copy link
Contributor

@snuyanzin snuyanzin May 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've attempted this case, and it seems only possible to test it within an integration test in stream mode. I propose we can add another testRankOnOver test specifically for testing the RANK() function.

this I didn't get, there is alreadytestRankOnOver, would you like to add another one?

Maybe we'd better address this as well

yep, makes sense, I extracted the logic and reused it for §createPercentRankAggFunction` as well

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I mean is to add a test testRankOnOver in OverAggregateITCase. scala to test the function Rank(). Because the UT tests about plan here cannot test the modified part (when I reversed the changes in AggFunctionFactory.scala, these tests still passed).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Besides plans there are already a test in OverAggregateITCase. scala both for RANK and DENSE_RANK, that's way I have this question
https://github.com/apache/flink/pull/19797/files#diff-0fa625835219d2a3fcacb2fa8de5274ed6795f10b06035c6c9b68d11b700f9a8R198-R203

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I missed it. Pardon me...

util.verifyExecPlan(
"SELECT a, DENSE_RANK() OVER (PARTITION BY a ORDER BY proctime) FROM MyTableWithProctime")
}

@Test
def testRankOnOver(): Unit = {
util.verifyExecPlan(
"SELECT a, RANK() OVER (PARTITION BY a ORDER BY proctime) FROM MyTableWithProctime")
}

@Test
def testDiffPartitionKeysWithSameOrderKeys(): Unit = {
val sqlQuery =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,66 @@ class OverAggregateITCase(mode: StateBackendMode) extends StreamingWithStateTest
assertThat(sink.getAppendResults.sorted).isEqualTo(expected.sorted)
}

@TestTemplate
def testDenseRankOnOver(): Unit = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you can add some test cases for testing plan, not only just for ITCases.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @xuyangzhong . Thanks for your review and sorry for delay response. I didn't find the UT of AggFunctionFactory, so I borrowed the test for ROW_NUMBER. And I haven't waited for the community's reply, I can add corresponding unit tests later if necessary.

val t = failingDataSource(TestData.tupleData5)
.toTable(tEnv, 'a, 'b, 'c, 'd, 'e, 'proctime.proctime)
tEnv.createTemporaryView("MyTable", t)
val sqlQuery = "SELECT a, DENSE_RANK() OVER (PARTITION BY a ORDER BY proctime) FROM MyTable"

val sink = new TestingAppendSink
tEnv.sqlQuery(sqlQuery).toDataStream.addSink(sink)
env.execute()

val expected = List(
"1,1",
"2,1",
"2,2",
"3,1",
"3,2",
"3,3",
"4,1",
"4,2",
"4,3",
"4,4",
"5,1",
"5,2",
"5,3",
"5,4",
"5,5")
assertThat(expected.sorted).isEqualTo(sink.getAppendResults.sorted)
}

@TestTemplate
def testRankOnOver(): Unit = {
val t = failingDataSource(TestData.tupleData5)
.toTable(tEnv, 'a, 'b, 'c, 'd, 'e, 'proctime.proctime)
tEnv.createTemporaryView("MyTable", t)
val sqlQuery = "SELECT a, RANK() OVER (PARTITION BY a ORDER BY proctime) FROM MyTable"

val sink = new TestingAppendSink
tEnv.sqlQuery(sqlQuery).toDataStream.addSink(sink)
env.execute()

val expected = List(
"1,1",
"2,1",
"2,2",
"3,1",
"3,2",
"3,3",
"4,1",
"4,2",
"4,3",
"4,4",
"5,1",
"5,2",
"5,3",
"5,4",
"5,5")
assertThat(expected.sorted).isEqualTo(sink.getAppendResults.sorted)
}

@TestTemplate
def testProcTimeBoundedPartitionedRowsOver(): Unit = {
val t = failingDataSource(TestData.tupleData5)
Expand Down