Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-32862][SS] Left semi stream-stream join #30076

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -291,17 +291,17 @@ object UnsupportedOperationChecker extends Logging {
throwError("Full outer joins with streaming DataFrames/Datasets are not supported")
}

case LeftSemi | LeftAnti =>
case LeftAnti =>
if (right.isStreaming) {
throwError("Left semi/anti joins with a streaming DataFrame/Dataset " +
throwError("Left anti joins with a streaming DataFrame/Dataset " +
"on the right are not supported")
}

// We support streaming left outer joins with static on the right always, and with
// stream on both sides under the appropriate conditions.
case LeftOuter =>
// We support streaming left outer and left semi joins with static on the right always,
// and with stream on both sides under the appropriate conditions.
case LeftOuter | LeftSemi =>
if (!left.isStreaming && right.isStreaming) {
throwError("Left outer join with a streaming DataFrame/Dataset " +
throwError(s"$joinType join with a streaming DataFrame/Dataset " +
"on the right and a static DataFrame/Dataset on the left is not supported")
} else if (left.isStreaming && right.isStreaming) {
val watermarkInJoinKeys = StreamingJoinHelper.isWatermarkInJoinKeys(subPlan)
Expand All @@ -311,7 +311,8 @@ object UnsupportedOperationChecker extends Logging {
left.outputSet, right.outputSet, condition, Some(1000000)).isDefined

if (!watermarkInJoinKeys && !hasValidWatermarkRange) {
throwError("Stream-stream outer join between two streaming DataFrame/Datasets " +
throwError(
s"Stream-stream $joinType join between two streaming DataFrame/Datasets " +
"is not supported without a watermark in the join keys, or a watermark on " +
"the nullable side and an appropriate range condition")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,16 @@ class JoinedRow extends InternalRow {
this
}

/** Gets this JoinedRow's left base row. */
def getLeft: InternalRow = {
row1
}

/** Gets this JoinedRow's right base row. */
def getRight: InternalRow = {
row2
}

override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = {
assert(fieldTypes.length == row1.numFields + row2.numFields)
val (left, right) = fieldTypes.splitAt(row1.numFields)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -490,15 +490,77 @@ class UnsupportedOperationsSuite extends SparkFunSuite {
_.join(_, joinType = LeftSemi),
streamStreamSupported = false,
batchStreamSupported = false,
expectedMsg = "left semi/anti joins")
expectedMsg = "LeftSemi join")

// Left semi joins: update and complete mode not allowed
c21 marked this conversation as resolved.
Show resolved Hide resolved
assertNotSupportedInStreamingPlan(
"left semi join with stream-stream relations and update mode",
streamRelation.join(streamRelation, joinType = LeftSemi,
condition = Some(attribute === attribute)),
OutputMode.Update(),
Seq("is not supported in Update output mode"))
assertNotSupportedInStreamingPlan(
"left semi join with stream-stream relations and complete mode",
Aggregate(Nil, aggExprs("d"), streamRelation.join(streamRelation, joinType = LeftSemi,
condition = Some(attribute === attribute))),
OutputMode.Complete(),
Seq("is not supported in Complete output mode"))

// Left semi joins: stream-stream allowed with join on watermark attribute
// Note that the attribute need not be watermarked on both sides.
assertSupportedInStreamingPlan(
"left semi join with stream-stream relations and join on attribute with left watermark",
streamRelation.join(streamRelation, joinType = LeftSemi,
condition = Some(attributeWithWatermark === attribute)),
OutputMode.Append())
assertSupportedInStreamingPlan(
"left semi join with stream-stream relations and join on attribute with right watermark",
streamRelation.join(streamRelation, joinType = LeftSemi,
condition = Some(attribute === attributeWithWatermark)),
OutputMode.Append())
assertNotSupportedInStreamingPlan(
"left semi join with stream-stream relations and join on non-watermarked attribute",
streamRelation.join(streamRelation, joinType = LeftSemi,
condition = Some(attribute === attribute)),
OutputMode.Append(),
Seq("without a watermark in the join keys"))

// Left semi joins: stream-stream allowed with range condition yielding state value watermark
assertSupportedInStreamingPlan(
"left semi join with stream-stream relations and state value watermark", {
val leftRelation = streamRelation
val rightTimeWithWatermark =
AttributeReference("b", IntegerType)().withMetadata(watermarkMetadata)
val rightRelation = new TestStreamingRelation(rightTimeWithWatermark)
leftRelation.join(
rightRelation,
joinType = LeftSemi,
condition = Some(attribute > rightTimeWithWatermark + 10))
},
OutputMode.Append())

// Left semi joins: stream-stream not allowed with insufficient range condition
assertNotSupportedInStreamingPlan(
"left semi join with stream-stream relations and state value watermark", {
val leftRelation = streamRelation
val rightTimeWithWatermark =
AttributeReference("b", IntegerType)().withMetadata(watermarkMetadata)
val rightRelation = new TestStreamingRelation(rightTimeWithWatermark)
leftRelation.join(
rightRelation,
joinType = LeftSemi,
condition = Some(attribute < rightTimeWithWatermark + 10))
},
OutputMode.Append(),
Seq("appropriate range condition"))

// Left anti joins: stream-* not allowed
testBinaryOperationInStreamingPlan(
"left anti join",
_.join(_, joinType = LeftAnti),
streamStreamSupported = false,
batchStreamSupported = false,
expectedMsg = "left semi/anti joins")
expectedMsg = "Left anti join")

// Right outer joins: stream-* not allowed
testBinaryOperationInStreamingPlan(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@ case class StreamingSymmetricHashJoinExec(
}

if (stateFormatVersion < 2 && joinType != Inner) {
throw new IllegalArgumentException("The query is using stream-stream outer join with state" +
throw new IllegalArgumentException(
s"The query is using stream-stream $joinType join with state" +
s" format version ${stateFormatVersion} - correctness issue is discovered. Please discard" +
" the checkpoint and rerun the query. See SPARK-26154 for more details.")
}
Expand All @@ -165,7 +166,7 @@ case class StreamingSymmetricHashJoinExec(
}

require(
joinType == Inner || joinType == LeftOuter || joinType == RightOuter,
joinType == Inner || joinType == LeftOuter || joinType == RightOuter || joinType == LeftSemi,
errorMessageForJoinType)
require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType))

Expand All @@ -185,6 +186,7 @@ case class StreamingSymmetricHashJoinExec(
case _: InnerLike => left.output ++ right.output
case LeftOuter => left.output ++ right.output.map(_.withNullability(true))
case RightOuter => left.output.map(_.withNullability(true)) ++ right.output
case LeftSemi => left.output
case _ => throwBadJoinTypeException()
}

Expand All @@ -193,6 +195,7 @@ case class StreamingSymmetricHashJoinExec(
PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning))
case LeftOuter => left.outputPartitioning
case RightOuter => right.outputPartitioning
case LeftSemi => left.outputPartitioning
case _ => throwBadJoinTypeException()
}

Expand Down Expand Up @@ -246,37 +249,43 @@ case class StreamingSymmetricHashJoinExec(

// Join one side input using the other side's buffered/state rows. Here is how it is done.
//
// - `leftSideJoiner.storeAndJoinWithOtherSide(rightSideJoiner)` generates all rows from
// matching new left input with stored right input, and also stores all the left input
// - `leftSideJoiner.storeAndJoinWithOtherSide(rightSideJoiner)`
// - Inner, Left Outer, Right Outer Join: generates all rows from matching new left input
// with stored right input, and also stores all the left input.
// - Left Semi Join: generates all new left input rows from matching new left input with
// stored right input, and also stores all the non-matched left input.
//
// - `rightSideJoiner.storeAndJoinWithOtherSide(leftSideJoiner)` generates all rows from
// matching new right input with stored left input, and also stores all the right input.
// It also generates all rows from matching new left input with new right input, since
// the new left input has become stored by that point. This tiny asymmetry is necessary
// to avoid duplication.
// - `rightSideJoiner.storeAndJoinWithOtherSide(leftSideJoiner)`
// - Inner, Left Outer, Right Outer Join: generates all rows from matching new right input
// with stored left input, and also stores all the right input.
// It also generates all rows from matching new left input with new right input, since
// the new left input has become stored by that point. This tiny asymmetry is necessary
// to avoid duplication.
// - Left Semi Join: generates all stored left input rows, from matching new right input
// with stored left input, and also stores all the right input. Note only first-time
// matched left input rows will be generated, this is to guarantee left semi semantics.
val leftOutputIter = leftSideJoiner.storeAndJoinWithOtherSide(rightSideJoiner) {
(input: InternalRow, matched: InternalRow) => joinedRow.withLeft(input).withRight(matched)
}
val rightOutputIter = rightSideJoiner.storeAndJoinWithOtherSide(leftSideJoiner) {
(input: InternalRow, matched: InternalRow) => joinedRow.withLeft(matched).withRight(input)
}

// We need to save the time that the inner join output iterator completes, since outer join
// output counts as both update and removal time.
var innerOutputCompletionTimeNs: Long = 0
def onInnerOutputCompletion = {
innerOutputCompletionTimeNs = System.nanoTime
// We need to save the time that the one side hash join output iterator completes, since
// other join output counts as both update and removal time.
var hashJoinOutputCompletionTimeNs: Long = 0
def onHashJoinOutputCompletion(): Unit = {
hashJoinOutputCompletionTimeNs = System.nanoTime
}
// This is the iterator which produces the inner join rows. For outer joins, this will be
// prepended to a second iterator producing outer join rows; for inner joins, this is the full
// output.
val innerOutputIter = CompletionIterator[InternalRow, Iterator[InternalRow]](
(leftOutputIter ++ rightOutputIter), onInnerOutputCompletion)

// This is the iterator which produces the inner and left semi join rows. For other joins,
// this will be prepended to a second iterator producing other rows; for inner and left semi
// joins, this is the full output.
val hashJoinOutputIter = CompletionIterator[InternalRow, Iterator[InternalRow]](
leftOutputIter ++ rightOutputIter, onHashJoinOutputCompletion())

val outputIter: Iterator[InternalRow] = joinType match {
case Inner =>
innerOutputIter
case Inner | LeftSemi =>
hashJoinOutputIter
case LeftOuter =>
// We generate the outer join input by:
// * Getting an iterator over the rows that have aged out on the left side. These rows are
Expand Down Expand Up @@ -311,7 +320,7 @@ case class StreamingSymmetricHashJoinExec(
}
}.map(pair => joinedRow.withLeft(pair.value).withRight(nullRight))

innerOutputIter ++ outerOutputIter
hashJoinOutputIter ++ outerOutputIter
case RightOuter =>
// See comments for left outer case.
def matchesWithLeftSideState(rightKeyValue: UnsafeRowPair) = {
Expand All @@ -330,11 +339,15 @@ case class StreamingSymmetricHashJoinExec(
}
}.map(pair => joinedRow.withLeft(nullLeft).withRight(pair.value))

innerOutputIter ++ outerOutputIter
hashJoinOutputIter ++ outerOutputIter
case _ => throwBadJoinTypeException()
}

val outputProjection = UnsafeProjection.create(left.output ++ right.output, output)
val outputProjection = if (joinType == LeftSemi) {
UnsafeProjection.create(output, output)
} else {
UnsafeProjection.create(left.output ++ right.output, output)
}
val outputIterWithMetrics = outputIter.map { row =>
numOutputRows += 1
outputProjection(row)
Expand All @@ -345,24 +358,28 @@ case class StreamingSymmetricHashJoinExec(
// All processing time counts as update time.
allUpdatesTimeMs += math.max(NANOSECONDS.toMillis(System.nanoTime - updateStartTimeNs), 0)

// Processing time between inner output completion and here comes from the outer portion of a
// join, and thus counts as removal time as we remove old state from one side while iterating.
if (innerOutputCompletionTimeNs != 0) {
// Processing time between one side hash join output completion and here comes from the
// outer portion of a join, and thus counts as removal time as we remove old state from
// one side while iterating.
if (hashJoinOutputCompletionTimeNs != 0) {
allRemovalsTimeMs +=
math.max(NANOSECONDS.toMillis(System.nanoTime - innerOutputCompletionTimeNs), 0)
math.max(NANOSECONDS.toMillis(System.nanoTime - hashJoinOutputCompletionTimeNs), 0)
}

allRemovalsTimeMs += timeTakenMs {
// Remove any remaining state rows which aren't needed because they're below the watermark.
//
// For inner joins, we have to remove unnecessary state rows from both sides if possible.
// For inner and left semi joins, we have to remove unnecessary state rows from both sides
// if possible.
//
// For outer joins, we have already removed unnecessary state rows from the outer side
// (e.g., left side for left outer join) while generating the outer "null" outputs. Now, we
// have to remove unnecessary state rows from the other side (e.g., right side for the left
// outer join) if possible. In all cases, nothing needs to be outputted, hence the removal
// needs to be done greedily by immediately consuming the returned iterator.
val cleanupIter = joinType match {
case Inner => leftSideJoiner.removeOldState() ++ rightSideJoiner.removeOldState()
case Inner | LeftSemi =>
leftSideJoiner.removeOldState() ++ rightSideJoiner.removeOldState()
case LeftOuter => rightSideJoiner.removeOldState()
case RightOuter => leftSideJoiner.removeOldState()
case _ => throwBadJoinTypeException()
Expand Down Expand Up @@ -481,6 +498,26 @@ case class StreamingSymmetricHashJoinExec(
case _ => (_: InternalRow) => Iterator.empty
}

val excludeRowsAlreadyMatched = joinType == LeftSemi && joinSide == RightSide

val generateOutputIter: (InternalRow, Iterator[JoinedRow]) => Iterator[InternalRow] =
joinSide match {
case LeftSide if joinType == LeftSemi =>
(input: InternalRow, joinedRowIter: Iterator[JoinedRow]) =>
// For left side of left semi join, generate one left row if there is matched
// rows from right side. Otherwise, generate nothing.
if (joinedRowIter.nonEmpty) {
Iterator.single(input)
} else {
Iterator.empty
}
case RightSide if joinType == LeftSemi =>
(_: InternalRow, joinedRowIter: Iterator[JoinedRow]) =>
// For right side of left semi join, generate matched left rows only.
joinedRowIter.map(_.getLeft)
case _ => (_: InternalRow, joinedRowIter: Iterator[JoinedRow]) => joinedRowIter
}

nonLateRows.flatMap { row =>
val thisRow = row.asInstanceOf[UnsafeRow]
// If this row fails the pre join filter, that means it can never satisfy the full join
Expand All @@ -489,8 +526,12 @@ case class StreamingSymmetricHashJoinExec(
// the case of inner join).
if (preJoinFilter(thisRow)) {
val key = keyGenerator(thisRow)
val outputIter: Iterator[JoinedRow] = otherSideJoiner.joinStateManager
.getJoinedRows(key, thatRow => generateJoinedRow(thisRow, thatRow), postJoinFilter)
val joinedRowIter: Iterator[JoinedRow] = otherSideJoiner.joinStateManager.getJoinedRows(
key,
thatRow => generateJoinedRow(thisRow, thatRow),
postJoinFilter,
excludeRowsAlreadyMatched)
val outputIter = generateOutputIter(thisRow, joinedRowIter)
new AddingProcessedRowToStateCompletionIterator(key, thisRow, outputIter)
} else {
generateFilteredJoinedRow(thisRow)
Expand All @@ -501,13 +542,19 @@ case class StreamingSymmetricHashJoinExec(
private class AddingProcessedRowToStateCompletionIterator(
key: UnsafeRow,
thisRow: UnsafeRow,
subIter: Iterator[JoinedRow])
extends CompletionIterator[JoinedRow, Iterator[JoinedRow]](subIter) {
subIter: Iterator[InternalRow])
extends CompletionIterator[InternalRow, Iterator[InternalRow]](subIter) {

private val iteratorNotEmpty: Boolean = super.hasNext

override def completion(): Unit = {
val shouldAddToState = // add only if both removal predicates do not match
!stateKeyWatermarkPredicateFunc(key) && !stateValueWatermarkPredicateFunc(thisRow)
val isLeftSemiWithMatch =
joinType == LeftSemi && joinSide == LeftSide && iteratorNotEmpty
// Add to state store only if both removal predicates do not match,
// and the row is not matched for left side of left semi join.
val shouldAddToState =
!stateKeyWatermarkPredicateFunc(key) && !stateValueWatermarkPredicateFunc(thisRow) &&
!isLeftSemiWithMatch
if (shouldAddToState) {
joinStateManager.append(key, thisRow, matched = iteratorNotEmpty)
updatedStateRowsCount += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,20 @@ class SymmetricHashJoinStateManager(
/**
* Get all the matched values for given join condition, with marking matched.
* This method is designed to mark joined rows properly without exposing internal index of row.
*
* @param excludeRowsAlreadyMatched Do not join with rows already matched previously.
* This is used for right side of left semi join in
* [[StreamingSymmetricHashJoinExec]] only.
*/
def getJoinedRows(
key: UnsafeRow,
generateJoinedRow: InternalRow => JoinedRow,
predicate: JoinedRow => Boolean): Iterator[JoinedRow] = {
predicate: JoinedRow => Boolean,
excludeRowsAlreadyMatched: Boolean = false): Iterator[JoinedRow] = {
val numValues = keyToNumValues.get(key)
keyWithIndexToValue.getAll(key, numValues).map { keyIdxToValue =>
keyWithIndexToValue.getAll(key, numValues).filterNot { keyIdxToValue =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes more sense to add this filter logic in the predicate param(i.e postJoinFilter for OneSideHashJoiner) for rightSideJoiner only, corresponding to the comment https://github.com/apache/spark/pull/30076/files#diff-6cd66da710d8d54025c1edf658bbec5230e8b4e748f9f2f884a60b1ba1efed42R264

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought this first and not proposed because current predicate cannot check the condition. We can still do this via adjusting the type of predicate a bit, but I guess the followup PR would try to separate left semi case of performance which lets us to can revert the change here. For the reason I prefer the small change for now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, after taking a further look, the joinedRow already dropped the message of matched, so it's hard to do now. +1 for the change now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI I created https://issues.apache.org/jira/browse/SPARK-33211 for this followup.

excludeRowsAlreadyMatched && keyIdxToValue.matched
}.map { keyIdxToValue =>
val joinedRow = generateJoinedRow(keyIdxToValue.value)
if (predicate(joinedRow)) {
if (!keyIdxToValue.matched) {
Expand Down
Loading