Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -191,16 +191,22 @@ case class SortMergeJoinExec(
}

private def genComparison(ctx: CodegenContext, a: Seq[ExprCode], b: Seq[ExprCode]): String = {
val comparisons = a.zip(b).zipWithIndex.map { case ((l, r), i) =>
// The first key compare always runs, so emit it unguarded. Each subsequent key compare runs
// only when previous keys were equal (comp == 0).
val pairs = a.zip(b).zipWithIndex
val firstCompare = pairs.headOption.map { case ((l, r), i) =>
s"comp = ${ctx.genComp(leftKeys(i).dataType, l.value, r.value)};"
}.getOrElse("comp = 0;")
val restCompares = pairs.tail.map { case ((l, r), i) =>
s"""
|if (comp == 0) {
| comp = ${ctx.genComp(leftKeys(i).dataType, l.value, r.value)};
|}
""".stripMargin.trim
}
s"""
|comp = 0;
|${comparisons.mkString("\n")}
|$firstCompare
|${restCompares.mkString("\n")}
""".stripMargin
}

Expand All @@ -216,10 +222,14 @@ case class SortMergeJoinExec(
val streamedRow = ctx.addMutableState("InternalRow", "streamedRow", forceInline = true)
val bufferedRow = ctx.addMutableState("InternalRow", "bufferedRow", forceInline = true)

// Create variables for join keys from both sides.
// Create variables for join keys from both sides. When all keys are statically non-nullable,
// `isNull` is `FalseLiteral` for each and the null-check disjunction would always be `false`;
// we skip emitting the check (and the dead handler branch) entirely in that case.
val streamedKeyVars = createJoinKey(ctx, streamedRow, streamedKeys, streamedOutput)
val streamedKeysNullable = streamedKeyVars.exists(_.isNull != FalseLiteral)
val streamedAnyNull = streamedKeyVars.map(_.isNull).mkString(" || ")
val bufferedKeyTmpVars = createJoinKey(ctx, bufferedRow, bufferedKeys, bufferedOutput)
val bufferedKeysNullable = bufferedKeyTmpVars.exists(_.isNull != FalseLiteral)
val bufferedAnyNull = bufferedKeyTmpVars.map(_.isNull).mkString(" || ")
// Copy the buffered key as class members so they could be used in next function call.
val bufferedKeyVars = copyKeys(ctx, bufferedKeyTmpVars)
Expand Down Expand Up @@ -287,6 +297,27 @@ case class SortMergeJoinExec(
s"$matches.add((UnsafeRow) $bufferedRow);"
}

val checkStreamedAnyNull = if (streamedKeysNullable) {
s"""
|if ($streamedAnyNull) {
| $handleStreamedAnyNull
|}
""".stripMargin
} else {
""
}

val checkBufferedAnyNull = if (bufferedKeysNullable) {
s"""
|if ($bufferedAnyNull) {
| $bufferedRow = null;
| continue;
|}
""".stripMargin
} else {
""
}

// Generate a function to scan both streamed and buffered sides to find a match.
// Return whether a match is found.
//
Expand Down Expand Up @@ -329,9 +360,7 @@ case class SortMergeJoinExec(
| if (!streamedIter.hasNext()) return false;
| $streamedRow = (InternalRow) streamedIter.next();
| ${streamedKeyVars.map(_.code).mkString("\n")}
| if ($streamedAnyNull) {
| $handleStreamedAnyNull
| }
| $checkStreamedAnyNull
| if (!$matches.isEmpty()) {
| ${genComparison(ctx, streamedKeyVars, matchedKeyVars)}
| if (comp == 0) {
Expand All @@ -348,10 +377,7 @@ case class SortMergeJoinExec(
| }
| $bufferedRow = (InternalRow) bufferedIter.next();
| ${bufferedKeyTmpVars.map(_.code).mkString("\n")}
| if ($bufferedAnyNull) {
| $bufferedRow = null;
| continue;
| }
| $checkBufferedAnyNull
| ${bufferedKeyVars.map(_.code).mkString("\n")}
| }
| ${genComparison(ctx, streamedKeyVars, bufferedKeyVars)}
Expand Down Expand Up @@ -802,10 +828,13 @@ case class SortMergeJoinExec(
val leftInputRow = ctx.addMutableState("InternalRow", "leftInputRow", forceInline = true)
val rightInputRow = ctx.addMutableState("InternalRow", "rightInputRow", forceInline = true)

// Create variables for join keys from both sides.
// Create variables for join keys from both sides. As in `genScanner`, omit the null check
// when all keys are statically non-nullable.
val leftKeyVars = createJoinKey(ctx, leftInputRow, leftKeys, left.output)
val leftKeysNullable = leftKeyVars.exists(_.isNull != FalseLiteral)
val leftAnyNull = leftKeyVars.map(_.isNull).mkString(" || ")
val rightKeyVars = createJoinKey(ctx, rightInputRow, rightKeys, right.output)
val rightKeysNullable = rightKeyVars.exists(_.isNull != FalseLiteral)
val rightAnyNull = rightKeyVars.map(_.isNull).mkString(" || ")
val matchedKeyVars = copyKeys(ctx, leftKeyVars)
val leftMatchedKeyVars = createJoinKey(ctx, leftInputRow, leftKeys, left.output)
Expand Down Expand Up @@ -880,6 +909,30 @@ case class SortMergeJoinExec(
// - Step 3: Buffer rows with same join keys from both sides into `leftBuffer` and
// `rightBuffer`. Reset bit sets for both buffers accordingly (`leftMatched` and
// `rightMatched`).
val checkLeftAnyNull = if (leftKeysNullable) {
s"""
|if ($leftAnyNull) {
| // The left row join key is null, join it with null row
| $outputLeftNoMatch
| return;
|}
""".stripMargin
} else {
""
}

val checkRightAnyNull = if (rightKeysNullable) {
s"""
|if ($rightAnyNull) {
| // The right row join key is null, join it with null row
| $outputRightNoMatch
| return;
|}
""".stripMargin
} else {
""
}

val findNextJoinRowsFuncName = ctx.freshName("findNextJoinRows")
ctx.addNewFunction(findNextJoinRowsFuncName,
s"""
Expand All @@ -898,18 +951,10 @@ case class SortMergeJoinExec(
| }
|
| ${leftKeyVars.map(_.code).mkString("\n")}
| if ($leftAnyNull) {
| // The left row join key is null, join it with null row
| $outputLeftNoMatch
| return;
| }
| $checkLeftAnyNull
|
| ${rightKeyVars.map(_.code).mkString("\n")}
| if ($rightAnyNull) {
| // The right row join key is null, join it with null row
| $outputRightNoMatch
| return;
| }
| $checkRightAnyNull
|
| ${genComparison(ctx, leftKeyVars, rightKeyVars)}
| if (comp < 0) {
Expand Down