From 8e244d2d712a541fab5f7b92589657ba51820d5f Mon Sep 17 00:00:00 2001 From: zml1206 Date: Tue, 9 Jan 2024 15:43:32 +0800 Subject: [PATCH 1/2] Fix oom by code-gen for full outer sort merge join Fix oom by code-gen for full outer sort merge join --- .../execution/joins/SortMergeJoinExec.scala | 37 ++++++++++++++----- 1 file changed, 28 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 8d49b1558d687..0b594ac36085f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -815,8 +815,10 @@ case class SortMergeJoinExec( v => s"$v = new $matchedClsName(1);", forceInline = true) val rightMatched = ctx.addMutableState(matchedClsName, "rightMatched", v => s"$v = new $matchedClsName(1);", forceInline = true) - val leftIndex = ctx.freshName("leftIndex") - val rightIndex = ctx.freshName("rightIndex") + val leftIndex = ctx.addMutableState(CodeGenerator.JAVA_INT, "leftIndex", + v => s"$v = 0;", forceInline = true) + val rightIndex = ctx.addMutableState(CodeGenerator.JAVA_INT, "rightIndex", + v => s"$v = 0;", forceInline = true) // Generate code for join condition val leftResultVars = genOneSideJoinVars( @@ -878,6 +880,8 @@ case class SortMergeJoinExec( | int comp = 0; | $leftBuffer.clear(); | $rightBuffer.clear(); + | $leftIndex = 0; + | $rightIndex = 0; | | if ($leftInputRow == null) { | $leftInputRow = (InternalRow) leftIter.next(); @@ -959,38 +963,53 @@ case class SortMergeJoinExec( // Scan the left and right buffers to find all matched rows. val matchRowsInBuffer = s""" - |int $leftIndex; - |int $rightIndex; - | - |for ($leftIndex = 0; $leftIndex < $leftBuffer.size(); $leftIndex++) { + |while ($leftIndex < $leftBuffer.size()) { | $leftOutputRow = (InternalRow) $leftBuffer.get($leftIndex); - | for ($rightIndex = 0; $rightIndex < $rightBuffer.size(); $rightIndex++) { + | while ($rightIndex < $rightBuffer.size()) { | $rightOutputRow = (InternalRow) $rightBuffer.get($rightIndex); | $conditionCheck { | $consumeFullOuterJoinRow(); | $leftMatched.set($leftIndex); | $rightMatched.set($rightIndex); + | if (shouldStop()) { + | $rightIndex++; + | return; + | } | } + | $rightIndex++; | } - | + | $rightIndex = 0; | if (!$leftMatched.get($leftIndex)) { | | $rightOutputRow = null; | $consumeFullOuterJoinRow(); + | if (shouldStop()) { + | $leftIndex++; + | return; + | } | } + | $leftIndex++; |} | |$leftOutputRow = null; - |for ($rightIndex = 0; $rightIndex < $rightBuffer.size(); $rightIndex++) { + |while ($rightIndex < $rightBuffer.size()) { | if (!$rightMatched.get($rightIndex)) { | // The right row has never matched any left row, join it with null row | $rightOutputRow = (InternalRow) $rightBuffer.get($rightIndex); | $consumeFullOuterJoinRow(); + | if (shouldStop()) { + | $rightIndex++; + | return; + | } | } + | $rightIndex++; |} """.stripMargin s""" + |if ($leftIndex < $leftBuffer.size() || $rightIndex < $rightBuffer.size()) { + | $matchRowsInBuffer + |} |while (($leftInputRow != null || $leftInput.hasNext()) && | ($rightInputRow != null || $rightInput.hasNext())) { | $findNextJoinRowsFuncName($leftInput, $rightInput); From f34e0a5d9fcaf42b2d7465ae21ae3e9b260df07d Mon Sep 17 00:00:00 2001 From: zml1206 Date: Wed, 24 Jan 2024 18:00:27 +0800 Subject: [PATCH 2/2] wrap "$matchRowsInBuffer" as a function --- .../execution/joins/SortMergeJoinExec.scala | 76 ++++++++++--------- 1 file changed, 40 insertions(+), 36 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 0b594ac36085f..92ffdbfaebad4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -961,59 +961,63 @@ case class SortMergeJoinExec( """.stripMargin) // Scan the left and right buffers to find all matched rows. - val matchRowsInBuffer = + val matchRowsInBufferFuncName = ctx.freshName("matchRowsInBuffer") + ctx.addNewFunction(matchRowsInBufferFuncName, s""" - |while ($leftIndex < $leftBuffer.size()) { - | $leftOutputRow = (InternalRow) $leftBuffer.get($leftIndex); - | while ($rightIndex < $rightBuffer.size()) { - | $rightOutputRow = (InternalRow) $rightBuffer.get($rightIndex); - | $conditionCheck { + |private boolean $matchRowsInBufferFuncName() throws java.io.IOException { + | while ($leftIndex < $leftBuffer.size()) { + | $leftOutputRow = (InternalRow) $leftBuffer.get($leftIndex); + | while ($rightIndex < $rightBuffer.size()) { + | $rightOutputRow = (InternalRow) $rightBuffer.get($rightIndex); + | $conditionCheck { + | $consumeFullOuterJoinRow(); + | $leftMatched.set($leftIndex); + | $rightMatched.set($rightIndex); + | if (shouldStop()) { + | $rightIndex++; + | return true; + | } + | } + | $rightIndex++; + | } + | $rightIndex = 0; + | if (!$leftMatched.get($leftIndex)) { + | $rightOutputRow = null; | $consumeFullOuterJoinRow(); - | $leftMatched.set($leftIndex); - | $rightMatched.set($rightIndex); | if (shouldStop()) { - | $rightIndex++; - | return; + | $leftIndex++; + | return true; | } | } - | $rightIndex++; + | $leftIndex++; | } - | $rightIndex = 0; - | if (!$leftMatched.get($leftIndex)) { | - | $rightOutputRow = null; - | $consumeFullOuterJoinRow(); - | if (shouldStop()) { - | $leftIndex++; - | return; - | } - | } - | $leftIndex++; - |} - | - |$leftOutputRow = null; - |while ($rightIndex < $rightBuffer.size()) { - | if (!$rightMatched.get($rightIndex)) { - | // The right row has never matched any left row, join it with null row - | $rightOutputRow = (InternalRow) $rightBuffer.get($rightIndex); - | $consumeFullOuterJoinRow(); - | if (shouldStop()) { - | $rightIndex++; - | return; + | $leftOutputRow = null; + | while ($rightIndex < $rightBuffer.size()) { + | if (!$rightMatched.get($rightIndex)) { + | // The right row has never matched any left row, join it with null row + | $rightOutputRow = (InternalRow) $rightBuffer.get($rightIndex); + | $consumeFullOuterJoinRow(); + | if (shouldStop()) { + | $rightIndex++; + | return true; + | } | } + | $rightIndex++; | } - | $rightIndex++; + | return false; |} - """.stripMargin + """.stripMargin) s""" + |// If we already buffered some matching rows, use them directly |if ($leftIndex < $leftBuffer.size() || $rightIndex < $rightBuffer.size()) { - | $matchRowsInBuffer + | if($matchRowsInBufferFuncName()) return; |} |while (($leftInputRow != null || $leftInput.hasNext()) && | ($rightInputRow != null || $rightInput.hasNext())) { | $findNextJoinRowsFuncName($leftInput, $rightInput); - | $matchRowsInBuffer + | $matchRowsInBufferFuncName(); | if (shouldStop()) return; |} |