Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
Ngone51 committed Apr 23, 2020
1 parent 5ca2da0 commit c2a9294
Showing 1 changed file with 40 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, ExternalBlockStoreClient}
import org.apache.spark.network.util.LimitedInputStream
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.ShuffleBlockFetcherIterator.FetchBlockInfo
import org.apache.spark.util.Utils


Expand Down Expand Up @@ -1071,4 +1072,43 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
val e = intercept[FetchFailedException] { iterator.next() }
assert(e.getMessage.contains("Received a zero-size buffer"))
}

test("SPARK-31521: correct the fetch size when merging blocks into a merged block") {
val bId1 = ShuffleBlockBatchId(0, 0, 0, 5)
val bId2 = ShuffleBlockId(0, 0, 6)
val bId3 = ShuffleBlockId(0, 0, 7)
val block1 = FetchBlockInfo(bId1, 40, 0)
val block2 = FetchBlockInfo(bId2, 50, 0)
val block3 = FetchBlockInfo(bId3, 60, 0)
val inputBlocks = Seq(block1, block2, block3)

val blockManager = mock(classOf[BlockManager])
val localBmId = BlockManagerId("test-client", "test-client", 1)
val taskContext = TaskContext.empty()
doReturn(localBmId).when(blockManager).blockManagerId
val iterator = new ShuffleBlockFetcherIterator(
taskContext,
createMockTransfer(Map.empty),
blockManager,
Iterator.empty,
(_, in) => in,
48 * 1024 * 1024,
Int.MaxValue,
Int.MaxValue,
Int.MaxValue,
true,
false,
taskContext.taskMetrics.createTempShuffleReadMetrics(),
true)

val mergeMethod =
PrivateMethod[Seq[FetchBlockInfo]](Symbol("mergeContinuousShuffleBlockIdsIfNeeded"))
val mergedBlocks = iterator.invokePrivate(mergeMethod(inputBlocks))
assert(mergedBlocks.size === 1)
val mergedBlock = mergedBlocks.head
val mergedBlockId = mergedBlock.blockId.asInstanceOf[ShuffleBlockBatchId]
assert(mergedBlockId.startReduceId === bId1.startReduceId)
assert(mergedBlockId.endReduceId === bId3.reduceId + 1)
assert(mergedBlock.size === inputBlocks.map(_.size).sum)
}
}

0 comments on commit c2a9294

Please sign in to comment.