From c79588b134a1270956a6d32b7a0a13ff4e3f483d Mon Sep 17 00:00:00 2001 From: Xingcan Cui Date: Wed, 30 Aug 2017 13:57:38 +0800 Subject: [PATCH 1/7] [FLINK-6233] [table] Support rowtime inner equi-join between two streams in the SQL API --- .../scala/TimeBoundedJoinExample.scala | 147 +++++ .../datastream/DataStreamWindowJoin.scala | 67 ++- .../datastream/DataStreamWindowJoinRule.scala | 8 +- .../join/TimeBoundedStreamInnerJoin.scala | 533 ++++++++++++++++++ .../table/runtime/join/WindowJoinUtil.scala | 26 +- .../runtime/harness/JoinHarnessTest.scala | 162 +++++- 6 files changed, 928 insertions(+), 15 deletions(-) create mode 100644 flink-examples/flink-examples-table/src/main/scala/org/apache/flink/table/examples/scala/TimeBoundedJoinExample.scala create mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/TimeBoundedStreamInnerJoin.scala diff --git a/flink-examples/flink-examples-table/src/main/scala/org/apache/flink/table/examples/scala/TimeBoundedJoinExample.scala b/flink-examples/flink-examples-table/src/main/scala/org/apache/flink/table/examples/scala/TimeBoundedJoinExample.scala new file mode 100644 index 0000000000000..2c0f098114096 --- /dev/null +++ b/flink-examples/flink-examples-table/src/main/scala/org/apache/flink/table/examples/scala/TimeBoundedJoinExample.scala @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.examples.scala + +import org.apache.flink.api.common.io.GenericInputFormat +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.java.typeutils.ResultTypeQueryable +import org.apache.flink.api.scala._ +import org.apache.flink.streaming.api.TimeCharacteristic +import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks +import org.apache.flink.streaming.api.functions.sink.SinkFunction +import org.apache.flink.streaming.api.scala.{DataStream, StreamExecutionEnvironment} +import org.apache.flink.streaming.api.watermark.Watermark +import org.apache.flink.table.api.TableEnvironment +import org.apache.flink.table.api.scala._ + +import scala.util.Random + +/** + * A temporary example to show how to use the TimeBoundedJoin. + * Will be removed before code committing. + */ +object TimeBoundedJoinExample { + + // ************************************************************************* + // PROGRAM + // ************************************************************************* + + def main(args: Array[String]): Unit = { + + // set up execution environment + val env = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) + env.getConfig.setAutoWatermarkInterval(10) + + val orderA: DataStream[Order] = env.createInput(new InfiniteSource(1000, 0)) + .setParallelism(1) + .assignTimestampsAndWatermarks(new WatermarksAssigner(500)) + + val orderB: DataStream[Order] = env.createInput(new InfiniteSource(1000, 0)) + .setParallelism(1).assignTimestampsAndWatermarks(new WatermarksAssigner(500)) + + // register the DataStreams under the name "OrderA" and "OrderB" + tEnv.registerDataStream("OrderA", orderA, 'userA, 'productA, 'amountA, 'rtA.rowtime) + tEnv.registerDataStream("OrderB", orderB, 'userB, 'productB, 'amountB, 'rtB.rowtime) + tEnv.queryConfig + + // union the two tables + val result = tEnv.sql( + "SELECT userA, productA, amountA, DATE_FORMAT(rtA, '%i:%S'), DATE_FORMAT(rtB, '%i:%S'), " + + "productB FROM OrderA, OrderB" + + " WHERE OrderA.productA = OrderB.productB " + + "AND OrderA.rtA BETWEEN OrderB.rtB - INTERVAL '10' SECOND AND OrderB.rtB + INTERVAL '8' " + + //"AND OrderA.rtA BETWEEN OrderB.rtB - INTERVAL '10' SECOND AND OrderB.rtB - INTERVAL '8' + // " + + //"AND OrderA.rtA BETWEEN OrderB.rtB + INTERVAL '6' SECOND AND OrderB.rtB + INTERVAL '10' + // " + + "SECOND") + + result.toAppendStream[Order2].addSink(new SinkFunction[Order2] { + override def invoke(value: Order2): Unit = { + println(value) + } + }) + println(env.getExecutionPlan) + env.execute() + } + + /** + * InfiniteSource + * + * @param interval + */ + class InfiniteSource(interval: Int, offset: Long) + extends GenericInputFormat[Order] + with ResultTypeQueryable[Order] { + var a: Long = 0 + val b: Seq[String] = Seq("beer", "diaper", "rubber") + var c: Int = 0 + + override def reachedEnd(): Boolean = { + false + } + + override def nextRecord(reuse: Order): Order = { + Thread.sleep(interval) + a += 1 + c += 1 + Order(a, b(1), c, System.currentTimeMillis() - Random.nextInt(1000) - offset) + } + + override def getProducedType: TypeInformation[Order] = { + implicitly[TypeInformation[Order]] + } + } + + + /** + * WatermarksAssigner + * + * @param interval + */ + class WatermarksAssigner(interval: Long) extends AssignerWithPunctuatedWatermarks[Order] { + var lastWatermarks: Long = 0 + + override def extractTimestamp(element: Order, previousElementTimestamp: Long): Long = { + element.rt + } + + override def checkAndGetNextWatermark(lastElement: Order, extractedTimestamp: Long): Watermark = { + if (extractedTimestamp >= lastWatermarks + interval) { + lastWatermarks = lastWatermarks + ((extractedTimestamp - lastWatermarks) / interval) * + interval + new Watermark(lastWatermarks) + } else { + null + } + } + } + + // ************************************************************************* + // USER DATA TYPES + // ************************************************************************* + + case class Order(user: Long, product: String, amount: Int, rt: Long) + + case class Order2(userA: Long, productA: String, amountA: Int, rtA: String, rtB: String, + productB: String) + + +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamWindowJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamWindowJoin.scala index f8015b354e6d3..2f6693742b464 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamWindowJoin.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamWindowJoin.scala @@ -29,7 +29,8 @@ import org.apache.flink.table.api.{StreamQueryConfig, StreamTableEnvironment, Ta import org.apache.flink.table.plan.nodes.CommonJoin import org.apache.flink.table.plan.schema.RowSchema import org.apache.flink.table.plan.util.UpdatingPlanChecker -import org.apache.flink.table.runtime.join.{ProcTimeWindowInnerJoin, WindowJoinUtil} +import org.apache.flink.table.runtime.join.{JoinTimeIndicator, ProcTimeWindowInnerJoin, TimeBoundedStreamInnerJoin, WindowJoinUtil} +import org.apache.flink.table.runtime.operators.KeyedCoProcessOperatorWithWatermarkDelay import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo} /** @@ -48,6 +49,8 @@ class DataStreamWindowJoin( isRowTime: Boolean, leftLowerBound: Long, leftUpperBound: Long, + leftTimeIdx: Int, + rightTimeIdx: Int, remainCondition: Option[RexNode], ruleDescription: String) extends BiRel(cluster, traitSet, leftNode, rightNode) @@ -70,6 +73,8 @@ class DataStreamWindowJoin( isRowTime, leftLowerBound, leftUpperBound, + leftTimeIdx, + rightTimeIdx, remainCondition, ruleDescription) } @@ -127,8 +132,14 @@ class DataStreamWindowJoin( case JoinRelType.INNER => if (isRowTime) { // RowTime JoinCoProcessFunction - throw new TableException( - "RowTime inner join between stream and stream is not supported yet.") + createRowTimeInnerJoinFunction( + leftDataStream, + rightDataStream, + joinFunction.name, + joinFunction.code, + leftKeys, + rightKeys + ) } else { // Proctime JoinCoProcessFunction createProcTimeInnerJoinFunction( @@ -184,4 +195,54 @@ class DataStreamWindowJoin( .returns(returnTypeInfo) } } + + def createRowTimeInnerJoinFunction( + leftDataStream: DataStream[CRow], + rightDataStream: DataStream[CRow], + joinFunctionName: String, + joinFunctionCode: String, + leftKeys: Array[Int], + rightKeys: Array[Int]): DataStream[CRow] = { + + val returnTypeInfo = CRowTypeInfo(schema.typeInfo) + + val rowTimeInnerJoinFunc = new TimeBoundedStreamInnerJoin( + leftLowerBound, + leftUpperBound, + 0L, + leftSchema.typeInfo, + rightSchema.typeInfo, + joinFunctionName, + joinFunctionCode, + leftTimeIdx, + rightTimeIdx, + JoinTimeIndicator.ROWTIME + ) + + if (!leftKeys.isEmpty) { + leftDataStream + .connect(rightDataStream) + .keyBy(leftKeys, rightKeys) + .transform( + "rowTimeInnerJoinFunc", + returnTypeInfo, + new KeyedCoProcessOperatorWithWatermarkDelay[CRow, CRow, CRow, CRow]( + rowTimeInnerJoinFunc, + rowTimeInnerJoinFunc.getMaxOutputDelay) + ).returns(returnTypeInfo) + } else { + leftDataStream.connect(rightDataStream) + .keyBy(new NullByteKeySelector[CRow](), new NullByteKeySelector[CRow]) + .transform( + "rowTimeInnerJoinFunc", + returnTypeInfo, + new KeyedCoProcessOperatorWithWatermarkDelay[CRow, CRow, CRow, CRow]( + rowTimeInnerJoinFunc, + rowTimeInnerJoinFunc.getMaxOutputDelay) + ) + .setParallelism(1) + .setMaxParallelism(1) + .returns(returnTypeInfo) + } + } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamWindowJoinRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamWindowJoinRule.scala index 7dfcbc523d309..c51a1bc39cf10 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamWindowJoinRule.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamWindowJoinRule.scala @@ -55,8 +55,10 @@ class DataStreamWindowJoinRule if (windowBounds.isDefined) { if (windowBounds.get.isEventTime) { - // we cannot handle event-time window joins yet - false + val procTimeAttrInOutput = join.getRowType.getFieldList.asScala + .exists(f => FlinkTypeFactory.isProctimeIndicatorType(f.getType)) + + !remainingPredsAccessTime && !procTimeAttrInOutput } else { // Check that no event-time attributes are in the input. // The proc-time join implementation does ensure that record timestamp are correctly set. @@ -105,6 +107,8 @@ class DataStreamWindowJoinRule windowBounds.get.isEventTime, windowBounds.get.leftLowerBound, windowBounds.get.leftUpperBound, + windowBounds.get.leftTimeIdx, + windowBounds.get.rightTimeIdx, remainCondition, description) } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/TimeBoundedStreamInnerJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/TimeBoundedStreamInnerJoin.scala new file mode 100644 index 0000000000000..f9b751bdb7682 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/TimeBoundedStreamInnerJoin.scala @@ -0,0 +1,533 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.join + +import java.text.SimpleDateFormat +import java.util +import java.util.Map.Entry +import java.util.{Date, List => JList} + +import org.apache.flink.api.common.functions.FlatJoinFunction +import org.apache.flink.api.common.state._ +import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} +import org.apache.flink.api.java.typeutils.ListTypeInfo +import org.apache.flink.configuration.Configuration +import org.apache.flink.streaming.api.functions.co.CoProcessFunction +import org.apache.flink.table.codegen.Compiler +import org.apache.flink.table.runtime.CRowWrappingCollector +import org.apache.flink.table.runtime.join.JoinTimeIndicator.JoinTimeIndicator +import org.apache.flink.table.runtime.types.CRow +import org.apache.flink.table.util.Logging +import org.apache.flink.types.Row +import org.apache.flink.util.Collector + +/** + * A CoProcessFunction to execute time-bounded stream inner-join. + * + * Sample criteria: + * + * L.time between R.time + X and R.time + Y + * or AND R.time between L.time - Y and L.time - X + * + * @param leftLowerBound X + * @param leftUpperBound Y + * @param allowedLateness the lateness allowed for the two streams + * @param leftType the input type of left stream + * @param rightType the input type of right stream + * @param genJoinFuncName the function code of other non-equi conditions + * @param genJoinFuncCode the function name of other non-equi conditions + * @param timeIndicator indicate whether joining on proctime or rowtime + * + */ +class TimeBoundedStreamInnerJoin( + private val leftLowerBound: Long, + private val leftUpperBound: Long, + private val allowedLateness: Long, + private val leftType: TypeInformation[Row], + private val rightType: TypeInformation[Row], + private val genJoinFuncName: String, + private val genJoinFuncCode: String, + private val leftTimeIdx: Int, + private val rightTimeIdx: Int, + private val timeIndicator: JoinTimeIndicator) + extends CoProcessFunction[CRow, CRow, CRow] + with Compiler[FlatJoinFunction[Row, Row, Row]] + with Logging { + + private var cRowWrapper: CRowWrappingCollector = _ + + // the join function for other conditions + private var joinFunction: FlatJoinFunction[Row, Row, Row] = _ + + // cache to store the left stream records + private var leftCache: MapState[Long, JList[Row]] = _ + // cache to store right stream records + private var rightCache: MapState[Long, JList[Row]] = _ + + // state to record the timer on the left stream. 0 means no timer set + private var leftTimerState: ValueState[Long] = _ + // state to record the timer on the right stream. 0 means no timer set + private var rightTimerState: ValueState[Long] = _ + + private val leftRelativeSize: Long = -leftLowerBound + private val rightRelativeSize: Long = leftUpperBound + + private val relativeWindowSize = rightRelativeSize + leftRelativeSize + + private var leftOperatorTime: Long = 0L + private var rightOperatorTime: Long = 0L + + private var backPressureSuggestion: Long = 0L + + if (relativeWindowSize <= 0) { + LOG.warn("The relative window size is non-positive, please check the join conditions.") + } + + if (allowedLateness < 0) { + throw new IllegalArgumentException("The allowed lateness must be non-negative.") + } + + + /** + * For holding back watermarks. + * + * @return the maximum delay for the outputs + */ + def getMaxOutputDelay = Math.max(leftRelativeSize, rightRelativeSize) + allowedLateness; + + /** + * For dynamic query optimization. + * + * @return the suggested offset time for back-pressure + */ + def getBackPressureSuggestion = backPressureSuggestion + + override def open(config: Configuration) { + val clazz = compile( + getRuntimeContext.getUserCodeClassLoader, + genJoinFuncName, + genJoinFuncCode) + joinFunction = clazz.newInstance() + + cRowWrapper = new CRowWrappingCollector() + cRowWrapper.setChange(true) + + // Initialize the data caches. + val leftListTypeInfo: TypeInformation[JList[Row]] = new ListTypeInfo[Row](leftType) + val leftStateDescriptor: MapStateDescriptor[Long, JList[Row]] = + new MapStateDescriptor[Long, JList[Row]](timeIndicator + "InnerJoinLeftCache", + BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]], leftListTypeInfo) + leftCache = getRuntimeContext.getMapState(leftStateDescriptor) + + val rightListTypeInfo: TypeInformation[JList[Row]] = new ListTypeInfo[Row](rightType) + val rightStateDescriptor: MapStateDescriptor[Long, JList[Row]] = + new MapStateDescriptor[Long, JList[Row]](timeIndicator + "InnerJoinRightCache", + BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]], rightListTypeInfo) + rightCache = getRuntimeContext.getMapState(rightStateDescriptor) + + // Initialize the timer states. + val leftTimerStateDesc: ValueStateDescriptor[Long] = + new ValueStateDescriptor[Long](timeIndicator + "InnerJoinLeftTimerState", + classOf[Long]) + leftTimerState = getRuntimeContext.getState(leftTimerStateDesc) + + val rightTimerStateDesc: ValueStateDescriptor[Long] = + new ValueStateDescriptor[Long](timeIndicator + "InnerJoinRightTimerState", + classOf[Long]) + rightTimerState = getRuntimeContext.getState(rightTimerStateDesc) + } + + /** + * Process records from the left stream. + * + * @param cRowValue the input record + * @param ctx the context to register timer or get current time + * @param out the collector for outputting results + * + */ + override def processElement1( + cRowValue: CRow, + ctx: CoProcessFunction[CRow, CRow, CRow]#Context, + out: Collector[CRow]): Unit = { + val timeForRecord: Long = getTimeForRecord(ctx, cRowValue, true) + getCurrentOperatorTime(ctx) + processElement( + cRowValue, + timeForRecord, + ctx, + out, + leftOperatorTime, + rightOperatorTime, + rightTimerState, + leftCache, + rightCache, + true + ) + } + + /** + * Process records from the right stream. + * + * @param cRowValue the input record + * @param ctx the context to get current time + * @param out the collector for outputting results + * + */ + override def processElement2( + cRowValue: CRow, + ctx: CoProcessFunction[CRow, CRow, CRow]#Context, + out: Collector[CRow]): Unit = { + val timeForRecord: Long = getTimeForRecord(ctx, cRowValue, false) + getCurrentOperatorTime(ctx) + processElement( + cRowValue, + timeForRecord, + ctx, + out, + rightOperatorTime, + leftOperatorTime, + leftTimerState, + rightCache, + leftCache, + false + ) + } + + /** + * Put a record from the input stream into the cache and iterate the opposite cache to + * output records meeting the join conditions. If there is no timer set for the OPPOSITE + * STREAM, register one. + */ + private def processElement( + cRowValue: CRow, + timeForRecord: Long, + ctx: CoProcessFunction[CRow, CRow, CRow]#Context, + out: Collector[CRow], + myWatermark: Long, + oppositeWatermark: Long, + oppositeTimeState: ValueState[Long], + recordListCache: MapState[Long, JList[Row]], + oppositeCache: MapState[Long, JList[Row]], + leftRecord: Boolean): Unit = { + if (relativeWindowSize > 0) { + //TODO Shall we consider adding a method for initialization with the context and collector? + cRowWrapper.out = out + + val record = cRowValue.row + + //TODO Only if the time of the record is greater than the watermark, can we continue. + if (timeForRecord >= myWatermark - allowedLateness) { + val oppositeLowerBound: Long = + if (leftRecord) timeForRecord - rightRelativeSize else timeForRecord - leftRelativeSize + + val oppositeUpperBound: Long = + if (leftRecord) timeForRecord + leftRelativeSize else timeForRecord + rightRelativeSize + + // Put the record into the cache for later use. + val recordList = if (recordListCache.contains(timeForRecord)) { + recordListCache.get(timeForRecord) + } else { + new util.ArrayList[Row]() + } + recordList.add(record) + recordListCache.put(timeForRecord, recordList) + + // Register a timer on THE OTHER STREAM to remove records from the cache once they are + // expired. + if (oppositeTimeState.value == 0) { + registerCleanUpTimer( + ctx, timeForRecord, oppositeWatermark, oppositeTimeState, leftRecord, true) + } + + // Join the record with records from the opposite stream. + val oppositeIterator = oppositeCache.iterator() + var oppositeEntry: Entry[Long, util.List[Row]] = null + var oppositeTime: Long = 0L; + while (oppositeIterator.hasNext) { + oppositeEntry = oppositeIterator.next + oppositeTime = oppositeEntry.getKey + if (oppositeTime < oppositeLowerBound - allowedLateness) { + //TODO Considering the data out-of-order, we should not remove records here. + } else if (oppositeTime >= oppositeLowerBound && oppositeTime <= oppositeUpperBound) { + val oppositeRows = oppositeEntry.getValue + var i = 0 + if (leftRecord) { + while (i < oppositeRows.size) { + joinFunction.join(record, oppositeRows.get(i), cRowWrapper) + i += 1 + } + } else { + while (i < oppositeRows.size) { + joinFunction.join(oppositeRows.get(i), record, cRowWrapper) + i += 1 + } + } + } else if (oppositeTime > oppositeUpperBound) { + //TODO If the keys are ordered, can we break here? + } + } + } else { + //TODO Need some extra logic here? + LOG.warn(s"$record is out-of-date.") + } + } + } + + /** + * Register a timer for cleaning up records in a specified time. + * + * @param ctx the context to register timer + * @param timeForRecord time for the input record + * @param oppositeWatermark watermark of the opposite stream + * @param timerState stores the timestamp for the next timer + * @param leftRecord record from the left or the right stream + * @param firstTimer whether this is the first timer + */ + private def registerCleanUpTimer( + ctx: CoProcessFunction[CRow, CRow, CRow]#Context, + timeForRecord: Long, + oppositeWatermark: Long, + timerState: ValueState[Long], + leftRecord: Boolean, + firstTimer: Boolean): Unit = { + val cleanUpTime = timeForRecord + (if (leftRecord) leftRelativeSize else rightRelativeSize) + + allowedLateness + 1 + registerTimer(ctx, !leftRecord, cleanUpTime) + LOG.debug(s"Register a clean up timer on the ${if (leftRecord) "RIGHT" else "LEFT"} state:" + + s" timeForRecord = ${timeForRecord}, cleanUpTime = ${cleanUpTime}, oppositeWatermark = " + + s"${oppositeWatermark}") + timerState.update(cleanUpTime) + if (cleanUpTime <= oppositeWatermark + allowedLateness && firstTimer) { + backPressureSuggestion = + if (leftRecord) (oppositeWatermark + allowedLateness - cleanUpTime) + else -(oppositeWatermark + allowedLateness - cleanUpTime) + LOG.warn("The clean timer for the " + + s"${if (leftRecord) "left" else "right"}" + + s" stream is lower than ${if (leftRecord) "right" else "left"} watermark." + + s" requiredTime = ${formatTime(cleanUpTime)}, watermark = ${formatTime(oppositeWatermark)}," + + s"backPressureSuggestion = " + s"${backPressureSuggestion}.") + } + } + + + /** + * Called when a registered timer is fired. + * Remove records which are earlier than the expiration time, + * and register a new timer for the earliest remaining records. + * + * @param timestamp the timestamp of the timer + * @param ctx the context to register timer or get current time + * @param out the collector for returning result values + */ + override def onTimer( + timestamp: Long, + ctx: CoProcessFunction[CRow, CRow, CRow]#OnTimerContext, + out: Collector[CRow]): Unit = { + getCurrentOperatorTime(ctx) + //TODO In the future, we should separate the left and right watermarks. Otherwise, the + //TODO registered timer of the faster stream will be delayed, even if the watermarks have + //TODO already been emitted by the source. + if (leftTimerState.value == timestamp) { + val rightExpirationTime = leftOperatorTime - rightRelativeSize - allowedLateness - 1 + removeExpiredRecords( + timestamp, + rightExpirationTime, + leftOperatorTime, + rightCache, + leftTimerState, + ctx, + false + ) + } + + if (rightTimerState.value == timestamp) { + val leftExpirationTime = rightOperatorTime - leftRelativeSize - allowedLateness - 1 + removeExpiredRecords( + timestamp, + leftExpirationTime, + rightOperatorTime, + leftCache, + rightTimerState, + ctx, + true + ) + } + } + + /** + * Remove the expired records. Register a new timer if the cache still holds records + * after the cleaning up. + */ + private def removeExpiredRecords( + timerFiringTime: Long, + expirationTime: Long, + oppositeWatermark: Long, + recordCache: MapState[Long, JList[Row]], + timerState: ValueState[Long], + ctx: CoProcessFunction[CRow, CRow, CRow]#OnTimerContext, + removeLeft: Boolean): Unit = { + + val keysIterator = recordCache.keys().iterator() + + // Search for expired timestamps. + // If we find a non-expired timestamp, remember the timestamp and leave the loop. + // This way we find all expired timestamps if they are sorted without doing a full pass. + var earliestTimestamp: Long = -1L + var recordTime: Long = 0L + while (keysIterator.hasNext) { + //TODO The "short-circuit" code was commented, because when using a StateMap with + //TODO unordered keys, the cache will grow indefinitely! + // && earliestTimestamp < 0) { + recordTime = keysIterator.next + if (recordTime <= expirationTime) { + // TODO Not sure if we can remove records directly. + keysIterator.remove() + } else { + // We find the earliest timestamp that is still valid. + if (recordTime < earliestTimestamp || earliestTimestamp < 0) { + earliestTimestamp = recordTime + } + } + } + // If the cache contains non-expired timestamps, register a new timer. + // Otherwise clear the states. + if (earliestTimestamp > 0) { + registerCleanUpTimer(ctx, earliestTimestamp, oppositeWatermark, timerState, removeLeft, false) + } else { + // The timerState will be 0. + timerState.clear() + recordCache.clear() + } + } + + /** + * Get the operator times of the two streams. + * + * @param ctx the context to acquire watermarks + */ + protected def getCurrentOperatorTime( + ctx: CoProcessFunction[CRow, CRow, CRow]#Context): Unit = { + timeIndicator match { + case JoinTimeIndicator.ROWTIME => { + rightOperatorTime = + if (ctx.timerService().currentWatermark() > 0) ctx.timerService().currentWatermark() + else 0L; + leftOperatorTime = + if (ctx.timerService().currentWatermark() > 0) ctx.timerService().currentWatermark() + else 0L; + } + case JoinTimeIndicator.PROCTIME => { + rightOperatorTime = ctx.timerService().currentProcessingTime() + leftOperatorTime = ctx.timerService().currentProcessingTime() + } + } + } + + + /** + * Return the rowtime or proctime for the target record. + * + * @param context the runtime context + * @param record the target record + * @param isLeft whether the record is from the left stream + * @return time for the target record + */ + protected def getTimeForRecord( + context: CoProcessFunction[CRow, CRow, CRow]#Context, + record: CRow, + isLeft: Boolean): Long = { + timeIndicator match { + case JoinTimeIndicator.ROWTIME => { + return if (isLeft) { + record.row.getField(leftTimeIdx).asInstanceOf[Long] + } else { + record.row.getField(rightTimeIdx).asInstanceOf[Long]; + } + } + case JoinTimeIndicator.PROCTIME => { + return context.timerService().currentProcessingTime(); + } + } + } + + /** + * Register a proctime or rowtime timer. + * + * @param ctx the context to register the timer + * @param isLeft whether this timer should be registered on the left stream + * @param cleanupTime timestamp for the timer + */ + protected def registerTimer( + ctx: CoProcessFunction[CRow, CRow, CRow]#Context, isLeft: Boolean, cleanupTime: Long): Unit = { + // Maybe we can register timers for different streams in the future. + timeIndicator match { + case JoinTimeIndicator.ROWTIME => { + ctx.timerService.registerEventTimeTimer(cleanupTime) + } + case JoinTimeIndicator.PROCTIME => { + ctx.timerService.registerProcessingTimeTimer(cleanupTime) + } + } + } + + //********* Functions for temporary test use. *****************// + + def formatTime(time: Long): String = { + if (0 == time) { + return "null" + } + val f: SimpleDateFormat = new SimpleDateFormat("HH:mm:ss SSS") + f.format(new Date(time)) + } + + def printCacheSize = { + var leftSize = 0; + var rightSize = 0; + var iterator = leftCache.iterator(); + while (iterator.hasNext) { + leftSize = leftSize + iterator.next().getValue.size() + } + iterator = rightCache.iterator(); + while (iterator.hasNext) { + rightSize = rightSize + iterator.next().getValue.size() + } + + println(s"leftSize = $leftSize, rightSize = $rightSize") + } + + override def toString = s"RowTimeWindowInnerJoin(" + + s"leftTimerState=${formatTime(leftTimerState.value())}, " + + s"rightTimerState=${formatTime(rightTimerState.value())}, " + + s"leftRelativeSize=$leftRelativeSize, " + + s"rightRelativeSize=$rightRelativeSize, relativeWindowSize=$relativeWindowSize, " + + s"leftOperatorTime=${formatTime(leftOperatorTime)}," + + s" rightOperatorTime=${formatTime(rightOperatorTime)})" + +} + +//********* Will be removed before committing. ************// + +/** + * TODO Not sure if that can be replaced by [[org.apache.flink.streaming.api.TimeCharacteristic]] + */ +object JoinTimeIndicator extends Enumeration { + type JoinTimeIndicator = Value + val ROWTIME, PROCTIME = Value +} + diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/WindowJoinUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/WindowJoinUtil.scala index b5661139d2efe..05a5e26f1d564 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/WindowJoinUtil.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/WindowJoinUtil.scala @@ -39,14 +39,17 @@ import scala.collection.JavaConverters._ */ object WindowJoinUtil { - case class WindowBounds(isEventTime: Boolean, leftLowerBound: Long, leftUpperBound: Long) + case class WindowBounds(isEventTime: Boolean, leftLowerBound: Long, leftUpperBound: Long, + leftTimeIdx: Int, rightTimeIdx: Int) protected case class WindowBound(bound: Long, isLeftLower: Boolean) protected case class TimePredicate( isEventTime: Boolean, leftInputOnLeftSide: Boolean, + leftTimeIdx: Int, + rightTimeIdx: Int, pred: RexCall) - protected case class TimeAttributeAccess(isEventTime: Boolean, isLeftInput: Boolean) + protected case class TimeAttributeAccess(isEventTime: Boolean, isLeftInput: Boolean, idx: Int) /** * Extracts the window bounds from a join predicate. @@ -115,10 +118,15 @@ object WindowJoinUtil { case _ => Some(otherPreds.reduceLeft((l, r) => RelOptUtil.andJoinFilters(rexBuilder, l, r))) } - - val bounds = Some(WindowBounds(timePreds.head.isEventTime, leftLowerBound, leftUpperBound)) - - (bounds, remainCondition) + if (timePreds.head.leftInputOnLeftSide) { + val bounds = Some(WindowBounds(timePreds.head.isEventTime, leftLowerBound, leftUpperBound, + timePreds.head.leftTimeIdx, timePreds.last.rightTimeIdx - leftLogicalFieldCnt)) + (bounds, remainCondition) + } else { + val bounds = Some(WindowBounds(timePreds.head.isEventTime, leftLowerBound, leftUpperBound, + timePreds.last.rightTimeIdx, timePreds.head.leftTimeIdx - leftLogicalFieldCnt)) + (bounds, remainCondition) + } } /** @@ -196,8 +204,8 @@ object WindowJoinUtil { case (Some(left), Some(right)) if left.isLeftInput == right.isLeftInput => // Window join predicates must reference the time attribute of both inputs. Right(pred) - case (Some(left), Some(_)) => - Left(TimePredicate(left.isEventTime, left.isLeftInput, c)) + case (Some(left), Some(right)) => + Left(TimePredicate(left.isEventTime, left.isLeftInput, left.idx, right.idx, c)) } // not a comparison predicate. case _ => Right(pred) @@ -225,7 +233,7 @@ object WindowJoinUtil { case t: TimeIndicatorRelDataType => // time attribute access. Remember time type and side of input val isLeftInput = idx < leftFieldCount - Seq(TimeAttributeAccess(t.isEventTime, isLeftInput)) + Seq(TimeAttributeAccess(t.isEventTime, isLeftInput, idx)) case _ => // not a time attribute access. Seq() diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/JoinHarnessTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/JoinHarnessTest.scala index 065b7bcc15e38..74c04edf3f88f 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/JoinHarnessTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/JoinHarnessTest.scala @@ -27,7 +27,7 @@ import org.apache.flink.streaming.api.operators.co.KeyedCoProcessOperator import org.apache.flink.streaming.runtime.streamrecord.StreamRecord import org.apache.flink.streaming.util.KeyedTwoInputStreamOperatorTestHarness import org.apache.flink.table.runtime.harness.HarnessTestBase.{RowResultSortComparator, TupleRowKeySelector} -import org.apache.flink.table.runtime.join.ProcTimeWindowInnerJoin +import org.apache.flink.table.runtime.join.{JoinTimeIndicator, ProcTimeWindowInnerJoin, TimeBoundedStreamInnerJoin} import org.apache.flink.table.runtime.types.CRow import org.apache.flink.types.Row import org.junit.Test @@ -232,4 +232,164 @@ class JoinHarnessTest extends HarnessTestBase{ testHarness.close() } + /** a.proctime >= b.proctime - 10 and a.proctime <= b.proctime + 20 **/ + @Test + def testNewNormalProcTimeJoin() { + + val joinProcessFunc = new TimeBoundedStreamInnerJoin( + -10, 20, 0, rT, rT, "TestJoinFunction", funcCode, -1, -1, JoinTimeIndicator.PROCTIME) + + val operator: KeyedCoProcessOperator[Integer, CRow, CRow, CRow] = + new KeyedCoProcessOperator[Integer, CRow, CRow, CRow](joinProcessFunc) + val testHarness: KeyedTwoInputStreamOperatorTestHarness[Integer, CRow, CRow, CRow] = + new KeyedTwoInputStreamOperatorTestHarness[Integer, CRow, CRow, CRow]( + operator, + new TupleRowKeySelector[Integer](0), + new TupleRowKeySelector[Integer](0), + BasicTypeInfo.INT_TYPE_INFO, + 1, 1, 0) + + testHarness.open() + + // left stream input + testHarness.setProcessingTime(1) + testHarness.processElement1(new StreamRecord( + CRow(Row.of(1: JInt, "aaa"), true), 1)) + assertEquals(1, testHarness.numProcessingTimeTimers()) + testHarness.setProcessingTime(2) + testHarness.processElement1(new StreamRecord( + CRow(Row.of(2: JInt, "bbb"), true), 2)) + assertEquals(2, testHarness.numProcessingTimeTimers()) + testHarness.setProcessingTime(3) + testHarness.processElement1(new StreamRecord( + CRow(Row.of(1: JInt, "aaa2"), true), 3)) + assertEquals(4, testHarness.numKeyedStateEntries()) + assertEquals(2, testHarness.numProcessingTimeTimers()) + + // right stream input and output normally + testHarness.processElement2(new StreamRecord( + CRow(Row.of(1: JInt, "Hi1"), true), 3)) + testHarness.setProcessingTime(4) + testHarness.processElement2(new StreamRecord( + CRow(Row.of(2: JInt, "Hello1"), true), 4)) + assertEquals(8, testHarness.numKeyedStateEntries()) + assertEquals(4, testHarness.numProcessingTimeTimers()) + + // expired left stream record at timestamp 1 + testHarness.setProcessingTime(12) + assertEquals(8, testHarness.numKeyedStateEntries()) + assertEquals(4, testHarness.numProcessingTimeTimers()) + testHarness.processElement2(new StreamRecord( + CRow(Row.of(1: JInt, "Hi2"), true), 12)) + + // expired right stream record at timestamp 4 and all left stream + testHarness.setProcessingTime(25) + assertEquals(2, testHarness.numKeyedStateEntries()) + assertEquals(1, testHarness.numProcessingTimeTimers()) + testHarness.processElement1(new StreamRecord( + CRow(Row.of(1: JInt, "aaa3"), true), 25)) + testHarness.processElement1(new StreamRecord( + CRow(Row.of(2: JInt, "bbb2"), true), 25)) + testHarness.processElement2(new StreamRecord( + CRow(Row.of(2: JInt, "Hello2"), true), 25)) + + testHarness.setProcessingTime(45) + assertTrue(testHarness.numKeyedStateEntries() > 0) + testHarness.setProcessingTime(46) + assertEquals(0, testHarness.numKeyedStateEntries()) + val result = testHarness.getOutput + + val expectedOutput = new ConcurrentLinkedQueue[Object]() + + expectedOutput.add(new StreamRecord( + CRow(Row.of(1: JInt, "aaa", 1: JInt, "Hi1"), true), 3)) + expectedOutput.add(new StreamRecord( + CRow(Row.of(1: JInt, "aaa2", 1: JInt, "Hi1"), true), 3)) + expectedOutput.add(new StreamRecord( + CRow(Row.of(2: JInt, "bbb", 2: JInt, "Hello1"), true), 4)) + expectedOutput.add(new StreamRecord( + CRow(Row.of(1: JInt, "aaa2", 1: JInt, "Hi2"), true), 12)) + expectedOutput.add(new StreamRecord( + CRow(Row.of(1: JInt, "aaa3", 1: JInt, "Hi2"), true), 25)) + expectedOutput.add(new StreamRecord( + CRow(Row.of(2: JInt, "bbb2", 2: JInt, "Hello2"), true), 25)) + + verify(expectedOutput, result, new RowResultSortComparator()) + + testHarness.close() + } + + /** a.proctime >= b.proctime - 10 and a.proctime <= b.proctime - 5 **/ + @Test + def testNewProcTimeJoinSingleNeedStore() { + + val joinProcessFunc = new TimeBoundedStreamInnerJoin( + -10, -5, 0, rT, rT, "TestJoinFunction", funcCode, -1, -1, JoinTimeIndicator.PROCTIME) + + val operator: KeyedCoProcessOperator[Integer, CRow, CRow, CRow] = + new KeyedCoProcessOperator[Integer, CRow, CRow, CRow](joinProcessFunc) + val testHarness: KeyedTwoInputStreamOperatorTestHarness[Integer, CRow, CRow, CRow] = + new KeyedTwoInputStreamOperatorTestHarness[Integer, CRow, CRow, CRow]( + operator, + new TupleRowKeySelector[Integer](0), + new TupleRowKeySelector[Integer](0), + BasicTypeInfo.INT_TYPE_INFO, + 1, 1, 0) + + testHarness.open() + + testHarness.setProcessingTime(1) + testHarness.processElement1(new StreamRecord( + CRow(Row.of(1: JInt, "aaa1"), true), 1)) + testHarness.setProcessingTime(2) + testHarness.processElement1(new StreamRecord( + CRow(Row.of(2: JInt, "aaa2"), true), 2)) + testHarness.setProcessingTime(3) + testHarness.processElement1(new StreamRecord( + CRow(Row.of(1: JInt, "aaa3"), true), 3)) + assertEquals(4, testHarness.numKeyedStateEntries()) + assertEquals(2, testHarness.numProcessingTimeTimers()) + + // all the records should be stored + testHarness.processElement2(new StreamRecord( + CRow(Row.of(1: JInt, "bbb3"), true), 3)) + assertEquals(6, testHarness.numKeyedStateEntries()) + assertEquals(3, testHarness.numProcessingTimeTimers()) + + // meet a.proctime <= b.proctime - 5 + testHarness.setProcessingTime(7) + testHarness.processElement2(new StreamRecord( + CRow(Row.of(2: JInt, "bbb7"), true), 7)) + assertEquals(6, testHarness.numKeyedStateEntries()) + assertEquals(3, testHarness.numProcessingTimeTimers()) + + // expire record of stream a at timestamp 1 + testHarness.setProcessingTime(12) + assertEquals(4, testHarness.numKeyedStateEntries()) + assertEquals(2, testHarness.numProcessingTimeTimers()) + testHarness.processElement2(new StreamRecord( + CRow(Row.of(1: JInt, "bbb12"), true), 12)) + + testHarness.setProcessingTime(13) + assertEquals(2, testHarness.numKeyedStateEntries()) + assertEquals(1, testHarness.numProcessingTimeTimers()) + + // state must be cleaned after the window timer interval has passed without new rows. + testHarness.setProcessingTime(23) + assertEquals(0, testHarness.numKeyedStateEntries()) + assertEquals(0, testHarness.numProcessingTimeTimers()) + val result = testHarness.getOutput + + val expectedOutput = new ConcurrentLinkedQueue[Object]() + + expectedOutput.add(new StreamRecord( + CRow(Row.of(2: JInt, "aaa2", 2: JInt, "bbb7"), true), 7)) + expectedOutput.add(new StreamRecord( + CRow(Row.of(1: JInt, "aaa3", 1: JInt, "bbb12"), true), 12)) + + verify(expectedOutput, result, new RowResultSortComparator()) + + testHarness.close() + } + } From d697a18214cd85cd0bbf8db3aaeaec10f4a95e4b Mon Sep 17 00:00:00 2001 From: Xingcan Cui Date: Tue, 12 Sep 2017 16:03:55 +0800 Subject: [PATCH 2/7] Feedback addressed Major changes: 1. Separate the rowtime and proctime join functions. 2. Replace the current proctime join function with the new one. 3. Add a static delay (relativeWindowSize/2) for cleanup timers. 4. Set an empty join function when relativeWindowSize < 0. 5. Remove the back pressure suggestion. 6. Add tests. 7. Refine the docs. 8. Reformat the codes. Other feedback has also been addressed. Remaining tasks: 1. Optimise the data caching and cleaning up policies. 2. Distinguish the < and <= (> and >=) for time criteria. --- docs/dev/table/sql.md | 2 +- .../scala/TimeBoundedJoinExample.scala | 147 ----- .../datastream/DataStreamWindowJoin.scala | 115 ++-- .../datastream/DataStreamWindowJoinRule.scala | 5 +- .../join/ProcTimeBoundedStreamInnerJoin.scala | 74 +++ .../join/RowTimeBoundedStreamInnerJoin.scala | 83 +++ .../join/TimeBoundedStreamInnerJoin.scala | 501 +++++++----------- .../table/runtime/join/WindowJoinUtil.scala | 40 +- .../flink/table/api/stream/sql/JoinTest.scala | 76 +++ .../runtime/harness/JoinHarnessTest.scala | 262 +++++++-- 10 files changed, 748 insertions(+), 557 deletions(-) delete mode 100644 flink-examples/flink-examples-table/src/main/scala/org/apache/flink/table/examples/scala/TimeBoundedJoinExample.scala create mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/ProcTimeBoundedStreamInnerJoin.scala create mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/RowTimeBoundedStreamInnerJoin.scala diff --git a/docs/dev/table/sql.md b/docs/dev/table/sql.md index fa4e3f330b268..88852336f6047 100644 --- a/docs/dev/table/sql.md +++ b/docs/dev/table/sql.md @@ -385,7 +385,7 @@ FROM Orders LEFT JOIN Product ON Orders.productId = Product.id

-

Note: Currently, only processing time window joins and INNER joins are supported.

+

Note: Currently, only INNER joins are supported.

{% highlight sql %} SELECT * diff --git a/flink-examples/flink-examples-table/src/main/scala/org/apache/flink/table/examples/scala/TimeBoundedJoinExample.scala b/flink-examples/flink-examples-table/src/main/scala/org/apache/flink/table/examples/scala/TimeBoundedJoinExample.scala deleted file mode 100644 index 2c0f098114096..0000000000000 --- a/flink-examples/flink-examples-table/src/main/scala/org/apache/flink/table/examples/scala/TimeBoundedJoinExample.scala +++ /dev/null @@ -1,147 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.flink.table.examples.scala - -import org.apache.flink.api.common.io.GenericInputFormat -import org.apache.flink.api.common.typeinfo.TypeInformation -import org.apache.flink.api.java.typeutils.ResultTypeQueryable -import org.apache.flink.api.scala._ -import org.apache.flink.streaming.api.TimeCharacteristic -import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks -import org.apache.flink.streaming.api.functions.sink.SinkFunction -import org.apache.flink.streaming.api.scala.{DataStream, StreamExecutionEnvironment} -import org.apache.flink.streaming.api.watermark.Watermark -import org.apache.flink.table.api.TableEnvironment -import org.apache.flink.table.api.scala._ - -import scala.util.Random - -/** - * A temporary example to show how to use the TimeBoundedJoin. - * Will be removed before code committing. - */ -object TimeBoundedJoinExample { - - // ************************************************************************* - // PROGRAM - // ************************************************************************* - - def main(args: Array[String]): Unit = { - - // set up execution environment - val env = StreamExecutionEnvironment.getExecutionEnvironment - val tEnv = TableEnvironment.getTableEnvironment(env) - env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) - env.getConfig.setAutoWatermarkInterval(10) - - val orderA: DataStream[Order] = env.createInput(new InfiniteSource(1000, 0)) - .setParallelism(1) - .assignTimestampsAndWatermarks(new WatermarksAssigner(500)) - - val orderB: DataStream[Order] = env.createInput(new InfiniteSource(1000, 0)) - .setParallelism(1).assignTimestampsAndWatermarks(new WatermarksAssigner(500)) - - // register the DataStreams under the name "OrderA" and "OrderB" - tEnv.registerDataStream("OrderA", orderA, 'userA, 'productA, 'amountA, 'rtA.rowtime) - tEnv.registerDataStream("OrderB", orderB, 'userB, 'productB, 'amountB, 'rtB.rowtime) - tEnv.queryConfig - - // union the two tables - val result = tEnv.sql( - "SELECT userA, productA, amountA, DATE_FORMAT(rtA, '%i:%S'), DATE_FORMAT(rtB, '%i:%S'), " + - "productB FROM OrderA, OrderB" + - " WHERE OrderA.productA = OrderB.productB " + - "AND OrderA.rtA BETWEEN OrderB.rtB - INTERVAL '10' SECOND AND OrderB.rtB + INTERVAL '8' " + - //"AND OrderA.rtA BETWEEN OrderB.rtB - INTERVAL '10' SECOND AND OrderB.rtB - INTERVAL '8' - // " + - //"AND OrderA.rtA BETWEEN OrderB.rtB + INTERVAL '6' SECOND AND OrderB.rtB + INTERVAL '10' - // " + - "SECOND") - - result.toAppendStream[Order2].addSink(new SinkFunction[Order2] { - override def invoke(value: Order2): Unit = { - println(value) - } - }) - println(env.getExecutionPlan) - env.execute() - } - - /** - * InfiniteSource - * - * @param interval - */ - class InfiniteSource(interval: Int, offset: Long) - extends GenericInputFormat[Order] - with ResultTypeQueryable[Order] { - var a: Long = 0 - val b: Seq[String] = Seq("beer", "diaper", "rubber") - var c: Int = 0 - - override def reachedEnd(): Boolean = { - false - } - - override def nextRecord(reuse: Order): Order = { - Thread.sleep(interval) - a += 1 - c += 1 - Order(a, b(1), c, System.currentTimeMillis() - Random.nextInt(1000) - offset) - } - - override def getProducedType: TypeInformation[Order] = { - implicitly[TypeInformation[Order]] - } - } - - - /** - * WatermarksAssigner - * - * @param interval - */ - class WatermarksAssigner(interval: Long) extends AssignerWithPunctuatedWatermarks[Order] { - var lastWatermarks: Long = 0 - - override def extractTimestamp(element: Order, previousElementTimestamp: Long): Long = { - element.rt - } - - override def checkAndGetNextWatermark(lastElement: Order, extractedTimestamp: Long): Watermark = { - if (extractedTimestamp >= lastWatermarks + interval) { - lastWatermarks = lastWatermarks + ((extractedTimestamp - lastWatermarks) / interval) * - interval - new Watermark(lastWatermarks) - } else { - null - } - } - } - - // ************************************************************************* - // USER DATA TYPES - // ************************************************************************* - - case class Order(user: Long, product: String, amount: Int, rt: Long) - - case class Order2(userA: Long, productA: String, amountA: Int, rtA: String, rtB: String, - productB: String) - - -} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamWindowJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamWindowJoin.scala index 2f6693742b464..f81a84fecea54 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamWindowJoin.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamWindowJoin.scala @@ -23,15 +23,19 @@ import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.core.{JoinInfo, JoinRelType} import org.apache.calcite.rel.{BiRel, RelNode, RelWriter} import org.apache.calcite.rex.RexNode +import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.functions.NullByteKeySelector import org.apache.flink.streaming.api.datastream.DataStream +import org.apache.flink.streaming.api.functions.co.CoProcessFunction import org.apache.flink.table.api.{StreamQueryConfig, StreamTableEnvironment, TableException} import org.apache.flink.table.plan.nodes.CommonJoin import org.apache.flink.table.plan.schema.RowSchema import org.apache.flink.table.plan.util.UpdatingPlanChecker -import org.apache.flink.table.runtime.join.{JoinTimeIndicator, ProcTimeWindowInnerJoin, TimeBoundedStreamInnerJoin, WindowJoinUtil} +import org.apache.flink.table.runtime.join.{ProcTimeBoundedStreamInnerJoin, RowTimeBoundedStreamInnerJoin, WindowJoinUtil} import org.apache.flink.table.runtime.operators.KeyedCoProcessOperatorWithWatermarkDelay import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo} +import org.apache.flink.util.Collector +import org.slf4j.{Logger, LoggerFactory} /** * RelNode for a time windowed stream join. @@ -57,6 +61,8 @@ class DataStreamWindowJoin( with CommonJoin with DataStreamRel { + lazy val LOG: Logger = LoggerFactory.getLogger(getClass) + override def deriveRowType(): RelDataType = schema.relDataType override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = { @@ -117,6 +123,11 @@ class DataStreamWindowJoin( val leftKeys = joinInfo.leftKeys.toIntArray val rightKeys = joinInfo.rightKeys.toIntArray + val relativeWindowSize = leftUpperBound - leftLowerBound + + val returnTypeInfo = CRowTypeInfo(schema.typeInfo) + + // generate join function val joinFunction = WindowJoinUtil.generateJoinFunction( @@ -130,26 +141,34 @@ class DataStreamWindowJoin( joinType match { case JoinRelType.INNER => - if (isRowTime) { - // RowTime JoinCoProcessFunction - createRowTimeInnerJoinFunction( - leftDataStream, - rightDataStream, - joinFunction.name, - joinFunction.code, - leftKeys, - rightKeys - ) + if (relativeWindowSize < 0) { + LOG.warn(s"The relative window size ${relativeWindowSize} is negative," + + " please check the join conditions.") + createEmptyInnerJoinFunction(leftDataStream, rightDataStream, returnTypeInfo) } else { - // Proctime JoinCoProcessFunction - createProcTimeInnerJoinFunction( - leftDataStream, - rightDataStream, - joinFunction.name, - joinFunction.code, - leftKeys, - rightKeys - ) + if (isRowTime) { + // RowTime JoinCoProcessFunction + createRowTimeInnerJoinFunction( + leftDataStream, + rightDataStream, + returnTypeInfo, + joinFunction.name, + joinFunction.code, + leftKeys, + rightKeys + ) + } else { + // Proctime JoinCoProcessFunction + createProcTimeInnerJoinFunction( + leftDataStream, + rightDataStream, + returnTypeInfo, + joinFunction.name, + joinFunction.code, + leftKeys, + rightKeys + ) + } } case JoinRelType.FULL => throw new TableException( @@ -166,16 +185,16 @@ class DataStreamWindowJoin( def createProcTimeInnerJoinFunction( leftDataStream: DataStream[CRow], rightDataStream: DataStream[CRow], + returnTypeInfo: TypeInformation[CRow], joinFunctionName: String, joinFunctionCode: String, leftKeys: Array[Int], rightKeys: Array[Int]): DataStream[CRow] = { - val returnTypeInfo = CRowTypeInfo(schema.typeInfo) - - val procInnerJoinFunc = new ProcTimeWindowInnerJoin( + val procInnerJoinFunc = new ProcTimeBoundedStreamInnerJoin( leftLowerBound, leftUpperBound, + allowedLateness = 0L, leftSchema.typeInfo, rightSchema.typeInfo, joinFunctionName, @@ -196,45 +215,62 @@ class DataStreamWindowJoin( } } + def createEmptyInnerJoinFunction( + leftDataStream: DataStream[CRow], + rightDataStream: DataStream[CRow], + returnTypeInfo: TypeInformation[CRow]) = { + leftDataStream.connect(rightDataStream).process( + new CoProcessFunction[CRow, CRow, CRow] { + override def processElement1( + value: CRow, + ctx: CoProcessFunction[CRow, CRow, CRow]#Context, + out: Collector[CRow]) = { + //Do nothing. + } + override def processElement2( + value: CRow, + ctx: CoProcessFunction[CRow, CRow, CRow]#Context, + out: Collector[CRow]) = { + //Do nothing. + } + }) + } def createRowTimeInnerJoinFunction( - leftDataStream: DataStream[CRow], - rightDataStream: DataStream[CRow], - joinFunctionName: String, - joinFunctionCode: String, - leftKeys: Array[Int], - rightKeys: Array[Int]): DataStream[CRow] = { - - val returnTypeInfo = CRowTypeInfo(schema.typeInfo) + leftDataStream: DataStream[CRow], + rightDataStream: DataStream[CRow], + returnTypeInfo: TypeInformation[CRow], + joinFunctionName: String, + joinFunctionCode: String, + leftKeys: Array[Int], + rightKeys: Array[Int]): DataStream[CRow] = { - val rowTimeInnerJoinFunc = new TimeBoundedStreamInnerJoin( + val rowTimeInnerJoinFunc = new RowTimeBoundedStreamInnerJoin( leftLowerBound, leftUpperBound, - 0L, + allowedLateness = 0L, leftSchema.typeInfo, rightSchema.typeInfo, joinFunctionName, joinFunctionCode, leftTimeIdx, - rightTimeIdx, - JoinTimeIndicator.ROWTIME - ) + rightTimeIdx) if (!leftKeys.isEmpty) { leftDataStream .connect(rightDataStream) .keyBy(leftKeys, rightKeys) .transform( - "rowTimeInnerJoinFunc", + "InnerRowtimeWindowJoin", returnTypeInfo, new KeyedCoProcessOperatorWithWatermarkDelay[CRow, CRow, CRow, CRow]( rowTimeInnerJoinFunc, rowTimeInnerJoinFunc.getMaxOutputDelay) - ).returns(returnTypeInfo) + ) } else { leftDataStream.connect(rightDataStream) .keyBy(new NullByteKeySelector[CRow](), new NullByteKeySelector[CRow]) .transform( - "rowTimeInnerJoinFunc", + "InnerRowtimeWindowJoin", returnTypeInfo, new KeyedCoProcessOperatorWithWatermarkDelay[CRow, CRow, CRow, CRow]( rowTimeInnerJoinFunc, @@ -242,7 +278,6 @@ class DataStreamWindowJoin( ) .setParallelism(1) .setMaxParallelism(1) - .returns(returnTypeInfo) } } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamWindowJoinRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamWindowJoinRule.scala index c51a1bc39cf10..a446864618649 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamWindowJoinRule.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamWindowJoinRule.scala @@ -55,10 +55,7 @@ class DataStreamWindowJoinRule if (windowBounds.isDefined) { if (windowBounds.get.isEventTime) { - val procTimeAttrInOutput = join.getRowType.getFieldList.asScala - .exists(f => FlinkTypeFactory.isProctimeIndicatorType(f.getType)) - - !remainingPredsAccessTime && !procTimeAttrInOutput + !remainingPredsAccessTime } else { // Check that no event-time attributes are in the input. // The proc-time join implementation does ensure that record timestamp are correctly set. diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/ProcTimeBoundedStreamInnerJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/ProcTimeBoundedStreamInnerJoin.scala new file mode 100644 index 0000000000000..11d42b9b1a18c --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/ProcTimeBoundedStreamInnerJoin.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.join + +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.streaming.api.functions.co.CoProcessFunction +import org.apache.flink.table.runtime.types.CRow +import org.apache.flink.types.Row + +/** + * The function to execute processing time bounded stream inner-join. + */ +class ProcTimeBoundedStreamInnerJoin( + leftLowerBound: Long, + leftUpperBound: Long, + allowedLateness: Long, + leftType: TypeInformation[Row], + rightType: TypeInformation[Row], + genJoinFuncName: String, + genJoinFuncCode: String) + extends TimeBoundedStreamInnerJoin( + leftLowerBound, + leftUpperBound, + allowedLateness, + leftType, + rightType, + genJoinFuncName, + genJoinFuncCode, + leftTimeIdx = -1, + rightTimeIdx = -1, + JoinTimeIndicator.PROCTIME) { + + override def checkRowOutOfDate(timeForRow: Long, watermark: Long) = false + + override def updateOperatorTime(ctx: CoProcessFunction[CRow, CRow, CRow]#Context): Unit = { + rightOperatorTime = ctx.timerService().currentProcessingTime() + leftOperatorTime = ctx.timerService().currentProcessingTime() + } + + override def getTimeForLeftStream( + context: CoProcessFunction[CRow, CRow, CRow]#Context, + row: CRow): Long = { + context.timerService().currentProcessingTime() + } + + override def getTimeForRightStream( + context: CoProcessFunction[CRow, CRow, CRow]#Context, + row: CRow): Long = { + context.timerService().currentProcessingTime() + } + + override def registerTimer( + ctx: CoProcessFunction[CRow, CRow, CRow]#Context, + isLeft: Boolean, + cleanupTime: Long): Unit = { + ctx.timerService.registerProcessingTimeTimer(cleanupTime) + } +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/RowTimeBoundedStreamInnerJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/RowTimeBoundedStreamInnerJoin.scala new file mode 100644 index 0000000000000..798fefdaaf384 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/RowTimeBoundedStreamInnerJoin.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.join + +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.streaming.api.functions.co.CoProcessFunction +import org.apache.flink.table.runtime.types.CRow +import org.apache.flink.types.Row + +/** + * The function to execute row(event) time bounded stream inner-join. + */ +class RowTimeBoundedStreamInnerJoin( + leftLowerBound: Long, + leftUpperBound: Long, + allowedLateness: Long, + leftType: TypeInformation[Row], + rightType: TypeInformation[Row], + genJoinFuncName: String, + genJoinFuncCode: String, + leftTimeIdx: Int, + rightTimeIdx: Int) + extends TimeBoundedStreamInnerJoin( + leftLowerBound, + leftUpperBound, + allowedLateness, + leftType, + rightType, + genJoinFuncName, + genJoinFuncCode, + leftTimeIdx, + rightTimeIdx, + JoinTimeIndicator.ROWTIME) { + + override def checkRowOutOfDate(timeForRow: Long, watermark: Long) = { + timeForRow <= watermark - allowedLateness + } + + override def updateOperatorTime(ctx: CoProcessFunction[CRow, CRow, CRow]#Context): Unit = { + rightOperatorTime = + if (ctx.timerService().currentWatermark() > 0) ctx.timerService().currentWatermark() + else 0L + leftOperatorTime = + if (ctx.timerService().currentWatermark() > 0) ctx.timerService().currentWatermark() + else 0L + } + + override def getTimeForLeftStream( + context: CoProcessFunction[CRow, CRow, CRow]#Context, + row: CRow): Long = { + row.row.getField(leftTimeIdx).asInstanceOf[Long] + } + + override def getTimeForRightStream( + context: CoProcessFunction[CRow, CRow, CRow]#Context, + row: CRow): Long = { + row.row.getField(rightTimeIdx).asInstanceOf[Long] + } + + override def registerTimer( + ctx: CoProcessFunction[CRow, CRow, CRow]#Context, + isLeft: Boolean, + cleanupTime: Long): Unit = { + // Maybe we can register timers for different streams in the future. + ctx.timerService.registerEventTimeTimer(cleanupTime) + } +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/TimeBoundedStreamInnerJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/TimeBoundedStreamInnerJoin.scala index f9b751bdb7682..2ee0028781bf2 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/TimeBoundedStreamInnerJoin.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/TimeBoundedStreamInnerJoin.scala @@ -18,10 +18,7 @@ package org.apache.flink.table.runtime.join -import java.text.SimpleDateFormat -import java.util -import java.util.Map.Entry -import java.util.{Date, List => JList} +import java.util.{ArrayList, List => JList} import org.apache.flink.api.common.functions.FlatJoinFunction import org.apache.flink.api.common.state._ @@ -39,14 +36,11 @@ import org.apache.flink.util.Collector /** * A CoProcessFunction to execute time-bounded stream inner-join. + * Two kinds of time criteria: + * "L.time between R.time + X and R.time + Y" or "R.time between L.time - Y and L.time - X". * - * Sample criteria: - * - * L.time between R.time + X and R.time + Y - * or AND R.time between L.time - Y and L.time - X - * - * @param leftLowerBound X - * @param leftUpperBound Y + * @param leftLowerBound the lower bound for the left stream (X in the criteria) + * @param leftUpperBound the upper bound for the left stream (Y in the criteria) * @param allowedLateness the lateness allowed for the two streams * @param leftType the input type of left stream * @param rightType the input type of right stream @@ -55,18 +49,18 @@ import org.apache.flink.util.Collector * @param timeIndicator indicate whether joining on proctime or rowtime * */ -class TimeBoundedStreamInnerJoin( - private val leftLowerBound: Long, - private val leftUpperBound: Long, - private val allowedLateness: Long, - private val leftType: TypeInformation[Row], - private val rightType: TypeInformation[Row], - private val genJoinFuncName: String, - private val genJoinFuncCode: String, - private val leftTimeIdx: Int, - private val rightTimeIdx: Int, - private val timeIndicator: JoinTimeIndicator) - extends CoProcessFunction[CRow, CRow, CRow] +abstract class TimeBoundedStreamInnerJoin( + private val leftLowerBound: Long, + private val leftUpperBound: Long, + private val allowedLateness: Long, + private val leftType: TypeInformation[Row], + private val rightType: TypeInformation[Row], + private val genJoinFuncName: String, + private val genJoinFuncCode: String, + private val leftTimeIdx: Int, + private val rightTimeIdx: Int, + private val timeIndicator: JoinTimeIndicator) + extends CoProcessFunction[CRow, CRow, CRow] with Compiler[FlatJoinFunction[Row, Row, Row]] with Logging { @@ -75,9 +69,9 @@ class TimeBoundedStreamInnerJoin( // the join function for other conditions private var joinFunction: FlatJoinFunction[Row, Row, Row] = _ - // cache to store the left stream records + // cache to store rows from the left stream private var leftCache: MapState[Long, JList[Row]] = _ - // cache to store right stream records + // cache to store rows from the right stream private var rightCache: MapState[Long, JList[Row]] = _ // state to record the timer on the left stream. 0 means no timer set @@ -88,41 +82,32 @@ class TimeBoundedStreamInnerJoin( private val leftRelativeSize: Long = -leftLowerBound private val rightRelativeSize: Long = leftUpperBound - private val relativeWindowSize = rightRelativeSize + leftRelativeSize - - private var leftOperatorTime: Long = 0L - private var rightOperatorTime: Long = 0L - - private var backPressureSuggestion: Long = 0L + protected var leftOperatorTime: Long = 0L + protected var rightOperatorTime: Long = 0L - if (relativeWindowSize <= 0) { - LOG.warn("The relative window size is non-positive, please check the join conditions.") - } + //For delayed cleanup + private val cleanupDelay = (leftRelativeSize + rightRelativeSize) / 2 if (allowedLateness < 0) { throw new IllegalArgumentException("The allowed lateness must be non-negative.") } - /** - * For holding back watermarks. + * Get the maximum interval between receiving a row and emitting it (as part of a joined result). + * Only reasonable for row time join. * * @return the maximum delay for the outputs */ - def getMaxOutputDelay = Math.max(leftRelativeSize, rightRelativeSize) + allowedLateness; - - /** - * For dynamic query optimization. - * - * @return the suggested offset time for back-pressure - */ - def getBackPressureSuggestion = backPressureSuggestion + def getMaxOutputDelay: Long = Math.max(leftRelativeSize, rightRelativeSize) + allowedLateness override def open(config: Configuration) { + LOG.debug(s"Compiling JoinFunction: $genJoinFuncName \n\n " + + s"Code:\n$genJoinFuncCode") val clazz = compile( getRuntimeContext.getUserCodeClassLoader, genJoinFuncName, genJoinFuncCode) + LOG.debug("Instantiating JoinFunction.") joinFunction = clazz.newInstance() cRowWrapper = new CRowWrappingCollector() @@ -131,340 +116,308 @@ class TimeBoundedStreamInnerJoin( // Initialize the data caches. val leftListTypeInfo: TypeInformation[JList[Row]] = new ListTypeInfo[Row](leftType) val leftStateDescriptor: MapStateDescriptor[Long, JList[Row]] = - new MapStateDescriptor[Long, JList[Row]](timeIndicator + "InnerJoinLeftCache", - BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]], leftListTypeInfo) + new MapStateDescriptor[Long, JList[Row]]( + timeIndicator + "InnerJoinLeftCache", + BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]], + leftListTypeInfo) leftCache = getRuntimeContext.getMapState(leftStateDescriptor) val rightListTypeInfo: TypeInformation[JList[Row]] = new ListTypeInfo[Row](rightType) val rightStateDescriptor: MapStateDescriptor[Long, JList[Row]] = - new MapStateDescriptor[Long, JList[Row]](timeIndicator + "InnerJoinRightCache", - BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]], rightListTypeInfo) + new MapStateDescriptor[Long, JList[Row]]( + timeIndicator + "InnerJoinRightCache", + BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]], + rightListTypeInfo) rightCache = getRuntimeContext.getMapState(rightStateDescriptor) // Initialize the timer states. val leftTimerStateDesc: ValueStateDescriptor[Long] = - new ValueStateDescriptor[Long](timeIndicator + "InnerJoinLeftTimerState", - classOf[Long]) + new ValueStateDescriptor[Long](timeIndicator + "InnerJoinLeftTimerState", classOf[Long]) leftTimerState = getRuntimeContext.getState(leftTimerStateDesc) val rightTimerStateDesc: ValueStateDescriptor[Long] = - new ValueStateDescriptor[Long](timeIndicator + "InnerJoinRightTimerState", - classOf[Long]) + new ValueStateDescriptor[Long](timeIndicator + "InnerJoinRightTimerState", classOf[Long]) rightTimerState = getRuntimeContext.getState(rightTimerStateDesc) } /** - * Process records from the left stream. - * - * @param cRowValue the input record - * @param ctx the context to register timer or get current time - * @param out the collector for outputting results - * + * Process rows from the left stream. */ override def processElement1( - cRowValue: CRow, - ctx: CoProcessFunction[CRow, CRow, CRow]#Context, - out: Collector[CRow]): Unit = { - val timeForRecord: Long = getTimeForRecord(ctx, cRowValue, true) - getCurrentOperatorTime(ctx) + cRowValue: CRow, + ctx: CoProcessFunction[CRow, CRow, CRow]#Context, + out: Collector[CRow]): Unit = { + updateOperatorTime(ctx) + val rowTime: Long = getTimeForLeftStream(ctx, cRowValue) + val oppositeLowerBound: Long = rowTime - rightRelativeSize + val oppositeUpperBound: Long = rowTime + leftRelativeSize processElement( cRowValue, - timeForRecord, + rowTime, ctx, out, leftOperatorTime, + oppositeLowerBound, + oppositeUpperBound, rightOperatorTime, rightTimerState, leftCache, rightCache, - true + leftRow = true ) } /** - * Process records from the right stream. - * - * @param cRowValue the input record - * @param ctx the context to get current time - * @param out the collector for outputting results - * + * Process rows from the right stream. */ override def processElement2( - cRowValue: CRow, - ctx: CoProcessFunction[CRow, CRow, CRow]#Context, - out: Collector[CRow]): Unit = { - val timeForRecord: Long = getTimeForRecord(ctx, cRowValue, false) - getCurrentOperatorTime(ctx) + cRowValue: CRow, + ctx: CoProcessFunction[CRow, CRow, CRow]#Context, + out: Collector[CRow]): Unit = { + updateOperatorTime(ctx) + val rowTime: Long = getTimeForRightStream(ctx, cRowValue) + val oppositeLowerBound: Long = rowTime - leftRelativeSize + val oppositeUpperBound: Long = rowTime + rightRelativeSize processElement( cRowValue, - timeForRecord, + rowTime, ctx, out, rightOperatorTime, + oppositeLowerBound, + oppositeUpperBound, leftOperatorTime, leftTimerState, rightCache, leftCache, - false + leftRow = false ) } /** - * Put a record from the input stream into the cache and iterate the opposite cache to - * output records meeting the join conditions. If there is no timer set for the OPPOSITE + * Put a row from the input stream into the cache and iterate the opposite cache to + * output join results meeting the conditions. If there is no timer set for the OPPOSITE * STREAM, register one. */ private def processElement( - cRowValue: CRow, - timeForRecord: Long, - ctx: CoProcessFunction[CRow, CRow, CRow]#Context, - out: Collector[CRow], - myWatermark: Long, - oppositeWatermark: Long, - oppositeTimeState: ValueState[Long], - recordListCache: MapState[Long, JList[Row]], - oppositeCache: MapState[Long, JList[Row]], - leftRecord: Boolean): Unit = { - if (relativeWindowSize > 0) { - //TODO Shall we consider adding a method for initialization with the context and collector? - cRowWrapper.out = out - - val record = cRowValue.row - - //TODO Only if the time of the record is greater than the watermark, can we continue. - if (timeForRecord >= myWatermark - allowedLateness) { - val oppositeLowerBound: Long = - if (leftRecord) timeForRecord - rightRelativeSize else timeForRecord - leftRelativeSize - - val oppositeUpperBound: Long = - if (leftRecord) timeForRecord + leftRelativeSize else timeForRecord + rightRelativeSize - - // Put the record into the cache for later use. - val recordList = if (recordListCache.contains(timeForRecord)) { - recordListCache.get(timeForRecord) - } else { - new util.ArrayList[Row]() - } - recordList.add(record) - recordListCache.put(timeForRecord, recordList) - - // Register a timer on THE OTHER STREAM to remove records from the cache once they are - // expired. - if (oppositeTimeState.value == 0) { - registerCleanUpTimer( - ctx, timeForRecord, oppositeWatermark, oppositeTimeState, leftRecord, true) - } + cRowValue: CRow, + timeForRow: Long, + ctx: CoProcessFunction[CRow, CRow, CRow]#Context, + out: Collector[CRow], + myWatermark: Long, + oppositeLowerBound: Long, + oppositeUpperBound: Long, + oppositeWatermark: Long, + oppositeTimeState: ValueState[Long], + rowListCache: MapState[Long, JList[Row]], + oppositeCache: MapState[Long, JList[Row]], + leftRow: Boolean): Unit = { + cRowWrapper.out = out + val row = cRowValue.row + if (!checkRowOutOfDate(timeForRow, myWatermark)) { + // Put the row into the cache for later use. + var rowList = rowListCache.get(timeForRow) + if (null == rowList) { + rowList = new ArrayList[Row](1) + } + rowList.add(row) + rowListCache.put(timeForRow, rowList) + // Register a timer on THE OPPOSITE STREAM to remove rows from the cache once they are + // expired. + if (oppositeTimeState.value == 0) { + registerCleanUpTimer( + ctx, timeForRow, oppositeWatermark, oppositeTimeState, leftRow, firstTimer = true) + } - // Join the record with records from the opposite stream. - val oppositeIterator = oppositeCache.iterator() - var oppositeEntry: Entry[Long, util.List[Row]] = null - var oppositeTime: Long = 0L; - while (oppositeIterator.hasNext) { - oppositeEntry = oppositeIterator.next - oppositeTime = oppositeEntry.getKey - if (oppositeTime < oppositeLowerBound - allowedLateness) { - //TODO Considering the data out-of-order, we should not remove records here. - } else if (oppositeTime >= oppositeLowerBound && oppositeTime <= oppositeUpperBound) { - val oppositeRows = oppositeEntry.getValue - var i = 0 - if (leftRecord) { - while (i < oppositeRows.size) { - joinFunction.join(record, oppositeRows.get(i), cRowWrapper) - i += 1 - } - } else { - while (i < oppositeRows.size) { - joinFunction.join(oppositeRows.get(i), record, cRowWrapper) - i += 1 - } + // Join the row with rows from the opposite stream. + val oppositeIterator = oppositeCache.iterator() + while (oppositeIterator.hasNext) { + val oppositeEntry = oppositeIterator.next + val oppositeTime = oppositeEntry.getKey + if (oppositeTime >= oppositeLowerBound && oppositeTime <= oppositeUpperBound) { + val oppositeRows = oppositeEntry.getValue + var i = 0 + if (leftRow) { + while (i < oppositeRows.size) { + joinFunction.join(row, oppositeRows.get(i), cRowWrapper) + i += 1 + } + } else { + while (i < oppositeRows.size) { + joinFunction.join(oppositeRows.get(i), row, cRowWrapper) + i += 1 } - } else if (oppositeTime > oppositeUpperBound) { - //TODO If the keys are ordered, can we break here? } } - } else { - //TODO Need some extra logic here? - LOG.warn(s"$record is out-of-date.") + // We could do the short-cutting optimization here once we get a state with ordered keys. } } + // We need to deal with the late data in the future. } /** - * Register a timer for cleaning up records in a specified time. + * Register a timer for cleaning up rows in a specified time. * * @param ctx the context to register timer - * @param timeForRecord time for the input record + * @param rowTime time for the input row * @param oppositeWatermark watermark of the opposite stream * @param timerState stores the timestamp for the next timer - * @param leftRecord record from the left or the right stream + * @param leftRow whether this row comes from the left stream * @param firstTimer whether this is the first timer */ private def registerCleanUpTimer( - ctx: CoProcessFunction[CRow, CRow, CRow]#Context, - timeForRecord: Long, - oppositeWatermark: Long, - timerState: ValueState[Long], - leftRecord: Boolean, - firstTimer: Boolean): Unit = { - val cleanUpTime = timeForRecord + (if (leftRecord) leftRelativeSize else rightRelativeSize) + - allowedLateness + 1 - registerTimer(ctx, !leftRecord, cleanUpTime) - LOG.debug(s"Register a clean up timer on the ${if (leftRecord) "RIGHT" else "LEFT"} state:" - + s" timeForRecord = ${timeForRecord}, cleanUpTime = ${cleanUpTime}, oppositeWatermark = " + - s"${oppositeWatermark}") - timerState.update(cleanUpTime) - if (cleanUpTime <= oppositeWatermark + allowedLateness && firstTimer) { - backPressureSuggestion = - if (leftRecord) (oppositeWatermark + allowedLateness - cleanUpTime) - else -(oppositeWatermark + allowedLateness - cleanUpTime) - LOG.warn("The clean timer for the " + - s"${if (leftRecord) "left" else "right"}" + - s" stream is lower than ${if (leftRecord) "right" else "left"} watermark." + - s" requiredTime = ${formatTime(cleanUpTime)}, watermark = ${formatTime(oppositeWatermark)}," - + s"backPressureSuggestion = " + s"${backPressureSuggestion}.") + ctx: CoProcessFunction[CRow, CRow, CRow]#Context, + rowTime: Long, + oppositeWatermark: Long, + timerState: ValueState[Long], + leftRow: Boolean, + firstTimer: Boolean): Unit = { + val cleanupTime = if (leftRow) { + rowTime + leftRelativeSize + cleanupDelay + allowedLateness + 1 + } else { + rowTime + rightRelativeSize + cleanupDelay + allowedLateness + 1 } + registerTimer(ctx, !leftRow, cleanupTime) + LOG.debug(s"Register a clean up timer on the ${if (leftRow) "RIGHT" else "LEFT"} state:" + + s" timeForRow = ${rowTime}, cleanupTime should be ${cleanupTime - cleanupDelay}," + + s" but delayed to ${cleanupTime}," + + s" oppositeWatermark = ${oppositeWatermark}") + timerState.update(cleanupTime) + //if cleanupTime <= oppositeWatermark + allowedLateness && firstTimer, we may set the + // backPressureSuggestion = + // if (leftRow) (oppositeWatermark + allowedLateness - cleanupTime) + // else -(oppositeWatermark + allowedLateness - cleanupTime) } - /** * Called when a registered timer is fired. - * Remove records which are earlier than the expiration time, - * and register a new timer for the earliest remaining records. + * Remove rows whose timestamps are earlier than the expiration time, + * and register a new timer for the remaining rows. * * @param timestamp the timestamp of the timer * @param ctx the context to register timer or get current time * @param out the collector for returning result values */ override def onTimer( - timestamp: Long, - ctx: CoProcessFunction[CRow, CRow, CRow]#OnTimerContext, - out: Collector[CRow]): Unit = { - getCurrentOperatorTime(ctx) - //TODO In the future, we should separate the left and right watermarks. Otherwise, the - //TODO registered timer of the faster stream will be delayed, even if the watermarks have - //TODO already been emitted by the source. + timestamp: Long, + ctx: CoProcessFunction[CRow, CRow, CRow]#OnTimerContext, + out: Collector[CRow]): Unit = { + updateOperatorTime(ctx) + // In the future, we should separate the left and right watermarks. Otherwise, the + // registered timer of the faster stream will be delayed, even if the watermarks have + // already been emitted by the source. if (leftTimerState.value == timestamp) { val rightExpirationTime = leftOperatorTime - rightRelativeSize - allowedLateness - 1 - removeExpiredRecords( - timestamp, + removeExpiredRows( rightExpirationTime, leftOperatorTime, rightCache, leftTimerState, ctx, - false + removeLeft = false ) } if (rightTimerState.value == timestamp) { val leftExpirationTime = rightOperatorTime - leftRelativeSize - allowedLateness - 1 - removeExpiredRecords( - timestamp, + removeExpiredRows( leftExpirationTime, rightOperatorTime, leftCache, rightTimerState, ctx, - true + removeLeft = true ) } } /** - * Remove the expired records. Register a new timer if the cache still holds records + * Remove the expired rows. Register a new timer if the cache still holds valid rows * after the cleaning up. + * + * @param expirationTime the expiration time for this cache + * @param oppositeWatermark the watermark of the opposite stream + * @param rowCache the row cache + * @param timerState timer state for the opposite stream + * @param ctx the context to register the cleanup timer + * @param removeLeft whether to remove the left rows */ - private def removeExpiredRecords( - timerFiringTime: Long, - expirationTime: Long, - oppositeWatermark: Long, - recordCache: MapState[Long, JList[Row]], - timerState: ValueState[Long], - ctx: CoProcessFunction[CRow, CRow, CRow]#OnTimerContext, - removeLeft: Boolean): Unit = { + private def removeExpiredRows( + expirationTime: Long, + oppositeWatermark: Long, + rowCache: MapState[Long, JList[Row]], + timerState: ValueState[Long], + ctx: CoProcessFunction[CRow, CRow, CRow]#OnTimerContext, + removeLeft: Boolean): Unit = { - val keysIterator = recordCache.keys().iterator() + val keysIterator = rowCache.keys().iterator() // Search for expired timestamps. // If we find a non-expired timestamp, remember the timestamp and leave the loop. // This way we find all expired timestamps if they are sorted without doing a full pass. var earliestTimestamp: Long = -1L - var recordTime: Long = 0L + var rowTime: Long = 0L while (keysIterator.hasNext) { - //TODO The "short-circuit" code was commented, because when using a StateMap with - //TODO unordered keys, the cache will grow indefinitely! - // && earliestTimestamp < 0) { - recordTime = keysIterator.next - if (recordTime <= expirationTime) { - // TODO Not sure if we can remove records directly. + rowTime = keysIterator.next + if (rowTime <= expirationTime) { keysIterator.remove() } else { // We find the earliest timestamp that is still valid. - if (recordTime < earliestTimestamp || earliestTimestamp < 0) { - earliestTimestamp = recordTime + if (rowTime < earliestTimestamp || earliestTimestamp < 0) { + earliestTimestamp = rowTime } } } // If the cache contains non-expired timestamps, register a new timer. // Otherwise clear the states. if (earliestTimestamp > 0) { - registerCleanUpTimer(ctx, earliestTimestamp, oppositeWatermark, timerState, removeLeft, false) + registerCleanUpTimer( + ctx, + earliestTimestamp, + oppositeWatermark, + timerState, + removeLeft, + firstTimer = false) } else { // The timerState will be 0. timerState.clear() - recordCache.clear() + rowCache.clear() } } /** - * Get the operator times of the two streams. + * Check if the row is out of date. + * + * @param timeForRow time of the row + * @param watermark watermark for the stream + * @return true if the row is out of date; false otherwise + */ + def checkRowOutOfDate(timeForRow: Long, watermark: Long): Boolean + + /** + * Update the operator time of the two streams. * * @param ctx the context to acquire watermarks */ - protected def getCurrentOperatorTime( - ctx: CoProcessFunction[CRow, CRow, CRow]#Context): Unit = { - timeIndicator match { - case JoinTimeIndicator.ROWTIME => { - rightOperatorTime = - if (ctx.timerService().currentWatermark() > 0) ctx.timerService().currentWatermark() - else 0L; - leftOperatorTime = - if (ctx.timerService().currentWatermark() > 0) ctx.timerService().currentWatermark() - else 0L; - } - case JoinTimeIndicator.PROCTIME => { - rightOperatorTime = ctx.timerService().currentProcessingTime() - leftOperatorTime = ctx.timerService().currentProcessingTime() - } - } - } + def updateOperatorTime(ctx: CoProcessFunction[CRow, CRow, CRow]#Context): Unit + /** + * Return the time for the target row from the left stream. + * + * @param context the runtime context + * @param row the target row + * @return time for the target row + */ + def getTimeForLeftStream(context: CoProcessFunction[CRow, CRow, CRow]#Context, row: CRow): Long /** - * Return the rowtime or proctime for the target record. + * Return the time for the target row from the right stream. * * @param context the runtime context - * @param record the target record - * @param isLeft whether the record is from the left stream - * @return time for the target record + * @param row the target row + * @return time for the target row */ - protected def getTimeForRecord( - context: CoProcessFunction[CRow, CRow, CRow]#Context, - record: CRow, - isLeft: Boolean): Long = { - timeIndicator match { - case JoinTimeIndicator.ROWTIME => { - return if (isLeft) { - record.row.getField(leftTimeIdx).asInstanceOf[Long] - } else { - record.row.getField(rightTimeIdx).asInstanceOf[Long]; - } - } - case JoinTimeIndicator.PROCTIME => { - return context.timerService().currentProcessingTime(); - } - } - } + def getTimeForRightStream(context: CoProcessFunction[CRow, CRow, CRow]#Context, row: CRow): Long /** * Register a proctime or rowtime timer. @@ -473,58 +426,14 @@ class TimeBoundedStreamInnerJoin( * @param isLeft whether this timer should be registered on the left stream * @param cleanupTime timestamp for the timer */ - protected def registerTimer( - ctx: CoProcessFunction[CRow, CRow, CRow]#Context, isLeft: Boolean, cleanupTime: Long): Unit = { - // Maybe we can register timers for different streams in the future. - timeIndicator match { - case JoinTimeIndicator.ROWTIME => { - ctx.timerService.registerEventTimeTimer(cleanupTime) - } - case JoinTimeIndicator.PROCTIME => { - ctx.timerService.registerProcessingTimeTimer(cleanupTime) - } - } - } - - //********* Functions for temporary test use. *****************// - - def formatTime(time: Long): String = { - if (0 == time) { - return "null" - } - val f: SimpleDateFormat = new SimpleDateFormat("HH:mm:ss SSS") - f.format(new Date(time)) - } - - def printCacheSize = { - var leftSize = 0; - var rightSize = 0; - var iterator = leftCache.iterator(); - while (iterator.hasNext) { - leftSize = leftSize + iterator.next().getValue.size() - } - iterator = rightCache.iterator(); - while (iterator.hasNext) { - rightSize = rightSize + iterator.next().getValue.size() - } - - println(s"leftSize = $leftSize, rightSize = $rightSize") - } - - override def toString = s"RowTimeWindowInnerJoin(" + - s"leftTimerState=${formatTime(leftTimerState.value())}, " + - s"rightTimerState=${formatTime(rightTimerState.value())}, " + - s"leftRelativeSize=$leftRelativeSize, " + - s"rightRelativeSize=$rightRelativeSize, relativeWindowSize=$relativeWindowSize, " + - s"leftOperatorTime=${formatTime(leftOperatorTime)}," + - s" rightOperatorTime=${formatTime(rightOperatorTime)})" - + def registerTimer( + ctx: CoProcessFunction[CRow, CRow, CRow]#Context, + isLeft: Boolean, + cleanupTime: Long): Unit } -//********* Will be removed before committing. ************// - /** - * TODO Not sure if that can be replaced by [[org.apache.flink.streaming.api.TimeCharacteristic]] + * Defines the rowtime and proctime join indicators. */ object JoinTimeIndicator extends Enumeration { type JoinTimeIndicator = Value diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/WindowJoinUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/WindowJoinUtil.scala index 05a5e26f1d564..6f97f2a662f89 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/WindowJoinUtil.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/WindowJoinUtil.scala @@ -39,16 +39,22 @@ import scala.collection.JavaConverters._ */ object WindowJoinUtil { - case class WindowBounds(isEventTime: Boolean, leftLowerBound: Long, leftUpperBound: Long, - leftTimeIdx: Int, rightTimeIdx: Int) + case class WindowBounds( + isEventTime: Boolean, + leftLowerBound: Long, + leftUpperBound: Long, + leftTimeIdx: Int, + rightTimeIdx: Int) protected case class WindowBound(bound: Long, isLeftLower: Boolean) + protected case class TimePredicate( isEventTime: Boolean, leftInputOnLeftSide: Boolean, leftTimeIdx: Int, rightTimeIdx: Int, pred: RexCall) + protected case class TimeAttributeAccess(isEventTime: Boolean, isLeftInput: Boolean, idx: Int) /** @@ -118,15 +124,24 @@ object WindowJoinUtil { case _ => Some(otherPreds.reduceLeft((l, r) => RelOptUtil.andJoinFilters(rexBuilder, l, r))) } - if (timePreds.head.leftInputOnLeftSide) { - val bounds = Some(WindowBounds(timePreds.head.isEventTime, leftLowerBound, leftUpperBound, - timePreds.head.leftTimeIdx, timePreds.last.rightTimeIdx - leftLogicalFieldCnt)) - (bounds, remainCondition) + + val bounds = if (timePreds.head.leftInputOnLeftSide) { + Some(WindowBounds( + timePreds.head.isEventTime, + leftLowerBound, + leftUpperBound, + timePreds.head.leftTimeIdx, + timePreds.head.rightTimeIdx)) } else { - val bounds = Some(WindowBounds(timePreds.head.isEventTime, leftLowerBound, leftUpperBound, - timePreds.last.rightTimeIdx, timePreds.head.leftTimeIdx - leftLogicalFieldCnt)) - (bounds, remainCondition) + Some(WindowBounds( + timePreds.head.isEventTime, + leftLowerBound, + leftUpperBound, + timePreds.head.rightTimeIdx, + timePreds.head.leftTimeIdx)) } + + (bounds, remainCondition) } /** @@ -232,8 +247,11 @@ object WindowJoinUtil { inputType.getFieldList.get(idx).getType match { case t: TimeIndicatorRelDataType => // time attribute access. Remember time type and side of input - val isLeftInput = idx < leftFieldCount - Seq(TimeAttributeAccess(t.isEventTime, isLeftInput, idx)) + if (idx < leftFieldCount) { + Seq(TimeAttributeAccess(t.isEventTime, true, idx)) + } else { + Seq(TimeAttributeAccess(t.isEventTime, false, idx - leftFieldCount)) + } case _ => // not a time attribute access. Seq() diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/JoinTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/JoinTest.scala index 640fd266bd09f..be72d2e822ce5 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/JoinTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/JoinTest.scala @@ -69,6 +69,44 @@ class JoinTest extends TableTestBase { streamUtil.verifySql(sqlQuery, expected) } + @Test + def testRowTimeInnerJoinWithOnClause() = { + + val sqlQuery = + """ + |SELECT t1.a, t2.b + |FROM MyTable t1 JOIN MyTable2 t2 ON + | t1.a = t2.a AND + | t1.c BETWEEN t2.c - INTERVAL '10' SECOND AND t2.c + INTERVAL '1' HOUR + |""".stripMargin + + val expected = + unaryNode( + "DataStreamCalc", + binaryNode( + "DataStreamWindowJoin", + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "a", "c") + ), + unaryNode( + "DataStreamCalc", + streamTableNode(1), + term("select", "a", "b", "c") + ), + term("where", + "AND(=(a, a0), >=(c, -(c0, 10000)), " + + "<=(c, DATETIME_PLUS(c0, 3600000)))"), + term("join", "a, c, a0, b, c0"), + term("joinType", "InnerJoin") + ), + term("select", "a", "b") + ) + + streamUtil.verifySql(sqlQuery, expected) + } + @Test def testProcessingTimeInnerJoinWithWhereClause() = { @@ -107,6 +145,44 @@ class JoinTest extends TableTestBase { streamUtil.verifySql(sqlQuery, expected) } + @Test + def testRowTimeInnerJoinWithWhereClause() = { + + val sqlQuery = + """ + |SELECT t1.a, t2.b + |FROM MyTable t1, MyTable2 t2 + |WHERE t1.a = t2.a AND + | t1.c BETWEEN t2.c - INTERVAL '10' MINUTE AND t2.c + INTERVAL '1' HOUR + |""".stripMargin + + val expected = + unaryNode( + "DataStreamCalc", + binaryNode( + "DataStreamWindowJoin", + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "a", "c") + ), + unaryNode( + "DataStreamCalc", + streamTableNode(1), + term("select", "a", "b", "c") + ), + term("where", + "AND(=(a, a0), >=(c, -(c0, 600000)), " + + "<=(c, DATETIME_PLUS(c0, 3600000)))"), + term("join", "a, c, a0, b, c0"), + term("joinType", "InnerJoin") + ), + term("select", "a", "b0 AS b") + ) + + streamUtil.verifySql(sqlQuery, expected) + } + @Test def testJoinTimeBoundary(): Unit = { verifyTimeBoundary( diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/JoinHarnessTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/JoinHarnessTest.scala index 74c04edf3f88f..0b18f95c25e3a 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/JoinHarnessTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/JoinHarnessTest.scala @@ -18,16 +18,17 @@ package org.apache.flink.table.runtime.harness import java.util.concurrent.ConcurrentLinkedQueue -import java.lang.{Integer => JInt} +import java.lang.{Long => JLong} import org.apache.flink.api.common.typeinfo.BasicTypeInfo._ import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} import org.apache.flink.api.java.typeutils.RowTypeInfo import org.apache.flink.streaming.api.operators.co.KeyedCoProcessOperator +import org.apache.flink.streaming.api.watermark.Watermark import org.apache.flink.streaming.runtime.streamrecord.StreamRecord import org.apache.flink.streaming.util.KeyedTwoInputStreamOperatorTestHarness import org.apache.flink.table.runtime.harness.HarnessTestBase.{RowResultSortComparator, TupleRowKeySelector} -import org.apache.flink.table.runtime.join.{JoinTimeIndicator, ProcTimeWindowInnerJoin, TimeBoundedStreamInnerJoin} +import org.apache.flink.table.runtime.join.{ProcTimeBoundedStreamInnerJoin, ProcTimeWindowInnerJoin, RowTimeBoundedStreamInnerJoin} import org.apache.flink.table.runtime.types.CRow import org.apache.flink.types.Row import org.junit.Test @@ -36,7 +37,7 @@ import org.junit.Assert.{assertEquals, assertTrue} class JoinHarnessTest extends HarnessTestBase{ private val rT = new RowTypeInfo(Array[TypeInformation[_]]( - INT_TYPE_INFO, + LONG_TYPE_INFO, STRING_TYPE_INFO), Array("a", "b")) @@ -94,24 +95,24 @@ class JoinHarnessTest extends HarnessTestBase{ // left stream input testHarness.setProcessingTime(1) testHarness.processElement1(new StreamRecord( - CRow(Row.of(1: JInt, "aaa"), true), 1)) + CRow(Row.of(1L: JLong, "aaa"), true), 1)) assertEquals(1, testHarness.numProcessingTimeTimers()) testHarness.setProcessingTime(2) testHarness.processElement1(new StreamRecord( - CRow(Row.of(2: JInt, "bbb"), true), 2)) + CRow(Row.of(2L: JLong, "bbb"), true), 2)) assertEquals(2, testHarness.numProcessingTimeTimers()) testHarness.setProcessingTime(3) testHarness.processElement1(new StreamRecord( - CRow(Row.of(1: JInt, "aaa2"), true), 3)) + CRow(Row.of(1L: JLong, "aaa2"), true), 3)) assertEquals(4, testHarness.numKeyedStateEntries()) assertEquals(2, testHarness.numProcessingTimeTimers()) // right stream input and output normally testHarness.processElement2(new StreamRecord( - CRow(Row.of(1: JInt, "Hi1"), true), 3)) + CRow(Row.of(1L: JLong, "Hi1"), true), 3)) testHarness.setProcessingTime(4) testHarness.processElement2(new StreamRecord( - CRow(Row.of(2: JInt, "Hello1"), true), 4)) + CRow(Row.of(2L: JLong, "Hello1"), true), 4)) assertEquals(8, testHarness.numKeyedStateEntries()) assertEquals(4, testHarness.numProcessingTimeTimers()) @@ -120,18 +121,18 @@ class JoinHarnessTest extends HarnessTestBase{ assertEquals(8, testHarness.numKeyedStateEntries()) assertEquals(4, testHarness.numProcessingTimeTimers()) testHarness.processElement2(new StreamRecord( - CRow(Row.of(1: JInt, "Hi2"), true), 12)) + CRow(Row.of(1L: JLong, "Hi2"), true), 12)) // expired right stream record at timestamp 4 and all left stream testHarness.setProcessingTime(25) assertEquals(2, testHarness.numKeyedStateEntries()) assertEquals(1, testHarness.numProcessingTimeTimers()) testHarness.processElement1(new StreamRecord( - CRow(Row.of(1: JInt, "aaa3"), true), 25)) + CRow(Row.of(1L: JLong, "aaa3"), true), 25)) testHarness.processElement1(new StreamRecord( - CRow(Row.of(2: JInt, "bbb2"), true), 25)) + CRow(Row.of(2L: JLong, "bbb2"), true), 25)) testHarness.processElement2(new StreamRecord( - CRow(Row.of(2: JInt, "Hello2"), true), 25)) + CRow(Row.of(2L: JLong, "Hello2"), true), 25)) testHarness.setProcessingTime(45) assertTrue(testHarness.numKeyedStateEntries() > 0) @@ -142,17 +143,17 @@ class JoinHarnessTest extends HarnessTestBase{ val expectedOutput = new ConcurrentLinkedQueue[Object]() expectedOutput.add(new StreamRecord( - CRow(Row.of(1: JInt, "aaa", 1: JInt, "Hi1"), true), 3)) + CRow(Row.of(1L: JLong, "aaa", 1L: JLong, "Hi1"), true), 3)) expectedOutput.add(new StreamRecord( - CRow(Row.of(1: JInt, "aaa2", 1: JInt, "Hi1"), true), 3)) + CRow(Row.of(1L: JLong, "aaa2", 1L: JLong, "Hi1"), true), 3)) expectedOutput.add(new StreamRecord( - CRow(Row.of(2: JInt, "bbb", 2: JInt, "Hello1"), true), 4)) + CRow(Row.of(2L: JLong, "bbb", 2L: JLong, "Hello1"), true), 4)) expectedOutput.add(new StreamRecord( - CRow(Row.of(1: JInt, "aaa2", 1: JInt, "Hi2"), true), 12)) + CRow(Row.of(1L: JLong, "aaa2", 1L: JLong, "Hi2"), true), 12)) expectedOutput.add(new StreamRecord( - CRow(Row.of(1: JInt, "aaa3", 1: JInt, "Hi2"), true), 25)) + CRow(Row.of(1L: JLong, "aaa3", 1L: JLong, "Hi2"), true), 25)) expectedOutput.add(new StreamRecord( - CRow(Row.of(2: JInt, "bbb2", 2: JInt, "Hello2"), true), 25)) + CRow(Row.of(2L: JLong, "bbb2", 2L: JLong, "Hello2"), true), 25)) verify(expectedOutput, result, new RowResultSortComparator()) @@ -179,27 +180,27 @@ class JoinHarnessTest extends HarnessTestBase{ testHarness.setProcessingTime(1) testHarness.processElement1(new StreamRecord( - CRow(Row.of(1: JInt, "aaa1"), true), 1)) + CRow(Row.of(1L: JLong, "aaa1"), true), 1)) testHarness.setProcessingTime(2) testHarness.processElement1(new StreamRecord( - CRow(Row.of(2: JInt, "aaa2"), true), 2)) + CRow(Row.of(2L: JLong, "aaa2"), true), 2)) testHarness.setProcessingTime(3) testHarness.processElement1(new StreamRecord( - CRow(Row.of(1: JInt, "aaa3"), true), 3)) + CRow(Row.of(1L: JLong, "aaa3"), true), 3)) assertEquals(4, testHarness.numKeyedStateEntries()) assertEquals(2, testHarness.numProcessingTimeTimers()) // Do not store b elements // not meet a.proctime <= b.proctime - 5 testHarness.processElement2(new StreamRecord( - CRow(Row.of(1: JInt, "bbb3"), true), 3)) + CRow(Row.of(1L: JLong, "bbb3"), true), 3)) assertEquals(4, testHarness.numKeyedStateEntries()) assertEquals(2, testHarness.numProcessingTimeTimers()) // meet a.proctime <= b.proctime - 5 testHarness.setProcessingTime(7) testHarness.processElement2(new StreamRecord( - CRow(Row.of(2: JInt, "bbb7"), true), 7)) + CRow(Row.of(2L: JLong, "bbb7"), true), 7)) assertEquals(4, testHarness.numKeyedStateEntries()) assertEquals(2, testHarness.numProcessingTimeTimers()) @@ -208,7 +209,7 @@ class JoinHarnessTest extends HarnessTestBase{ assertEquals(4, testHarness.numKeyedStateEntries()) assertEquals(2, testHarness.numProcessingTimeTimers()) testHarness.processElement2(new StreamRecord( - CRow(Row.of(1: JInt, "bbb12"), true), 12)) + CRow(Row.of(1L: JLong, "bbb12"), true), 12)) testHarness.setProcessingTime(13) assertEquals(2, testHarness.numKeyedStateEntries()) @@ -223,9 +224,9 @@ class JoinHarnessTest extends HarnessTestBase{ val expectedOutput = new ConcurrentLinkedQueue[Object]() expectedOutput.add(new StreamRecord( - CRow(Row.of(2: JInt, "aaa2", 2: JInt, "bbb7"), true), 7)) + CRow(Row.of(2L: JLong, "aaa2", 2L: JLong, "bbb7"), true), 7)) expectedOutput.add(new StreamRecord( - CRow(Row.of(1: JInt, "aaa3", 1: JInt, "bbb12"), true), 12)) + CRow(Row.of(1L: JLong, "aaa3", 1L: JLong, "bbb12"), true), 12)) verify(expectedOutput, result, new RowResultSortComparator()) @@ -236,8 +237,8 @@ class JoinHarnessTest extends HarnessTestBase{ @Test def testNewNormalProcTimeJoin() { - val joinProcessFunc = new TimeBoundedStreamInnerJoin( - -10, 20, 0, rT, rT, "TestJoinFunction", funcCode, -1, -1, JoinTimeIndicator.PROCTIME) + val joinProcessFunc = new ProcTimeBoundedStreamInnerJoin( + -10, 20, 0, rT, rT, "TestJoinFunction", funcCode) val operator: KeyedCoProcessOperator[Integer, CRow, CRow, CRow] = new KeyedCoProcessOperator[Integer, CRow, CRow, CRow](joinProcessFunc) @@ -254,24 +255,24 @@ class JoinHarnessTest extends HarnessTestBase{ // left stream input testHarness.setProcessingTime(1) testHarness.processElement1(new StreamRecord( - CRow(Row.of(1: JInt, "aaa"), true), 1)) + CRow(Row.of(1L: JLong, "aaa"), true), 1)) assertEquals(1, testHarness.numProcessingTimeTimers()) testHarness.setProcessingTime(2) testHarness.processElement1(new StreamRecord( - CRow(Row.of(2: JInt, "bbb"), true), 2)) + CRow(Row.of(2L: JLong, "bbb"), true), 2)) assertEquals(2, testHarness.numProcessingTimeTimers()) testHarness.setProcessingTime(3) testHarness.processElement1(new StreamRecord( - CRow(Row.of(1: JInt, "aaa2"), true), 3)) + CRow(Row.of(1L: JLong, "aaa2"), true), 3)) assertEquals(4, testHarness.numKeyedStateEntries()) assertEquals(2, testHarness.numProcessingTimeTimers()) // right stream input and output normally testHarness.processElement2(new StreamRecord( - CRow(Row.of(1: JInt, "Hi1"), true), 3)) + CRow(Row.of(1L: JLong, "Hi1"), true), 3)) testHarness.setProcessingTime(4) testHarness.processElement2(new StreamRecord( - CRow(Row.of(2: JInt, "Hello1"), true), 4)) + CRow(Row.of(2L: JLong, "Hello1"), true), 4)) assertEquals(8, testHarness.numKeyedStateEntries()) assertEquals(4, testHarness.numProcessingTimeTimers()) @@ -280,39 +281,39 @@ class JoinHarnessTest extends HarnessTestBase{ assertEquals(8, testHarness.numKeyedStateEntries()) assertEquals(4, testHarness.numProcessingTimeTimers()) testHarness.processElement2(new StreamRecord( - CRow(Row.of(1: JInt, "Hi2"), true), 12)) + CRow(Row.of(1L: JLong, "Hi2"), true), 12)) // expired right stream record at timestamp 4 and all left stream testHarness.setProcessingTime(25) - assertEquals(2, testHarness.numKeyedStateEntries()) - assertEquals(1, testHarness.numProcessingTimeTimers()) + assertEquals(8, testHarness.numKeyedStateEntries()) + assertEquals(4, testHarness.numProcessingTimeTimers()) testHarness.processElement1(new StreamRecord( - CRow(Row.of(1: JInt, "aaa3"), true), 25)) + CRow(Row.of(1L: JLong, "aaa3"), true), 25)) testHarness.processElement1(new StreamRecord( - CRow(Row.of(2: JInt, "bbb2"), true), 25)) + CRow(Row.of(2L: JLong, "bbb2"), true), 25)) testHarness.processElement2(new StreamRecord( - CRow(Row.of(2: JInt, "Hello2"), true), 25)) + CRow(Row.of(2L: JLong, "Hello2"), true), 25)) testHarness.setProcessingTime(45) assertTrue(testHarness.numKeyedStateEntries() > 0) testHarness.setProcessingTime(46) - assertEquals(0, testHarness.numKeyedStateEntries()) + assertEquals(2, testHarness.numKeyedStateEntries()) val result = testHarness.getOutput val expectedOutput = new ConcurrentLinkedQueue[Object]() expectedOutput.add(new StreamRecord( - CRow(Row.of(1: JInt, "aaa", 1: JInt, "Hi1"), true), 3)) + CRow(Row.of(1L: JLong, "aaa", 1L: JLong, "Hi1"), true), 3)) expectedOutput.add(new StreamRecord( - CRow(Row.of(1: JInt, "aaa2", 1: JInt, "Hi1"), true), 3)) + CRow(Row.of(1L: JLong, "aaa2", 1L: JLong, "Hi1"), true), 3)) expectedOutput.add(new StreamRecord( - CRow(Row.of(2: JInt, "bbb", 2: JInt, "Hello1"), true), 4)) + CRow(Row.of(2L: JLong, "bbb", 2L: JLong, "Hello1"), true), 4)) expectedOutput.add(new StreamRecord( - CRow(Row.of(1: JInt, "aaa2", 1: JInt, "Hi2"), true), 12)) + CRow(Row.of(1L: JLong, "aaa2", 1L: JLong, "Hi2"), true), 12)) expectedOutput.add(new StreamRecord( - CRow(Row.of(1: JInt, "aaa3", 1: JInt, "Hi2"), true), 25)) + CRow(Row.of(1L: JLong, "aaa3", 1L: JLong, "Hi2"), true), 25)) expectedOutput.add(new StreamRecord( - CRow(Row.of(2: JInt, "bbb2", 2: JInt, "Hello2"), true), 25)) + CRow(Row.of(2L: JLong, "bbb2", 2L: JLong, "Hello2"), true), 25)) verify(expectedOutput, result, new RowResultSortComparator()) @@ -323,8 +324,8 @@ class JoinHarnessTest extends HarnessTestBase{ @Test def testNewProcTimeJoinSingleNeedStore() { - val joinProcessFunc = new TimeBoundedStreamInnerJoin( - -10, -5, 0, rT, rT, "TestJoinFunction", funcCode, -1, -1, JoinTimeIndicator.PROCTIME) + val joinProcessFunc = new ProcTimeBoundedStreamInnerJoin( + -10, -5, 0, rT, rT, "TestJoinFunction", funcCode) val operator: KeyedCoProcessOperator[Integer, CRow, CRow, CRow] = new KeyedCoProcessOperator[Integer, CRow, CRow, CRow](joinProcessFunc) @@ -340,26 +341,26 @@ class JoinHarnessTest extends HarnessTestBase{ testHarness.setProcessingTime(1) testHarness.processElement1(new StreamRecord( - CRow(Row.of(1: JInt, "aaa1"), true), 1)) + CRow(Row.of(1L: JLong, "aaa1"), true), 1)) testHarness.setProcessingTime(2) testHarness.processElement1(new StreamRecord( - CRow(Row.of(2: JInt, "aaa2"), true), 2)) + CRow(Row.of(2L: JLong, "aaa2"), true), 2)) testHarness.setProcessingTime(3) testHarness.processElement1(new StreamRecord( - CRow(Row.of(1: JInt, "aaa3"), true), 3)) + CRow(Row.of(1L: JLong, "aaa3"), true), 3)) assertEquals(4, testHarness.numKeyedStateEntries()) assertEquals(2, testHarness.numProcessingTimeTimers()) // all the records should be stored testHarness.processElement2(new StreamRecord( - CRow(Row.of(1: JInt, "bbb3"), true), 3)) + CRow(Row.of(1L: JLong, "bbb3"), true), 3)) assertEquals(6, testHarness.numKeyedStateEntries()) assertEquals(3, testHarness.numProcessingTimeTimers()) // meet a.proctime <= b.proctime - 5 testHarness.setProcessingTime(7) testHarness.processElement2(new StreamRecord( - CRow(Row.of(2: JInt, "bbb7"), true), 7)) + CRow(Row.of(2L: JLong, "bbb7"), true), 7)) assertEquals(6, testHarness.numKeyedStateEntries()) assertEquals(3, testHarness.numProcessingTimeTimers()) @@ -368,11 +369,11 @@ class JoinHarnessTest extends HarnessTestBase{ assertEquals(4, testHarness.numKeyedStateEntries()) assertEquals(2, testHarness.numProcessingTimeTimers()) testHarness.processElement2(new StreamRecord( - CRow(Row.of(1: JInt, "bbb12"), true), 12)) + CRow(Row.of(1L: JLong, "bbb12"), true), 12)) testHarness.setProcessingTime(13) - assertEquals(2, testHarness.numKeyedStateEntries()) - assertEquals(1, testHarness.numProcessingTimeTimers()) + assertEquals(4, testHarness.numKeyedStateEntries()) + assertEquals(2, testHarness.numProcessingTimeTimers()) // state must be cleaned after the window timer interval has passed without new rows. testHarness.setProcessingTime(23) @@ -383,13 +384,158 @@ class JoinHarnessTest extends HarnessTestBase{ val expectedOutput = new ConcurrentLinkedQueue[Object]() expectedOutput.add(new StreamRecord( - CRow(Row.of(2: JInt, "aaa2", 2: JInt, "bbb7"), true), 7)) + CRow(Row.of(2L: JLong, "aaa2", 2L: JLong, "bbb7"), true), 7)) expectedOutput.add(new StreamRecord( - CRow(Row.of(1: JInt, "aaa3", 1: JInt, "bbb12"), true), 12)) + CRow(Row.of(1L: JLong, "aaa3", 1L: JLong, "bbb12"), true), 12)) verify(expectedOutput, result, new RowResultSortComparator()) testHarness.close() } + /** a.c1 >= b.rowtime - 10 and a.rowtime <= b.rowtime + 20 **/ + @Test + def testCommonRowTimeJoin() { + + val joinProcessFunc = new RowTimeBoundedStreamInnerJoin( + -10, 20, 0, rT, rT, "TestJoinFunction", funcCode, 0, 0) + + val operator: KeyedCoProcessOperator[String, CRow, CRow, CRow] = + new KeyedCoProcessOperator[String, CRow, CRow, CRow](joinProcessFunc) + val testHarness: KeyedTwoInputStreamOperatorTestHarness[String, CRow, CRow, CRow] = + new KeyedTwoInputStreamOperatorTestHarness[String, CRow, CRow, CRow]( + operator, + new TupleRowKeySelector[String](1), + new TupleRowKeySelector[String](1), + BasicTypeInfo.STRING_TYPE_INFO, + 1, 1, 0) + + testHarness.open() + + // Advance + testHarness.processWatermark1(new Watermark(1)) + testHarness.processWatermark2(new Watermark(1)) + + // Test late data + testHarness.processElement1(new StreamRecord[CRow]( + CRow(Row.of(1L: JLong, "k1"), true), 0)) + + assertEquals(0, testHarness.numEventTimeTimers()) + + testHarness.processElement1(new StreamRecord[CRow]( + CRow(Row.of(2L: JLong, "k1"), true), 0)) + testHarness.processElement2(new StreamRecord[CRow]( + CRow(Row.of(2L: JLong, "k1"), true), 0)) + + assertEquals(2, testHarness.numEventTimeTimers()) + assertEquals(4, testHarness.numKeyedStateEntries()) + + testHarness.processElement1(new StreamRecord[CRow]( + CRow(Row.of(5L: JLong, "k1"), true), 0)) + testHarness.processElement2(new StreamRecord[CRow]( + CRow(Row.of(15L: JLong, "k1"), true), 0)) + + testHarness.processWatermark1(new Watermark(20)) + testHarness.processWatermark2(new Watermark(20)) + + assertEquals(4, testHarness.numKeyedStateEntries()) + + testHarness.processElement1(new StreamRecord[CRow]( + CRow(Row.of(35L: JLong, "k1"), true), 0)) + + testHarness.processWatermark1(new Watermark(38)) + testHarness.processWatermark2(new Watermark(38)) + + testHarness.processElement1(new StreamRecord[CRow]( + CRow(Row.of(40L: JLong, "k2"), true), 0)) + testHarness.processElement2(new StreamRecord[CRow]( + CRow(Row.of(39L: JLong, "k2"), true), 0)) + + assertEquals(6, testHarness.numKeyedStateEntries()) + + testHarness.processWatermark1(new Watermark(61)) + testHarness.processWatermark2(new Watermark(61)) + + assertEquals(4, testHarness.numKeyedStateEntries()) + + val expectedOutput = new ConcurrentLinkedQueue[Object]() + expectedOutput.add(new StreamRecord( + CRow(Row.of(2L: JLong, "k1", 2L: JLong, "k1"), true), 0)) + expectedOutput.add(new StreamRecord( + CRow(Row.of(5L: JLong, "k1", 2L: JLong, "k1"), true), 0)) + expectedOutput.add(new StreamRecord( + CRow(Row.of(5L: JLong, "k1", 15L: JLong, "k1"), true), 0)) + expectedOutput.add(new StreamRecord( + CRow(Row.of(35L: JLong, "k1", 15L: JLong, "k1"), true), 0)) + expectedOutput.add(new StreamRecord( + CRow(Row.of(40L: JLong, "k2", 39L: JLong, "k2"), true), 0)) + + val result = testHarness.getOutput + verify(expectedOutput, result, new RowResultSortComparator()) + testHarness.close() + } + + /** a.rowtime >= b.rowtime - 10 and a.rowtime <= b.rowtime - 7 **/ + @Test + def testNegativeRowTimeJoin() { + + val joinProcessFunc = new RowTimeBoundedStreamInnerJoin( + -10, -7, 0, rT, rT, "TestJoinFunction", funcCode, 0, 0) + + val operator: KeyedCoProcessOperator[String, CRow, CRow, CRow] = + new KeyedCoProcessOperator[String, CRow, CRow, CRow](joinProcessFunc) + val testHarness: KeyedTwoInputStreamOperatorTestHarness[String, CRow, CRow, CRow] = + new KeyedTwoInputStreamOperatorTestHarness[String, CRow, CRow, CRow]( + operator, + new TupleRowKeySelector[String](1), + new TupleRowKeySelector[String](1), + BasicTypeInfo.STRING_TYPE_INFO, + 1, 1, 0) + + testHarness.open() + + // Advance + testHarness.processWatermark1(new Watermark(1)) + testHarness.processWatermark2(new Watermark(1)) + + testHarness.processElement2(new StreamRecord[CRow]( + CRow(Row.of(2L: JLong, "k1"), true), 0)) + + testHarness.processWatermark1(new Watermark(2)) + testHarness.processWatermark2(new Watermark(2)) + + assertEquals(0, testHarness.numKeyedStateEntries()) + + testHarness.processElement1(new StreamRecord[CRow]( + CRow(Row.of(3L: JLong, "k1"), true), 0)) + testHarness.processElement2(new StreamRecord[CRow]( + CRow(Row.of(3L: JLong, "k1"), true), 0)) + + testHarness.processElement2(new StreamRecord[CRow]( + CRow(Row.of(13L: JLong, "k1"), true), 0)) + testHarness.processElement1(new StreamRecord[CRow]( + CRow(Row.of(6L: JLong, "k1"), true), 0)) + + assertEquals(4, testHarness.numKeyedStateEntries()) + + testHarness.processWatermark1(new Watermark(10)) + testHarness.processWatermark2(new Watermark(10)) + + assertEquals(2, testHarness.numKeyedStateEntries()) + + testHarness.processWatermark1(new Watermark(18)) + testHarness.processWatermark2(new Watermark(18)) + + assertEquals(0, testHarness.numKeyedStateEntries()) + + val expectedOutput = new ConcurrentLinkedQueue[Object]() + expectedOutput.add(new StreamRecord( + CRow(Row.of(3L: JLong, "k1", 13L: JLong, "k1"), true), 0)) + expectedOutput.add(new StreamRecord( + CRow(Row.of(6L: JLong, "k1", 13L: JLong, "k1"), true), 0)) + + val result = testHarness.getOutput + verify(expectedOutput, result, new RowResultSortComparator()) + testHarness.close() + } } From fd666b049a7e39be7b24ea1f14db3c487d89ffee Mon Sep 17 00:00:00 2001 From: Xingcan Cui Date: Sun, 24 Sep 2017 21:33:49 +0800 Subject: [PATCH 3/7] Refactor Major changes: 1. Remove the JoinTimeIndicator. 2. Fix the identical processing time problem. 3. Using SQL data types instead of BasicTypeInfo. 4. Remove the processElement() method. 5. Add eager removing policy for the cache. 6. Improve the tests. --- .../datastream/DataStreamWindowJoin.scala | 65 ++-- .../join/ProcTimeBoundedStreamInnerJoin.scala | 18 +- .../join/ProcTimeWindowInnerJoin.scala | 346 ------------------ .../join/RowTimeBoundedStreamInnerJoin.scala | 20 +- .../join/TimeBoundedStreamInnerJoin.scala | 298 +++++++-------- .../runtime/harness/JoinHarnessTest.scala | 332 +++++------------ .../table/runtime/stream/sql/JoinITCase.scala | 148 +++++++- 7 files changed, 414 insertions(+), 813 deletions(-) delete mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/ProcTimeWindowInnerJoin.scala diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamWindowJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamWindowJoin.scala index f81a84fecea54..c5fa9314f425b 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamWindowJoin.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamWindowJoin.scala @@ -34,8 +34,8 @@ import org.apache.flink.table.plan.util.UpdatingPlanChecker import org.apache.flink.table.runtime.join.{ProcTimeBoundedStreamInnerJoin, RowTimeBoundedStreamInnerJoin, WindowJoinUtil} import org.apache.flink.table.runtime.operators.KeyedCoProcessOperatorWithWatermarkDelay import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo} +import org.apache.flink.table.util.Logging import org.apache.flink.util.Collector -import org.slf4j.{Logger, LoggerFactory} /** * RelNode for a time windowed stream join. @@ -59,9 +59,8 @@ class DataStreamWindowJoin( ruleDescription: String) extends BiRel(cluster, traitSet, leftNode, rightNode) with CommonJoin - with DataStreamRel { - - lazy val LOG: Logger = LoggerFactory.getLogger(getClass) + with DataStreamRel + with Logging { override def deriveRowType(): RelDataType = schema.relDataType @@ -118,16 +117,13 @@ class DataStreamWindowJoin( val leftDataStream = left.asInstanceOf[DataStreamRel].translateToPlan(tableEnv, queryConfig) val rightDataStream = right.asInstanceOf[DataStreamRel].translateToPlan(tableEnv, queryConfig) - // get the equality keys and other condition + // get the equi-keys and other conditions val joinInfo = JoinInfo.of(leftNode, rightNode, joinCondition) val leftKeys = joinInfo.leftKeys.toIntArray val rightKeys = joinInfo.rightKeys.toIntArray - val relativeWindowSize = leftUpperBound - leftLowerBound - val returnTypeInfo = CRowTypeInfo(schema.typeInfo) - // generate join function val joinFunction = WindowJoinUtil.generateJoinFunction( @@ -144,11 +140,10 @@ class DataStreamWindowJoin( if (relativeWindowSize < 0) { LOG.warn(s"The relative window size ${relativeWindowSize} is negative," + " please check the join conditions.") - createEmptyInnerJoinFunction(leftDataStream, rightDataStream, returnTypeInfo) + createEmptyInnerJoin(leftDataStream, rightDataStream, returnTypeInfo) } else { if (isRowTime) { - // RowTime JoinCoProcessFunction - createRowTimeInnerJoinFunction( + createRowTimeInnerJoin( leftDataStream, rightDataStream, returnTypeInfo, @@ -158,8 +153,7 @@ class DataStreamWindowJoin( rightKeys ) } else { - // Proctime JoinCoProcessFunction - createProcTimeInnerJoinFunction( + createProcTimeInnerJoin( leftDataStream, rightDataStream, returnTypeInfo, @@ -182,7 +176,28 @@ class DataStreamWindowJoin( } } - def createProcTimeInnerJoinFunction( + def createEmptyInnerJoin( + leftDataStream: DataStream[CRow], + rightDataStream: DataStream[CRow], + returnTypeInfo: TypeInformation[CRow]) = { + leftDataStream.connect(rightDataStream).process( + new CoProcessFunction[CRow, CRow, CRow] { + override def processElement1( + value: CRow, + ctx: CoProcessFunction[CRow, CRow, CRow]#Context, + out: Collector[CRow]) = { + //Do nothing. + } + override def processElement2( + value: CRow, + ctx: CoProcessFunction[CRow, CRow, CRow]#Context, + out: Collector[CRow]) = { + //Do nothing. + } + }) + } + + def createProcTimeInnerJoin( leftDataStream: DataStream[CRow], rightDataStream: DataStream[CRow], returnTypeInfo: TypeInformation[CRow], @@ -215,27 +230,7 @@ class DataStreamWindowJoin( } } - def createEmptyInnerJoinFunction( - leftDataStream: DataStream[CRow], - rightDataStream: DataStream[CRow], - returnTypeInfo: TypeInformation[CRow]) = { - leftDataStream.connect(rightDataStream).process( - new CoProcessFunction[CRow, CRow, CRow] { - override def processElement1( - value: CRow, - ctx: CoProcessFunction[CRow, CRow, CRow]#Context, - out: Collector[CRow]) = { - //Do nothing. - } - override def processElement2( - value: CRow, - ctx: CoProcessFunction[CRow, CRow, CRow]#Context, - out: Collector[CRow]) = { - //Do nothing. - } - }) - } - def createRowTimeInnerJoinFunction( + def createRowTimeInnerJoin( leftDataStream: DataStream[CRow], rightDataStream: DataStream[CRow], returnTypeInfo: TypeInformation[CRow], diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/ProcTimeBoundedStreamInnerJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/ProcTimeBoundedStreamInnerJoin.scala index 11d42b9b1a18c..1efac7f74c049 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/ProcTimeBoundedStreamInnerJoin.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/ProcTimeBoundedStreamInnerJoin.scala @@ -26,7 +26,7 @@ import org.apache.flink.types.Row /** * The function to execute processing time bounded stream inner-join. */ -class ProcTimeBoundedStreamInnerJoin( +final class ProcTimeBoundedStreamInnerJoin( leftLowerBound: Long, leftUpperBound: Long, allowedLateness: Long, @@ -43,31 +43,29 @@ class ProcTimeBoundedStreamInnerJoin( genJoinFuncName, genJoinFuncCode, leftTimeIdx = -1, - rightTimeIdx = -1, - JoinTimeIndicator.PROCTIME) { + rightTimeIdx = -1) { - override def checkRowOutOfDate(timeForRow: Long, watermark: Long) = false + override def isRowTooLate(timeForRow: Long, watermark: Long): Boolean = false override def updateOperatorTime(ctx: CoProcessFunction[CRow, CRow, CRow]#Context): Unit = { - rightOperatorTime = ctx.timerService().currentProcessingTime() leftOperatorTime = ctx.timerService().currentProcessingTime() + rightOperatorTime = leftOperatorTime } override def getTimeForLeftStream( context: CoProcessFunction[CRow, CRow, CRow]#Context, - row: CRow): Long = { - context.timerService().currentProcessingTime() + row: Row): Long = { + leftOperatorTime } override def getTimeForRightStream( context: CoProcessFunction[CRow, CRow, CRow]#Context, - row: CRow): Long = { - context.timerService().currentProcessingTime() + row: Row): Long = { + rightOperatorTime } override def registerTimer( ctx: CoProcessFunction[CRow, CRow, CRow]#Context, - isLeft: Boolean, cleanupTime: Long): Unit = { ctx.timerService.registerProcessingTimeTimer(cleanupTime) } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/ProcTimeWindowInnerJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/ProcTimeWindowInnerJoin.scala deleted file mode 100644 index 824037630f285..0000000000000 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/ProcTimeWindowInnerJoin.scala +++ /dev/null @@ -1,346 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.table.runtime.join - -import java.util -import java.util.{List => JList} - -import org.apache.flink.api.common.functions.FlatJoinFunction -import org.apache.flink.api.common.state._ -import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} -import org.apache.flink.api.java.typeutils.ListTypeInfo -import org.apache.flink.configuration.Configuration -import org.apache.flink.streaming.api.functions.co.CoProcessFunction -import org.apache.flink.table.codegen.Compiler -import org.apache.flink.table.runtime.CRowWrappingCollector -import org.apache.flink.table.runtime.types.CRow -import org.apache.flink.table.util.Logging -import org.apache.flink.types.Row -import org.apache.flink.util.Collector - -/** - * A CoProcessFunction to support stream join stream, currently just support inner-join - * - * @param leftLowerBound - * the left stream lower bound, and -leftLowerBound is the right stream upper bound - * @param leftUpperBound - * the left stream upper bound, and -leftUpperBound is the right stream lower bound - * @param element1Type the input type of left stream - * @param element2Type the input type of right stream - * @param genJoinFuncName the function code of other non-equi condition - * @param genJoinFuncCode the function name of other non-equi condition - * - */ -class ProcTimeWindowInnerJoin( - private val leftLowerBound: Long, - private val leftUpperBound: Long, - private val element1Type: TypeInformation[Row], - private val element2Type: TypeInformation[Row], - private val genJoinFuncName: String, - private val genJoinFuncCode: String) - extends CoProcessFunction[CRow, CRow, CRow] - with Compiler[FlatJoinFunction[Row, Row, Row]] - with Logging { - - private var cRowWrapper: CRowWrappingCollector = _ - - // other condition function - private var joinFunction: FlatJoinFunction[Row, Row, Row] = _ - - // tmp list to store expired records - private var removeList: JList[Long] = _ - - // state to hold left stream element - private var row1MapState: MapState[Long, JList[Row]] = _ - // state to hold right stream element - private var row2MapState: MapState[Long, JList[Row]] = _ - - // state to record last timer of left stream, 0 means no timer - private var timerState1: ValueState[Long] = _ - // state to record last timer of right stream, 0 means no timer - private var timerState2: ValueState[Long] = _ - - // compute window sizes, i.e., how long to keep rows in state. - // window size of -1 means rows do not need to be put into state. - private val leftStreamWinSize: Long = if (leftLowerBound <= 0) -leftLowerBound else -1 - private val rightStreamWinSize: Long = if (leftUpperBound >= 0) leftUpperBound else -1 - - override def open(config: Configuration) { - LOG.debug(s"Compiling JoinFunction: $genJoinFuncName \n\n " + - s"Code:\n$genJoinFuncCode") - val clazz = compile( - getRuntimeContext.getUserCodeClassLoader, - genJoinFuncName, - genJoinFuncCode) - LOG.debug("Instantiating JoinFunction.") - joinFunction = clazz.newInstance() - - removeList = new util.ArrayList[Long]() - cRowWrapper = new CRowWrappingCollector() - cRowWrapper.setChange(true) - - // initialize row state - val rowListTypeInfo1: TypeInformation[JList[Row]] = new ListTypeInfo[Row](element1Type) - val mapStateDescriptor1: MapStateDescriptor[Long, JList[Row]] = - new MapStateDescriptor[Long, JList[Row]]("row1mapstate", - BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]], rowListTypeInfo1) - row1MapState = getRuntimeContext.getMapState(mapStateDescriptor1) - - val rowListTypeInfo2: TypeInformation[JList[Row]] = new ListTypeInfo[Row](element2Type) - val mapStateDescriptor2: MapStateDescriptor[Long, JList[Row]] = - new MapStateDescriptor[Long, JList[Row]]("row2mapstate", - BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]], rowListTypeInfo2) - row2MapState = getRuntimeContext.getMapState(mapStateDescriptor2) - - // initialize timer state - val valueStateDescriptor1: ValueStateDescriptor[Long] = - new ValueStateDescriptor[Long]("timervaluestate1", classOf[Long]) - timerState1 = getRuntimeContext.getState(valueStateDescriptor1) - - val valueStateDescriptor2: ValueStateDescriptor[Long] = - new ValueStateDescriptor[Long]("timervaluestate2", classOf[Long]) - timerState2 = getRuntimeContext.getState(valueStateDescriptor2) - } - - /** - * Process left stream records - * - * @param valueC The input value. - * @param ctx The ctx to register timer or get current time - * @param out The collector for returning result values. - * - */ - override def processElement1( - valueC: CRow, - ctx: CoProcessFunction[CRow, CRow, CRow]#Context, - out: Collector[CRow]): Unit = { - - processElement( - valueC, - ctx, - out, - leftStreamWinSize, - timerState1, - row1MapState, - row2MapState, - -leftUpperBound, // right stream lower - -leftLowerBound, // right stream upper - isLeft = true - ) - } - - /** - * Process right stream records - * - * @param valueC The input value. - * @param ctx The ctx to register timer or get current time - * @param out The collector for returning result values. - * - */ - override def processElement2( - valueC: CRow, - ctx: CoProcessFunction[CRow, CRow, CRow]#Context, - out: Collector[CRow]): Unit = { - - processElement( - valueC, - ctx, - out, - rightStreamWinSize, - timerState2, - row2MapState, - row1MapState, - leftLowerBound, // left stream lower - leftUpperBound, // left stream upper - isLeft = false - ) - } - - /** - * Called when a processing timer trigger. - * Expire left/right records which earlier than current time - windowsize. - * - * @param timestamp The timestamp of the firing timer. - * @param ctx The ctx to register timer or get current time - * @param out The collector for returning result values. - */ - override def onTimer( - timestamp: Long, - ctx: CoProcessFunction[CRow, CRow, CRow]#OnTimerContext, - out: Collector[CRow]): Unit = { - - if (timerState1.value == timestamp) { - expireOutTimeRow( - timestamp, - leftStreamWinSize, - row1MapState, - timerState1, - ctx - ) - } - - if (timerState2.value == timestamp) { - expireOutTimeRow( - timestamp, - rightStreamWinSize, - row2MapState, - timerState2, - ctx - ) - } - } - - /** - * Puts an element from the input stream into state and search the other state to - * output records meet the condition, and registers a timer for the current record - * if there is no timer at present. - */ - private def processElement( - valueC: CRow, - ctx: CoProcessFunction[CRow, CRow, CRow]#Context, - out: Collector[CRow], - winSize: Long, - timerState: ValueState[Long], - rowMapState: MapState[Long, JList[Row]], - otherRowMapState: MapState[Long, JList[Row]], - otherLowerBound: Long, - otherUpperBound: Long, - isLeft: Boolean): Unit = { - - cRowWrapper.out = out - - val row = valueC.row - - val curProcessTime = ctx.timerService.currentProcessingTime - val otherLowerTime = curProcessTime + otherLowerBound - val otherUpperTime = curProcessTime + otherUpperBound - - if (winSize >= 0) { - // put row into state for later joining. - // (winSize == 0) joins rows received in the same millisecond. - var rowList = rowMapState.get(curProcessTime) - if (rowList == null) { - rowList = new util.ArrayList[Row]() - } - rowList.add(row) - rowMapState.put(curProcessTime, rowList) - - // register a timer to remove the row from state once it is expired - if (timerState.value == 0) { - val cleanupTime = curProcessTime + winSize + 1 - ctx.timerService.registerProcessingTimeTimer(cleanupTime) - timerState.update(cleanupTime) - } - } - - // join row with rows received from the other input - val otherTimeIter = otherRowMapState.keys().iterator() - if (isLeft) { - // go over all timestamps in the other input's state - while (otherTimeIter.hasNext) { - val otherTimestamp = otherTimeIter.next() - if (otherTimestamp < otherLowerTime) { - // other timestamp is expired. Remove it later. - removeList.add(otherTimestamp) - } else if (otherTimestamp <= otherUpperTime) { - // join row with all rows from the other input for this timestamp - val otherRows = otherRowMapState.get(otherTimestamp) - var i = 0 - while (i < otherRows.size) { - joinFunction.join(row, otherRows.get(i), cRowWrapper) - i += 1 - } - } - } - } else { - // go over all timestamps in the other input's state - while (otherTimeIter.hasNext) { - val otherTimestamp = otherTimeIter.next() - if (otherTimestamp < otherLowerTime) { - // other timestamp is expired. Remove it later. - removeList.add(otherTimestamp) - } else if (otherTimestamp <= otherUpperTime) { - // join row with all rows from the other input for this timestamp - val otherRows = otherRowMapState.get(otherTimestamp) - var i = 0 - while (i < otherRows.size) { - joinFunction.join(otherRows.get(i), row, cRowWrapper) - i += 1 - } - } - } - } - - // remove rows for expired timestamps - var i = removeList.size - 1 - while (i >= 0) { - otherRowMapState.remove(removeList.get(i)) - i -= 1 - } - removeList.clear() - } - - /** - * Removes records which are outside the join window from the state. - * Registers a new timer if the state still holds records after the clean-up. - */ - private def expireOutTimeRow( - curTime: Long, - winSize: Long, - rowMapState: MapState[Long, JList[Row]], - timerState: ValueState[Long], - ctx: CoProcessFunction[CRow, CRow, CRow]#OnTimerContext): Unit = { - - val expiredTime = curTime - winSize - val keyIter = rowMapState.keys().iterator() - var validTimestamp: Boolean = false - // Search for expired timestamps. - // If we find a non-expired timestamp, remember the timestamp and leave the loop. - // This way we find all expired timestamps if they are sorted without doing a full pass. - while (keyIter.hasNext && !validTimestamp) { - val recordTime = keyIter.next - if (recordTime < expiredTime) { - removeList.add(recordTime) - } else { - // we found a timestamp that is still valid - validTimestamp = true - } - } - - // If the state has non-expired timestamps, register a new timer. - // Otherwise clean the complete state for this input. - if (validTimestamp) { - - // Remove expired records from state - var i = removeList.size - 1 - while (i >= 0) { - rowMapState.remove(removeList.get(i)) - i -= 1 - } - removeList.clear() - - val cleanupTime = curTime + winSize + 1 - ctx.timerService.registerProcessingTimeTimer(cleanupTime) - timerState.update(cleanupTime) - } else { - timerState.clear() - rowMapState.clear() - } - } -} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/RowTimeBoundedStreamInnerJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/RowTimeBoundedStreamInnerJoin.scala index 798fefdaaf384..a48a9849f2755 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/RowTimeBoundedStreamInnerJoin.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/RowTimeBoundedStreamInnerJoin.scala @@ -26,7 +26,7 @@ import org.apache.flink.types.Row /** * The function to execute row(event) time bounded stream inner-join. */ -class RowTimeBoundedStreamInnerJoin( +final class RowTimeBoundedStreamInnerJoin( leftLowerBound: Long, leftUpperBound: Long, allowedLateness: Long, @@ -45,37 +45,35 @@ class RowTimeBoundedStreamInnerJoin( genJoinFuncName, genJoinFuncCode, leftTimeIdx, - rightTimeIdx, - JoinTimeIndicator.ROWTIME) { + rightTimeIdx) { - override def checkRowOutOfDate(timeForRow: Long, watermark: Long) = { + override def isRowTooLate(timeForRow: Long, watermark: Long): Boolean = { timeForRow <= watermark - allowedLateness } override def updateOperatorTime(ctx: CoProcessFunction[CRow, CRow, CRow]#Context): Unit = { - rightOperatorTime = + leftOperatorTime = if (ctx.timerService().currentWatermark() > 0) ctx.timerService().currentWatermark() else 0L - leftOperatorTime = + rightOperatorTime = if (ctx.timerService().currentWatermark() > 0) ctx.timerService().currentWatermark() else 0L } override def getTimeForLeftStream( context: CoProcessFunction[CRow, CRow, CRow]#Context, - row: CRow): Long = { - row.row.getField(leftTimeIdx).asInstanceOf[Long] + row: Row): Long = { + row.getField(leftTimeIdx).asInstanceOf[Long] } override def getTimeForRightStream( context: CoProcessFunction[CRow, CRow, CRow]#Context, - row: CRow): Long = { - row.row.getField(rightTimeIdx).asInstanceOf[Long] + row: Row): Long = { + row.getField(rightTimeIdx).asInstanceOf[Long] } override def registerTimer( ctx: CoProcessFunction[CRow, CRow, CRow]#Context, - isLeft: Boolean, cleanupTime: Long): Unit = { // Maybe we can register timers for different streams in the future. ctx.timerService.registerEventTimeTimer(cleanupTime) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/TimeBoundedStreamInnerJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/TimeBoundedStreamInnerJoin.scala index 2ee0028781bf2..db72b63a5b197 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/TimeBoundedStreamInnerJoin.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/TimeBoundedStreamInnerJoin.scala @@ -22,13 +22,13 @@ import java.util.{ArrayList, List => JList} import org.apache.flink.api.common.functions.FlatJoinFunction import org.apache.flink.api.common.state._ -import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} +import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.typeutils.ListTypeInfo import org.apache.flink.configuration.Configuration import org.apache.flink.streaming.api.functions.co.CoProcessFunction +import org.apache.flink.table.api.Types import org.apache.flink.table.codegen.Compiler import org.apache.flink.table.runtime.CRowWrappingCollector -import org.apache.flink.table.runtime.join.JoinTimeIndicator.JoinTimeIndicator import org.apache.flink.table.runtime.types.CRow import org.apache.flink.table.util.Logging import org.apache.flink.types.Row @@ -46,7 +46,6 @@ import org.apache.flink.util.Collector * @param rightType the input type of right stream * @param genJoinFuncName the function code of other non-equi conditions * @param genJoinFuncCode the function name of other non-equi conditions - * @param timeIndicator indicate whether joining on proctime or rowtime * */ abstract class TimeBoundedStreamInnerJoin( @@ -58,8 +57,7 @@ abstract class TimeBoundedStreamInnerJoin( private val genJoinFuncName: String, private val genJoinFuncCode: String, private val leftTimeIdx: Int, - private val rightTimeIdx: Int, - private val timeIndicator: JoinTimeIndicator) + private val rightTimeIdx: Int) extends CoProcessFunction[CRow, CRow, CRow] with Compiler[FlatJoinFunction[Row, Row, Row]] with Logging { @@ -85,7 +83,7 @@ abstract class TimeBoundedStreamInnerJoin( protected var leftOperatorTime: Long = 0L protected var rightOperatorTime: Long = 0L - //For delayed cleanup + // for delayed cleanup private val cleanupDelay = (leftRelativeSize + rightRelativeSize) / 2 if (allowedLateness < 0) { @@ -117,26 +115,26 @@ abstract class TimeBoundedStreamInnerJoin( val leftListTypeInfo: TypeInformation[JList[Row]] = new ListTypeInfo[Row](leftType) val leftStateDescriptor: MapStateDescriptor[Long, JList[Row]] = new MapStateDescriptor[Long, JList[Row]]( - timeIndicator + "InnerJoinLeftCache", - BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]], + "InnerJoinLeftCache", + Types.LONG.asInstanceOf[TypeInformation[Long]], leftListTypeInfo) leftCache = getRuntimeContext.getMapState(leftStateDescriptor) val rightListTypeInfo: TypeInformation[JList[Row]] = new ListTypeInfo[Row](rightType) val rightStateDescriptor: MapStateDescriptor[Long, JList[Row]] = new MapStateDescriptor[Long, JList[Row]]( - timeIndicator + "InnerJoinRightCache", - BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]], + "InnerJoinRightCache", + Types.LONG.asInstanceOf[TypeInformation[Long]], rightListTypeInfo) rightCache = getRuntimeContext.getMapState(rightStateDescriptor) // Initialize the timer states. val leftTimerStateDesc: ValueStateDescriptor[Long] = - new ValueStateDescriptor[Long](timeIndicator + "InnerJoinLeftTimerState", classOf[Long]) + new ValueStateDescriptor[Long]("InnerJoinLeftTimerState", classOf[Long]) leftTimerState = getRuntimeContext.getState(leftTimerStateDesc) val rightTimerStateDesc: ValueStateDescriptor[Long] = - new ValueStateDescriptor[Long](timeIndicator + "InnerJoinRightTimerState", classOf[Long]) + new ValueStateDescriptor[Long]("InnerJoinRightTimerState", classOf[Long]) rightTimerState = getRuntimeContext.getState(rightTimerStateDesc) } @@ -148,23 +146,46 @@ abstract class TimeBoundedStreamInnerJoin( ctx: CoProcessFunction[CRow, CRow, CRow]#Context, out: Collector[CRow]): Unit = { updateOperatorTime(ctx) - val rowTime: Long = getTimeForLeftStream(ctx, cRowValue) - val oppositeLowerBound: Long = rowTime - rightRelativeSize - val oppositeUpperBound: Long = rowTime + leftRelativeSize - processElement( - cRowValue, - rowTime, - ctx, - out, - leftOperatorTime, - oppositeLowerBound, - oppositeUpperBound, - rightOperatorTime, - rightTimerState, - leftCache, - rightCache, - leftRow = true - ) + val leftRow = cRowValue.row + val timeForLeftRow: Long = getTimeForLeftStream(ctx, leftRow) + val rightQualifiedLowerBound: Long = timeForLeftRow - rightRelativeSize + val rightQualifiedUpperBound: Long = timeForLeftRow + leftRelativeSize + val rightExpirationTime = calExpirationTime(leftOperatorTime, rightRelativeSize) + cRowWrapper.out = out + + if (!isRowTooLate(timeForLeftRow, leftOperatorTime)) { + if (rightOperatorTime < rightQualifiedUpperBound) { + // Put the leftRow into the cache for later use. + var leftRowList = leftCache.get(timeForLeftRow) + if (null == leftRowList) { + leftRowList = new ArrayList[Row](1) + } + leftRowList.add(leftRow) + leftCache.put(timeForLeftRow, leftRowList) + if (rightTimerState.value == 0) { + // Register a timer on the RIGHT stream to remove rows. + registerCleanUpTimer(ctx, timeForLeftRow, rightTimerState, leftRow = true) + } + } + + // Join the leftRow with rows from the right cache. + val rightIterator = rightCache.iterator() + while (rightIterator.hasNext) { + val rightEntry = rightIterator.next + val rightTime = rightEntry.getKey + if (rightTime >= rightQualifiedLowerBound && rightTime <= rightQualifiedUpperBound) { + val rightRows = rightEntry.getValue + var i = 0 + while (i < rightRows.size) { + joinFunction.join(leftRow, rightRows.get(i), cRowWrapper) + i += 1 + } + } else if (rightTime <= rightExpirationTime) { + // eager remove + rightIterator.remove() + }// We could do the short-cutting optimization here once we get a state with ordered keys. + } + } } /** @@ -175,118 +196,46 @@ abstract class TimeBoundedStreamInnerJoin( ctx: CoProcessFunction[CRow, CRow, CRow]#Context, out: Collector[CRow]): Unit = { updateOperatorTime(ctx) - val rowTime: Long = getTimeForRightStream(ctx, cRowValue) - val oppositeLowerBound: Long = rowTime - leftRelativeSize - val oppositeUpperBound: Long = rowTime + rightRelativeSize - processElement( - cRowValue, - rowTime, - ctx, - out, - rightOperatorTime, - oppositeLowerBound, - oppositeUpperBound, - leftOperatorTime, - leftTimerState, - rightCache, - leftCache, - leftRow = false - ) - } - - /** - * Put a row from the input stream into the cache and iterate the opposite cache to - * output join results meeting the conditions. If there is no timer set for the OPPOSITE - * STREAM, register one. - */ - private def processElement( - cRowValue: CRow, - timeForRow: Long, - ctx: CoProcessFunction[CRow, CRow, CRow]#Context, - out: Collector[CRow], - myWatermark: Long, - oppositeLowerBound: Long, - oppositeUpperBound: Long, - oppositeWatermark: Long, - oppositeTimeState: ValueState[Long], - rowListCache: MapState[Long, JList[Row]], - oppositeCache: MapState[Long, JList[Row]], - leftRow: Boolean): Unit = { + val rightRow = cRowValue.row + val timeForRightRow: Long = getTimeForRightStream(ctx, rightRow) + val leftQualifiedLowerBound: Long = timeForRightRow - leftRelativeSize + val leftQualifiedUpperBound: Long = timeForRightRow + rightRelativeSize + val leftExpirationTime = calExpirationTime(rightOperatorTime, leftRelativeSize) cRowWrapper.out = out - val row = cRowValue.row - if (!checkRowOutOfDate(timeForRow, myWatermark)) { - // Put the row into the cache for later use. - var rowList = rowListCache.get(timeForRow) - if (null == rowList) { - rowList = new ArrayList[Row](1) - } - rowList.add(row) - rowListCache.put(timeForRow, rowList) - // Register a timer on THE OPPOSITE STREAM to remove rows from the cache once they are - // expired. - if (oppositeTimeState.value == 0) { - registerCleanUpTimer( - ctx, timeForRow, oppositeWatermark, oppositeTimeState, leftRow, firstTimer = true) + + if (!isRowTooLate(timeForRightRow, rightOperatorTime)) { + if (leftOperatorTime < leftQualifiedUpperBound) { + // Put the rightRow into the cache for later use. + var rightRowList = rightCache.get(timeForRightRow) + if (null == rightRowList) { + rightRowList = new ArrayList[Row](1) + } + rightRowList.add(rightRow) + rightCache.put(timeForRightRow, rightRowList) + if (leftTimerState.value == 0) { + // Register a timer on the LEFT stream to remove rows. + registerCleanUpTimer(ctx, timeForRightRow, leftTimerState, leftRow = false) + } } - // Join the row with rows from the opposite stream. - val oppositeIterator = oppositeCache.iterator() - while (oppositeIterator.hasNext) { - val oppositeEntry = oppositeIterator.next - val oppositeTime = oppositeEntry.getKey - if (oppositeTime >= oppositeLowerBound && oppositeTime <= oppositeUpperBound) { - val oppositeRows = oppositeEntry.getValue + // Join the rightRow with rows from the left cache. + val leftIterator = leftCache.iterator() + while (leftIterator.hasNext) { + val leftEntry = leftIterator.next + val leftTime = leftEntry.getKey + if (leftTime >= leftQualifiedLowerBound && leftTime <= leftQualifiedUpperBound) { + val leftRows = leftEntry.getValue var i = 0 - if (leftRow) { - while (i < oppositeRows.size) { - joinFunction.join(row, oppositeRows.get(i), cRowWrapper) - i += 1 - } - } else { - while (i < oppositeRows.size) { - joinFunction.join(oppositeRows.get(i), row, cRowWrapper) - i += 1 - } + while (i < leftRows.size) { + joinFunction.join(leftRows.get(i), rightRow, cRowWrapper) + i += 1 } - } - // We could do the short-cutting optimization here once we get a state with ordered keys. + } else if (leftTime <= leftExpirationTime) { + // eager remove + leftIterator.remove() + } // We could do the short-cutting optimization here once we get a state with ordered keys. } } - // We need to deal with the late data in the future. - } - - /** - * Register a timer for cleaning up rows in a specified time. - * - * @param ctx the context to register timer - * @param rowTime time for the input row - * @param oppositeWatermark watermark of the opposite stream - * @param timerState stores the timestamp for the next timer - * @param leftRow whether this row comes from the left stream - * @param firstTimer whether this is the first timer - */ - private def registerCleanUpTimer( - ctx: CoProcessFunction[CRow, CRow, CRow]#Context, - rowTime: Long, - oppositeWatermark: Long, - timerState: ValueState[Long], - leftRow: Boolean, - firstTimer: Boolean): Unit = { - val cleanupTime = if (leftRow) { - rowTime + leftRelativeSize + cleanupDelay + allowedLateness + 1 - } else { - rowTime + rightRelativeSize + cleanupDelay + allowedLateness + 1 - } - registerTimer(ctx, !leftRow, cleanupTime) - LOG.debug(s"Register a clean up timer on the ${if (leftRow) "RIGHT" else "LEFT"} state:" - + s" timeForRow = ${rowTime}, cleanupTime should be ${cleanupTime - cleanupDelay}," + - s" but delayed to ${cleanupTime}," + - s" oppositeWatermark = ${oppositeWatermark}") - timerState.update(cleanupTime) - //if cleanupTime <= oppositeWatermark + allowedLateness && firstTimer, we may set the - // backPressureSuggestion = - // if (leftRow) (oppositeWatermark + allowedLateness - cleanupTime) - // else -(oppositeWatermark + allowedLateness - cleanupTime) } /** @@ -307,10 +256,9 @@ abstract class TimeBoundedStreamInnerJoin( // registered timer of the faster stream will be delayed, even if the watermarks have // already been emitted by the source. if (leftTimerState.value == timestamp) { - val rightExpirationTime = leftOperatorTime - rightRelativeSize - allowedLateness - 1 + val rightExpirationTime = calExpirationTime(leftOperatorTime, rightRelativeSize) removeExpiredRows( rightExpirationTime, - leftOperatorTime, rightCache, leftTimerState, ctx, @@ -319,10 +267,9 @@ abstract class TimeBoundedStreamInnerJoin( } if (rightTimerState.value == timestamp) { - val leftExpirationTime = rightOperatorTime - leftRelativeSize - allowedLateness - 1 + val leftExpirationTime = calExpirationTime(rightOperatorTime, leftRelativeSize) removeExpiredRows( leftExpirationTime, - rightOperatorTime, leftCache, rightTimerState, ctx, @@ -331,20 +278,55 @@ abstract class TimeBoundedStreamInnerJoin( } } + /** + * Calculate the expiration time with the given operator time and relative window size. + * + * @param operatorTime the operator time + * @param relativeSize the relative window size + * @return the expiration time for cached rows + */ + private def calExpirationTime(operatorTime: Long, relativeSize: Long): Long = { + if (operatorTime < Long.MaxValue) { + operatorTime - relativeSize - allowedLateness - 1 + } else { + Long.MaxValue + } + } + + /** + * Register a timer for cleaning up rows in a specified time. + * + * @param ctx the context to register timer + * @param rowTime time for the input row + * @param timerState stores the timestamp for the next timer + * @param leftRow whether this row comes from the left stream + */ + private def registerCleanUpTimer( + ctx: CoProcessFunction[CRow, CRow, CRow]#Context, + rowTime: Long, + timerState: ValueState[Long], + leftRow: Boolean): Unit = { + val cleanupTime = if (leftRow) { + rowTime + leftRelativeSize + cleanupDelay + allowedLateness + 1 + } else { + rowTime + rightRelativeSize + cleanupDelay + allowedLateness + 1 + } + registerTimer(ctx, cleanupTime) + timerState.update(cleanupTime) + } + /** * Remove the expired rows. Register a new timer if the cache still holds valid rows * after the cleaning up. * - * @param expirationTime the expiration time for this cache - * @param oppositeWatermark the watermark of the opposite stream - * @param rowCache the row cache - * @param timerState timer state for the opposite stream - * @param ctx the context to register the cleanup timer - * @param removeLeft whether to remove the left rows + * @param expirationTime the expiration time for this cache + * @param rowCache the row cache + * @param timerState timer state for the opposite stream + * @param ctx the context to register the cleanup timer + * @param removeLeft whether to remove the left rows */ private def removeExpiredRows( expirationTime: Long, - oppositeWatermark: Long, rowCache: MapState[Long, JList[Row]], timerState: ValueState[Long], ctx: CoProcessFunction[CRow, CRow, CRow]#OnTimerContext, @@ -374,10 +356,8 @@ abstract class TimeBoundedStreamInnerJoin( registerCleanUpTimer( ctx, earliestTimestamp, - oppositeWatermark, timerState, - removeLeft, - firstTimer = false) + removeLeft) } else { // The timerState will be 0. timerState.clear() @@ -386,16 +366,17 @@ abstract class TimeBoundedStreamInnerJoin( } /** - * Check if the row is out of date. + * Check if the row is too late to be processed. * * @param timeForRow time of the row - * @param watermark watermark for the stream + * @param operatorTime operator time for current stream * @return true if the row is out of date; false otherwise */ - def checkRowOutOfDate(timeForRow: Long, watermark: Long): Boolean + def isRowTooLate(timeForRow: Long, operatorTime: Long): Boolean /** * Update the operator time of the two streams. + * Must be the first call in all processing methods (i.e., processElement(), onTimer()). * * @param ctx the context to acquire watermarks */ @@ -408,7 +389,7 @@ abstract class TimeBoundedStreamInnerJoin( * @param row the target row * @return time for the target row */ - def getTimeForLeftStream(context: CoProcessFunction[CRow, CRow, CRow]#Context, row: CRow): Long + def getTimeForLeftStream(context: CoProcessFunction[CRow, CRow, CRow]#Context, row: Row): Long /** * Return the time for the target row from the right stream. @@ -417,26 +398,15 @@ abstract class TimeBoundedStreamInnerJoin( * @param row the target row * @return time for the target row */ - def getTimeForRightStream(context: CoProcessFunction[CRow, CRow, CRow]#Context, row: CRow): Long + def getTimeForRightStream(context: CoProcessFunction[CRow, CRow, CRow]#Context, row: Row): Long /** * Register a proctime or rowtime timer. * * @param ctx the context to register the timer - * @param isLeft whether this timer should be registered on the left stream * @param cleanupTime timestamp for the timer */ def registerTimer( ctx: CoProcessFunction[CRow, CRow, CRow]#Context, - isLeft: Boolean, cleanupTime: Long): Unit } - -/** - * Defines the rowtime and proctime join indicators. - */ -object JoinTimeIndicator extends Enumeration { - type JoinTimeIndicator = Value - val ROWTIME, PROCTIME = Value -} - diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/JoinHarnessTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/JoinHarnessTest.scala index 0b18f95c25e3a..9ebd3ecd122c1 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/JoinHarnessTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/JoinHarnessTest.scala @@ -17,30 +17,26 @@ */ package org.apache.flink.table.runtime.harness -import java.util.concurrent.ConcurrentLinkedQueue import java.lang.{Long => JLong} +import java.util.concurrent.ConcurrentLinkedQueue -import org.apache.flink.api.common.typeinfo.BasicTypeInfo._ -import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} -import org.apache.flink.api.java.typeutils.RowTypeInfo import org.apache.flink.streaming.api.operators.co.KeyedCoProcessOperator import org.apache.flink.streaming.api.watermark.Watermark import org.apache.flink.streaming.runtime.streamrecord.StreamRecord import org.apache.flink.streaming.util.KeyedTwoInputStreamOperatorTestHarness +import org.apache.flink.table.api.Types import org.apache.flink.table.runtime.harness.HarnessTestBase.{RowResultSortComparator, TupleRowKeySelector} -import org.apache.flink.table.runtime.join.{ProcTimeBoundedStreamInnerJoin, ProcTimeWindowInnerJoin, RowTimeBoundedStreamInnerJoin} +import org.apache.flink.table.runtime.join.{ProcTimeBoundedStreamInnerJoin, RowTimeBoundedStreamInnerJoin} import org.apache.flink.table.runtime.types.CRow import org.apache.flink.types.Row +import org.junit.Assert.{assertEquals} import org.junit.Test -import org.junit.Assert.{assertEquals, assertTrue} - -class JoinHarnessTest extends HarnessTestBase{ - private val rT = new RowTypeInfo(Array[TypeInformation[_]]( - LONG_TYPE_INFO, - STRING_TYPE_INFO), - Array("a", "b")) +class JoinHarnessTest extends HarnessTestBase { + private val rowType = Types.ROW( + Types.LONG, + Types.STRING) val funcCode: String = """ @@ -76,169 +72,10 @@ class JoinHarnessTest extends HarnessTestBase{ /** a.proctime >= b.proctime - 10 and a.proctime <= b.proctime + 20 **/ @Test - def testNormalProcTimeJoin() { - - val joinProcessFunc = new ProcTimeWindowInnerJoin(-10, 20, rT, rT, "TestJoinFunction", funcCode) - - val operator: KeyedCoProcessOperator[Integer, CRow, CRow, CRow] = - new KeyedCoProcessOperator[Integer, CRow, CRow, CRow](joinProcessFunc) - val testHarness: KeyedTwoInputStreamOperatorTestHarness[Integer, CRow, CRow, CRow] = - new KeyedTwoInputStreamOperatorTestHarness[Integer, CRow, CRow, CRow]( - operator, - new TupleRowKeySelector[Integer](0), - new TupleRowKeySelector[Integer](0), - BasicTypeInfo.INT_TYPE_INFO, - 1, 1, 0) - - testHarness.open() - - // left stream input - testHarness.setProcessingTime(1) - testHarness.processElement1(new StreamRecord( - CRow(Row.of(1L: JLong, "aaa"), true), 1)) - assertEquals(1, testHarness.numProcessingTimeTimers()) - testHarness.setProcessingTime(2) - testHarness.processElement1(new StreamRecord( - CRow(Row.of(2L: JLong, "bbb"), true), 2)) - assertEquals(2, testHarness.numProcessingTimeTimers()) - testHarness.setProcessingTime(3) - testHarness.processElement1(new StreamRecord( - CRow(Row.of(1L: JLong, "aaa2"), true), 3)) - assertEquals(4, testHarness.numKeyedStateEntries()) - assertEquals(2, testHarness.numProcessingTimeTimers()) - - // right stream input and output normally - testHarness.processElement2(new StreamRecord( - CRow(Row.of(1L: JLong, "Hi1"), true), 3)) - testHarness.setProcessingTime(4) - testHarness.processElement2(new StreamRecord( - CRow(Row.of(2L: JLong, "Hello1"), true), 4)) - assertEquals(8, testHarness.numKeyedStateEntries()) - assertEquals(4, testHarness.numProcessingTimeTimers()) - - // expired left stream record at timestamp 1 - testHarness.setProcessingTime(12) - assertEquals(8, testHarness.numKeyedStateEntries()) - assertEquals(4, testHarness.numProcessingTimeTimers()) - testHarness.processElement2(new StreamRecord( - CRow(Row.of(1L: JLong, "Hi2"), true), 12)) - - // expired right stream record at timestamp 4 and all left stream - testHarness.setProcessingTime(25) - assertEquals(2, testHarness.numKeyedStateEntries()) - assertEquals(1, testHarness.numProcessingTimeTimers()) - testHarness.processElement1(new StreamRecord( - CRow(Row.of(1L: JLong, "aaa3"), true), 25)) - testHarness.processElement1(new StreamRecord( - CRow(Row.of(2L: JLong, "bbb2"), true), 25)) - testHarness.processElement2(new StreamRecord( - CRow(Row.of(2L: JLong, "Hello2"), true), 25)) - - testHarness.setProcessingTime(45) - assertTrue(testHarness.numKeyedStateEntries() > 0) - testHarness.setProcessingTime(46) - assertEquals(0, testHarness.numKeyedStateEntries()) - val result = testHarness.getOutput - - val expectedOutput = new ConcurrentLinkedQueue[Object]() - - expectedOutput.add(new StreamRecord( - CRow(Row.of(1L: JLong, "aaa", 1L: JLong, "Hi1"), true), 3)) - expectedOutput.add(new StreamRecord( - CRow(Row.of(1L: JLong, "aaa2", 1L: JLong, "Hi1"), true), 3)) - expectedOutput.add(new StreamRecord( - CRow(Row.of(2L: JLong, "bbb", 2L: JLong, "Hello1"), true), 4)) - expectedOutput.add(new StreamRecord( - CRow(Row.of(1L: JLong, "aaa2", 1L: JLong, "Hi2"), true), 12)) - expectedOutput.add(new StreamRecord( - CRow(Row.of(1L: JLong, "aaa3", 1L: JLong, "Hi2"), true), 25)) - expectedOutput.add(new StreamRecord( - CRow(Row.of(2L: JLong, "bbb2", 2L: JLong, "Hello2"), true), 25)) - - verify(expectedOutput, result, new RowResultSortComparator()) - - testHarness.close() - } - - /** a.proctime >= b.proctime - 10 and a.proctime <= b.proctime - 5 **/ - @Test - def testProcTimeJoinSingleNeedStore() { - - val joinProcessFunc = new ProcTimeWindowInnerJoin(-10, -5, rT, rT, "TestJoinFunction", funcCode) - - val operator: KeyedCoProcessOperator[Integer, CRow, CRow, CRow] = - new KeyedCoProcessOperator[Integer, CRow, CRow, CRow](joinProcessFunc) - val testHarness: KeyedTwoInputStreamOperatorTestHarness[Integer, CRow, CRow, CRow] = - new KeyedTwoInputStreamOperatorTestHarness[Integer, CRow, CRow, CRow]( - operator, - new TupleRowKeySelector[Integer](0), - new TupleRowKeySelector[Integer](0), - BasicTypeInfo.INT_TYPE_INFO, - 1, 1, 0) - - testHarness.open() - - testHarness.setProcessingTime(1) - testHarness.processElement1(new StreamRecord( - CRow(Row.of(1L: JLong, "aaa1"), true), 1)) - testHarness.setProcessingTime(2) - testHarness.processElement1(new StreamRecord( - CRow(Row.of(2L: JLong, "aaa2"), true), 2)) - testHarness.setProcessingTime(3) - testHarness.processElement1(new StreamRecord( - CRow(Row.of(1L: JLong, "aaa3"), true), 3)) - assertEquals(4, testHarness.numKeyedStateEntries()) - assertEquals(2, testHarness.numProcessingTimeTimers()) - - // Do not store b elements - // not meet a.proctime <= b.proctime - 5 - testHarness.processElement2(new StreamRecord( - CRow(Row.of(1L: JLong, "bbb3"), true), 3)) - assertEquals(4, testHarness.numKeyedStateEntries()) - assertEquals(2, testHarness.numProcessingTimeTimers()) - - // meet a.proctime <= b.proctime - 5 - testHarness.setProcessingTime(7) - testHarness.processElement2(new StreamRecord( - CRow(Row.of(2L: JLong, "bbb7"), true), 7)) - assertEquals(4, testHarness.numKeyedStateEntries()) - assertEquals(2, testHarness.numProcessingTimeTimers()) - - // expire record of stream a at timestamp 1 - testHarness.setProcessingTime(12) - assertEquals(4, testHarness.numKeyedStateEntries()) - assertEquals(2, testHarness.numProcessingTimeTimers()) - testHarness.processElement2(new StreamRecord( - CRow(Row.of(1L: JLong, "bbb12"), true), 12)) - - testHarness.setProcessingTime(13) - assertEquals(2, testHarness.numKeyedStateEntries()) - assertEquals(1, testHarness.numProcessingTimeTimers()) - - // state must be cleaned after the window timer interval has passed without new rows. - testHarness.setProcessingTime(23) - assertEquals(0, testHarness.numKeyedStateEntries()) - assertEquals(0, testHarness.numProcessingTimeTimers()) - val result = testHarness.getOutput - - val expectedOutput = new ConcurrentLinkedQueue[Object]() - - expectedOutput.add(new StreamRecord( - CRow(Row.of(2L: JLong, "aaa2", 2L: JLong, "bbb7"), true), 7)) - expectedOutput.add(new StreamRecord( - CRow(Row.of(1L: JLong, "aaa3", 1L: JLong, "bbb12"), true), 12)) - - verify(expectedOutput, result, new RowResultSortComparator()) - - testHarness.close() - } - - /** a.proctime >= b.proctime - 10 and a.proctime <= b.proctime + 20 **/ - @Test - def testNewNormalProcTimeJoin() { + def testProcTimeJoinWithCommonBounds() { val joinProcessFunc = new ProcTimeBoundedStreamInnerJoin( - -10, 20, 0, rT, rT, "TestJoinFunction", funcCode) + -10, 20, 0, rowType, rowType, "TestJoinFunction", funcCode) val operator: KeyedCoProcessOperator[Integer, CRow, CRow, CRow] = new KeyedCoProcessOperator[Integer, CRow, CRow, CRow](joinProcessFunc) @@ -247,73 +84,76 @@ class JoinHarnessTest extends HarnessTestBase{ operator, new TupleRowKeySelector[Integer](0), new TupleRowKeySelector[Integer](0), - BasicTypeInfo.INT_TYPE_INFO, + Types.INT, 1, 1, 0) testHarness.open() - // left stream input testHarness.setProcessingTime(1) testHarness.processElement1(new StreamRecord( - CRow(Row.of(1L: JLong, "aaa"), true), 1)) + CRow(Row.of(1L: JLong, "1a1"), true), 1)) assertEquals(1, testHarness.numProcessingTimeTimers()) testHarness.setProcessingTime(2) testHarness.processElement1(new StreamRecord( - CRow(Row.of(2L: JLong, "bbb"), true), 2)) + CRow(Row.of(2L: JLong, "2a2"), true), 2)) + + // timers for key = 1 and key = 2 assertEquals(2, testHarness.numProcessingTimeTimers()) + testHarness.setProcessingTime(3) testHarness.processElement1(new StreamRecord( - CRow(Row.of(1L: JLong, "aaa2"), true), 3)) + CRow(Row.of(1L: JLong, "1a3"), true), 3)) assertEquals(4, testHarness.numKeyedStateEntries()) + + // The number of timers won't increase. assertEquals(2, testHarness.numProcessingTimeTimers()) - // right stream input and output normally testHarness.processElement2(new StreamRecord( - CRow(Row.of(1L: JLong, "Hi1"), true), 3)) + CRow(Row.of(1L: JLong, "1b3"), true), 3)) testHarness.setProcessingTime(4) testHarness.processElement2(new StreamRecord( - CRow(Row.of(2L: JLong, "Hello1"), true), 4)) - assertEquals(8, testHarness.numKeyedStateEntries()) - assertEquals(4, testHarness.numProcessingTimeTimers()) + CRow(Row.of(2L: JLong, "2b4"), true), 4)) - // expired left stream record at timestamp 1 - testHarness.setProcessingTime(12) + // The number of states should be doubled. assertEquals(8, testHarness.numKeyedStateEntries()) assertEquals(4, testHarness.numProcessingTimeTimers()) + + // Test for -10 boundary (13 - 10 = 3). + // The left row (key = 1) with timestamp = 1 will be eagerly removed here. + testHarness.setProcessingTime(13) testHarness.processElement2(new StreamRecord( - CRow(Row.of(1L: JLong, "Hi2"), true), 12)) + CRow(Row.of(1L: JLong, "1b13"), true), 13)) - // expired right stream record at timestamp 4 and all left stream - testHarness.setProcessingTime(25) - assertEquals(8, testHarness.numKeyedStateEntries()) - assertEquals(4, testHarness.numProcessingTimeTimers()) + // Test for +20 boundary (13 + 20 = 33). + testHarness.setProcessingTime(33) + assertEquals(4, testHarness.numKeyedStateEntries()) + assertEquals(2, testHarness.numProcessingTimeTimers()) testHarness.processElement1(new StreamRecord( - CRow(Row.of(1L: JLong, "aaa3"), true), 25)) + CRow(Row.of(1L: JLong, "1a33"), true), 33)) + testHarness.processElement1(new StreamRecord( - CRow(Row.of(2L: JLong, "bbb2"), true), 25)) + CRow(Row.of(2L: JLong, "2a33"), true), 33)) + + // The left row (key = 2) with timestamp = 2 will be eagerly removed here. testHarness.processElement2(new StreamRecord( - CRow(Row.of(2L: JLong, "Hello2"), true), 25)) + CRow(Row.of(2L: JLong, "2b33"), true), 33)) - testHarness.setProcessingTime(45) - assertTrue(testHarness.numKeyedStateEntries() > 0) - testHarness.setProcessingTime(46) - assertEquals(2, testHarness.numKeyedStateEntries()) val result = testHarness.getOutput val expectedOutput = new ConcurrentLinkedQueue[Object]() expectedOutput.add(new StreamRecord( - CRow(Row.of(1L: JLong, "aaa", 1L: JLong, "Hi1"), true), 3)) + CRow(Row.of(1L: JLong, "1a1", 1L: JLong, "1b3"), true), 3)) expectedOutput.add(new StreamRecord( - CRow(Row.of(1L: JLong, "aaa2", 1L: JLong, "Hi1"), true), 3)) + CRow(Row.of(1L: JLong, "1a3", 1L: JLong, "1b3"), true), 3)) expectedOutput.add(new StreamRecord( - CRow(Row.of(2L: JLong, "bbb", 2L: JLong, "Hello1"), true), 4)) + CRow(Row.of(2L: JLong, "2a2", 2L: JLong, "2b4"), true), 4)) expectedOutput.add(new StreamRecord( - CRow(Row.of(1L: JLong, "aaa2", 1L: JLong, "Hi2"), true), 12)) + CRow(Row.of(1L: JLong, "1a3", 1L: JLong, "1b13"), true), 13)) expectedOutput.add(new StreamRecord( - CRow(Row.of(1L: JLong, "aaa3", 1L: JLong, "Hi2"), true), 25)) + CRow(Row.of(1L: JLong, "1a33", 1L: JLong, "1b13"), true), 33)) expectedOutput.add(new StreamRecord( - CRow(Row.of(2L: JLong, "bbb2", 2L: JLong, "Hello2"), true), 25)) + CRow(Row.of(2L: JLong, "2a33", 2L: JLong, "2b33"), true), 33)) verify(expectedOutput, result, new RowResultSortComparator()) @@ -322,10 +162,10 @@ class JoinHarnessTest extends HarnessTestBase{ /** a.proctime >= b.proctime - 10 and a.proctime <= b.proctime - 5 **/ @Test - def testNewProcTimeJoinSingleNeedStore() { + def testProcTimeJoinWithNegativeBounds() { val joinProcessFunc = new ProcTimeBoundedStreamInnerJoin( - -10, -5, 0, rT, rT, "TestJoinFunction", funcCode) + -10, -5, 0, rowType, rowType, "TestJoinFunction", funcCode) val operator: KeyedCoProcessOperator[Integer, CRow, CRow, CRow] = new KeyedCoProcessOperator[Integer, CRow, CRow, CRow](joinProcessFunc) @@ -334,49 +174,58 @@ class JoinHarnessTest extends HarnessTestBase{ operator, new TupleRowKeySelector[Integer](0), new TupleRowKeySelector[Integer](0), - BasicTypeInfo.INT_TYPE_INFO, + Types.INT, 1, 1, 0) testHarness.open() testHarness.setProcessingTime(1) testHarness.processElement1(new StreamRecord( - CRow(Row.of(1L: JLong, "aaa1"), true), 1)) + CRow(Row.of(1L: JLong, "1a1"), true), 1)) testHarness.setProcessingTime(2) testHarness.processElement1(new StreamRecord( - CRow(Row.of(2L: JLong, "aaa2"), true), 2)) + CRow(Row.of(2L: JLong, "2a2"), true), 2)) testHarness.setProcessingTime(3) testHarness.processElement1(new StreamRecord( - CRow(Row.of(1L: JLong, "aaa3"), true), 3)) + CRow(Row.of(1L: JLong, "1a3"), true), 3)) assertEquals(4, testHarness.numKeyedStateEntries()) assertEquals(2, testHarness.numProcessingTimeTimers()) - // all the records should be stored + // All the right rows will not be cached. testHarness.processElement2(new StreamRecord( - CRow(Row.of(1L: JLong, "bbb3"), true), 3)) - assertEquals(6, testHarness.numKeyedStateEntries()) - assertEquals(3, testHarness.numProcessingTimeTimers()) + CRow(Row.of(1L: JLong, "1b3"), true), 3)) + assertEquals(4, testHarness.numKeyedStateEntries()) + assertEquals(2, testHarness.numProcessingTimeTimers()) - // meet a.proctime <= b.proctime - 5 testHarness.setProcessingTime(7) - testHarness.processElement2(new StreamRecord( - CRow(Row.of(2L: JLong, "bbb7"), true), 7)) - assertEquals(6, testHarness.numKeyedStateEntries()) - assertEquals(3, testHarness.numProcessingTimeTimers()) - // expire record of stream a at timestamp 1 - testHarness.setProcessingTime(12) + // Meets a.proctime <= b.proctime - 5. + // This row will only be joined without being cached (7 >= 7 - 5). + testHarness.processElement2(new StreamRecord( + CRow(Row.of(2L: JLong, "2b7"), true), 7)) assertEquals(4, testHarness.numKeyedStateEntries()) assertEquals(2, testHarness.numProcessingTimeTimers()) + + testHarness.setProcessingTime(12) + // The left row (key = 1) with timestamp = 1 will be eagerly removed here. testHarness.processElement2(new StreamRecord( - CRow(Row.of(1L: JLong, "bbb12"), true), 12)) + CRow(Row.of(1L: JLong, "1b12"), true), 12)) + // We add a delay (relativeWindowSize / 2) for cleaning up state. + // No timers will be triggered here. testHarness.setProcessingTime(13) assertEquals(4, testHarness.numKeyedStateEntries()) assertEquals(2, testHarness.numProcessingTimeTimers()) - // state must be cleaned after the window timer interval has passed without new rows. - testHarness.setProcessingTime(23) + // Trigger the timer registered by the left row (key = 1) with timestamp = 1 + // (1 + 10 + 2 + 0 + 1 = 14). + // The left row (key = 1) with timestamp = 3 will removed here. + testHarness.setProcessingTime(14) + assertEquals(2, testHarness.numKeyedStateEntries()) + assertEquals(1, testHarness.numProcessingTimeTimers()) + + // Clean up the left row (key = 2) with timestamp = 2. + testHarness.setProcessingTime(16) assertEquals(0, testHarness.numKeyedStateEntries()) assertEquals(0, testHarness.numProcessingTimeTimers()) val result = testHarness.getOutput @@ -384,9 +233,9 @@ class JoinHarnessTest extends HarnessTestBase{ val expectedOutput = new ConcurrentLinkedQueue[Object]() expectedOutput.add(new StreamRecord( - CRow(Row.of(2L: JLong, "aaa2", 2L: JLong, "bbb7"), true), 7)) + CRow(Row.of(2L: JLong, "2a2", 2L: JLong, "2b7"), true), 7)) expectedOutput.add(new StreamRecord( - CRow(Row.of(1L: JLong, "aaa3", 1L: JLong, "bbb12"), true), 12)) + CRow(Row.of(1L: JLong, "1a3", 1L: JLong, "1b12"), true), 12)) verify(expectedOutput, result, new RowResultSortComparator()) @@ -395,10 +244,10 @@ class JoinHarnessTest extends HarnessTestBase{ /** a.c1 >= b.rowtime - 10 and a.rowtime <= b.rowtime + 20 **/ @Test - def testCommonRowTimeJoin() { + def testRowTimeJoinWithCommonBounds() { val joinProcessFunc = new RowTimeBoundedStreamInnerJoin( - -10, 20, 0, rT, rT, "TestJoinFunction", funcCode, 0, 0) + -10, 20, 0, rowType, rowType, "TestJoinFunction", funcCode, 0, 0) val operator: KeyedCoProcessOperator[String, CRow, CRow, CRow] = new KeyedCoProcessOperator[String, CRow, CRow, CRow](joinProcessFunc) @@ -407,16 +256,15 @@ class JoinHarnessTest extends HarnessTestBase{ operator, new TupleRowKeySelector[String](1), new TupleRowKeySelector[String](1), - BasicTypeInfo.STRING_TYPE_INFO, + Types.STRING, 1, 1, 0) testHarness.open() - // Advance testHarness.processWatermark1(new Watermark(1)) testHarness.processWatermark2(new Watermark(1)) - // Test late data + // Test late data. testHarness.processElement1(new StreamRecord[CRow]( CRow(Row.of(1L: JLong, "k1"), true), 0)) @@ -432,6 +280,7 @@ class JoinHarnessTest extends HarnessTestBase{ testHarness.processElement1(new StreamRecord[CRow]( CRow(Row.of(5L: JLong, "k1"), true), 0)) + testHarness.processElement2(new StreamRecord[CRow]( CRow(Row.of(15L: JLong, "k1"), true), 0)) @@ -443,6 +292,8 @@ class JoinHarnessTest extends HarnessTestBase{ testHarness.processElement1(new StreamRecord[CRow]( CRow(Row.of(35L: JLong, "k1"), true), 0)) + // The right rows with timestamp = 2 and 5 will be removed here. + // The left rows with timestamp = 2 and 15 will be removed here. testHarness.processWatermark1(new Watermark(38)) testHarness.processWatermark2(new Watermark(38)) @@ -453,6 +304,7 @@ class JoinHarnessTest extends HarnessTestBase{ assertEquals(6, testHarness.numKeyedStateEntries()) + // The right row with timestamp = 35 will be removed here. testHarness.processWatermark1(new Watermark(61)) testHarness.processWatermark2(new Watermark(61)) @@ -477,10 +329,10 @@ class JoinHarnessTest extends HarnessTestBase{ /** a.rowtime >= b.rowtime - 10 and a.rowtime <= b.rowtime - 7 **/ @Test - def testNegativeRowTimeJoin() { + def testRowTimeJoinWithNegativeBounds() { val joinProcessFunc = new RowTimeBoundedStreamInnerJoin( - -10, -7, 0, rT, rT, "TestJoinFunction", funcCode, 0, 0) + -10, -7, 0, rowType, rowType, "TestJoinFunction", funcCode, 0, 0) val operator: KeyedCoProcessOperator[String, CRow, CRow, CRow] = new KeyedCoProcessOperator[String, CRow, CRow, CRow](joinProcessFunc) @@ -489,40 +341,48 @@ class JoinHarnessTest extends HarnessTestBase{ operator, new TupleRowKeySelector[String](1), new TupleRowKeySelector[String](1), - BasicTypeInfo.STRING_TYPE_INFO, + Types.STRING, 1, 1, 0) testHarness.open() - // Advance testHarness.processWatermark1(new Watermark(1)) testHarness.processWatermark2(new Watermark(1)) + // This row will not be cached. testHarness.processElement2(new StreamRecord[CRow]( CRow(Row.of(2L: JLong, "k1"), true), 0)) + assertEquals(0, testHarness.numKeyedStateEntries()) + testHarness.processWatermark1(new Watermark(2)) testHarness.processWatermark2(new Watermark(2)) - assertEquals(0, testHarness.numKeyedStateEntries()) - testHarness.processElement1(new StreamRecord[CRow]( CRow(Row.of(3L: JLong, "k1"), true), 0)) testHarness.processElement2(new StreamRecord[CRow]( CRow(Row.of(3L: JLong, "k1"), true), 0)) + // Test for -10 boundary (13 - 10 = 3). + // This row from the right stream will be cached. + // The clean time for the left stream is 13 - 7 + 1 - 1 = 8 testHarness.processElement2(new StreamRecord[CRow]( CRow(Row.of(13L: JLong, "k1"), true), 0)) + + // Test for -7 boundary (13 - 7 = 6). testHarness.processElement1(new StreamRecord[CRow]( CRow(Row.of(6L: JLong, "k1"), true), 0)) assertEquals(4, testHarness.numKeyedStateEntries()) + // Trigger the left timer with timestamp 8. + // The row with timestamp = 13 will be removed here (13 < 10 + 7). testHarness.processWatermark1(new Watermark(10)) testHarness.processWatermark2(new Watermark(10)) assertEquals(2, testHarness.numKeyedStateEntries()) + // Clear the states. testHarness.processWatermark1(new Watermark(18)) testHarness.processWatermark2(new Watermark(18)) diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/JoinITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/JoinITCase.scala index ab7925b264950..602f4d4b52dce 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/JoinITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/JoinITCase.scala @@ -19,6 +19,7 @@ package org.apache.flink.table.runtime.stream.sql import org.apache.flink.api.scala._ +import org.apache.flink.streaming.api.TimeCharacteristic import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment import org.apache.flink.table.api.TableEnvironment import org.apache.flink.table.api.scala._ @@ -30,7 +31,7 @@ import scala.collection.mutable class JoinITCase extends StreamingWithStateTestBase { - /** test process time inner join **/ + /** test proctime inner join **/ @Test def testProcessTimeInnerJoin(): Unit = { val env = StreamExecutionEnvironment.getExecutionEnvironment @@ -39,8 +40,14 @@ class JoinITCase extends StreamingWithStateTestBase { StreamITCase.clear env.setParallelism(1) - val sqlQuery = "SELECT t2.a, t2.c, t1.c from T1 as t1 join T2 as t2 on t1.a = t2.a and " + - "t1.proctime between t2.proctime - interval '5' second and t2.proctime + interval '5' second" + val sqlQuery = + """ + |SELECT t2.a, t2.c, t1.c + |FROM T1 as t1 join T2 as t2 ON + | t1.a = t2.a AND + | t1.proctime BETWEEN t2.proctime - INTERVAL '5' SECOND AND + | t2.proctime + INTERVAL '5' SECOND + |""".stripMargin val data1 = new mutable.MutableList[(Int, Long, String)] data1.+=((1, 1L, "Hi1")) @@ -65,19 +72,25 @@ class JoinITCase extends StreamingWithStateTestBase { env.execute() } - /** test process time inner join with other condition **/ + /** test proctime inner join with other condition **/ @Test def testProcessTimeInnerJoinWithOtherCondition(): Unit = { val env = StreamExecutionEnvironment.getExecutionEnvironment val tEnv = TableEnvironment.getTableEnvironment(env) env.setStateBackend(getStateBackend) StreamITCase.clear - env.setParallelism(1) - - val sqlQuery = "SELECT t2.a, t2.c, t1.c from T1 as t1 join T2 as t2 on t1.a = t2.a and " + - "t1.proctime between t2.proctime - interval '5' second " + - "and t2.proctime + interval '5' second " + - "and t1.b > t2.b and t1.b + t2.b < 14" + env.setParallelism(2) + + val sqlQuery = + """ + |SELECT t2.a, t2.c, t1.c + |FROM T1 as t1 JOIN T2 as t2 ON + | t1.a = t2.a AND + | t1.proctime BETWEEN t2.proctime - interval '5' SECOND AND + | t2.proctime + interval '5' second AND + | t1.b > t2.b AND + | t1.b + t2.b < 14 + |""".stripMargin val data1 = new mutable.MutableList[(String, Long, String)] data1.+=(("1", 1L, "Hi1")) @@ -102,5 +115,118 @@ class JoinITCase extends StreamingWithStateTestBase { env.execute() } -} + /** test rowtime inner join **/ + @Test + def testRowTimeInnerJoin(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + env.setStateBackend(getStateBackend) + env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) + StreamITCase.clear + env.setParallelism(3) + + val sqlQuery = + """ + |SELECT t2.a, t2.c, t1.c + |FROM T1 as t1 join T2 as t2 ON + | t1.a = t2.a AND + | t1.rt BETWEEN t2.rt - INTERVAL '5' SECOND AND + | t2.rt + INTERVAL '6' SECOND + |""".stripMargin + + val data1 = new mutable.MutableList[(Int, Long, String, Long)] + // for boundary test + data1.+=((1, 999L, "LEFT0.999", 999L)) + data1.+=((1, 1000L, "LEFT1", 1000L)) + data1.+=((1, 2000L, "LEFT2", 2000L)) + data1.+=((1, 3000L, "LEFT3", 3000L)) + data1.+=((2, 4000L, "LEFT4", 4000L)) + data1.+=((1, 5000L, "LEFT5", 5000L)) + data1.+=((1, 6000L, "LEFT6", 6000L)) + + val data2 = new mutable.MutableList[(Int, Long, String, Long)] + data2.+=((1, 6000L, "RIGHT6", 6000L)) + data2.+=((2, 7000L, "RIGHT7", 7000L)) + + val t1 = env.fromCollection(data1) + .assignAscendingTimestamps(row => row._4) + .toTable(tEnv, 'a, 'b, 'c, 'rt.rowtime) + val t2 = env.fromCollection(data2) + .assignAscendingTimestamps(row => row._4) + .toTable(tEnv, 'a, 'b, 'c, 'rt.rowtime) + + tEnv.registerTable("T1", t1) + tEnv.registerTable("T2", t2) + + val result = tEnv.sql(sqlQuery).toAppendStream[Row] + result.addSink(new StreamITCase.StringSink[Row]) + env.execute() + val expected = new java.util.ArrayList[String] + expected.add("1,RIGHT6,LEFT1") + expected.add("1,RIGHT6,LEFT2") + expected.add("1,RIGHT6,LEFT3") + expected.add("1,RIGHT6,LEFT5") + expected.add("1,RIGHT6,LEFT6") + expected.add("2,RIGHT7,LEFT4") + StreamITCase.compareWithList(expected) + } + /** test rowtime inner join with other conditions **/ + @Test + def testRowTimeInnerJoinWithOtherCondition(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + env.setStateBackend(getStateBackend) + env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) + StreamITCase.clear + env.setParallelism(4) + + val sqlQuery = + """ + |SELECT t2.a, t1.c, t2.c + |FROM T1 as t1 JOIN T2 as t2 ON + | t1.a = t2.a AND + | t1.rt > t2.rt - INTERVAL '5' SECOND AND + | t1.rt < t2.rt - INTERVAL '1' SECOND AND + | t1.b < t2.b AND + | t1.b > 3 + |""".stripMargin + + val data1 = new mutable.MutableList[(Int, Long, String, Long)] + data1.+=((1, 4L, "LEFT1", 1000L)) + // for boundary test + data1.+=((1, 8L, "LEFT1.1", 1001L)) + data1.+=((1, 2L, "LEFT2", 2000L)) + data1.+=((1, 7L, "LEFT3", 3000L)) + data1.+=((2, 5L, "LEFT4", 4000L)) + // for boundary test + data1.+=((1, 4L, "LEFT4.9", 4999L)) + data1.+=((1, 4L, "LEFT5", 5000L)) + data1.+=((1, 10L, "LEFT6", 6000L)) + + val data2 = new mutable.MutableList[(Int, Long, String, Long)] + data2.+=((1, 9L, "RIGHT6", 6000L)) + data2.+=((2, 14L, "RIGHT7", 7000L)) + + val t1 = env.fromCollection(data1) + .assignAscendingTimestamps(row => row._4) + .toTable(tEnv, 'a, 'b, 'c, 'rt.rowtime) + val t2 = env.fromCollection(data2) + .assignAscendingTimestamps(row => row._4) + .toTable(tEnv, 'a, 'b, 'c, 'rt.rowtime) + + tEnv.registerTable("T1", t1) + tEnv.registerTable("T2", t2) + + val result = tEnv.sql(sqlQuery).toAppendStream[Row] + result.addSink(new StreamITCase.StringSink[Row]) + env.execute() + + val expected = new java.util.ArrayList[String] + expected.add("1,LEFT3,RIGHT6") + expected.add("1,LEFT1.1,RIGHT6") + expected.add("2,LEFT4,RIGHT7") + expected.add("1,LEFT4.9,RIGHT6") + StreamITCase.compareWithList(expected) + } +} From d71c48cf98cc93325d271260c8b826cdf3f5743e Mon Sep 17 00:00:00 2001 From: Xingcan Cui Date: Sat, 30 Sep 2017 22:09:56 +0800 Subject: [PATCH 4/7] Change the logic for joining late data and refine tests --- .../join/ProcTimeBoundedStreamInnerJoin.scala | 2 - .../join/RowTimeBoundedStreamInnerJoin.scala | 4 - .../join/TimeBoundedStreamInnerJoin.scala | 84 +++++++++---------- .../runtime/harness/JoinHarnessTest.scala | 7 +- .../table/runtime/stream/sql/JoinITCase.scala | 54 ++++++++++-- 5 files changed, 93 insertions(+), 58 deletions(-) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/ProcTimeBoundedStreamInnerJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/ProcTimeBoundedStreamInnerJoin.scala index 1efac7f74c049..01d2343804bfc 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/ProcTimeBoundedStreamInnerJoin.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/ProcTimeBoundedStreamInnerJoin.scala @@ -45,8 +45,6 @@ final class ProcTimeBoundedStreamInnerJoin( leftTimeIdx = -1, rightTimeIdx = -1) { - override def isRowTooLate(timeForRow: Long, watermark: Long): Boolean = false - override def updateOperatorTime(ctx: CoProcessFunction[CRow, CRow, CRow]#Context): Unit = { leftOperatorTime = ctx.timerService().currentProcessingTime() rightOperatorTime = leftOperatorTime diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/RowTimeBoundedStreamInnerJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/RowTimeBoundedStreamInnerJoin.scala index a48a9849f2755..aafcdff70974e 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/RowTimeBoundedStreamInnerJoin.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/RowTimeBoundedStreamInnerJoin.scala @@ -47,10 +47,6 @@ final class RowTimeBoundedStreamInnerJoin( leftTimeIdx, rightTimeIdx) { - override def isRowTooLate(timeForRow: Long, watermark: Long): Boolean = { - timeForRow <= watermark - allowedLateness - } - override def updateOperatorTime(ctx: CoProcessFunction[CRow, CRow, CRow]#Context): Unit = { leftOperatorTime = if (ctx.timerService().currentWatermark() > 0) ctx.timerService().currentWatermark() diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/TimeBoundedStreamInnerJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/TimeBoundedStreamInnerJoin.scala index db72b63a5b197..61ed10b86654a 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/TimeBoundedStreamInnerJoin.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/TimeBoundedStreamInnerJoin.scala @@ -80,9 +80,13 @@ abstract class TimeBoundedStreamInnerJoin( private val leftRelativeSize: Long = -leftLowerBound private val rightRelativeSize: Long = leftUpperBound + private var leftExpirationTime: Long = 0L; + private var rightExpirationTime: Long = 0L; + protected var leftOperatorTime: Long = 0L protected var rightOperatorTime: Long = 0L + // for delayed cleanup private val cleanupDelay = (leftRelativeSize + rightRelativeSize) / 2 @@ -150,24 +154,24 @@ abstract class TimeBoundedStreamInnerJoin( val timeForLeftRow: Long = getTimeForLeftStream(ctx, leftRow) val rightQualifiedLowerBound: Long = timeForLeftRow - rightRelativeSize val rightQualifiedUpperBound: Long = timeForLeftRow + leftRelativeSize - val rightExpirationTime = calExpirationTime(leftOperatorTime, rightRelativeSize) cRowWrapper.out = out - if (!isRowTooLate(timeForLeftRow, leftOperatorTime)) { - if (rightOperatorTime < rightQualifiedUpperBound) { - // Put the leftRow into the cache for later use. - var leftRowList = leftCache.get(timeForLeftRow) - if (null == leftRowList) { - leftRowList = new ArrayList[Row](1) - } - leftRowList.add(leftRow) - leftCache.put(timeForLeftRow, leftRowList) - if (rightTimerState.value == 0) { - // Register a timer on the RIGHT stream to remove rows. - registerCleanUpTimer(ctx, timeForLeftRow, rightTimerState, leftRow = true) - } + if (rightOperatorTime < rightQualifiedUpperBound) { + // Put the leftRow into the cache for later use. + var leftRowList = leftCache.get(timeForLeftRow) + if (null == leftRowList) { + leftRowList = new ArrayList[Row](1) } - + leftRowList.add(leftRow) + leftCache.put(timeForLeftRow, leftRowList) + if (rightTimerState.value == 0) { + // Register a timer on the RIGHT stream to remove rows. + registerCleanUpTimer(ctx, timeForLeftRow, rightTimerState, leftRow = true) + } + } + // We'd like to produce as many results as possible. + if (rightExpirationTime < rightQualifiedUpperBound) { + rightExpirationTime = calExpirationTime(leftOperatorTime, rightRelativeSize) // Join the leftRow with rows from the right cache. val rightIterator = rightCache.iterator() while (rightIterator.hasNext) { @@ -180,7 +184,9 @@ abstract class TimeBoundedStreamInnerJoin( joinFunction.join(leftRow, rightRows.get(i), cRowWrapper) i += 1 } - } else if (rightTime <= rightExpirationTime) { + } + + if (rightTime <= rightExpirationTime) { // eager remove rightIterator.remove() }// We could do the short-cutting optimization here once we get a state with ordered keys. @@ -200,24 +206,24 @@ abstract class TimeBoundedStreamInnerJoin( val timeForRightRow: Long = getTimeForRightStream(ctx, rightRow) val leftQualifiedLowerBound: Long = timeForRightRow - leftRelativeSize val leftQualifiedUpperBound: Long = timeForRightRow + rightRelativeSize - val leftExpirationTime = calExpirationTime(rightOperatorTime, leftRelativeSize) cRowWrapper.out = out - if (!isRowTooLate(timeForRightRow, rightOperatorTime)) { - if (leftOperatorTime < leftQualifiedUpperBound) { - // Put the rightRow into the cache for later use. - var rightRowList = rightCache.get(timeForRightRow) - if (null == rightRowList) { - rightRowList = new ArrayList[Row](1) - } - rightRowList.add(rightRow) - rightCache.put(timeForRightRow, rightRowList) - if (leftTimerState.value == 0) { - // Register a timer on the LEFT stream to remove rows. - registerCleanUpTimer(ctx, timeForRightRow, leftTimerState, leftRow = false) - } + if (leftOperatorTime < leftQualifiedUpperBound) { + // Put the rightRow into the cache for later use. + var rightRowList = rightCache.get(timeForRightRow) + if (null == rightRowList) { + rightRowList = new ArrayList[Row](1) } - + rightRowList.add(rightRow) + rightCache.put(timeForRightRow, rightRowList) + if (leftTimerState.value == 0) { + // Register a timer on the LEFT stream to remove rows. + registerCleanUpTimer(ctx, timeForRightRow, leftTimerState, leftRow = false) + } + } + // We'd like to produce as many results as possible. + if (leftExpirationTime < leftQualifiedUpperBound) { + leftExpirationTime = calExpirationTime(rightOperatorTime, leftRelativeSize) // Join the rightRow with rows from the left cache. val leftIterator = leftCache.iterator() while (leftIterator.hasNext) { @@ -230,7 +236,8 @@ abstract class TimeBoundedStreamInnerJoin( joinFunction.join(leftRows.get(i), rightRow, cRowWrapper) i += 1 } - } else if (leftTime <= leftExpirationTime) { + } + if (leftTime <= leftExpirationTime) { // eager remove leftIterator.remove() } // We could do the short-cutting optimization here once we get a state with ordered keys. @@ -256,7 +263,7 @@ abstract class TimeBoundedStreamInnerJoin( // registered timer of the faster stream will be delayed, even if the watermarks have // already been emitted by the source. if (leftTimerState.value == timestamp) { - val rightExpirationTime = calExpirationTime(leftOperatorTime, rightRelativeSize) + rightExpirationTime = calExpirationTime(leftOperatorTime, rightRelativeSize) removeExpiredRows( rightExpirationTime, rightCache, @@ -267,7 +274,7 @@ abstract class TimeBoundedStreamInnerJoin( } if (rightTimerState.value == timestamp) { - val leftExpirationTime = calExpirationTime(rightOperatorTime, leftRelativeSize) + leftExpirationTime = calExpirationTime(rightOperatorTime, leftRelativeSize) removeExpiredRows( leftExpirationTime, leftCache, @@ -365,15 +372,6 @@ abstract class TimeBoundedStreamInnerJoin( } } - /** - * Check if the row is too late to be processed. - * - * @param timeForRow time of the row - * @param operatorTime operator time for current stream - * @return true if the row is out of date; false otherwise - */ - def isRowTooLate(timeForRow: Long, operatorTime: Long): Boolean - /** * Update the operator time of the two streams. * Must be the first call in all processing methods (i.e., processElement(), onTimer()). diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/JoinHarnessTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/JoinHarnessTest.scala index 9ebd3ecd122c1..192befd759705 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/JoinHarnessTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/JoinHarnessTest.scala @@ -268,7 +268,8 @@ class JoinHarnessTest extends HarnessTestBase { testHarness.processElement1(new StreamRecord[CRow]( CRow(Row.of(1L: JLong, "k1"), true), 0)) - assertEquals(0, testHarness.numEventTimeTimers()) + // Though (1L, "k1") is actually late, it will also be cached. + assertEquals(1, testHarness.numEventTimeTimers()) testHarness.processElement1(new StreamRecord[CRow]( CRow(Row.of(2L: JLong, "k1"), true), 0)) @@ -322,6 +323,10 @@ class JoinHarnessTest extends HarnessTestBase { expectedOutput.add(new StreamRecord( CRow(Row.of(40L: JLong, "k2", 39L: JLong, "k2"), true), 0)) + // This result is produced by the late row (1, "k1"). + expectedOutput.add(new StreamRecord( + CRow(Row.of(1L: JLong, "k1", 2L: JLong, "k1"), true), 0)) + val result = testHarness.getOutput verify(expectedOutput, result, new RowResultSortComparator()) testHarness.close() diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/JoinITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/JoinITCase.scala index 602f4d4b52dce..bd879a6a53f45 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/JoinITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/JoinITCase.scala @@ -20,7 +20,9 @@ package org.apache.flink.table.runtime.stream.sql import org.apache.flink.api.scala._ import org.apache.flink.streaming.api.TimeCharacteristic +import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment +import org.apache.flink.streaming.api.watermark.Watermark import org.apache.flink.table.api.TableEnvironment import org.apache.flink.table.api.scala._ import org.apache.flink.table.runtime.utils.{StreamITCase, StreamingWithStateTestBase} @@ -74,7 +76,7 @@ class JoinITCase extends StreamingWithStateTestBase { /** test proctime inner join with other condition **/ @Test - def testProcessTimeInnerJoinWithOtherCondition(): Unit = { + def testProcessTimeInnerJoinWithOtherConditions(): Unit = { val env = StreamExecutionEnvironment.getExecutionEnvironment val tEnv = TableEnvironment.getTableEnvironment(env) env.setStateBackend(getStateBackend) @@ -123,7 +125,7 @@ class JoinITCase extends StreamingWithStateTestBase { env.setStateBackend(getStateBackend) env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) StreamITCase.clear - env.setParallelism(3) + env.setParallelism(1) val sqlQuery = """ @@ -149,10 +151,10 @@ class JoinITCase extends StreamingWithStateTestBase { data2.+=((2, 7000L, "RIGHT7", 7000L)) val t1 = env.fromCollection(data1) - .assignAscendingTimestamps(row => row._4) + .assignTimestampsAndWatermarks(new Tuple2WatermarkExtractor) .toTable(tEnv, 'a, 'b, 'c, 'rt.rowtime) val t2 = env.fromCollection(data2) - .assignAscendingTimestamps(row => row._4) + .assignTimestampsAndWatermarks(new Tuple2WatermarkExtractor) .toTable(tEnv, 'a, 'b, 'c, 'rt.rowtime) tEnv.registerTable("T1", t1) @@ -173,12 +175,14 @@ class JoinITCase extends StreamingWithStateTestBase { /** test rowtime inner join with other conditions **/ @Test - def testRowTimeInnerJoinWithOtherCondition(): Unit = { + def testRowTimeInnerJoinWithOtherConditions(): Unit = { val env = StreamExecutionEnvironment.getExecutionEnvironment val tEnv = TableEnvironment.getTableEnvironment(env) env.setStateBackend(getStateBackend) env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) StreamITCase.clear + + // different parallelisms lead to different join results env.setParallelism(4) val sqlQuery = @@ -189,13 +193,14 @@ class JoinITCase extends StreamingWithStateTestBase { | t1.rt > t2.rt - INTERVAL '5' SECOND AND | t1.rt < t2.rt - INTERVAL '1' SECOND AND | t1.b < t2.b AND - | t1.b > 3 + | t1.b > 2 |""".stripMargin val data1 = new mutable.MutableList[(Int, Long, String, Long)] data1.+=((1, 4L, "LEFT1", 1000L)) // for boundary test data1.+=((1, 8L, "LEFT1.1", 1001L)) + // predicate (t1.b > 2) push down data1.+=((1, 2L, "LEFT2", 2000L)) data1.+=((1, 7L, "LEFT3", 3000L)) data1.+=((2, 5L, "LEFT4", 4000L)) @@ -203,16 +208,23 @@ class JoinITCase extends StreamingWithStateTestBase { data1.+=((1, 4L, "LEFT4.9", 4999L)) data1.+=((1, 4L, "LEFT5", 5000L)) data1.+=((1, 10L, "LEFT6", 6000L)) + // a left late row + data1.+=((1, 3L, "LEFT3.5", 3500L)) val data2 = new mutable.MutableList[(Int, Long, String, Long)] + // just for watermark + data2.+=((1, 1L, "RIGHT1", 1000L)) data2.+=((1, 9L, "RIGHT6", 6000L)) data2.+=((2, 14L, "RIGHT7", 7000L)) + data2.+=((1, 4L, "RIGHT8", 8000L)) + // a right late row + data2.+=((1, 10L, "RIGHT5", 5000L)) val t1 = env.fromCollection(data1) - .assignAscendingTimestamps(row => row._4) + .assignTimestampsAndWatermarks(new Tuple2WatermarkExtractor) .toTable(tEnv, 'a, 'b, 'c, 'rt.rowtime) val t2 = env.fromCollection(data2) - .assignAscendingTimestamps(row => row._4) + .assignTimestampsAndWatermarks(new Tuple2WatermarkExtractor) .toTable(tEnv, 'a, 'b, 'c, 'rt.rowtime) tEnv.registerTable("T1", t1) @@ -227,6 +239,32 @@ class JoinITCase extends StreamingWithStateTestBase { expected.add("1,LEFT1.1,RIGHT6") expected.add("2,LEFT4,RIGHT7") expected.add("1,LEFT4.9,RIGHT6") + // produced by the left late rows + expected.add("1,LEFT3.5,RIGHT6") + expected.add("1,LEFT3.5,RIGHT8") + // produced by the right late rows + expected.add("1,LEFT3,RIGHT5") + expected.add("1,LEFT3.5,RIGHT5") + // these two results will only be produced when parallelism >= 2 + expected.add("1,LEFT1,RIGHT5") + expected.add("1,LEFT1.1,RIGHT5") + StreamITCase.compareWithList(expected) } } + +private class Tuple2WatermarkExtractor + extends AssignerWithPunctuatedWatermarks[(Int, Long, String, Long)] { + + override def checkAndGetNextWatermark( + lastElement: (Int, Long, String, Long), + extractedTimestamp: Long): Watermark = { + new Watermark(extractedTimestamp - 1) + } + + override def extractTimestamp( + element: (Int, Long, String, Long), + previousElementTimestamp: Long): Long = { + element._4 + } +} From 8c9bcd6b46f1592c7b53e173598c3d5c90a7bbc1 Mon Sep 17 00:00:00 2001 From: Xingcan Cui Date: Sun, 8 Oct 2017 23:37:20 +0800 Subject: [PATCH 5/7] Feedback addressed. 1. Fix the generic types of KeyedCoProcessOperatorWithWatermarkDelay. 2. Add return types for createEmptyInnerJoin(). 3. Include the equi-join predicates into the remainCondition. 4. Remove the timeIndex parameters for ProcTimeBoundedStreamInnerJoin. 5. Move the getMaxOutputDelay() to RowTimeBoundedStreamInnerJoin. 6. Refine comments. 7. Add a test for String null keys. 8. Other minor changes. --- .../datastream/DataStreamWindowJoin.scala | 15 ++-- .../datastream/DataStreamWindowJoinRule.scala | 6 +- .../join/ProcTimeBoundedStreamInnerJoin.scala | 4 +- .../join/RowTimeBoundedStreamInnerJoin.scala | 17 +++-- .../join/TimeBoundedStreamInnerJoin.scala | 75 +++++++++---------- .../table/runtime/stream/sql/JoinITCase.scala | 55 +++++++++----- 6 files changed, 95 insertions(+), 77 deletions(-) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamWindowJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamWindowJoin.scala index c5fa9314f425b..9358aa36bf570 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamWindowJoin.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamWindowJoin.scala @@ -25,6 +25,7 @@ import org.apache.calcite.rel.{BiRel, RelNode, RelWriter} import org.apache.calcite.rex.RexNode import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.functions.NullByteKeySelector +import org.apache.flink.api.java.tuple.Tuple import org.apache.flink.streaming.api.datastream.DataStream import org.apache.flink.streaming.api.functions.co.CoProcessFunction import org.apache.flink.table.api.{StreamQueryConfig, StreamTableEnvironment, TableException} @@ -138,7 +139,7 @@ class DataStreamWindowJoin( joinType match { case JoinRelType.INNER => if (relativeWindowSize < 0) { - LOG.warn(s"The relative window size ${relativeWindowSize} is negative," + + LOG.warn(s"The relative window size $relativeWindowSize is negative," + " please check the join conditions.") createEmptyInnerJoin(leftDataStream, rightDataStream, returnTypeInfo) } else { @@ -179,22 +180,22 @@ class DataStreamWindowJoin( def createEmptyInnerJoin( leftDataStream: DataStream[CRow], rightDataStream: DataStream[CRow], - returnTypeInfo: TypeInformation[CRow]) = { + returnTypeInfo: TypeInformation[CRow]): DataStream[CRow] = { leftDataStream.connect(rightDataStream).process( new CoProcessFunction[CRow, CRow, CRow] { override def processElement1( value: CRow, ctx: CoProcessFunction[CRow, CRow, CRow]#Context, - out: Collector[CRow]) = { + out: Collector[CRow]): Unit = { //Do nothing. } override def processElement2( value: CRow, ctx: CoProcessFunction[CRow, CRow, CRow]#Context, - out: Collector[CRow]) = { + out: Collector[CRow]): Unit = { //Do nothing. } - }) + }).returns(returnTypeInfo) } def createProcTimeInnerJoin( @@ -257,7 +258,7 @@ class DataStreamWindowJoin( .transform( "InnerRowtimeWindowJoin", returnTypeInfo, - new KeyedCoProcessOperatorWithWatermarkDelay[CRow, CRow, CRow, CRow]( + new KeyedCoProcessOperatorWithWatermarkDelay[Tuple, CRow, CRow, CRow]( rowTimeInnerJoinFunc, rowTimeInnerJoinFunc.getMaxOutputDelay) ) @@ -267,7 +268,7 @@ class DataStreamWindowJoin( .transform( "InnerRowtimeWindowJoin", returnTypeInfo, - new KeyedCoProcessOperatorWithWatermarkDelay[CRow, CRow, CRow, CRow]( + new KeyedCoProcessOperatorWithWatermarkDelay[java.lang.Byte, CRow, CRow, CRow]( rowTimeInnerJoinFunc, rowTimeInnerJoinFunc.getMaxOutputDelay) ) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamWindowJoinRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamWindowJoinRule.scala index a446864618649..d208d2b38a373 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamWindowJoinRule.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamWindowJoinRule.scala @@ -40,10 +40,9 @@ class DataStreamWindowJoinRule override def matches(call: RelOptRuleCall): Boolean = { val join: FlinkLogicalJoin = call.rel(0).asInstanceOf[FlinkLogicalJoin] - val joinInfo = join.analyzeCondition val (windowBounds, remainingPreds) = WindowJoinUtil.extractWindowBoundsFromPredicate( - joinInfo.getRemaining(join.getCluster.getRexBuilder), + join.getCondition, join.getLeft.getRowType.getFieldCount, join.getRowType, join.getCluster.getRexBuilder, @@ -79,13 +78,12 @@ class DataStreamWindowJoinRule val traitSet: RelTraitSet = rel.getTraitSet.replace(FlinkConventions.DATASTREAM) val convLeft: RelNode = RelOptRule.convert(join.getInput(0), FlinkConventions.DATASTREAM) val convRight: RelNode = RelOptRule.convert(join.getInput(1), FlinkConventions.DATASTREAM) - val joinInfo = join.analyzeCondition val leftRowSchema = new RowSchema(convLeft.getRowType) val rightRowSchema = new RowSchema(convRight.getRowType) val (windowBounds, remainCondition) = WindowJoinUtil.extractWindowBoundsFromPredicate( - joinInfo.getRemaining(join.getCluster.getRexBuilder), + join.getCondition, leftRowSchema.arity, join.getRowType, join.getCluster.getRexBuilder, diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/ProcTimeBoundedStreamInnerJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/ProcTimeBoundedStreamInnerJoin.scala index 01d2343804bfc..ab5a9c3092490 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/ProcTimeBoundedStreamInnerJoin.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/ProcTimeBoundedStreamInnerJoin.scala @@ -41,9 +41,7 @@ final class ProcTimeBoundedStreamInnerJoin( leftType, rightType, genJoinFuncName, - genJoinFuncCode, - leftTimeIdx = -1, - rightTimeIdx = -1) { + genJoinFuncCode) { override def updateOperatorTime(ctx: CoProcessFunction[CRow, CRow, CRow]#Context): Unit = { leftOperatorTime = ctx.timerService().currentProcessingTime() diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/RowTimeBoundedStreamInnerJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/RowTimeBoundedStreamInnerJoin.scala index aafcdff70974e..5cf5a5372cda3 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/RowTimeBoundedStreamInnerJoin.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/RowTimeBoundedStreamInnerJoin.scala @@ -43,17 +43,22 @@ final class RowTimeBoundedStreamInnerJoin( leftType, rightType, genJoinFuncName, - genJoinFuncCode, - leftTimeIdx, - rightTimeIdx) { + genJoinFuncCode) { + + /** + * Get the maximum interval between receiving a row and emitting it (as part of a joined result). + * Only reasonable for row time join. + * + * @return the maximum delay for the outputs + */ + def getMaxOutputDelay: Long = Math.max(leftRelativeSize, rightRelativeSize) + allowedLateness override def updateOperatorTime(ctx: CoProcessFunction[CRow, CRow, CRow]#Context): Unit = { leftOperatorTime = if (ctx.timerService().currentWatermark() > 0) ctx.timerService().currentWatermark() else 0L - rightOperatorTime = - if (ctx.timerService().currentWatermark() > 0) ctx.timerService().currentWatermark() - else 0L + // We may set different operator times in the future. + rightOperatorTime = leftOperatorTime } override def getTimeForLeftStream( diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/TimeBoundedStreamInnerJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/TimeBoundedStreamInnerJoin.scala index 61ed10b86654a..896881893b3df 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/TimeBoundedStreamInnerJoin.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/TimeBoundedStreamInnerJoin.scala @@ -18,7 +18,8 @@ package org.apache.flink.table.runtime.join -import java.util.{ArrayList, List => JList} +import java.util +import java.util.{List => JList} import org.apache.flink.api.common.functions.FlatJoinFunction import org.apache.flink.api.common.state._ @@ -55,9 +56,7 @@ abstract class TimeBoundedStreamInnerJoin( private val leftType: TypeInformation[Row], private val rightType: TypeInformation[Row], private val genJoinFuncName: String, - private val genJoinFuncCode: String, - private val leftTimeIdx: Int, - private val rightTimeIdx: Int) + private val genJoinFuncCode: String) extends CoProcessFunction[CRow, CRow, CRow] with Compiler[FlatJoinFunction[Row, Row, Row]] with Logging { @@ -77,11 +76,11 @@ abstract class TimeBoundedStreamInnerJoin( // state to record the timer on the right stream. 0 means no timer set private var rightTimerState: ValueState[Long] = _ - private val leftRelativeSize: Long = -leftLowerBound - private val rightRelativeSize: Long = leftUpperBound + protected val leftRelativeSize: Long = -leftLowerBound + protected val rightRelativeSize: Long = leftUpperBound - private var leftExpirationTime: Long = 0L; - private var rightExpirationTime: Long = 0L; + private var leftExpirationTime: Long = 0L + private var rightExpirationTime: Long = 0L protected var leftOperatorTime: Long = 0L protected var rightOperatorTime: Long = 0L @@ -94,14 +93,6 @@ abstract class TimeBoundedStreamInnerJoin( throw new IllegalArgumentException("The allowed lateness must be non-negative.") } - /** - * Get the maximum interval between receiving a row and emitting it (as part of a joined result). - * Only reasonable for row time join. - * - * @return the maximum delay for the outputs - */ - def getMaxOutputDelay: Long = Math.max(leftRelativeSize, rightRelativeSize) + allowedLateness - override def open(config: Configuration) { LOG.debug(s"Compiling JoinFunction: $genJoinFuncName \n\n " + s"Code:\n$genJoinFuncCode") @@ -155,22 +146,26 @@ abstract class TimeBoundedStreamInnerJoin( val rightQualifiedLowerBound: Long = timeForLeftRow - rightRelativeSize val rightQualifiedUpperBound: Long = timeForLeftRow + leftRelativeSize cRowWrapper.out = out - + // Check if we need to cache the current row. if (rightOperatorTime < rightQualifiedUpperBound) { - // Put the leftRow into the cache for later use. + // Operator time of right stream has not exceeded the upper window bound of the current + // row. Put it into the left cache, since later coming records from the right stream are + // expected to join with it. var leftRowList = leftCache.get(timeForLeftRow) if (null == leftRowList) { - leftRowList = new ArrayList[Row](1) + leftRowList = new util.ArrayList[Row](1) } leftRowList.add(leftRow) leftCache.put(timeForLeftRow, leftRowList) if (rightTimerState.value == 0) { // Register a timer on the RIGHT stream to remove rows. - registerCleanUpTimer(ctx, timeForLeftRow, rightTimerState, leftRow = true) + registerCleanUpTimer(ctx, timeForLeftRow, leftRow = true) } } - // We'd like to produce as many results as possible. + // Check if we need to join the current row against cached rows of the right input. if (rightExpirationTime < rightQualifiedUpperBound) { + // Upper bound of current join window has not passed the cache expiration time yet. + // There might be qualifying rows in the cache that the current row needs to be joined with. rightExpirationTime = calExpirationTime(leftOperatorTime, rightRelativeSize) // Join the leftRow with rows from the right cache. val rightIterator = rightCache.iterator() @@ -207,21 +202,23 @@ abstract class TimeBoundedStreamInnerJoin( val leftQualifiedLowerBound: Long = timeForRightRow - leftRelativeSize val leftQualifiedUpperBound: Long = timeForRightRow + rightRelativeSize cRowWrapper.out = out - + // Check if we need to cache the current row. if (leftOperatorTime < leftQualifiedUpperBound) { - // Put the rightRow into the cache for later use. + // Operator time of left stream has not exceeded the upper window bound of the current + // row. Put it into the right cache, since later coming records from the left stream are + // expected to join with it. var rightRowList = rightCache.get(timeForRightRow) if (null == rightRowList) { - rightRowList = new ArrayList[Row](1) + rightRowList = new util.ArrayList[Row](1) } rightRowList.add(rightRow) rightCache.put(timeForRightRow, rightRowList) if (leftTimerState.value == 0) { // Register a timer on the LEFT stream to remove rows. - registerCleanUpTimer(ctx, timeForRightRow, leftTimerState, leftRow = false) + registerCleanUpTimer(ctx, timeForRightRow, leftRow = false) } } - // We'd like to produce as many results as possible. + // Check if we need to join the current row against cached rows of the left input. if (leftExpirationTime < leftQualifiedUpperBound) { leftExpirationTime = calExpirationTime(rightOperatorTime, leftRelativeSize) // Join the rightRow with rows from the left cache. @@ -305,21 +302,21 @@ abstract class TimeBoundedStreamInnerJoin( * * @param ctx the context to register timer * @param rowTime time for the input row - * @param timerState stores the timestamp for the next timer * @param leftRow whether this row comes from the left stream */ private def registerCleanUpTimer( ctx: CoProcessFunction[CRow, CRow, CRow]#Context, rowTime: Long, - timerState: ValueState[Long], leftRow: Boolean): Unit = { - val cleanupTime = if (leftRow) { - rowTime + leftRelativeSize + cleanupDelay + allowedLateness + 1 + if (leftRow) { + val cleanupTime = rowTime + leftRelativeSize + cleanupDelay + allowedLateness + 1 + registerTimer(ctx, cleanupTime) + rightTimerState.update(cleanupTime) } else { - rowTime + rightRelativeSize + cleanupDelay + allowedLateness + 1 + val cleanupTime = rowTime + rightRelativeSize + cleanupDelay + allowedLateness + 1 + registerTimer(ctx, cleanupTime) + leftTimerState.update(cleanupTime) } - registerTimer(ctx, cleanupTime) - timerState.update(cleanupTime) } /** @@ -341,11 +338,11 @@ abstract class TimeBoundedStreamInnerJoin( val keysIterator = rowCache.keys().iterator() - // Search for expired timestamps. - // If we find a non-expired timestamp, remember the timestamp and leave the loop. - // This way we find all expired timestamps if they are sorted without doing a full pass. var earliestTimestamp: Long = -1L var rowTime: Long = 0L + + // We remove all expired keys and do not leave the loop early. + // Hence, we do a full pass over the state. while (keysIterator.hasNext) { rowTime = keysIterator.next if (rowTime <= expirationTime) { @@ -357,16 +354,14 @@ abstract class TimeBoundedStreamInnerJoin( } } } - // If the cache contains non-expired timestamps, register a new timer. - // Otherwise clear the states. if (earliestTimestamp > 0) { + // There are rows left in the cache. Register a timer to expire them later. registerCleanUpTimer( ctx, earliestTimestamp, - timerState, removeLeft) } else { - // The timerState will be 0. + // No rows left in the cache. Clear the states and the timerState will be 0. timerState.clear() rowCache.clear() } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/JoinITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/JoinITCase.scala index bd879a6a53f45..976e50f7c4886 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/JoinITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/JoinITCase.scala @@ -27,6 +27,7 @@ import org.apache.flink.table.api.TableEnvironment import org.apache.flink.table.api.scala._ import org.apache.flink.table.runtime.utils.{StreamITCase, StreamingWithStateTestBase} import org.apache.flink.types.Row +import org.hamcrest.CoreMatchers import org.junit._ import scala.collection.mutable @@ -90,8 +91,7 @@ class JoinITCase extends StreamingWithStateTestBase { | t1.a = t2.a AND | t1.proctime BETWEEN t2.proctime - interval '5' SECOND AND | t2.proctime + interval '5' second AND - | t1.b > t2.b AND - | t1.b + t2.b < 14 + | t1.b = t2.b |""".stripMargin val data1 = new mutable.MutableList[(String, Long, String)] @@ -106,6 +106,10 @@ class JoinITCase extends StreamingWithStateTestBase { data2.+=(("1", 5L, "HiHi")) data2.+=(("2", 2L, "HeHe")) + // For null key test + data1.+=((null.asInstanceOf[String], 20L, "leftNull")) + data2.+=((null.asInstanceOf[String], 20L, "rightNull")) + val t1 = env.fromCollection(data1).toTable(tEnv, 'a, 'b, 'c, 'proctime.proctime) val t2 = env.fromCollection(data2).toTable(tEnv, 'a, 'b, 'c, 'proctime.proctime) @@ -115,6 +119,9 @@ class JoinITCase extends StreamingWithStateTestBase { val result = tEnv.sql(sqlQuery).toAppendStream[Row] result.addSink(new StreamITCase.StringSink[Row]) env.execute() + + // Assert there is no result with null keys. + Assert.assertFalse(StreamITCase.testResults.toString().contains("null")) } /** test rowtime inner join **/ @@ -183,7 +190,7 @@ class JoinITCase extends StreamingWithStateTestBase { StreamITCase.clear // different parallelisms lead to different join results - env.setParallelism(4) + env.setParallelism(1) val sqlQuery = """ @@ -234,22 +241,36 @@ class JoinITCase extends StreamingWithStateTestBase { result.addSink(new StreamITCase.StringSink[Row]) env.execute() - val expected = new java.util.ArrayList[String] - expected.add("1,LEFT3,RIGHT6") - expected.add("1,LEFT1.1,RIGHT6") - expected.add("2,LEFT4,RIGHT7") - expected.add("1,LEFT4.9,RIGHT6") + // There may be two expected results according to the process order. + val expected1 = new mutable.MutableList[String] + expected1+= "1,LEFT3,RIGHT6" + expected1+= "1,LEFT1.1,RIGHT6" + expected1+= "2,LEFT4,RIGHT7" + expected1+= "1,LEFT4.9,RIGHT6" // produced by the left late rows - expected.add("1,LEFT3.5,RIGHT6") - expected.add("1,LEFT3.5,RIGHT8") + expected1+= "1,LEFT3.5,RIGHT6" + expected1+= "1,LEFT3.5,RIGHT8" // produced by the right late rows - expected.add("1,LEFT3,RIGHT5") - expected.add("1,LEFT3.5,RIGHT5") - // these two results will only be produced when parallelism >= 2 - expected.add("1,LEFT1,RIGHT5") - expected.add("1,LEFT1.1,RIGHT5") - - StreamITCase.compareWithList(expected) + expected1+= "1,LEFT3,RIGHT5" + expected1+= "1,LEFT3.5,RIGHT5" + + val expected2 = new mutable.MutableList[String] + expected2+= "1,LEFT3,RIGHT6" + expected2+= "1,LEFT1.1,RIGHT6" + expected2+= "2,LEFT4,RIGHT7" + expected2+= "1,LEFT4.9,RIGHT6" + // produced by the left late rows + expected2+= "1,LEFT3.5,RIGHT6" + expected2+= "1,LEFT3.5,RIGHT8" + // produced by the right late rows + expected2+= "1,LEFT3,RIGHT5" + expected2+= "1,LEFT1,RIGHT5" + expected2+= "1,LEFT1.1,RIGHT5" + + Assert.assertThat( + StreamITCase.testResults.sorted, + CoreMatchers.either(CoreMatchers.is(expected1.sorted)). + or(CoreMatchers.is(expected2.sorted))) } } From 70164294ef2f53b4b04a56a3c23fb44d74661d63 Mon Sep 17 00:00:00 2001 From: Xingcan Cui Date: Mon, 9 Oct 2017 17:04:05 +0800 Subject: [PATCH 6/7] Add some comments. --- .../runtime/join/TimeBoundedStreamInnerJoin.scala | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/TimeBoundedStreamInnerJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/TimeBoundedStreamInnerJoin.scala index 896881893b3df..7bf3d33f158b9 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/TimeBoundedStreamInnerJoin.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/TimeBoundedStreamInnerJoin.scala @@ -150,7 +150,7 @@ abstract class TimeBoundedStreamInnerJoin( if (rightOperatorTime < rightQualifiedUpperBound) { // Operator time of right stream has not exceeded the upper window bound of the current // row. Put it into the left cache, since later coming records from the right stream are - // expected to join with it. + // expected to be joined with it. var leftRowList = leftCache.get(timeForLeftRow) if (null == leftRowList) { leftRowList = new util.ArrayList[Row](1) @@ -163,6 +163,9 @@ abstract class TimeBoundedStreamInnerJoin( } } // Check if we need to join the current row against cached rows of the right input. + // The condition here should be rightMinimumTime < rightQualifiedUpperBound. + // I use rightExpirationTime as an approximation of the rightMinimumTime here, + // since rightExpirationTime <= rightMinimumTime is always true. if (rightExpirationTime < rightQualifiedUpperBound) { // Upper bound of current join window has not passed the cache expiration time yet. // There might be qualifying rows in the cache that the current row needs to be joined with. @@ -206,7 +209,7 @@ abstract class TimeBoundedStreamInnerJoin( if (leftOperatorTime < leftQualifiedUpperBound) { // Operator time of left stream has not exceeded the upper window bound of the current // row. Put it into the right cache, since later coming records from the left stream are - // expected to join with it. + // expected to be joined with it. var rightRowList = rightCache.get(timeForRightRow) if (null == rightRowList) { rightRowList = new util.ArrayList[Row](1) @@ -219,6 +222,9 @@ abstract class TimeBoundedStreamInnerJoin( } } // Check if we need to join the current row against cached rows of the left input. + // The condition here should be leftMinimumTime < leftQualifiedUpperBound. + // I use leftExpirationTime as an approximation of the leftMinimumTime here, + // since leftExpirationTime <= leftMinimumTime is always true. if (leftExpirationTime < leftQualifiedUpperBound) { leftExpirationTime = calExpirationTime(rightOperatorTime, leftRelativeSize) // Join the rightRow with rows from the left cache. @@ -293,6 +299,7 @@ abstract class TimeBoundedStreamInnerJoin( if (operatorTime < Long.MaxValue) { operatorTime - relativeSize - allowedLateness - 1 } else { + // When operatorTime = Long.MaxValue, it means the stream has reached the end. Long.MaxValue } } From 7d37ac965bb7ace971e8f1709b11f5e8ba36eb7f Mon Sep 17 00:00:00 2001 From: Xingcan Cui Date: Tue, 10 Oct 2017 00:43:53 +0800 Subject: [PATCH 7/7] Change the results for remain condition test. --- .../flink/table/api/stream/sql/JoinTest.scala | 22 +++++++++---------- 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/JoinTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/JoinTest.scala index be72d2e822ce5..81a8fabb81f22 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/JoinTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/JoinTest.scala @@ -32,7 +32,7 @@ class JoinTest extends TableTestBase { streamUtil.addTable[(Int, String, Long)]("MyTable2", 'a, 'b, 'c.rowtime, 'proctime.proctime) @Test - def testProcessingTimeInnerJoinWithOnClause() = { + def testProcessingTimeInnerJoinWithOnClause(): Unit = { val sqlQuery = """ @@ -70,7 +70,7 @@ class JoinTest extends TableTestBase { } @Test - def testRowTimeInnerJoinWithOnClause() = { + def testRowTimeInnerJoinWithOnClause(): Unit = { val sqlQuery = """ @@ -108,7 +108,7 @@ class JoinTest extends TableTestBase { } @Test - def testProcessingTimeInnerJoinWithWhereClause() = { + def testProcessingTimeInnerJoinWithWhereClause(): Unit = { val sqlQuery = """ @@ -146,7 +146,7 @@ class JoinTest extends TableTestBase { } @Test - def testRowTimeInnerJoinWithWhereClause() = { + def testRowTimeInnerJoinWithWhereClause(): Unit = { val sqlQuery = """ @@ -251,16 +251,17 @@ class JoinTest extends TableTestBase { "SELECT t1.a, t2.c FROM MyTable3 as t1 join MyTable4 as t2 on t1.a = t2.a and " + "t1.b >= t2.b - interval '10' second and t1.b <= t2.b - interval '5' second and " + "t1.c > t2.c" + // The equi-join predicate should also be included verifyRemainConditionConvert( query, - ">($2, $6)") + "AND(=($0, $4), >($2, $6))") val query1 = "SELECT t1.a, t2.c FROM MyTable3 as t1 join MyTable4 as t2 on t1.a = t2.a and " + "t1.b >= t2.b - interval '10' second and t1.b <= t2.b - interval '5' second " verifyRemainConditionConvert( query1, - "") + "=($0, $4)") streamUtil.addTable[(Int, Long, Int)]("MyTable5", 'a, 'b, 'c, 'proctime.proctime) streamUtil.addTable[(Int, Long, Int)]("MyTable6", 'a, 'b, 'c, 'proctime.proctime) @@ -271,7 +272,7 @@ class JoinTest extends TableTestBase { "t1.c > t2.c" verifyRemainConditionConvert( query2, - ">($2, $6)") + "AND(=($0, $4), >($2, $6))") } private def verifyTimeBoundary( @@ -285,10 +286,9 @@ class JoinTest extends TableTestBase { val resultTable = streamUtil.tableEnv.sql(query) val relNode = resultTable.getRelNode val joinNode = relNode.getInput(0).asInstanceOf[LogicalJoin] - val rexNode = joinNode.getCondition val (windowBounds, _) = WindowJoinUtil.extractWindowBoundsFromPredicate( - rexNode, + joinNode.getCondition, 4, joinNode.getRowType, joinNode.getCluster.getRexBuilder, @@ -309,11 +309,9 @@ class JoinTest extends TableTestBase { val resultTable = streamUtil.tableEnv.sql(query) val relNode = resultTable.getRelNode val joinNode = relNode.getInput(0).asInstanceOf[LogicalJoin] - val joinInfo = joinNode.analyzeCondition - val rexNode = joinInfo.getRemaining(joinNode.getCluster.getRexBuilder) val (_, remainCondition) = WindowJoinUtil.extractWindowBoundsFromPredicate( - rexNode, + joinNode.getCondition, 4, joinNode.getRowType, joinNode.getCluster.getRexBuilder,