From e5af8e12112533150321a292a261489b3523e23a Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Fri, 16 Oct 2020 18:58:45 -0700 Subject: [PATCH 1/6] left semi stream-stream join --- .../UnsupportedOperationChecker.scala | 11 +- .../sql/catalyst/expressions/JoinedRow.scala | 10 + .../analysis/UnsupportedOperationsSuite.scala | 66 ++- .../StreamingSymmetricHashJoinExec.scala | 114 +++-- .../state/SymmetricHashJoinStateManager.scala | 11 +- .../sql/streaming/StreamingJoinSuite.scala | 399 ++++++++++++------ 6 files changed, 439 insertions(+), 172 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index 46f178f3a9ce2..3df7b726d94d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -291,17 +291,17 @@ object UnsupportedOperationChecker extends Logging { throwError("Full outer joins with streaming DataFrames/Datasets are not supported") } - case LeftSemi | LeftAnti => + case LeftAnti => if (right.isStreaming) { - throwError("Left semi/anti joins with a streaming DataFrame/Dataset " + + throwError("Left anti joins with a streaming DataFrame/Dataset " + "on the right are not supported") } // We support streaming left outer joins with static on the right always, and with // stream on both sides under the appropriate conditions. - case LeftOuter => + case LeftOuter | LeftSemi => if (!left.isStreaming && right.isStreaming) { - throwError("Left outer join with a streaming DataFrame/Dataset " + + throwError(s"$joinType join with a streaming DataFrame/Dataset " + "on the right and a static DataFrame/Dataset on the left is not supported") } else if (left.isStreaming && right.isStreaming) { val watermarkInJoinKeys = StreamingJoinHelper.isWatermarkInJoinKeys(subPlan) @@ -311,7 +311,8 @@ object UnsupportedOperationChecker extends Logging { left.outputSet, right.outputSet, condition, Some(1000000)).isDefined if (!watermarkInJoinKeys && !hasValidWatermarkRange) { - throwError("Stream-stream outer join between two streaming DataFrame/Datasets " + + throwError( + s"Stream-stream $joinType join between two streaming DataFrame/Datasets " + "is not supported without a watermark in the join keys, or a watermark on " + "the nullable side and an appropriate range condition") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala index 7770684a5b399..86871223d66ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala @@ -55,6 +55,16 @@ class JoinedRow extends InternalRow { this } + /** Gets this JoinedRow's left base row. */ + def getLeft: InternalRow = { + row1 + } + + /** Gets this JoinedRow's right base row. */ + def getRight: InternalRow = { + row2 + } + override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = { assert(fieldTypes.length == row1.numFields + row2.numFields) val (left, right) = fieldTypes.splitAt(row1.numFields) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index 3ec6fdeedd4b8..f1b8436533d1f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -490,7 +490,69 @@ class UnsupportedOperationsSuite extends SparkFunSuite { _.join(_, joinType = LeftSemi), streamStreamSupported = false, batchStreamSupported = false, - expectedMsg = "left semi/anti joins") + expectedMsg = "LeftSemi join") + + // Left semi joins: update and complete mode not allowed + assertNotSupportedInStreamingPlan( + s"left semi join with stream-stream relations and update mode", + streamRelation.join(streamRelation, joinType = LeftSemi, + condition = Some(attribute === attribute)), + OutputMode.Update(), + Seq("is not supported in Update output mode")) + assertNotSupportedInStreamingPlan( + s"left semi join with stream-stream relations and complete mode", + Aggregate(Nil, aggExprs("d"), streamRelation.join(streamRelation, joinType = LeftSemi, + condition = Some(attribute === attribute))), + OutputMode.Complete(), + Seq("is not supported in Complete output mode")) + + // Left ousemiter joins: stream-stream allowed with join on watermark attribute + // Note that the attribute need not be watermarked on both sides. + assertSupportedInStreamingPlan( + s"left semi join with stream-stream relations and join on attribute with left watermark", + streamRelation.join(streamRelation, joinType = LeftSemi, + condition = Some(attributeWithWatermark === attribute)), + OutputMode.Append()) + assertSupportedInStreamingPlan( + s"left semi join with stream-stream relations and join on attribute with right watermark", + streamRelation.join(streamRelation, joinType = LeftSemi, + condition = Some(attribute === attributeWithWatermark)), + OutputMode.Append()) + assertNotSupportedInStreamingPlan( + s"left semi join with stream-stream relations and join on non-watermarked attribute", + streamRelation.join(streamRelation, joinType = LeftSemi, + condition = Some(attribute === attribute)), + OutputMode.Append(), + Seq("watermark in the join keys")) + + // Left semi joins: stream-stream allowed with range condition yielding state value watermark + assertSupportedInStreamingPlan( + s"left semi join with stream-stream relations and state value watermark", { + val leftRelation = streamRelation + val rightTimeWithWatermark = + AttributeReference("b", IntegerType)().withMetadata(watermarkMetadata) + val rightRelation = new TestStreamingRelation(rightTimeWithWatermark) + leftRelation.join( + rightRelation, + joinType = LeftSemi, + condition = Some(attribute > rightTimeWithWatermark + 10)) + }, + OutputMode.Append()) + + // Left semi joins: stream-stream not allowed with insufficient range condition + assertNotSupportedInStreamingPlan( + s"left semi join with stream-stream relations and state value watermark", { + val leftRelation = streamRelation + val rightTimeWithWatermark = + AttributeReference("b", IntegerType)().withMetadata(watermarkMetadata) + val rightRelation = new TestStreamingRelation(rightTimeWithWatermark) + leftRelation.join( + rightRelation, + joinType = LeftSemi, + condition = Some(attribute < rightTimeWithWatermark + 10)) + }, + OutputMode.Append(), + Seq("appropriate range condition")) // Left anti joins: stream-* not allowed testBinaryOperationInStreamingPlan( @@ -498,7 +560,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite { _.join(_, joinType = LeftAnti), streamStreamSupported = false, batchStreamSupported = false, - expectedMsg = "left semi/anti joins") + expectedMsg = "Left anti join") // Right outer joins: stream-* not allowed testBinaryOperationInStreamingPlan( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index a52f5f4ac94ae..73e78513a3ef0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -152,7 +152,8 @@ case class StreamingSymmetricHashJoinExec( } if (stateFormatVersion < 2 && joinType != Inner) { - throw new IllegalArgumentException("The query is using stream-stream outer join with state" + + throw new IllegalArgumentException( + s"The query is using stream-stream $joinType join with state" + s" format version ${stateFormatVersion} - correctness issue is discovered. Please discard" + " the checkpoint and rerun the query. See SPARK-26154 for more details.") } @@ -165,7 +166,7 @@ case class StreamingSymmetricHashJoinExec( } require( - joinType == Inner || joinType == LeftOuter || joinType == RightOuter, + joinType == Inner || joinType == LeftOuter || joinType == RightOuter || joinType == LeftSemi, errorMessageForJoinType) require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType)) @@ -185,6 +186,7 @@ case class StreamingSymmetricHashJoinExec( case _: InnerLike => left.output ++ right.output case LeftOuter => left.output ++ right.output.map(_.withNullability(true)) case RightOuter => left.output.map(_.withNullability(true)) ++ right.output + case LeftSemi => left.output case _ => throwBadJoinTypeException() } @@ -193,6 +195,7 @@ case class StreamingSymmetricHashJoinExec( PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) case LeftOuter => left.outputPartitioning case RightOuter => right.outputPartitioning + case LeftSemi => left.outputPartitioning case _ => throwBadJoinTypeException() } @@ -246,14 +249,21 @@ case class StreamingSymmetricHashJoinExec( // Join one side input using the other side's buffered/state rows. Here is how it is done. // - // - `leftSideJoiner.storeAndJoinWithOtherSide(rightSideJoiner)` generates all rows from - // matching new left input with stored right input, and also stores all the left input + // - `leftSideJoiner.storeAndJoinWithOtherSide(rightSideJoiner)` + // - inner, left outer, right outer join: generates all rows from matching new left input + // with stored right input, and also stores all the left input. + // - left semi join: generates all new left input rows from matching new left input with + // stored right input, and also stores all the non-matched left input. // - // - `rightSideJoiner.storeAndJoinWithOtherSide(leftSideJoiner)` generates all rows from - // matching new right input with stored left input, and also stores all the right input. - // It also generates all rows from matching new left input with new right input, since - // the new left input has become stored by that point. This tiny asymmetry is necessary - // to avoid duplication. + // - `rightSideJoiner.storeAndJoinWithOtherSide(leftSideJoiner)` + // - inner, left outer, right outer join: generates all rows from matching new right input + // with stored left input, and also stores all the right input. + // It also generates all rows from matching new left input with new right input, since + // the new left input has become stored by that point. This tiny asymmetry is necessary + // to avoid duplication. + // - left semi join: generates all stored left input rows, from matching new right input + // with stored left input, and also stores all the right input. Note only first-time + // matched left input rows will be generated, this is to guarantee left semi semantics. val leftOutputIter = leftSideJoiner.storeAndJoinWithOtherSide(rightSideJoiner) { (input: InternalRow, matched: InternalRow) => joinedRow.withLeft(input).withRight(matched) } @@ -261,22 +271,21 @@ case class StreamingSymmetricHashJoinExec( (input: InternalRow, matched: InternalRow) => joinedRow.withLeft(matched).withRight(input) } - // We need to save the time that the inner join output iterator completes, since outer join - // output counts as both update and removal time. - var innerOutputCompletionTimeNs: Long = 0 - def onInnerOutputCompletion = { - innerOutputCompletionTimeNs = System.nanoTime + // We need to save the time that the one side hash join output iterator completes, since + // other join (outer and semi) output counts as both update and removal time. + var hashJoinOutputCompletionTimeNs: Long = 0 + def onHashJoinOutputCompletion(): Unit = { + hashJoinOutputCompletionTimeNs = System.nanoTime } - // This is the iterator which produces the inner join rows. For outer joins, this will be - // prepended to a second iterator producing outer join rows; for inner joins, this is the full - // output. - val innerOutputIter = CompletionIterator[InternalRow, Iterator[InternalRow]]( - (leftOutputIter ++ rightOutputIter), onInnerOutputCompletion) - + // This is the iterator which produces the inner join rows. For other joins (outer and semi), + // this will be prepended to a second iterator producing other rows; for inner joins, this is + // the full output. + val hashJoinOutputIter = CompletionIterator[InternalRow, Iterator[InternalRow]]( + leftOutputIter ++ rightOutputIter, onHashJoinOutputCompletion()) val outputIter: Iterator[InternalRow] = joinType match { case Inner => - innerOutputIter + hashJoinOutputIter case LeftOuter => // We generate the outer join input by: // * Getting an iterator over the rows that have aged out on the left side. These rows are @@ -311,7 +320,7 @@ case class StreamingSymmetricHashJoinExec( } }.map(pair => joinedRow.withLeft(pair.value).withRight(nullRight)) - innerOutputIter ++ outerOutputIter + hashJoinOutputIter ++ outerOutputIter case RightOuter => // See comments for left outer case. def matchesWithLeftSideState(rightKeyValue: UnsafeRowPair) = { @@ -330,11 +339,17 @@ case class StreamingSymmetricHashJoinExec( } }.map(pair => joinedRow.withLeft(nullLeft).withRight(pair.value)) - innerOutputIter ++ outerOutputIter + hashJoinOutputIter ++ outerOutputIter + case LeftSemi => + hashJoinOutputIter case _ => throwBadJoinTypeException() } - val outputProjection = UnsafeProjection.create(left.output ++ right.output, output) + val outputProjection = if (joinType == LeftSemi) { + UnsafeProjection.create(output, output) + } else { + UnsafeProjection.create(left.output ++ right.output, output) + } val outputIterWithMetrics = outputIter.map { row => numOutputRows += 1 outputProjection(row) @@ -347,24 +362,29 @@ case class StreamingSymmetricHashJoinExec( // Processing time between inner output completion and here comes from the outer portion of a // join, and thus counts as removal time as we remove old state from one side while iterating. - if (innerOutputCompletionTimeNs != 0) { + if (hashJoinOutputCompletionTimeNs != 0) { allRemovalsTimeMs += - math.max(NANOSECONDS.toMillis(System.nanoTime - innerOutputCompletionTimeNs), 0) + math.max(NANOSECONDS.toMillis(System.nanoTime - hashJoinOutputCompletionTimeNs), 0) } allRemovalsTimeMs += timeTakenMs { // Remove any remaining state rows which aren't needed because they're below the watermark. // // For inner joins, we have to remove unnecessary state rows from both sides if possible. + // // For outer joins, we have already removed unnecessary state rows from the outer side // (e.g., left side for left outer join) while generating the outer "null" outputs. Now, we // have to remove unnecessary state rows from the other side (e.g., right side for the left // outer join) if possible. In all cases, nothing needs to be outputted, hence the removal // needs to be done greedily by immediately consuming the returned iterator. + // + // For left semi joins, we have to remove unnecessary state rows from both sides if + // possible. val cleanupIter = joinType match { case Inner => leftSideJoiner.removeOldState() ++ rightSideJoiner.removeOldState() case LeftOuter => rightSideJoiner.removeOldState() case RightOuter => leftSideJoiner.removeOldState() + case LeftSemi => leftSideJoiner.removeOldState() ++ rightSideJoiner.removeOldState() case _ => throwBadJoinTypeException() } while (cleanupIter.hasNext) { @@ -481,6 +501,26 @@ case class StreamingSymmetricHashJoinExec( case _ => (_: InternalRow) => Iterator.empty } + val joinOnlyFirstTimeMatchedRow = joinType == LeftSemi && joinSide == RightSide + + val generateOutputIter: (InternalRow, Iterator[JoinedRow]) => Iterator[InternalRow] = + joinSide match { + case LeftSide if joinType == LeftSemi => + (input: InternalRow, joinedRowIter: Iterator[JoinedRow]) => + // For left side of left semi join, generate one left row if there is matched + // rows from right side. Otherwise, generate nothing. + if (joinedRowIter.nonEmpty) { + Iterator.single(input) + } else { + Iterator.empty + } + case RightSide if joinType == LeftSemi => + (_: InternalRow, joinedRowIter: Iterator[JoinedRow]) => + // For right side of left semi join, generate matched left rows only. + joinedRowIter.map(_.getLeft) + case _ => (_: InternalRow, joinedRowIter: Iterator[JoinedRow]) => joinedRowIter + } + nonLateRows.flatMap { row => val thisRow = row.asInstanceOf[UnsafeRow] // If this row fails the pre join filter, that means it can never satisfy the full join @@ -489,8 +529,12 @@ case class StreamingSymmetricHashJoinExec( // the case of inner join). if (preJoinFilter(thisRow)) { val key = keyGenerator(thisRow) - val outputIter: Iterator[JoinedRow] = otherSideJoiner.joinStateManager - .getJoinedRows(key, thatRow => generateJoinedRow(thisRow, thatRow), postJoinFilter) + val joinedRowIter: Iterator[JoinedRow] = otherSideJoiner.joinStateManager.getJoinedRows( + key, + thatRow => generateJoinedRow(thisRow, thatRow), + postJoinFilter, + joinOnlyFirstTimeMatchedRow) + val outputIter = generateOutputIter(thisRow, joinedRowIter) new AddingProcessedRowToStateCompletionIterator(key, thisRow, outputIter) } else { generateFilteredJoinedRow(thisRow) @@ -501,13 +545,19 @@ case class StreamingSymmetricHashJoinExec( private class AddingProcessedRowToStateCompletionIterator( key: UnsafeRow, thisRow: UnsafeRow, - subIter: Iterator[JoinedRow]) - extends CompletionIterator[JoinedRow, Iterator[JoinedRow]](subIter) { + subIter: Iterator[InternalRow]) + extends CompletionIterator[InternalRow, Iterator[InternalRow]](subIter) { + private val iteratorNotEmpty: Boolean = super.hasNext override def completion(): Unit = { - val shouldAddToState = // add only if both removal predicates do not match - !stateKeyWatermarkPredicateFunc(key) && !stateValueWatermarkPredicateFunc(thisRow) + val isLeftSemiWithMatch = + joinType == LeftSemi && joinSide == LeftSide && iteratorNotEmpty + // Add to state store only if both removal predicates do not match, + // and the row is not matched for left side of left semi join. + val shouldAddToState = + !stateKeyWatermarkPredicateFunc(key) && !stateValueWatermarkPredicateFunc(thisRow) && + !isLeftSemiWithMatch if (shouldAddToState) { joinStateManager.append(key, thisRow, matched = iteratorNotEmpty) updatedStateRowsCount += 1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala index 2aa2a18b9eaf4..c7b571262fb87 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala @@ -99,15 +99,22 @@ class SymmetricHashJoinStateManager( /** * Get all the matched values for given join condition, with marking matched. * This method is designed to mark joined rows properly without exposing internal index of row. + * + * @param joinOnlyFirstTimeMatchedRow Only join with first-time matched row. + * This is used for right side of left semi join in + * [[StreamingSymmetricHashJoinExec]] only. */ def getJoinedRows( key: UnsafeRow, generateJoinedRow: InternalRow => JoinedRow, - predicate: JoinedRow => Boolean): Iterator[JoinedRow] = { + predicate: JoinedRow => Boolean, + joinOnlyFirstTimeMatchedRow: Boolean = false): Iterator[JoinedRow] = { val numValues = keyToNumValues.get(key) keyWithIndexToValue.getAll(key, numValues).map { keyIdxToValue => val joinedRow = generateJoinedRow(keyIdxToValue.value) - if (predicate(joinedRow)) { + if (joinOnlyFirstTimeMatchedRow && keyIdxToValue.matched) { + null + } else if (predicate(joinedRow)) { if (!keyIdxToValue.matched) { keyWithIndexToValue.put(key, keyIdxToValue.valueIndex, keyIdxToValue.value, matched = true) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index b235bf7c3180a..5528cdc1dcabc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -41,18 +41,174 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils +abstract class StreamingJoinSuite + extends StreamTest with StateStoreMetricsTest with BeforeAndAfter { -class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with BeforeAndAfter { + import testImplicits._ before { - SparkSession.setActiveSessionInternal(spark) // set this before force initializing 'joinExec' - spark.streams.stateStoreCoordinator // initialize the lazy coordinator + SparkSession.setActiveSessionInternal(spark) // set this before force initializing 'joinExec' + spark.streams.stateStoreCoordinator // initialize the lazy coordinator } after { StateStore.stop() } + protected def setupStream(prefix: String, multiplier: Int): (MemoryStream[Int], DataFrame) = { + val input = MemoryStream[Int] + val df = input.toDF + .select( + 'value as "key", + timestamp_seconds($"value") as s"${prefix}Time", + ('value * multiplier) as s"${prefix}Value") + .withWatermark(s"${prefix}Time", "10 seconds") + + (input, df) + } + + protected def setupWindowedJoin(joinType: String) + : (MemoryStream[Int], MemoryStream[Int], DataFrame) = { + + val (input1, df1) = setupStream("left", 2) + val (input2, df2) = setupStream("right", 3) + val windowed1 = df1.select('key, window('leftTime, "10 second"), 'leftValue) + val windowed2 = df2.select('key, window('rightTime, "10 second"), 'rightValue) + val joined = windowed1.join(windowed2, Seq("key", "window"), joinType) + val select = if (joinType == "left_semi") { + joined.select('key, $"window.end".cast("long"), 'leftValue) + } else { + joined.select('key, $"window.end".cast("long"), 'leftValue, 'rightValue) + } + + (input1, input2, select) + } + + protected def setupWindowedJoinWithLeftCondition(joinType: String) + : (MemoryStream[Int], MemoryStream[Int], DataFrame) = { + + val (leftInput, df1) = setupStream("left", 2) + val (rightInput, df2) = setupStream("right", 3) + // Use different schemas to ensure the null row is being generated from the correct side. + val left = df1.select('key, window('leftTime, "10 second"), 'leftValue) + val right = df2.select('key, window('rightTime, "10 second"), 'rightValue.cast("string")) + + val joined = left.join( + right, + left("key") === right("key") + && left("window") === right("window") + && 'leftValue > 4, + joinType) + + val select = if (joinType == "left_semi") { + joined.select(left("key"), left("window.end").cast("long"), 'leftValue) + } else if (joinType == "left_outer") { + joined.select(left("key"), left("window.end").cast("long"), 'leftValue, 'rightValue) + } else if (joinType == "right_outer") { + joined.select(right("key"), right("window.end").cast("long"), 'leftValue, 'rightValue) + } else { + joined + } + + (leftInput, rightInput, select) + } + + protected def setupWindowedJoinWithRightCondition(joinType: String) + : (MemoryStream[Int], MemoryStream[Int], DataFrame) = { + + val (leftInput, df1) = setupStream("left", 2) + val (rightInput, df2) = setupStream("right", 3) + // Use different schemas to ensure the null row is being generated from the correct side. + val left = df1.select('key, window('leftTime, "10 second"), 'leftValue) + val right = df2.select('key, window('rightTime, "10 second"), 'rightValue.cast("string")) + + val joined = left.join( + right, + left("key") === right("key") + && left("window") === right("window") + && 'rightValue.cast("int") > 7, + joinType) + + val select = if (joinType == "left_semi") { + joined.select(left("key"), left("window.end").cast("long"), 'leftValue) + } else if (joinType == "left_outer") { + joined.select(left("key"), left("window.end").cast("long"), 'leftValue, 'rightValue) + } else if (joinType == "right_outer") { + joined.select(right("key"), right("window.end").cast("long"), 'leftValue, 'rightValue) + } else { + joined + } + + (leftInput, rightInput, select) + } + + protected def setupWindowedJoinWithRangeCondition(joinType: String) + : (MemoryStream[(Int, Int)], MemoryStream[(Int, Int)], DataFrame) = { + + val leftInput = MemoryStream[(Int, Int)] + val rightInput = MemoryStream[(Int, Int)] + + val df1 = leftInput.toDF.toDF("leftKey", "time") + .select('leftKey, timestamp_seconds($"time") as "leftTime", ('leftKey * 2) as "leftValue") + .withWatermark("leftTime", "10 seconds") + + val df2 = rightInput.toDF.toDF("rightKey", "time") + .select('rightKey, timestamp_seconds($"time") as "rightTime", + ('rightKey * 3) as "rightValue") + .withWatermark("rightTime", "10 seconds") + + val joined = + df1.join( + df2, + expr("leftKey = rightKey AND " + + "leftTime BETWEEN rightTime - interval 5 seconds AND rightTime + interval 5 seconds"), + joinType) + + val select = if (joinType == "left_semi") { + joined.select('leftKey, 'leftTime.cast("int")) + } else { + joined.select('leftKey, 'rightKey, 'leftTime.cast("int"), 'rightTime.cast("int")) + } + + (leftInput, rightInput, select) + } + + protected def setupWindowedSelfJoin(joinType: String) + : (MemoryStream[(Int, Long)], DataFrame) = { + + val inputStream = MemoryStream[(Int, Long)] + + val df = inputStream.toDS() + .select(col("_1").as("value"), timestamp_seconds($"_2").as("timestamp")) + + val leftStream = df.select(col("value").as("leftId"), col("timestamp").as("leftTime")) + + val rightStream = df + // Introduce misses for ease of debugging + .where(col("value") % 2 === 0) + .select(col("value").as("rightId"), col("timestamp").as("rightTime")) + + val joined = leftStream + .withWatermark("leftTime", "5 seconds") + .join( + rightStream.withWatermark("rightTime", "5 seconds"), + expr("leftId = rightId AND rightTime >= leftTime AND " + + "rightTime <= leftTime + interval 5 seconds"), + joinType) + + val select = if (joinType == "left_semi") { + joined.select(col("leftId"), col("leftTime").cast("int")) + } else { + joined.select(col("leftId"), col("leftTime").cast("int"), + col("rightId"), col("rightTime").cast("int")) + } + + (inputStream, select) + } +} + +class StreamingInnerJoinSuite extends StreamingJoinSuite { + import testImplicits._ test("stream stream inner join on non-time column") { val input1 = MemoryStream[Int] @@ -486,58 +642,13 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with } -class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with BeforeAndAfter { +class StreamingOuterJoinSuite extends StreamingJoinSuite { import testImplicits._ import org.apache.spark.sql.functions._ - before { - SparkSession.setActiveSessionInternal(spark) // set this before force initializing 'joinExec' - spark.streams.stateStoreCoordinator // initialize the lazy coordinator - } - - after { - StateStore.stop() - } - - private def setupStream(prefix: String, multiplier: Int): (MemoryStream[Int], DataFrame) = { - val input = MemoryStream[Int] - val df = input.toDF - .select( - 'value as "key", - timestamp_seconds($"value") as s"${prefix}Time", - ('value * multiplier) as s"${prefix}Value") - .withWatermark(s"${prefix}Time", "10 seconds") - - return (input, df) - } - - private def setupWindowedJoin(joinType: String): - (MemoryStream[Int], MemoryStream[Int], DataFrame) = { - val (input1, df1) = setupStream("left", 2) - val (input2, df2) = setupStream("right", 3) - val windowed1 = df1.select('key, window('leftTime, "10 second"), 'leftValue) - val windowed2 = df2.select('key, window('rightTime, "10 second"), 'rightValue) - val joined = windowed1.join(windowed2, Seq("key", "window"), joinType) - .select('key, $"window.end".cast("long"), 'leftValue, 'rightValue) - - (input1, input2, joined) - } - test("left outer early state exclusion on left") { - val (leftInput, df1) = setupStream("left", 2) - val (rightInput, df2) = setupStream("right", 3) - // Use different schemas to ensure the null row is being generated from the correct side. - val left = df1.select('key, window('leftTime, "10 second"), 'leftValue) - val right = df2.select('key, window('rightTime, "10 second"), 'rightValue.cast("string")) - - val joined = left.join( - right, - left("key") === right("key") - && left("window") === right("window") - && 'leftValue > 4, - "left_outer") - .select(left("key"), left("window.end").cast("long"), 'leftValue, 'rightValue) + val (leftInput, rightInput, joined) = setupWindowedJoinWithLeftCondition("left_outer") testStream(joined)( MultiAddData(leftInput, 1, 2, 3)(rightInput, 3, 4, 5), @@ -554,19 +665,7 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with } test("left outer early state exclusion on right") { - val (leftInput, df1) = setupStream("left", 2) - val (rightInput, df2) = setupStream("right", 3) - // Use different schemas to ensure the null row is being generated from the correct side. - val left = df1.select('key, window('leftTime, "10 second"), 'leftValue) - val right = df2.select('key, window('rightTime, "10 second"), 'rightValue.cast("string")) - - val joined = left.join( - right, - left("key") === right("key") - && left("window") === right("window") - && 'rightValue.cast("int") > 7, - "left_outer") - .select(left("key"), left("window.end").cast("long"), 'leftValue, 'rightValue) + val (leftInput, rightInput, joined) = setupWindowedJoinWithRightCondition("left_outer") testStream(joined)( MultiAddData(leftInput, 3, 4, 5)(rightInput, 1, 2, 3), @@ -583,19 +682,7 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with } test("right outer early state exclusion on left") { - val (leftInput, df1) = setupStream("left", 2) - val (rightInput, df2) = setupStream("right", 3) - // Use different schemas to ensure the null row is being generated from the correct side. - val left = df1.select('key, window('leftTime, "10 second"), 'leftValue) - val right = df2.select('key, window('rightTime, "10 second"), 'rightValue.cast("string")) - - val joined = left.join( - right, - left("key") === right("key") - && left("window") === right("window") - && 'leftValue > 4, - "right_outer") - .select(right("key"), right("window.end").cast("long"), 'leftValue, 'rightValue) + val (leftInput, rightInput, joined) = setupWindowedJoinWithLeftCondition("right_outer") testStream(joined)( MultiAddData(leftInput, 1, 2, 3)(rightInput, 3, 4, 5), @@ -612,19 +699,7 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with } test("right outer early state exclusion on right") { - val (leftInput, df1) = setupStream("left", 2) - val (rightInput, df2) = setupStream("right", 3) - // Use different schemas to ensure the null row is being generated from the correct side. - val left = df1.select('key, window('leftTime, "10 second"), 'leftValue) - val right = df2.select('key, window('rightTime, "10 second"), 'rightValue.cast("string")) - - val joined = left.join( - right, - left("key") === right("key") - && left("window") === right("window") - && 'rightValue.cast("int") > 7, - "right_outer") - .select(right("key"), right("window.end").cast("long"), 'leftValue, 'rightValue) + val (leftInput, rightInput, joined) = setupWindowedJoinWithRightCondition("right_outer") testStream(joined)( MultiAddData(leftInput, 3, 4, 5)(rightInput, 1, 2, 3), @@ -681,27 +756,8 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with ("right_outer", Row(null, 2, null, 5)) ).foreach { case (joinType: String, outerResult) => test(s"${joinType.replaceAllLiterally("_", " ")} with watermark range condition") { - import org.apache.spark.sql.functions._ - - val leftInput = MemoryStream[(Int, Int)] - val rightInput = MemoryStream[(Int, Int)] - - val df1 = leftInput.toDF.toDF("leftKey", "time") - .select('leftKey, timestamp_seconds($"time") as "leftTime", ('leftKey * 2) as "leftValue") - .withWatermark("leftTime", "10 seconds") - - val df2 = rightInput.toDF.toDF("rightKey", "time") - .select('rightKey, timestamp_seconds($"time") as "rightTime", - ('rightKey * 3) as "rightValue") - .withWatermark("rightTime", "10 seconds") - - val joined = - df1.join( - df2, - expr("leftKey = rightKey AND " + - "leftTime BETWEEN rightTime - interval 5 seconds AND rightTime + interval 5 seconds"), - joinType) - .select('leftKey, 'rightKey, 'leftTime.cast("int"), 'rightTime.cast("int")) + val (leftInput, rightInput, joined) = setupWindowedJoinWithRangeCondition(joinType) + testStream(joined)( AddData(leftInput, (1, 5), (3, 5)), CheckAnswer(), @@ -780,27 +836,7 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with } test("SPARK-26187 self left outer join should not return outer nulls for already matched rows") { - val inputStream = MemoryStream[(Int, Long)] - - val df = inputStream.toDS() - .select(col("_1").as("value"), timestamp_seconds($"_2").as("timestamp")) - - val leftStream = df.select(col("value").as("leftId"), col("timestamp").as("leftTime")) - - val rightStream = df - // Introduce misses for ease of debugging - .where(col("value") % 2 === 0) - .select(col("value").as("rightId"), col("timestamp").as("rightTime")) - - val query = leftStream - .withWatermark("leftTime", "5 seconds") - .join( - rightStream.withWatermark("rightTime", "5 seconds"), - expr("leftId = rightId AND rightTime >= leftTime AND " + - "rightTime <= leftTime + interval 5 seconds"), - joinType = "leftOuter") - .select(col("leftId"), col("leftTime").cast("int"), - col("rightId"), col("rightTime").cast("int")) + val (inputStream, query) = setupWindowedSelfJoin("left_outer") testStream(query)( AddData(inputStream, (1, 1L), (2, 2L), (3, 3L), (4, 4L), (5, 5L)), @@ -1041,3 +1077,104 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with ) } } + +class StreamingLeftSemiJoinSuite extends StreamingJoinSuite { + + import testImplicits._ + + test("windowed left semi join") { + val (leftInput, rightInput, joined) = setupWindowedJoin("left_semi") + + testStream(joined)( + MultiAddData(leftInput, 1, 2, 3, 4, 5)(rightInput, 3, 4, 5, 6, 7), + CheckNewAnswer(Row(3, 10, 6), Row(4, 10, 8), Row(5, 10, 10)), + MultiAddData(leftInput, 21)(rightInput, 22), + // Watermark = 11, should remove rows having window=[0,10] + CheckNewAnswer(), + assertNumStateRows(total = 2, updated = 12), + AddData(leftInput, 22), + CheckNewAnswer(Row(22, 30, 44)), + assertNumStateRows(total = 2, updated = 0), + StopStream, + StartStream(), + + AddData(leftInput, 1), + // Row not add as 1 < state key watermark = 12 + CheckNewAnswer(), + AddData(rightInput, 11), + // Row not add as 11 < state key watermark = 12 + CheckNewAnswer() + ) + } + + test("left semi early state exclusion on left") { + val (leftInput, rightInput, joined) = setupWindowedJoinWithLeftCondition("left_semi") + + testStream(joined)( + MultiAddData(leftInput, 1, 2, 3)(rightInput, 1, 2, 3, 4, 5), + // The left rows with leftValue <= 4 should not generate their semi join row and + // not get added to the state. + CheckNewAnswer(Row(3, 10, 6)), + assertNumStateRows(total = 4, updated = 4), + // We shouldn't get more semi join rows when the watermark advances. + MultiAddData(leftInput, 20)(rightInput, 21), + CheckNewAnswer(), + AddData(rightInput, 20), + CheckNewAnswer((20, 30, 40)) + ) + } + + test("left semi early state exclusion on right") { + val (leftInput, rightInput, joined) = setupWindowedJoinWithRightCondition("left_semi") + + testStream(joined)( + MultiAddData(leftInput, 3, 4, 5)(rightInput, 1, 2, 3), + // The right rows with rightValue <= 7 should never be added to the state. + // The right row with rightValue = 9 > 7, hence joined and added to state. + CheckNewAnswer(Row(3, 10, 6)), + assertNumStateRows(total = 4, updated = 4), + // We shouldn't get more semi join rows when the watermark advances. + MultiAddData(leftInput, 20)(rightInput, 21), + CheckNewAnswer(), + AddData(rightInput, 20), + CheckNewAnswer((20, 30, 40)) + ) + } + + test("left semi join with watermark range condition") { + val (leftInput, rightInput, joined) = setupWindowedJoinWithRangeCondition("left_semi") + + testStream(joined)( + AddData(leftInput, (1, 5), (3, 5)), + CheckAnswer(), + AddData(rightInput, (1, 10), (2, 5)), + CheckNewAnswer((1, 5)), + AddData(rightInput, (1, 11)), + // No match as left time is too low and left row is already matched. + CheckNewAnswer(), + assertNumStateRows(total = 5, updated = 5), + + // Increase event time watermark to 20s by adding data with time = 30s on left input. + AddData(leftInput, (1, 7), (1, 30)), + CheckNewAnswer((1, 7)), + assertNumStateRows(total = 6, updated = 1), + // Watermark = 30 - 10 = 20, no matched row. + AddData(rightInput, (0, 30)), + CheckNewAnswer(), + assertNumStateRows(total = 2, updated = 1) + ) + } + + test("self left semi join") { + val (inputStream, query) = setupWindowedSelfJoin("left_semi") + + testStream(query)( + AddData(inputStream, (1, 1L), (2, 2L), (3, 3L), (4, 4L), (5, 5L)), + CheckNewAnswer((2, 2), (4, 4)), + AddData(inputStream, (6, 6L), (7, 7L), (8, 8L), (9, 9L), (10, 10L)), + CheckNewAnswer((6, 6), (8, 8), (10, 10)), + AddData(inputStream, (11, 11L), (12, 12L), (13, 13L), (14, 14L), (15, 15L)), + CheckNewAnswer((12, 12), (14, 14)) + ) + } +} From ee166901bbaa3500a9fa8704f2a5598a8ce4b6d7 Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Fri, 16 Oct 2020 20:56:03 -0700 Subject: [PATCH 2/6] Fix unit test --- .../execution/streaming/StreamingSymmetricHashJoinExec.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index 73e78513a3ef0..7a1febaf75563 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -151,9 +151,8 @@ case class StreamingSymmetricHashJoinExec( stateWatermarkPredicates = JoinStateWatermarkPredicates(), stateFormatVersion, left, right) } - if (stateFormatVersion < 2 && joinType != Inner) { - throw new IllegalArgumentException( - s"The query is using stream-stream $joinType join with state" + + if (stateFormatVersion < 2 && (joinType == LeftOuter || joinType == RightOuter)) { + throw new IllegalArgumentException("The query is using stream-stream outer join with state" + s" format version ${stateFormatVersion} - correctness issue is discovered. Please discard" + " the checkpoint and rerun the query. See SPARK-26154 for more details.") } From 3918727a08c8d0d4c65ccc8ea902f77051b78b1d Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Mon, 19 Oct 2020 11:14:10 -0700 Subject: [PATCH 3/6] Address all comments --- .../analysis/UnsupportedOperationsSuite.scala | 2 +- .../StreamingSymmetricHashJoinExec.scala | 16 ++++++++-------- .../state/SymmetricHashJoinStateManager.scala | 8 ++++---- .../spark/sql/streaming/StreamingJoinSuite.scala | 2 +- 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index f1b8436533d1f..e692c58f40add 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -506,7 +506,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite { OutputMode.Complete(), Seq("is not supported in Complete output mode")) - // Left ousemiter joins: stream-stream allowed with join on watermark attribute + // Left semi joins: stream-stream allowed with join on watermark attribute // Note that the attribute need not be watermarked on both sides. assertSupportedInStreamingPlan( s"left semi join with stream-stream relations and join on attribute with left watermark", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index 7a1febaf75563..ad97a6c9731ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -151,8 +151,9 @@ case class StreamingSymmetricHashJoinExec( stateWatermarkPredicates = JoinStateWatermarkPredicates(), stateFormatVersion, left, right) } - if (stateFormatVersion < 2 && (joinType == LeftOuter || joinType == RightOuter)) { - throw new IllegalArgumentException("The query is using stream-stream outer join with state" + + if (stateFormatVersion < 2 && joinType != Inner) { + throw new IllegalArgumentException( + s"The query is using stream-stream $joinType join with state" + s" format version ${stateFormatVersion} - correctness issue is discovered. Please discard" + " the checkpoint and rerun the query. See SPARK-26154 for more details.") } @@ -283,7 +284,7 @@ case class StreamingSymmetricHashJoinExec( leftOutputIter ++ rightOutputIter, onHashJoinOutputCompletion()) val outputIter: Iterator[InternalRow] = joinType match { - case Inner => + case Inner | LeftSemi => hashJoinOutputIter case LeftOuter => // We generate the outer join input by: @@ -339,8 +340,6 @@ case class StreamingSymmetricHashJoinExec( }.map(pair => joinedRow.withLeft(nullLeft).withRight(pair.value)) hashJoinOutputIter ++ outerOutputIter - case LeftSemi => - hashJoinOutputIter case _ => throwBadJoinTypeException() } @@ -369,7 +368,8 @@ case class StreamingSymmetricHashJoinExec( allRemovalsTimeMs += timeTakenMs { // Remove any remaining state rows which aren't needed because they're below the watermark. // - // For inner joins, we have to remove unnecessary state rows from both sides if possible. + // For inner and left semi joins, we have to remove unnecessary state rows from both sides + // if possible. // // For outer joins, we have already removed unnecessary state rows from the outer side // (e.g., left side for left outer join) while generating the outer "null" outputs. Now, we @@ -380,10 +380,10 @@ case class StreamingSymmetricHashJoinExec( // For left semi joins, we have to remove unnecessary state rows from both sides if // possible. val cleanupIter = joinType match { - case Inner => leftSideJoiner.removeOldState() ++ rightSideJoiner.removeOldState() + case Inner | LeftSemi => + leftSideJoiner.removeOldState() ++ rightSideJoiner.removeOldState() case LeftOuter => rightSideJoiner.removeOldState() case RightOuter => leftSideJoiner.removeOldState() - case LeftSemi => leftSideJoiner.removeOldState() ++ rightSideJoiner.removeOldState() case _ => throwBadJoinTypeException() } while (cleanupIter.hasNext) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala index c7b571262fb87..2b231b2726981 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala @@ -110,11 +110,11 @@ class SymmetricHashJoinStateManager( predicate: JoinedRow => Boolean, joinOnlyFirstTimeMatchedRow: Boolean = false): Iterator[JoinedRow] = { val numValues = keyToNumValues.get(key) - keyWithIndexToValue.getAll(key, numValues).map { keyIdxToValue => + keyWithIndexToValue.getAll(key, numValues).filterNot { keyIdxToValue => + joinOnlyFirstTimeMatchedRow && keyIdxToValue.matched + }.map { keyIdxToValue => val joinedRow = generateJoinedRow(keyIdxToValue.value) - if (joinOnlyFirstTimeMatchedRow && keyIdxToValue.matched) { - null - } else if (predicate(joinedRow)) { + if (predicate(joinedRow)) { if (!keyIdxToValue.matched) { keyWithIndexToValue.put(key, keyIdxToValue.valueIndex, keyIdxToValue.value, matched = true) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index 5528cdc1dcabc..1e3284b2c3d88 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -974,7 +974,7 @@ class StreamingOuterJoinSuite extends StreamingJoinSuite { throw writer.exception.get } assert(e.getMessage.toLowerCase(Locale.ROOT) - .contains("the query is using stream-stream outer join with state format version 1")) + .contains("the query is using stream-stream leftouter join with state format version 1")) } test("SPARK-29438: ensure UNION doesn't lead stream-stream join to use shifted partition IDs") { From 765a233237e309f112058e73cef4f4c516f39a8d Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Tue, 20 Oct 2020 22:50:41 -0700 Subject: [PATCH 4/6] Address all comments --- .../analysis/UnsupportedOperationsSuite.scala | 14 +- .../StreamingSymmetricHashJoinExec.scala | 28 ++-- .../state/SymmetricHashJoinStateManager.scala | 10 +- .../sql/streaming/StreamingJoinSuite.scala | 130 ++++++++++++++++-- 4 files changed, 140 insertions(+), 42 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index e692c58f40add..77c42d89dca3a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -494,13 +494,13 @@ class UnsupportedOperationsSuite extends SparkFunSuite { // Left semi joins: update and complete mode not allowed assertNotSupportedInStreamingPlan( - s"left semi join with stream-stream relations and update mode", + "left semi join with stream-stream relations and update mode", streamRelation.join(streamRelation, joinType = LeftSemi, condition = Some(attribute === attribute)), OutputMode.Update(), Seq("is not supported in Update output mode")) assertNotSupportedInStreamingPlan( - s"left semi join with stream-stream relations and complete mode", + "left semi join with stream-stream relations and complete mode", Aggregate(Nil, aggExprs("d"), streamRelation.join(streamRelation, joinType = LeftSemi, condition = Some(attribute === attribute))), OutputMode.Complete(), @@ -509,17 +509,17 @@ class UnsupportedOperationsSuite extends SparkFunSuite { // Left semi joins: stream-stream allowed with join on watermark attribute // Note that the attribute need not be watermarked on both sides. assertSupportedInStreamingPlan( - s"left semi join with stream-stream relations and join on attribute with left watermark", + "left semi join with stream-stream relations and join on attribute with left watermark", streamRelation.join(streamRelation, joinType = LeftSemi, condition = Some(attributeWithWatermark === attribute)), OutputMode.Append()) assertSupportedInStreamingPlan( - s"left semi join with stream-stream relations and join on attribute with right watermark", + "left semi join with stream-stream relations and join on attribute with right watermark", streamRelation.join(streamRelation, joinType = LeftSemi, condition = Some(attribute === attributeWithWatermark)), OutputMode.Append()) assertNotSupportedInStreamingPlan( - s"left semi join with stream-stream relations and join on non-watermarked attribute", + "left semi join with stream-stream relations and join on non-watermarked attribute", streamRelation.join(streamRelation, joinType = LeftSemi, condition = Some(attribute === attribute)), OutputMode.Append(), @@ -527,7 +527,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite { // Left semi joins: stream-stream allowed with range condition yielding state value watermark assertSupportedInStreamingPlan( - s"left semi join with stream-stream relations and state value watermark", { + "left semi join with stream-stream relations and state value watermark", { val leftRelation = streamRelation val rightTimeWithWatermark = AttributeReference("b", IntegerType)().withMetadata(watermarkMetadata) @@ -541,7 +541,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite { // Left semi joins: stream-stream not allowed with insufficient range condition assertNotSupportedInStreamingPlan( - s"left semi join with stream-stream relations and state value watermark", { + "left semi join with stream-stream relations and state value watermark", { val leftRelation = streamRelation val rightTimeWithWatermark = AttributeReference("b", IntegerType)().withMetadata(watermarkMetadata) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index ad97a6c9731ea..8b69205530769 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -250,18 +250,18 @@ case class StreamingSymmetricHashJoinExec( // Join one side input using the other side's buffered/state rows. Here is how it is done. // // - `leftSideJoiner.storeAndJoinWithOtherSide(rightSideJoiner)` - // - inner, left outer, right outer join: generates all rows from matching new left input + // - Inner, Left Outer, Right Outer Join: generates all rows from matching new left input // with stored right input, and also stores all the left input. - // - left semi join: generates all new left input rows from matching new left input with + // - Left Semi Join: generates all new left input rows from matching new left input with // stored right input, and also stores all the non-matched left input. // // - `rightSideJoiner.storeAndJoinWithOtherSide(leftSideJoiner)` - // - inner, left outer, right outer join: generates all rows from matching new right input + // - Inner, Left Outer, Right Outer Join: generates all rows from matching new right input // with stored left input, and also stores all the right input. // It also generates all rows from matching new left input with new right input, since // the new left input has become stored by that point. This tiny asymmetry is necessary // to avoid duplication. - // - left semi join: generates all stored left input rows, from matching new right input + // - Left Semi Join: generates all stored left input rows, from matching new right input // with stored left input, and also stores all the right input. Note only first-time // matched left input rows will be generated, this is to guarantee left semi semantics. val leftOutputIter = leftSideJoiner.storeAndJoinWithOtherSide(rightSideJoiner) { @@ -272,14 +272,14 @@ case class StreamingSymmetricHashJoinExec( } // We need to save the time that the one side hash join output iterator completes, since - // other join (outer and semi) output counts as both update and removal time. + // other join output counts as both update and removal time. var hashJoinOutputCompletionTimeNs: Long = 0 def onHashJoinOutputCompletion(): Unit = { hashJoinOutputCompletionTimeNs = System.nanoTime } - // This is the iterator which produces the inner join rows. For other joins (outer and semi), - // this will be prepended to a second iterator producing other rows; for inner joins, this is - // the full output. + // This is the iterator which produces the inner and left semi join rows. For other joins, + // this will be prepended to a second iterator producing other rows; for inner and left semi + // joins, this is the full output. val hashJoinOutputIter = CompletionIterator[InternalRow, Iterator[InternalRow]]( leftOutputIter ++ rightOutputIter, onHashJoinOutputCompletion()) @@ -358,8 +358,9 @@ case class StreamingSymmetricHashJoinExec( // All processing time counts as update time. allUpdatesTimeMs += math.max(NANOSECONDS.toMillis(System.nanoTime - updateStartTimeNs), 0) - // Processing time between inner output completion and here comes from the outer portion of a - // join, and thus counts as removal time as we remove old state from one side while iterating. + // Processing time between one side hash join output completion and here comes from the + // outer portion of a join, and thus counts as removal time as we remove old state from + // one side while iterating. if (hashJoinOutputCompletionTimeNs != 0) { allRemovalsTimeMs += math.max(NANOSECONDS.toMillis(System.nanoTime - hashJoinOutputCompletionTimeNs), 0) @@ -376,9 +377,6 @@ case class StreamingSymmetricHashJoinExec( // have to remove unnecessary state rows from the other side (e.g., right side for the left // outer join) if possible. In all cases, nothing needs to be outputted, hence the removal // needs to be done greedily by immediately consuming the returned iterator. - // - // For left semi joins, we have to remove unnecessary state rows from both sides if - // possible. val cleanupIter = joinType match { case Inner | LeftSemi => leftSideJoiner.removeOldState() ++ rightSideJoiner.removeOldState() @@ -500,7 +498,7 @@ case class StreamingSymmetricHashJoinExec( case _ => (_: InternalRow) => Iterator.empty } - val joinOnlyFirstTimeMatchedRow = joinType == LeftSemi && joinSide == RightSide + val excludeRowsAlreadyMatched = joinType == LeftSemi && joinSide == RightSide val generateOutputIter: (InternalRow, Iterator[JoinedRow]) => Iterator[InternalRow] = joinSide match { @@ -532,7 +530,7 @@ case class StreamingSymmetricHashJoinExec( key, thatRow => generateJoinedRow(thisRow, thatRow), postJoinFilter, - joinOnlyFirstTimeMatchedRow) + excludeRowsAlreadyMatched) val outputIter = generateOutputIter(thisRow, joinedRowIter) new AddingProcessedRowToStateCompletionIterator(key, thisRow, outputIter) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala index 2b231b2726981..3fae3979757fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala @@ -100,18 +100,18 @@ class SymmetricHashJoinStateManager( * Get all the matched values for given join condition, with marking matched. * This method is designed to mark joined rows properly without exposing internal index of row. * - * @param joinOnlyFirstTimeMatchedRow Only join with first-time matched row. - * This is used for right side of left semi join in - * [[StreamingSymmetricHashJoinExec]] only. + * @param excludeRowsAlreadyMatched Do not join with rows already matched previously. + * This is used for right side of left semi join in + * [[StreamingSymmetricHashJoinExec]] only. */ def getJoinedRows( key: UnsafeRow, generateJoinedRow: InternalRow => JoinedRow, predicate: JoinedRow => Boolean, - joinOnlyFirstTimeMatchedRow: Boolean = false): Iterator[JoinedRow] = { + excludeRowsAlreadyMatched: Boolean = false): Iterator[JoinedRow] = { val numValues = keyToNumValues.get(key) keyWithIndexToValue.getAll(key, numValues).filterNot { keyIdxToValue => - joinOnlyFirstTimeMatchedRow && keyIdxToValue.matched + excludeRowsAlreadyMatched && keyIdxToValue.matched }.map { keyIdxToValue => val joinedRow = generateJoinedRow(keyIdxToValue.value) if (predicate(joinedRow)) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index 1e3284b2c3d88..231bc02273103 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -1088,22 +1088,47 @@ class StreamingLeftSemiJoinSuite extends StreamingJoinSuite { testStream(joined)( MultiAddData(leftInput, 1, 2, 3, 4, 5)(rightInput, 3, 4, 5, 6, 7), CheckNewAnswer(Row(3, 10, 6), Row(4, 10, 8), Row(5, 10, 10)), + // states + // left: 1, 2, 3, 4 ,5 + // right: 3, 4, 5, 6, 7 + assertNumStateRows(total = 10, updated = 10), MultiAddData(leftInput, 21)(rightInput, 22), - // Watermark = 11, should remove rows having window=[0,10] + // Watermark = 11, should remove rows having window=[0,10]. CheckNewAnswer(), - assertNumStateRows(total = 2, updated = 12), + // states + // left: 21 + // right: 22 + // + // states evicted + // left: 1, 2, 3, 4 ,5 (below watermark) + // right: 3, 4, 5, 6, 7 (below watermark) + assertNumStateRows(total = 2, updated = 2), AddData(leftInput, 22), CheckNewAnswer(Row(22, 30, 44)), + // Unlike inner/outer joins, given left input row matches with right input row, + // we don't buffer the matched left input row to the state store. + // + // states + // left: 21 + // right: 22 assertNumStateRows(total = 2, updated = 0), StopStream, StartStream(), AddData(leftInput, 1), - // Row not add as 1 < state key watermark = 12 + // Row not add as 1 < state key watermark = 12. CheckNewAnswer(), - AddData(rightInput, 11), - // Row not add as 11 < state key watermark = 12 - CheckNewAnswer() + // states + // left: 21 + // right: 22 + assertNumStateRows(total = 2, updated = 0, droppedByWatermark = 1), + AddData(rightInput, 5), + // Row not add as 5 < state key watermark = 12. + CheckNewAnswer(), + // states + // left: 21 + // right: 22 + assertNumStateRows(total = 2, updated = 0, droppedByWatermark = 1) ) } @@ -1111,16 +1136,31 @@ class StreamingLeftSemiJoinSuite extends StreamingJoinSuite { val (leftInput, rightInput, joined) = setupWindowedJoinWithLeftCondition("left_semi") testStream(joined)( - MultiAddData(leftInput, 1, 2, 3)(rightInput, 1, 2, 3, 4, 5), - // The left rows with leftValue <= 4 should not generate their semi join row and + MultiAddData(leftInput, 1, 2, 3)(rightInput, 3, 4, 5), + // The left rows with leftValue <= 4 should not generate their semi join rows and // not get added to the state. CheckNewAnswer(Row(3, 10, 6)), + // states + // left: 3 + // right: 3, 4, 5 assertNumStateRows(total = 4, updated = 4), // We shouldn't get more semi join rows when the watermark advances. MultiAddData(leftInput, 20)(rightInput, 21), CheckNewAnswer(), + // states + // left: 20 + // right: 21 + // + // states evicted + // left: 3 (below watermark) + // right: 3, 4, 5 (below watermark) + assertNumStateRows(total = 2, updated = 2), AddData(rightInput, 20), - CheckNewAnswer((20, 30, 40)) + CheckNewAnswer((20, 30, 40)), + // states + // left: 20 + // right: 21, 20 + assertNumStateRows(total = 3, updated = 1) ) } @@ -1132,12 +1172,27 @@ class StreamingLeftSemiJoinSuite extends StreamingJoinSuite { // The right rows with rightValue <= 7 should never be added to the state. // The right row with rightValue = 9 > 7, hence joined and added to state. CheckNewAnswer(Row(3, 10, 6)), + // states + // left: 3, 4, 5 + // right: 3 assertNumStateRows(total = 4, updated = 4), // We shouldn't get more semi join rows when the watermark advances. MultiAddData(leftInput, 20)(rightInput, 21), CheckNewAnswer(), + // states + // left: 20 + // right: 21 + // + // states evicted + // left: 3, 4, 5 (below watermark) + // right: 3 (below watermark) + assertNumStateRows(total = 2, updated = 2), AddData(rightInput, 20), - CheckNewAnswer((20, 30, 40)) + CheckNewAnswer((20, 30, 40)), + // states + // left: 20 + // right: 21, 20 + assertNumStateRows(total = 3, updated = 1) ) } @@ -1146,21 +1201,42 @@ class StreamingLeftSemiJoinSuite extends StreamingJoinSuite { testStream(joined)( AddData(leftInput, (1, 5), (3, 5)), - CheckAnswer(), + CheckNewAnswer(), + // states + // left: (1, 5), (3, 5) + // right: nothing + assertNumStateRows(total = 2, updated = 2), AddData(rightInput, (1, 10), (2, 5)), + // Match left row in the state. CheckNewAnswer((1, 5)), + // states + // left: (1, 5), (3, 5) + // right: (1, 10), (2, 5) + assertNumStateRows(total = 4, updated = 2), AddData(rightInput, (1, 11)), // No match as left time is too low and left row is already matched. CheckNewAnswer(), - assertNumStateRows(total = 5, updated = 5), - - // Increase event time watermark to 20s by adding data with time = 30s on left input. + // states + // left: (1, 5), (3, 5) + // right: (1, 10), (2, 5), (1, 11) + assertNumStateRows(total = 5, updated = 1), + // Increase event time watermark to 20s by adding data with time = 30s on both inputs. AddData(leftInput, (1, 7), (1, 30)), CheckNewAnswer((1, 7)), + // states + // left: (1, 5), (3, 5), (1, 30) + // right: (1, 10), (2, 5), (1, 11) assertNumStateRows(total = 6, updated = 1), // Watermark = 30 - 10 = 20, no matched row. AddData(rightInput, (0, 30)), CheckNewAnswer(), + // states + // left: (1, 30) + // right: (0, 30) + // + // states evicted + // left: (1, 5), (3, 5) (below watermark = 20) + // right: (1, 10), (2, 5), (1, 11) (below watermark = 20) assertNumStateRows(total = 2, updated = 1) ) } @@ -1171,10 +1247,34 @@ class StreamingLeftSemiJoinSuite extends StreamingJoinSuite { testStream(query)( AddData(inputStream, (1, 1L), (2, 2L), (3, 3L), (4, 4L), (5, 5L)), CheckNewAnswer((2, 2), (4, 4)), + // batch 1 - global watermark = 0 + // states + // left: (2, 2L), (4, 4L) + // (left rows with value % 2 != 0 is filtered per [[PushDownLeftSemiAntiJoin]]) + // right: (2, 2L), (4, 4L) + assertNumStateRows(total = 4, updated = 4), AddData(inputStream, (6, 6L), (7, 7L), (8, 8L), (9, 9L), (10, 10L)), CheckNewAnswer((6, 6), (8, 8), (10, 10)), + // batch 2 - global watermark = 5 + // states + // left: (2, 2L), (4, 4L), (6, 6L), (8, 8L), (10, 10L) + // right: (6, 6L), (8, 8L), (10, 10L) + // + // states evicted + // left: nothing (it waits for 5 seconds more than watermark due to join condition) + // right: (2, 2L), (4, 4L) + assertNumStateRows(total = 8, updated = 6), AddData(inputStream, (11, 11L), (12, 12L), (13, 13L), (14, 14L), (15, 15L)), - CheckNewAnswer((12, 12), (14, 14)) + CheckNewAnswer((12, 12), (14, 14)), + // batch 3 - global watermark = 9 + // states + // left: (4, 4L), (6, 6L), (8, 8L), (10, 10L), (12, 12L), (14, 14L) + // right: (10, 10L), (12, 12L), (14, 14L) + // + // states evicted + // left: (2, 2L) + // right: (6, 6L), (8, 8L) + assertNumStateRows(total = 9, updated = 4) ) } } From 9cd222f833cc47cd7880028f42fd13d8396e2414 Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Wed, 21 Oct 2020 13:18:34 -0700 Subject: [PATCH 5/6] Address all comments --- .../analysis/UnsupportedOperationChecker.scala | 4 ++-- .../analysis/UnsupportedOperationsSuite.scala | 2 +- .../spark/sql/streaming/StreamingJoinSuite.scala | 11 ++++++----- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index 3df7b726d94d0..f8ec392748175 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -297,8 +297,8 @@ object UnsupportedOperationChecker extends Logging { "on the right are not supported") } - // We support streaming left outer joins with static on the right always, and with - // stream on both sides under the appropriate conditions. + // We support streaming left outer and left semi joins with static on the right always, + // and with stream on both sides under the appropriate conditions. case LeftOuter | LeftSemi => if (!left.isStreaming && right.isStreaming) { throwError(s"$joinType join with a streaming DataFrame/Dataset " + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index 77c42d89dca3a..b9943a9744985 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -523,7 +523,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite { streamRelation.join(streamRelation, joinType = LeftSemi, condition = Some(attribute === attribute)), OutputMode.Append(), - Seq("watermark in the join keys")) + Seq("without a watermark in the join keys")) // Left semi joins: stream-stream allowed with range condition yielding state value watermark assertSupportedInStreamingPlan( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index 231bc02273103..e9d5b41175241 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -1213,19 +1213,19 @@ class StreamingLeftSemiJoinSuite extends StreamingJoinSuite { // left: (1, 5), (3, 5) // right: (1, 10), (2, 5) assertNumStateRows(total = 4, updated = 2), - AddData(rightInput, (1, 11)), + AddData(rightInput, (1, 9)), // No match as left time is too low and left row is already matched. CheckNewAnswer(), // states // left: (1, 5), (3, 5) - // right: (1, 10), (2, 5), (1, 11) + // right: (1, 10), (2, 5), (1, 9) assertNumStateRows(total = 5, updated = 1), // Increase event time watermark to 20s by adding data with time = 30s on both inputs. AddData(leftInput, (1, 7), (1, 30)), CheckNewAnswer((1, 7)), // states // left: (1, 5), (3, 5), (1, 30) - // right: (1, 10), (2, 5), (1, 11) + // right: (1, 10), (2, 5), (1, 9) assertNumStateRows(total = 6, updated = 1), // Watermark = 30 - 10 = 20, no matched row. AddData(rightInput, (0, 30)), @@ -1236,7 +1236,7 @@ class StreamingLeftSemiJoinSuite extends StreamingJoinSuite { // // states evicted // left: (1, 5), (3, 5) (below watermark = 20) - // right: (1, 10), (2, 5), (1, 11) (below watermark = 20) + // right: (1, 10), (2, 5), (1, 9) (below watermark = 20) assertNumStateRows(total = 2, updated = 1) ) } @@ -1250,8 +1250,9 @@ class StreamingLeftSemiJoinSuite extends StreamingJoinSuite { // batch 1 - global watermark = 0 // states // left: (2, 2L), (4, 4L) - // (left rows with value % 2 != 0 is filtered per [[PushDownLeftSemiAntiJoin]]) + // (left rows with value % 2 != 0 is filtered per [[PushPredicateThroughJoin]]) // right: (2, 2L), (4, 4L) + // (right rows with value % 2 != 0 is filtered per [[PushPredicateThroughJoin]]) assertNumStateRows(total = 4, updated = 4), AddData(inputStream, (6, 6L), (7, 7L), (8, 8L), (9, 9L), (10, 10L)), CheckNewAnswer((6, 6), (8, 8), (10, 10)), From 14871d9d2be6b751687e78dd4d17c2e249b8f205 Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Wed, 21 Oct 2020 15:41:23 -0700 Subject: [PATCH 6/6] Update comment --- .../org/apache/spark/sql/streaming/StreamingJoinSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index e9d5b41175241..91d1f5de3f211 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -1214,7 +1214,7 @@ class StreamingLeftSemiJoinSuite extends StreamingJoinSuite { // right: (1, 10), (2, 5) assertNumStateRows(total = 4, updated = 2), AddData(rightInput, (1, 9)), - // No match as left time is too low and left row is already matched. + // No match as left row is already matched. CheckNewAnswer(), // states // left: (1, 5), (3, 5)