diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala index 13bf50a5ef227..afef97120d6a9 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala @@ -838,6 +838,11 @@ abstract class CodeGenerator( val right = operands(1) generateEquals(nullCheck, left, right) + case IS_NOT_DISTINCT_FROM => + val left = operands.head + val right = operands(1) + generateIsNotDistinctFrom(left, right); + case NOT_EQUALS => val left = operands.head val right = operands(1) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarOperators.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarOperators.scala index 282b167e91429..afa534a70caf6 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarOperators.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarOperators.scala @@ -191,7 +191,37 @@ object ScalarOperators { ) ) } - } + } + + def generateIsNotDistinctFrom( + left: GeneratedExpression, + right: GeneratedExpression) + : GeneratedExpression = { + val resultTerm = newName("result") + val nullTerm = newName("isNull") + val resultTypeTerm = primitiveTypeTermForTypeInfo(BOOLEAN_TYPE_INFO) + val equalExpression = generateEquals( + false, + left.copy(code = GeneratedExpression.NO_CODE), + right.copy(code = GeneratedExpression.NO_CODE)) + + val resultCode = + s""" + |${left.code} + |${right.code} + |$resultTypeTerm $resultTerm; + |if (${left.nullTerm}) { + | $resultTerm = ${right.nullTerm}; + |} else if (${right.nullTerm}) { + | $resultTerm = ${left.nullTerm}; + |} else { + | ${equalExpression.code} + | $resultTerm = ${equalExpression.resultTerm}; + |} + |""".stripMargin + + GeneratedExpression(resultTerm, GeneratedExpression.NEVER_NULL, resultCode, BOOLEAN_TYPE_INFO) + } def generateEquals( nullCheck: Boolean, diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/ScalarOperatorsTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/ScalarOperatorsTest.scala index de4c804f075b0..d61627b2c41b0 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/ScalarOperatorsTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/ScalarOperatorsTest.scala @@ -406,6 +406,10 @@ class ScalarOperatorsTest extends ScalarOperatorsTestBase { "((((true) === true) || false).cast(STRING) + 'X ').trim", "trueX") testTableApi(12.isNull, "12.isNull", "false") + + testSqlApi("f12 IS NOT DISTINCT FROM NULL", "true") + testSqlApi("f9 IS NOT DISTINCT FROM NULL", "false") + testSqlApi("f9 IS NOT DISTINCT FROM 10", "true") } @Test diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/sql/AggregateITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/sql/AggregateITCase.scala index 09ccfc47edaba..fff252f4ea4f2 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/sql/AggregateITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/sql/AggregateITCase.scala @@ -559,4 +559,41 @@ class AggregateITCase( TestBaseUtils.compareResultAsText(result.asJava, expected) } + + @Test + def testMultipleDistinctWithDiffParams(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val sqlWithNull = "SELECT a, " + + " CASE WHEN b = 2 THEN null ELSE b END AS b, " + + " c FROM MyTable" + + val sqlQuery = + "SELECT b, " + + " COUNT(DISTINCT b), " + + " SUM(DISTINCT (a / 3)), " + + " COUNT(DISTINCT SUBSTRING(c FROM 1 FOR 2))," + + " COUNT(DISTINCT c) " + + "FROM (" + + sqlWithNull + + ") GROUP BY b " + + "ORDER BY b" + + val t = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv).as('a, 'b, 'c) + tEnv.registerTable("MyTable", t) + + val result = tEnv.sqlQuery(sqlQuery).toDataSet[Row].collect() + + val expected = Seq( + "1,1,0,1,1", + "3,1,3,3,3", + "4,1,5,1,4", + "5,1,12,1,5", + "6,1,18,1,6", + "null,0,1,1,2" + ).mkString("\n") + + TestBaseUtils.compareResultAsText(result.asJava, expected) + } }