diff --git a/docs/dev/table_api.md b/docs/dev/table_api.md index c29ed725be0bec..f3e81b13b742ad 100644 --- a/docs/dev/table_api.md +++ b/docs/dev/table_api.md @@ -2800,39 +2800,41 @@ value NOT IN (value [, value]* )

Whether value is not equal to every value in a list.

- + + --> diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetJoin.scala index bbb6325ef41ccc..6d7a30ef8f43da 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetJoin.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetJoin.scala @@ -215,7 +215,7 @@ class DataSetJoin( } private def joinTypeToString = joinType match { - case JoinRelType.INNER => "Join" + case JoinRelType.INNER => "InnerJoin" case JoinRelType.LEFT=> "LeftOuterJoin" case JoinRelType.RIGHT => "RightOuterJoin" case JoinRelType.FULL => "FullOuterJoin" diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/validate/FunctionCatalog.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/validate/FunctionCatalog.scala index 68e2f978c65ad1..679733c9389ea7 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/validate/FunctionCatalog.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/validate/FunctionCatalog.scala @@ -281,7 +281,8 @@ class BasicOperatorTable extends ReflectiveSqlOperatorTable { SqlStdOperatorTable.CAST, SqlStdOperatorTable.EXTRACT, SqlStdOperatorTable.QUARTER, - SqlStdOperatorTable.SCALAR_QUERY + SqlStdOperatorTable.SCALAR_QUERY, + SqlStdOperatorTable.EXISTS ) builtInSqlOperators.foreach(register) diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/sql/SetOperatorsTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/sql/SetOperatorsTest.scala new file mode 100644 index 00000000000000..5bc6e4a4c1c380 --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/sql/SetOperatorsTest.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.scala.batch.sql + +import org.apache.flink.api.scala._ +import org.apache.flink.api.scala.table._ +import org.apache.flink.api.table.utils.TableTestBase +import org.apache.flink.api.table.utils.TableTestUtil._ +import org.junit.Test + +class SetOperatorsTest extends TableTestBase { + + @Test + def testExists(): Unit = { + val util = batchTestUtil() + util.addTable[(Long, Int, String)]("A", 'a_long, 'a_int, 'a_string) + util.addTable[(Long, Int, String)]("B", 'b_long, 'b_int, 'b_string) + + val expected = unaryNode( + "DataSetCalc", + binaryNode( + "DataSetJoin", + batchTableNode(0), + unaryNode( + "DataSetAggregate", + unaryNode( + "DataSetCalc", + binaryNode( + "DataSetJoin", + batchTableNode(1), + unaryNode( + "DataSetAggregate", + batchTableNode(0), + term("groupBy", "a_long"), + term("select", "a_long") + ), + term("where", "=(a_long, b_long)"), + term("join", "b_long", "b_int", "b_string", "a_long"), + term("joinType", "InnerJoin") + ), + term("select", "a_long", "true AS $f0") + ), + term("groupBy", "a_long"), + term("select", "a_long", "MIN($f0) AS $f1") + ), + term("where", "=(a_long, a_long0)"), + term("join", "a_long", "a_int", "a_string", "a_long0", "$f1"), + term("joinType", "InnerJoin") + ), + term("select", "a_int", "a_string") + ) + + util.verifySql( + "SELECT a_int, a_string FROM A WHERE EXISTS(SELECT * FROM B WHERE a_long = b_long)", + expected + ) + } + +} diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala index ce693ffde597f8..2ea15a01dd8650 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala @@ -56,6 +56,10 @@ abstract class TableTestUtil { def addTable[T: TypeInformation](name: String, fields: Expression*): Table def verifySql(query: String, expected: String): Unit def verifyTable(resultTable: Table, expected: String): Unit + + // the print methods are for debugging purposes only + def printTable(resultTable: Table): Unit + def printSql(query: String): Unit } object TableTestUtil { @@ -87,6 +91,7 @@ object TableTestUtil { def streamTableNode(idx: Int): String = { s"DataStreamScan(table=[[_DataStreamTable_$idx]])" } + } case class BatchTableTestUtil() extends TableTestUtil { @@ -121,6 +126,16 @@ case class BatchTableTestUtil() extends TableTestUtil { expected.split("\n").map(_.trim).mkString("\n"), actual.split("\n").map(_.trim).mkString("\n")) } + + def printTable(resultTable: Table): Unit = { + val relNode = resultTable.getRelNode + val optimized = tEnv.optimize(relNode) + println(RelOptUtil.toString(optimized)) + } + + def printSql(query: String): Unit = { + printTable(tEnv.sql(query)) + } } case class StreamTableTestUtil() extends TableTestUtil { @@ -156,4 +171,15 @@ case class StreamTableTestUtil() extends TableTestUtil { expected.split("\n").map(_.trim).mkString("\n"), actual.split("\n").map(_.trim).mkString("\n")) } + + // the print methods are for debugging purposes only + def printTable(resultTable: Table): Unit = { + val relNode = resultTable.getRelNode + val optimized = tEnv.optimize(relNode) + println(RelOptUtil.toString(optimized)) + } + + def printSql(query: String): Unit = { + printTable(tEnv.sql(query)) + } }