diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSpillWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSpillWriter.java index b0e2b6022ef21..8e0a21ec6b3a5 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSpillWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSpillWriter.java @@ -69,13 +69,13 @@ public final class UnsafeShuffleSpillWriter { private final LinkedList spills = new LinkedList(); public UnsafeShuffleSpillWriter( - TaskMemoryManager memoryManager, - ShuffleMemoryManager shuffleMemoryManager, - BlockManager blockManager, - TaskContext taskContext, - int initialSize, - int numPartitions, - SparkConf conf) throws IOException { + TaskMemoryManager memoryManager, + ShuffleMemoryManager shuffleMemoryManager, + BlockManager blockManager, + TaskContext taskContext, + int initialSize, + int numPartitions, + SparkConf conf) throws IOException { this.memoryManager = memoryManager; this.shuffleMemoryManager = shuffleMemoryManager; this.blockManager = blockManager; @@ -266,7 +266,7 @@ public SpillInfo[] closeAndGetSpills() throws IOException { if (sorter != null) { writeSpillFile(); } - return (SpillInfo[]) spills.toArray(); + return spills.toArray(new SpillInfo[0]); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index 839c854963ccf..47fe214634abb 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -23,9 +23,12 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.FileChannel; +import java.util.Iterator; +import org.apache.spark.shuffle.ShuffleMemoryManager; import scala.Option; import scala.Product2; +import scala.collection.JavaConversions; import scala.reflect.ClassTag; import scala.reflect.ClassTag$; @@ -50,14 +53,18 @@ public class UnsafeShuffleWriter implements ShuffleWriter { private static final int SER_BUFFER_SIZE = 1024 * 1024; // TODO: tune this private static final ClassTag OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object(); + private final BlockManager blockManager; private final IndexShuffleBlockManager shuffleBlockManager; - private final BlockManager blockManager = SparkEnv.get().blockManager(); - private final int shuffleId; - private final int mapId; private final TaskMemoryManager memoryManager; + private final ShuffleMemoryManager shuffleMemoryManager; private final SerializerInstance serializer; private final Partitioner partitioner; private final ShuffleWriteMetrics writeMetrics; + private final int shuffleId; + private final int mapId; + private final TaskContext taskContext; + private final SparkConf sparkConf; + private MapStatus mapStatus = null; /** @@ -68,19 +75,31 @@ public class UnsafeShuffleWriter implements ShuffleWriter { private boolean stopping = false; public UnsafeShuffleWriter( + BlockManager blockManager, IndexShuffleBlockManager shuffleBlockManager, + TaskMemoryManager memoryManager, + ShuffleMemoryManager shuffleMemoryManager, UnsafeShuffleHandle handle, int mapId, - TaskContext context) { + TaskContext taskContext, + SparkConf sparkConf) { + this.blockManager = blockManager; this.shuffleBlockManager = shuffleBlockManager; + this.memoryManager = memoryManager; + this.shuffleMemoryManager = shuffleMemoryManager; this.mapId = mapId; - this.memoryManager = context.taskMemoryManager(); final ShuffleDependency dep = handle.dependency(); this.shuffleId = dep.shuffleId(); this.serializer = Serializer.getSerializer(dep.serializer()).newInstance(); this.partitioner = dep.partitioner(); this.writeMetrics = new ShuffleWriteMetrics(); - context.taskMetrics().shuffleWriteMetrics_$eq(Option.apply(writeMetrics)); + taskContext.taskMetrics().shuffleWriteMetrics_$eq(Option.apply(writeMetrics)); + this.taskContext = taskContext; + this.sparkConf = sparkConf; + } + + public void write(Iterator> records) { + write(JavaConversions.asScalaIterator(records)); } public void write(scala.collection.Iterator> records) { @@ -101,12 +120,12 @@ private SpillInfo[] insertRecordsIntoSorter( scala.collection.Iterator> records) throws Exception { final UnsafeShuffleSpillWriter sorter = new UnsafeShuffleSpillWriter( memoryManager, - SparkEnv$.MODULE$.get().shuffleMemoryManager(), - SparkEnv$.MODULE$.get().blockManager(), - TaskContext.get(), + shuffleMemoryManager, + blockManager, + taskContext, 4096, // Initial size (TODO: tune this!) partitioner.numPartitions(), - SparkEnv$.MODULE$.get().conf() + sparkConf ); final byte[] serArray = new byte[SER_BUFFER_SIZE]; diff --git a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala index 0dd34b372f624..14f29a36ec4f6 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala @@ -88,12 +88,17 @@ private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManage context: TaskContext): ShuffleWriter[K, V] = { handle match { case unsafeShuffleHandle: UnsafeShuffleHandle[K, V] => + val env = SparkEnv.get // TODO: do we need to do anything to register the shuffle here? new UnsafeShuffleWriter( + env.blockManager, shuffleBlockResolver.asInstanceOf[IndexShuffleBlockManager], + context.taskMemoryManager(), + env.shuffleMemoryManager, unsafeShuffleHandle, mapId, - context) + context, + env.conf) case other => sortShuffleManager.getWriter(handle, mapId, context) } diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java new file mode 100644 index 0000000000000..8ba548420bd4b --- /dev/null +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +import java.io.File; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.ArrayList; +import java.util.UUID; + +import scala.*; +import scala.runtime.AbstractFunction1; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import static org.mockito.AdditionalAnswers.returnsFirstArg; +import static org.mockito.AdditionalAnswers.returnsSecondArg; +import static org.mockito.Mockito.*; + +import org.apache.spark.*; +import org.apache.spark.serializer.Serializer; +import org.apache.spark.shuffle.IndexShuffleBlockManager; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.executor.TaskMetrics; +import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.storage.*; +import org.apache.spark.unsafe.memory.ExecutorMemoryManager; +import org.apache.spark.unsafe.memory.MemoryAllocator; +import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.util.Utils; +import org.apache.spark.serializer.KryoSerializer; +import org.apache.spark.scheduler.MapStatus; + +public class UnsafeShuffleWriterSuite { + + final TaskMemoryManager memoryManager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + // Compute key prefixes based on the records' partition ids + final HashPartitioner hashPartitioner = new HashPartitioner(4); + + ShuffleMemoryManager shuffleMemoryManager; + BlockManager blockManager; + IndexShuffleBlockManager shuffleBlockManager; + DiskBlockManager diskBlockManager; + File tempDir; + TaskContext taskContext; + SparkConf sparkConf; + + private static final class CompressStream extends AbstractFunction1 { + @Override + public OutputStream apply(OutputStream stream) { + return stream; + } + } + + @Before + public void setUp() { + shuffleMemoryManager = mock(ShuffleMemoryManager.class); + diskBlockManager = mock(DiskBlockManager.class); + blockManager = mock(BlockManager.class); + shuffleBlockManager = mock(IndexShuffleBlockManager.class); + tempDir = new File(Utils.createTempDir$default$1()); + taskContext = mock(TaskContext.class); + sparkConf = new SparkConf(); + when(taskContext.taskMetrics()).thenReturn(new TaskMetrics()); + when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg()); + when(blockManager.diskBlockManager()).thenReturn(diskBlockManager); + when(diskBlockManager.createTempLocalBlock()).thenAnswer(new Answer>() { + @Override + public Tuple2 answer(InvocationOnMock invocationOnMock) throws Throwable { + TempLocalBlockId blockId = new TempLocalBlockId(UUID.randomUUID()); + File file = File.createTempFile("spillFile", ".spill", tempDir); + return Tuple2$.MODULE$.apply(blockId, file); + } + }); + when(blockManager.getDiskWriter( + any(BlockId.class), + any(File.class), + any(SerializerInstance.class), + anyInt(), + any(ShuffleWriteMetrics.class))).thenAnswer(new Answer() { + @Override + public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable { + Object[] args = invocationOnMock.getArguments(); + + return new DiskBlockObjectWriter( + (BlockId) args[0], + (File) args[1], + (SerializerInstance) args[2], + (Integer) args[3], + new CompressStream(), + false, + (ShuffleWriteMetrics) args[4] + ); + } + }); + when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class))) + .then(returnsSecondArg()); + } + + @Test + public void basicShuffleWriting() throws Exception { + + final ShuffleDependency dep = mock(ShuffleDependency.class); + when(dep.serializer()).thenReturn(Option.apply(new KryoSerializer(sparkConf))); + when(dep.partitioner()).thenReturn(hashPartitioner); + + final File mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir); + when(shuffleBlockManager.getDataFile(anyInt(), anyInt())).thenReturn(mergedOutputFile); + final long[] partitionSizes = new long[hashPartitioner.numPartitions()]; + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocationOnMock) throws Throwable { + long[] receivedPartitionSizes = (long[]) invocationOnMock.getArguments()[2]; + System.arraycopy( + receivedPartitionSizes, 0, partitionSizes, 0, receivedPartitionSizes.length); + return null; + } + }).when(shuffleBlockManager).writeIndexFile(anyInt(), anyInt(), any(long[].class)); + + final UnsafeShuffleWriter writer = new UnsafeShuffleWriter( + blockManager, + shuffleBlockManager, + memoryManager, + shuffleMemoryManager, + new UnsafeShuffleHandle(0, 1, dep), + 0, // map id + taskContext, + sparkConf + ); + + final ArrayList> numbersToSort = + new ArrayList>(); + numbersToSort.add(new Tuple2(5, 5)); + numbersToSort.add(new Tuple2(1, 1)); + numbersToSort.add(new Tuple2(3, 3)); + numbersToSort.add(new Tuple2(2, 2)); + numbersToSort.add(new Tuple2(4, 4)); + + + writer.write(numbersToSort.iterator()); + final MapStatus mapStatus = writer.stop(true).get(); + + long sumOfPartitionSizes = 0; + for (long size: partitionSizes) { + sumOfPartitionSizes += size; + } + Assert.assertEquals(mergedOutputFile.length(), sumOfPartitionSizes); + + // TODO: test that the temporary spill files were cleaned up after the merge. + } + +}