Skip to content

Commit

Permalink
[FLINK-13433][table-planner-blink] Do not fetch data from LookupableT…
Browse files Browse the repository at this point in the history
…ableSource if the JoinKey in left side of LookupJoin contains null value

This closes #9285
  • Loading branch information
beyond1920 authored and wuchong committed Aug 7, 2019
1 parent 16f519a commit 8b2680a
Show file tree
Hide file tree
Showing 5 changed files with 240 additions and 93 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ object LookupJoinCodeGenerator {
: GeneratedFunction[FlatMapFunction[BaseRow, BaseRow]] = {

val ctx = CodeGeneratorContext(config)
val (prepareCode, parameters) = prepareParameters(
val (prepareCode, parameters, nullInParameters) = prepareParameters(
ctx,
typeFactory,
inputType,
Expand All @@ -87,11 +87,17 @@ object LookupJoinCodeGenerator {
s"$lookupFunctionTerm.setCollector($DEFAULT_COLLECTOR_TERM);"
}

// TODO: filter all records when there is any nulls on the join key, because
// "IS NOT DISTINCT FROM" is not supported yet.
val body =
s"""
|$prepareCode
|$setCollectorCode
|$lookupFunctionTerm.eval($parameters);
|if ($nullInParameters) {
| return;
|} else {
| $lookupFunctionTerm.eval($parameters);
| }
""".stripMargin

FunctionCodeGenerator.generateFunction(
Expand All @@ -118,7 +124,7 @@ object LookupJoinCodeGenerator {
: GeneratedFunction[AsyncFunction[BaseRow, AnyRef]] = {

val ctx = CodeGeneratorContext(config)
val (prepareCode, parameters) = prepareParameters(
val (prepareCode, parameters, nullInParameters) = prepareParameters(
ctx,
typeFactory,
inputType,
Expand All @@ -130,11 +136,18 @@ object LookupJoinCodeGenerator {
val lookupFunctionTerm = ctx.addReusableFunction(asyncLookupFunction)
val DELEGATE = className[DelegatingResultFuture[_]]

// TODO: filter all records when there is any nulls on the join key, because
// "IS NOT DISTINCT FROM" is not supported yet.
val body =
s"""
|$prepareCode
|$DELEGATE delegates = new $DELEGATE($DEFAULT_COLLECTOR_TERM);
|$lookupFunctionTerm.eval(delegates.getCompletableFuture(), $parameters);
|if ($nullInParameters) {
| $DEFAULT_COLLECTOR_TERM.complete(java.util.Collections.emptyList());
| return;
|} else {
| $DELEGATE delegates = new $DELEGATE($DEFAULT_COLLECTOR_TERM);
| $lookupFunctionTerm.eval(delegates.getCompletableFuture(), $parameters);
|}
""".stripMargin

FunctionCodeGenerator.generateFunction(
Expand All @@ -156,7 +169,7 @@ object LookupJoinCodeGenerator {
lookupKeyInOrder: Array[Int],
allLookupFields: Map[Int, LookupKey],
isExternalArgs: Boolean,
fieldCopy: Boolean): (String, String) = {
fieldCopy: Boolean): (String, String, String) = {

val inputFieldExprs = for (i <- lookupKeyInOrder) yield {
allLookupFields.get(i) match {
Expand Down Expand Up @@ -195,9 +208,12 @@ object LookupJoinCodeGenerator {
| $newTerm = $assign;
|}
""".stripMargin
(code, newTerm)
(code, newTerm, e.nullTerm)
}
(codeAndArg.map(_._1).mkString("\n"), codeAndArg.map(_._2).mkString(", "))
(
codeAndArg.map(_._1).mkString("\n"),
codeAndArg.map(_._2).mkString(", "),
codeAndArg.map(_._3).mkString("|| "))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ abstract class CommonLookupJoin(
joinType: JoinRelType): Unit = {

// check join on all fields of PRIMARY KEY or (UNIQUE) INDEX
if (allLookupKeys.isEmpty || allLookupKeys.isEmpty) {
if (allLookupKeys.isEmpty) {
throw new TableException(
"Temporal table join requires an equality condition on fields of " +
s"table [${tableSource.explainSource()}].")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,17 @@ import org.apache.flink.api.java.typeutils.RowTypeInfo
import org.apache.flink.table.api.Types
import org.apache.flink.table.planner.runtime.utils.{BatchTableEnvUtil, BatchTestBase, InMemoryLookupableTableSource}

import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import org.junit.{Before, Test}

class LookupJoinITCase extends BatchTestBase {
import java.lang.Boolean
import java.util

import scala.collection.JavaConversions._

@RunWith(classOf[Parameterized])
class LookupJoinITCase(isAsyncMode: Boolean) extends BatchTestBase {

val data = List(
BatchTestBase.row(1L, 12L, "Julian"),
Expand All @@ -33,6 +41,12 @@ class LookupJoinITCase extends BatchTestBase {
BatchTestBase.row(8L, 11L, "Hello world"),
BatchTestBase.row(9L, 12L, "Hello world!"))

val dataWithNull = List(
BatchTestBase.row(null, 15L, "Hello"),
BatchTestBase.row(3L, 15L, "Fabian"),
BatchTestBase.row(null, 11L, "Hello world"),
BatchTestBase.row(9L, 12L, "Hello world!"))

val typeInfo = new RowTypeInfo(LONG_TYPE_INFO, LONG_TYPE_INFO, STRING_TYPE_INFO)

val userData = List(
Expand All @@ -55,19 +69,72 @@ class LookupJoinITCase extends BatchTestBase {
.enableAsync()
.build()

val userDataWithNull = List(
(11, 1L, "Julian"),
(22, null, "Hello"),
(33, 3L, "Fabian"),
(44, null, "Hello world"))

val userWithNullDataTableSource = InMemoryLookupableTableSource.builder()
.data(userDataWithNull)
.field("age", Types.INT)
.field("id", Types.LONG)
.field("name", Types.STRING)
.build()

val userAsyncWithNullDataTableSource = InMemoryLookupableTableSource.builder()
.data(userDataWithNull)
.field("age", Types.INT)
.field("id", Types.LONG)
.field("name", Types.STRING)
.enableAsync()
.build()

var userTable: String = _
var userTableWithNull: String = _

@Before
override def before() {
super.before()
BatchTableEnvUtil.registerCollection(tEnv, "T0", data, typeInfo, "id, len, content")
val myTable = tEnv.sqlQuery("SELECT *, PROCTIME() as proctime FROM T0")
tEnv.registerTable("T", myTable)

BatchTableEnvUtil.registerCollection(
tEnv, "T1", dataWithNull, typeInfo, "id, len, content")
val myTable1 = tEnv.sqlQuery("SELECT *, PROCTIME() as proctime FROM T1")
tEnv.registerTable("nullableT", myTable1)

tEnv.registerTableSource("userTable", userTableSource)
tEnv.registerTableSource("userAsyncTable", userAsyncTableSource)
userTable = if (isAsyncMode) "userAsyncTable" else "userTable"

tEnv.registerTableSource("userWithNullDataTable", userWithNullDataTableSource)
tEnv.registerTableSource("userWithNullDataAsyncTable", userAsyncWithNullDataTableSource)
userTableWithNull = if (isAsyncMode) "userWithNullDataAsyncTable" else "userWithNullDataTable"

// TODO: enable object reuse until [FLINK-12351] is fixed.
env.getConfig.disableObjectReuse()
}

@Test
def testLeftJoinTemporalTableWithLocalPredicate(): Unit = {
val sql = s"SELECT T.id, T.len, T.content, D.name, D.age FROM T LEFT JOIN $userTable " +
"for system_time as of T.proctime AS D ON T.id = D.id " +
"AND T.len > 1 AND D.age > 20 AND D.name = 'Fabian' " +
"WHERE T.id > 1"

val expected = Seq(
BatchTestBase.row(2, 15, "Hello", null, null),
BatchTestBase.row(3, 15, "Fabian", "Fabian", 33),
BatchTestBase.row(8, 11, "Hello world", null, null),
BatchTestBase.row(9, 12, "Hello world!", null, null))
checkResult(sql, expected, false)
}

@Test
def testJoinTemporalTable(): Unit = {
val sql = "SELECT T.id, T.len, T.content, D.name FROM T JOIN userTable " +
val sql = s"SELECT T.id, T.len, T.content, D.name FROM T JOIN $userTable " +
"for system_time as of T.proctime AS D ON T.id = D.id"

val expected = Seq(
Expand All @@ -79,7 +146,7 @@ class LookupJoinITCase extends BatchTestBase {

@Test
def testJoinTemporalTableWithPushDown(): Unit = {
val sql = "SELECT T.id, T.len, T.content, D.name FROM T JOIN userTable " +
val sql = s"SELECT T.id, T.len, T.content, D.name FROM T JOIN $userTable " +
"for system_time as of T.proctime AS D ON T.id = D.id AND D.age > 20"

val expected = Seq(
Expand All @@ -90,7 +157,7 @@ class LookupJoinITCase extends BatchTestBase {

@Test
def testJoinTemporalTableWithNonEqualFilter(): Unit = {
val sql = "SELECT T.id, T.len, T.content, D.name, D.age FROM T JOIN userTable " +
val sql = s"SELECT T.id, T.len, T.content, D.name, D.age FROM T JOIN $userTable " +
"for system_time as of T.proctime AS D ON T.id = D.id WHERE T.len <= D.age"

val expected = Seq(
Expand All @@ -101,7 +168,7 @@ class LookupJoinITCase extends BatchTestBase {

@Test
def testJoinTemporalTableOnMultiFields(): Unit = {
val sql = "SELECT T.id, T.len, D.name FROM T JOIN userTable " +
val sql = s"SELECT T.id, T.len, D.name FROM T JOIN $userTable " +
"for system_time as of T.proctime AS D ON T.id = D.id AND T.content = D.name"

val expected = Seq(
Expand All @@ -112,7 +179,7 @@ class LookupJoinITCase extends BatchTestBase {

@Test
def testJoinTemporalTableOnMultiFieldsWithUdf(): Unit = {
val sql = "SELECT T.id, T.len, D.name FROM T JOIN userTable " +
val sql = s"SELECT T.id, T.len, D.name FROM T JOIN $userTable " +
"for system_time as of T.proctime AS D ON mod(T.id, 4) = D.id AND T.content = D.name"

val expected = Seq(
Expand All @@ -123,7 +190,7 @@ class LookupJoinITCase extends BatchTestBase {

@Test
def testJoinTemporalTableOnMultiKeyFields(): Unit = {
val sql = "SELECT T.id, T.len, D.name FROM T JOIN userTable " +
val sql = s"SELECT T.id, T.len, D.name FROM T JOIN $userTable " +
"for system_time as of T.proctime AS D ON T.content = D.name AND T.id = D.id"

val expected = Seq(
Expand All @@ -134,7 +201,7 @@ class LookupJoinITCase extends BatchTestBase {

@Test
def testLeftJoinTemporalTable(): Unit = {
val sql = "SELECT T.id, T.len, D.name, D.age FROM T LEFT JOIN userTable " +
val sql = s"SELECT T.id, T.len, D.name, D.age FROM T LEFT JOIN $userTable " +
"for system_time as of T.proctime AS D ON T.id = D.id"

val expected = Seq(
Expand All @@ -147,88 +214,50 @@ class LookupJoinITCase extends BatchTestBase {
}

@Test
def testAsyncJoinTemporalTable(): Unit = {
// TODO: enable object reuse until [FLINK-12351] is fixed.
env.getConfig.disableObjectReuse()
val sql = "SELECT T.id, T.len, T.content, D.name FROM T JOIN userAsyncTable " +
"for system_time as of T.proctime AS D ON T.id = D.id"

val expected = Seq(
BatchTestBase.row(1, 12, "Julian", "Julian"),
BatchTestBase.row(2, 15, "Hello", "Jark"),
BatchTestBase.row(3, 15, "Fabian", "Fabian"))
checkResult(sql, expected, false)
}

@Test
def testAsyncJoinTemporalTableWithPushDown(): Unit = {
// TODO: enable object reuse until [FLINK-12351] is fixed.
env.getConfig.disableObjectReuse()
val sql = "SELECT T.id, T.len, T.content, D.name FROM T JOIN userAsyncTable " +
"for system_time as of T.proctime AS D ON T.id = D.id AND D.age > 20"
def testJoinTemporalTableOnMultiKeyFieldsWithNullData(): Unit = {
val sql = s"SELECT T.id, T.len, D.name FROM nullableT T JOIN $userTableWithNull " +
"for system_time as of T.proctime AS D ON T.content = D.name AND T.id = D.id"

val expected = Seq(
BatchTestBase.row(2, 15, "Hello", "Jark"),
BatchTestBase.row(3, 15, "Fabian", "Fabian"))
BatchTestBase.row(3,15,"Fabian"))
checkResult(sql, expected, false)
}

@Test
def testAsyncJoinTemporalTableWithNonEqualFilter(): Unit = {
// TODO: enable object reuse until [FLINK-12351] is fixed.
env.getConfig.disableObjectReuse()
val sql = "SELECT T.id, T.len, T.content, D.name, D.age FROM T JOIN userAsyncTable " +
"for system_time as of T.proctime AS D ON T.id = D.id WHERE T.len <= D.age"

def testLeftJoinTemporalTableOnMultiKeyFieldsWithNullData(): Unit = {
val sql = s"SELECT D.id, T.len, D.name FROM nullableT T LEFT JOIN $userTableWithNull " +
"for system_time as of T.proctime AS D ON T.content = D.name AND T.id = D.id"
val expected = Seq(
BatchTestBase.row(2, 15, "Hello", "Jark", 22),
BatchTestBase.row(3, 15, "Fabian", "Fabian", 33))
BatchTestBase.row(null,15,null),
BatchTestBase.row(3,15,"Fabian"),
BatchTestBase.row(null,11,null),
BatchTestBase.row(null,12,null))
checkResult(sql, expected, false)
}

@Test
def testAsyncLeftJoinTemporalTableWithLocalPredicate(): Unit = {
// TODO: enable object reuse until [FLINK-12351] is fixed.
env.getConfig.disableObjectReuse()
val sql = "SELECT T.id, T.len, T.content, D.name, D.age FROM T LEFT JOIN userAsyncTable " +
"for system_time as of T.proctime AS D ON T.id = D.id " +
"AND T.len > 1 AND D.age > 20 AND D.name = 'Fabian' " +
"WHERE T.id > 1"

val expected = Seq(
BatchTestBase.row(2, 15, "Hello", null, null),
BatchTestBase.row(3, 15, "Fabian", "Fabian", 33),
BatchTestBase.row(8, 11, "Hello world", null, null),
BatchTestBase.row(9, 12, "Hello world!", null, null))
def testJoinTemporalTableOnNullConstantKey(): Unit = {
val sql = s"SELECT T.id, T.len, T.content FROM T JOIN $userTable " +
"for system_time as of T.proctime AS D ON D.id = null"
val expected = Seq()
checkResult(sql, expected, false)
}

@Test
def testAsyncJoinTemporalTableOnMultiFields(): Unit = {
// TODO: enable object reuse until [FLINK-12351] is fixed.
env.getConfig.disableObjectReuse()
val sql = "SELECT T.id, T.len, D.name FROM T JOIN userAsyncTable " +
"for system_time as of T.proctime AS D ON T.id = D.id AND T.content = D.name"

val expected = Seq(
BatchTestBase.row(1, 12, "Julian"),
BatchTestBase.row(3, 15, "Fabian"))
def testJoinTemporalTableOnMultiKeyFieldsWithNullConstantKey(): Unit = {
val sql = s"SELECT T.id, T.len, D.name FROM T JOIN $userTable " +
"for system_time as of T.proctime AS D ON T.content = D.name AND null = D.id"
val expected = Seq()
checkResult(sql, expected, false)
}
}

@Test
def testAsyncLeftJoinTemporalTable(): Unit = {
// TODO: enable object reuse until [FLINK-12351] is fixed.
env.getConfig.disableObjectReuse()
val sql = "SELECT T.id, T.len, D.name, D.age FROM T LEFT JOIN userAsyncTable " +
"for system_time as of T.proctime AS D ON T.id = D.id"
object LookupJoinITCase {

val expected = Seq(
BatchTestBase.row(1, 12, "Julian", 11),
BatchTestBase.row(2, 15, "Jark", 22),
BatchTestBase.row(3, 15, "Fabian", 33),
BatchTestBase.row(8, 11, null, null),
BatchTestBase.row(9, 12, null, null))
checkResult(sql, expected, false)
@Parameterized.Parameters(name = "isAsyncMode = {0}")
def parameters(): util.Collection[Array[java.lang.Object]] = {
Seq[Array[AnyRef]](
Array(Boolean.TRUE), Array(Boolean.FALSE)
)
}
}

0 comments on commit 8b2680a

Please sign in to comment.