From b2294412bd75a83f8f55d531416cfc2bfff0111a Mon Sep 17 00:00:00 2001 From: Guoqiang Li Date: Tue, 16 Aug 2016 16:43:51 +0800 Subject: [PATCH] Replace ByteBuffer with ChunkedByteBuffer --- .../spark/network/buffer/Allocator.java | 24 ++ .../network/buffer/ChunkedByteBuffer.java | 292 ++++++++++++++++++ .../buffer/ChunkedByteBufferInputStream.java | 120 +++++++ .../buffer/ChunkedByteBufferOutputStream.java | 127 ++++++++ .../buffer/FileSegmentManagedBuffer.java | 18 +- .../spark/network/buffer/ManagedBuffer.java | 2 +- .../network/buffer/NettyManagedBuffer.java | 12 +- .../network/buffer/NioManagedBuffer.java | 18 +- .../network/client/RpcResponseCallback.java | 4 +- .../spark/network/client/TransportClient.java | 20 +- .../network/protocol/ChunkFetchSuccess.java | 3 + .../network/sasl/SaslClientBootstrap.java | 8 +- .../spark/network/sasl/SaslRpcHandler.java | 12 +- .../spark/network/server/NoOpRpcHandler.java | 4 +- .../spark/network/server/RpcHandler.java | 9 +- .../server/TransportRequestHandler.java | 3 +- .../network/ChunkFetchIntegrationSuite.java | 7 +- .../RequestTimeoutIntegrationSuite.java | 27 +- .../spark/network/RpcIntegrationSuite.java | 37 ++- .../org/apache/spark/network/StreamSuite.java | 3 +- .../spark/network/TestManagedBuffer.java | 5 +- .../TransportResponseHandlerSuite.java | 3 +- .../spark/network/sasl/SparkSaslSuite.java | 13 +- .../shuffle/ExternalShuffleBlockHandler.java | 7 +- .../shuffle/ExternalShuffleClient.java | 4 +- .../shuffle/OneForOneBlockFetcher.java | 5 +- .../mesos/MesosExternalShuffleClient.java | 8 +- .../protocol/BlockTransferMessage.java | 9 +- .../network/sasl/SaslIntegrationSuite.java | 21 +- .../shuffle/BlockTransferMessagesSuite.java | 3 +- .../ExternalShuffleBlockHandlerSuite.java | 20 +- .../ExternalShuffleIntegrationSuite.java | 4 +- .../shuffle/OneForOneBlockFetcherSuite.java | 7 +- .../serializer/DummySerializerInstance.java | 7 +- .../spark/broadcast/TorrentBroadcast.scala | 16 +- .../master/ZooKeeperPersistenceEngine.scala | 7 +- .../mesos/MesosExternalShuffleService.scala | 3 +- .../CoarseGrainedExecutorBackend.scala | 5 +- .../org/apache/spark/executor/Executor.scala | 17 +- .../spark/executor/ExecutorBackend.scala | 3 +- .../spark/executor/MesosExecutorBackend.scala | 7 +- .../spark/network/BlockTransferService.scala | 24 +- .../network/netty/NettyBlockRpcServer.scala | 10 +- .../netty/NettyBlockTransferService.scala | 11 +- .../apache/spark/rdd/PairRDDFunctions.scala | 13 +- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 13 +- .../org/apache/spark/rpc/netty/Outbox.scala | 9 +- .../apache/spark/scheduler/DAGScheduler.scala | 5 +- .../apache/spark/scheduler/ResultTask.scala | 3 +- .../spark/scheduler/ShuffleMapTask.scala | 3 +- .../org/apache/spark/scheduler/Task.scala | 17 +- .../spark/scheduler/TaskDescription.scala | 8 +- .../apache/spark/scheduler/TaskResult.scala | 15 +- .../spark/scheduler/TaskResultGetter.scala | 20 +- .../spark/scheduler/TaskSchedulerImpl.scala | 3 +- .../spark/scheduler/TaskSetManager.scala | 9 +- .../cluster/CoarseGrainedClusterMessage.scala | 13 +- .../CoarseGrainedSchedulerBackend.scala | 8 +- .../MesosFineGrainedSchedulerBackend.scala | 7 +- .../local/LocalSchedulerBackend.scala | 5 +- .../spark/serializer/JavaSerializer.scala | 17 +- .../spark/serializer/KryoSerializer.scala | 27 +- .../apache/spark/serializer/Serializer.scala | 7 +- .../spark/serializer/SerializerManager.scala | 6 +- .../apache/spark/storage/BlockManager.scala | 24 +- .../storage/BlockManagerManagedBuffer.scala | 3 +- .../org/apache/spark/storage/DiskStore.scala | 2 +- .../spark/storage/memory/MemoryStore.scala | 10 +- .../spark/util/io/ChunkedByteBuffer.scala | 219 ------------- .../io/ChunkedByteBufferOutputStream.scala | 113 ------- .../serializer/TestJavaSerializerImpl.java | 16 +- .../org/apache/spark/DistributedSuite.scala | 4 +- .../spark/broadcast/BroadcastSuite.scala | 5 +- .../master/CustomRecoveryModeFactory.scala | 5 +- .../apache/spark/executor/ExecutorSuite.scala | 3 +- .../spark/io/ChunkedByteBufferSuite.scala | 12 +- .../scala/org/apache/spark/rdd/RDDSuite.scala | 4 +- .../rpc/netty/NettyRpcHandlerSuite.scala | 3 +- .../spark/scheduler/TaskContextSuite.scala | 4 +- .../scheduler/TaskResultGetterSuite.scala | 10 +- ...esosFineGrainedSchedulerBackendSuite.scala | 10 +- .../serializer/KryoSerializerSuite.scala | 2 +- .../spark/serializer/TestSerializer.scala | 9 +- .../BlockStoreShuffleReaderSuite.scala | 4 +- .../spark/storage/BlockManagerSuite.scala | 3 +- .../apache/spark/storage/DiskStoreSuite.scala | 2 +- .../spark/storage/MemoryStoreSuite.scala | 4 +- .../ChunkedByteBufferOutputStreamSuite.scala | 19 +- .../sql/execution/UnsafeRowSerializer.scala | 8 +- .../execution/streaming/HDFSMetadataLog.scala | 5 +- .../rdd/WriteAheadLogBackedBlockRDD.scala | 6 +- .../receiver/ReceivedBlockHandler.scala | 6 +- .../streaming/ReceivedBlockHandlerSuite.scala | 4 +- 93 files changed, 1027 insertions(+), 669 deletions(-) create mode 100644 common/network-common/src/main/java/org/apache/spark/network/buffer/Allocator.java create mode 100644 common/network-common/src/main/java/org/apache/spark/network/buffer/ChunkedByteBuffer.java create mode 100644 common/network-common/src/main/java/org/apache/spark/network/buffer/ChunkedByteBufferInputStream.java create mode 100644 common/network-common/src/main/java/org/apache/spark/network/buffer/ChunkedByteBufferOutputStream.java delete mode 100644 core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala delete mode 100644 core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala diff --git a/common/network-common/src/main/java/org/apache/spark/network/buffer/Allocator.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/Allocator.java new file mode 100644 index 0000000000000..b8bde9cbafb11 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/Allocator.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.buffer; + +import java.nio.ByteBuffer; + +public interface Allocator { + ByteBuffer allocate(int len); +} \ No newline at end of file diff --git a/common/network-common/src/main/java/org/apache/spark/network/buffer/ChunkedByteBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/ChunkedByteBuffer.java new file mode 100644 index 0000000000000..d5dd80faecc0d --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/ChunkedByteBuffer.java @@ -0,0 +1,292 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.buffer; + +import java.io.*; +import java.nio.ByteBuffer; +import java.nio.MappedByteBuffer; +import java.nio.channels.WritableByteChannel; +import java.util.ArrayList; + +import sun.nio.ch.DirectBuffer; +import com.google.common.base.Preconditions; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.util.ByteArrayWritableChannel; + +public class ChunkedByteBuffer implements Externalizable { + private static final Logger logger = LoggerFactory.getLogger(ChunkedByteBuffer.class); + private static final int BUF_SIZE = 0x1000; // 4K + private static final ByteBuffer[] emptyChunks = new ByteBuffer[0]; + private ByteBuffer[] chunks = null; + private boolean disposed = false; + + // For deserialization only + public ChunkedByteBuffer() { + this(emptyChunks); + } + + /** + * Read-only byte buffer which is physically stored as multiple chunks rather than a single + * contiguous array. + * + * @param chunks an array of [[ByteBuffer]]s. Each buffer in this array must have position == 0. + * Ownership of these buffers is transferred to the ChunkedByteBuffer, so if these + * buffers may also be used elsewhere then the caller is responsible for copying + * them as needed. + */ + public ChunkedByteBuffer(ByteBuffer[] chunks) { + this.chunks = chunks; + Preconditions.checkArgument(chunks != null, "chunks must not be null"); + for (int i = 0; i < chunks.length; i++) { + ByteBuffer bytes = chunks[i]; + // Preconditions.checkArgument(bytes.remaining() > 0, "chunks must be non-empty"); + } + } + + public ChunkedByteBuffer(ByteBuffer chunk) { + this.chunks = new ByteBuffer[1]; + this.chunks[0] = chunk; + } + + public void writeExternal(ObjectOutput out) throws IOException { + out.writeBoolean(disposed); + out.writeInt(chunks.length); + byte[] buf = null; + for (int i = 0; i < chunks.length; i++) { + ByteBuffer buffer = chunks[i].duplicate(); + out.writeInt(buffer.remaining()); + if (buffer.hasArray()) { + out.write(buffer.array(), buffer.arrayOffset() + buffer.position(), + buffer.remaining()); + } else { + if (buf == null) buf = new byte[BUF_SIZE]; + while (buffer.hasRemaining()) { + int r = Math.min(BUF_SIZE, buffer.remaining()); + buffer.get(buf, 0, r); + out.write(buf, 0, r); + } + } + } + } + + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + this.disposed = in.readBoolean(); + ByteBuffer[] buffers = new ByteBuffer[in.readInt()]; + for (int i = 0; i < buffers.length; i++) { + int length = in.readInt(); + byte[] buffer = new byte[length]; + in.readFully(buffer); + buffers[i] = ByteBuffer.wrap(buffer); + } + this.chunks = buffers; + } + + /** + * This size of this buffer, in bytes. + */ + public long size() { + if (chunks == null) return 0L; + int i = 0; + long sum = 0L; + while (i < chunks.length) { + sum += chunks[i].remaining(); + i++; + } + return sum; + } + + /** + * Write this buffer to a channel. + */ + public void writeFully(WritableByteChannel channel) throws IOException { + for (int i = 0; i < chunks.length; i++) { + ByteBuffer bytes = chunks[i].duplicate(); + while (bytes.remaining() > 0) { + channel.write(bytes); + } + } + } + + /** + * Wrap this buffer to view it as a Netty ByteBuf. + */ + public ByteBuf toNetty() { + return Unpooled.wrappedBuffer(getChunks()); + } + + /** + * Copy this buffer into a new byte array. + * + * @throws UnsupportedOperationException if this buffer's size exceeds the maximum array size. + */ + public byte[] toArray() throws IOException, UnsupportedOperationException { + long len = size(); + if (len >= Integer.MAX_VALUE) { + throw new UnsupportedOperationException( + "cannot call toArray because buffer size (" + len + + " bytes) exceeds maximum array size"); + } + ByteArrayWritableChannel byteChannel = new ByteArrayWritableChannel((int) len); + writeFully(byteChannel); + byteChannel.close(); + return byteChannel.getData(); + } + + /** + * Copy this buffer into a new ByteBuffer. + * + * @throws UnsupportedOperationException if this buffer's size exceeds the max ByteBuffer size. + */ + public ByteBuffer toByteBuffer() throws IOException, UnsupportedOperationException { + if (chunks.length == 1) { + return chunks[0].duplicate(); + } else { + return ByteBuffer.wrap(this.toArray()); + } + } + + public ChunkedByteBufferInputStream toInputStream() { + return toInputStream(false); + } + + /** + * Creates an input stream to read data from this ChunkedByteBuffer. + * + * @param dispose if true, [[dispose()]] will be called at the end of the stream + * in order to close any memory-mapped files which back this buffer. + */ + public ChunkedByteBufferInputStream toInputStream(boolean dispose) { + return new ChunkedByteBufferInputStream(this, dispose); + } + + /** + * Make a copy of this ChunkedByteBuffer, copying all of the backing data into new buffers. + * The new buffer will share no resources with the original buffer. + * + * @param allocator a method for allocating byte buffers + */ + public ChunkedByteBuffer copy(Allocator allocator) { + ByteBuffer[] copiedChunks = new ByteBuffer[chunks.length]; + for (int i = 0; i < chunks.length; i++) { + ByteBuffer chunk = chunks[i].duplicate(); + ByteBuffer newChunk = allocator.allocate(chunk.remaining()); + newChunk.put(chunk); + newChunk.flip(); + copiedChunks[i] = newChunk; + } + return new ChunkedByteBuffer(copiedChunks); + } + + /** + * Get duplicates of the ByteBuffers backing this ChunkedByteBuffer. + */ + public ByteBuffer[] getChunks() { + ByteBuffer[] buffs = new ByteBuffer[chunks.length]; + for (int i = 0; i < chunks.length; i++) { + buffs[i] = chunks[i].duplicate(); + } + return buffs; + } + + + /** + * Attempt to clean up a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that + * might cause errors if one attempts to read from the unmapped buffer, but it's better than + * waiting for the GC to find it because that could lead to huge numbers of open files. There's + * unfortunately no standard API to do this. + */ + public void dispose() { + if (!disposed) { + for (int i = 0; i < chunks.length; i++) { + dispose(chunks[i]); + } + disposed = true; + } + } + + public ChunkedByteBuffer slice(long offset, long length) { + long thisSize = size(); + if (offset < 0 || offset > thisSize - length) { + throw new IndexOutOfBoundsException(String.format( + "index: %d, length: %d (expected: range(0, %d))", offset, length, thisSize)); + } + if (length == 0) { + return wrap(new ByteBuffer[0]); + } + ArrayList list = new ArrayList(); + int i = 0; + long sum = 0L; + while (i < chunks.length && length > 0) { + long lastSum = sum + chunks[i].remaining(); + if (lastSum > offset) { + ByteBuffer buffer = chunks[i].duplicate(); + int localLength = (int) Math.min(length, buffer.remaining()); + if (localLength < buffer.remaining()) { + buffer.limit(buffer.position() + localLength); + } + length -= localLength; + list.add(buffer); + } + sum = lastSum; + i++; + } + return wrap(list.toArray(new ByteBuffer[list.size()])); + } + + public ChunkedByteBuffer duplicate() { + return new ChunkedByteBuffer(getChunks()); + } + + public static void dispose(ByteBuffer buffer) { + if (buffer != null && buffer instanceof MappedByteBuffer) { + logger.trace("Unmapping" + buffer); + if (buffer instanceof DirectBuffer) { + DirectBuffer directBuffer = (DirectBuffer) buffer; + if (directBuffer.cleaner() != null) directBuffer.cleaner().clean(); + } + } + } + + public static ChunkedByteBuffer wrap(ByteBuffer chunk) { + return new ChunkedByteBuffer(chunk); + } + + public static ChunkedByteBuffer wrap(ByteBuffer[] chunks) { + return new ChunkedByteBuffer(chunks); + } + + public static ChunkedByteBuffer wrap(byte[] array) { + return wrap(array, 0, array.length); + } + + public static ChunkedByteBuffer wrap(byte[] array, int offset, int length) { + return new ChunkedByteBuffer(ByteBuffer.wrap(array, offset, length)); + } + + public static ChunkedByteBuffer allocate(int capacity) { + return new ChunkedByteBuffer(ByteBuffer.allocate(capacity)); + } + + public static ChunkedByteBuffer allocate(int capacity, Allocator allocator) { + return new ChunkedByteBuffer(allocator.allocate(capacity)); + } +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/buffer/ChunkedByteBufferInputStream.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/ChunkedByteBufferInputStream.java new file mode 100644 index 0000000000000..c13c9cc6b47a8 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/ChunkedByteBufferInputStream.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.buffer; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; + +import com.google.common.primitives.UnsignedBytes; + +public class ChunkedByteBufferInputStream extends InputStream { + + private ChunkedByteBuffer chunkedByteBuffer; + private boolean dispose; + private Iterator chunks; + private ByteBuffer currentChunk; + + /** + * Reads data from a ChunkedByteBuffer. + * + * @param dispose if true, [[ChunkedByteBuffer.dispose()]] will be called at the end of the stream + * in order to close any memory-mapped files which back the buffer. + */ + public ChunkedByteBufferInputStream(ChunkedByteBuffer chunkedByteBuffer, boolean dispose) { + this.chunkedByteBuffer = chunkedByteBuffer; + this.dispose = dispose; + this.chunks = Arrays.asList(chunkedByteBuffer.getChunks()).iterator(); + if (chunks.hasNext()) { + currentChunk = chunks.next(); + } else { + currentChunk = null; + } + } + + public int read() throws IOException { + if (currentChunk != null && !currentChunk.hasRemaining() && chunks.hasNext()) { + currentChunk = chunks.next(); + } + if (currentChunk != null && currentChunk.hasRemaining()) { + return UnsignedBytes.toInt(currentChunk.get()); + } else { + close(); + return -1; + } + } + + public int read(byte[] dest, int offset, int length) throws IOException { + if (currentChunk != null && !currentChunk.hasRemaining() && chunks.hasNext()) { + currentChunk = chunks.next(); + } + if (currentChunk != null && currentChunk.hasRemaining()) { + int amountToGet = Math.min(currentChunk.remaining(), length); + currentChunk.get(dest, offset, amountToGet); + return amountToGet; + } else { + close(); + return -1; + } + } + + public long skip(long bytes) throws IOException { + if (currentChunk != null) { + int amountToSkip = (int) Math.min(bytes, currentChunk.remaining()); + currentChunk.position(currentChunk.position() + amountToSkip); + if (currentChunk.remaining() == 0) { + if (chunks.hasNext()) { + currentChunk = chunks.next(); + } else { + close(); + } + } + return amountToSkip; + } else { + return 0L; + } + } + + public void close() throws IOException { + if (chunkedByteBuffer != null && dispose) { + chunkedByteBuffer.dispose(); + } + chunkedByteBuffer = null; + chunks = null; + currentChunk = null; + } + + public ChunkedByteBuffer toChunkedByteBuffer() { + ArrayList list = new ArrayList(); + if (currentChunk != null && !currentChunk.hasRemaining() && chunks.hasNext()) { + currentChunk = chunks.next(); + } + while (currentChunk != null) { + list.add(currentChunk.slice()); + if (chunks.hasNext()) { + currentChunk = chunks.next(); + } else { + currentChunk = null; + } + } + return ChunkedByteBuffer.wrap(list.toArray(new ByteBuffer[list.size()])); + } +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/buffer/ChunkedByteBufferOutputStream.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/ChunkedByteBufferOutputStream.java new file mode 100644 index 0000000000000..b2ada2b631be6 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/ChunkedByteBufferOutputStream.java @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.network.buffer; + +import java.io.IOException; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.util.ArrayList; + +import com.google.common.base.Preconditions; + +public class ChunkedByteBufferOutputStream extends OutputStream { + + private final int chunkSize; + private final Allocator allocator; + /** + * Next position to write in the last chunk. + * + * If this equals chunkSize, it means for next write we need to allocate a new chunk. + * This can also never be 0. + */ + private int position; + + private ArrayList chunks = new ArrayList(); + /** Index of the last chunk. Starting with -1 when the chunks array is empty. */ + private int lastChunkIndex = -1; + private boolean toChunkedByteBufferWasCalled = false; + private long _size = 0; + + /** + * An OutputStream that writes to fixed-size chunks of byte arrays. + * + * @param chunkSize size of each chunk, in bytes. + */ + public ChunkedByteBufferOutputStream(int chunkSize, Allocator allocator) { + this.chunkSize = chunkSize; + this.allocator = allocator; + this.position = chunkSize; + } + + public ChunkedByteBufferOutputStream(int chunkSize) { + this(chunkSize, new Allocator() { + public ByteBuffer allocate(int len) { + return ByteBuffer.allocate(len); + } + }); + } + + public long size() { + return _size; + } + + public void write(int b) throws IOException { + allocateNewChunkIfNeeded(); + chunks.get(lastChunkIndex).put((byte) b); + position += 1; + _size += 1; + } + + public void write(byte[] bytes, int off, int len) throws IOException { + int written = 0; + while (written < len) { + allocateNewChunkIfNeeded(); + int thisBatch = Math.min(chunkSize - position, len - written); + chunks.get(lastChunkIndex).put(bytes, written + off, thisBatch); + written += thisBatch; + position += thisBatch; + } + _size += len; + } + + private void allocateNewChunkIfNeeded() { + Preconditions.checkArgument(!toChunkedByteBufferWasCalled, + "cannot write after toChunkedByteBuffer() is called"); + if (position == chunkSize) { + chunks.add(allocator.allocate(chunkSize)); + lastChunkIndex += 1; + position = 0; + } + } + + public ChunkedByteBuffer toChunkedByteBuffer() { + Preconditions.checkArgument(!toChunkedByteBufferWasCalled, + "toChunkedByteBuffer() can only be called once"); + toChunkedByteBufferWasCalled = true; + if (lastChunkIndex == -1) { + return new ChunkedByteBuffer(new ByteBuffer[0]); + } else { + // Copy the first n-1 chunks to the output, and then create an array that fits the last chunk. + // An alternative would have been returning an array of ByteBuffers, with the last buffer + // bounded to only the last chunk's position. However, given our use case in Spark (to put + // the chunks in block manager), only limiting the view bound of the buffer would still + // require the block manager to store the whole chunk. + ByteBuffer[] ret = new ByteBuffer[chunks.size()]; + for (int i = 0; i < chunks.size() - 1; i++) { + ret[i] = chunks.get(i); + ret[i].flip(); + } + + if (position == chunkSize) { + ret[lastChunkIndex] = chunks.get(lastChunkIndex); + ret[lastChunkIndex].flip(); + } else { + ret[lastChunkIndex] = allocator.allocate(position); + chunks.get(lastChunkIndex).flip(); + ret[lastChunkIndex].put(chunks.get(lastChunkIndex)); + ret[lastChunkIndex].flip(); + ChunkedByteBuffer.dispose(chunks.get(lastChunkIndex)); + } + return new ChunkedByteBuffer(ret); + } + } +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java index c20fab83c3460..098a1be395d66 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java @@ -55,7 +55,7 @@ public long size() { } @Override - public ByteBuffer nioByteBuffer() throws IOException { + public ChunkedByteBuffer nioByteBuffer() throws IOException { FileChannel channel = null; try { channel = new RandomAccessFile(file, "r").getChannel(); @@ -71,9 +71,21 @@ public ByteBuffer nioByteBuffer() throws IOException { } } buf.flip(); - return buf; + return new ChunkedByteBuffer(buf); } else { - return channel.map(FileChannel.MapMode.READ_ONLY, offset, length); + int pageSize = 32 * 1024; + int numPage = (int) Math.ceil((double) length / pageSize); + ByteBuffer[] buffers = new ByteBuffer[numPage]; + long len = length; + long off = offset; + for (int i = 0; i < buffers.length; i++) { + long pageLen = Math.min(len, pageSize); + buffers[i] = channel.map(FileChannel.MapMode.READ_ONLY, off, pageLen); + len -= pageLen; + off += pageLen; + } + return new ChunkedByteBuffer(buffers); + // return channel.map(FileChannel.MapMode.READ_ONLY, offset, length); } } catch (IOException e) { try { diff --git a/common/network-common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java index 1861f8d7fd8f3..4a853f3cebe74 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java @@ -44,7 +44,7 @@ public abstract class ManagedBuffer { * returned ByteBuffer should not affect the content of this buffer. */ // TODO: Deprecate this, usage may require expensive memory mapping or allocation. - public abstract ByteBuffer nioByteBuffer() throws IOException; + public abstract ChunkedByteBuffer nioByteBuffer() throws IOException; /** * Exposes this buffer's data as an InputStream. The underlying implementation does not diff --git a/common/network-common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java index acc49d968c186..006031460ccbb 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java @@ -24,6 +24,7 @@ import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufInputStream; +import io.netty.buffer.CompositeByteBuf; /** * A {@link ManagedBuffer} backed by a Netty {@link ByteBuf}. @@ -41,8 +42,15 @@ public long size() { } @Override - public ByteBuffer nioByteBuffer() throws IOException { - return buf.nioBuffer(); + public ChunkedByteBuffer nioByteBuffer() throws IOException { + if (buf instanceof CompositeByteBuf) { + CompositeByteBuf compositeByteBuf = (CompositeByteBuf) buf; + ByteBuffer[] buffers = compositeByteBuf.nioBuffers(compositeByteBuf.readerIndex(), + compositeByteBuf.readableBytes()); + return new ChunkedByteBuffer(buffers); + } else { + return new ChunkedByteBuffer(buf.nioBuffer()); + } } @Override diff --git a/common/network-common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java index 631d767715256..31693a1deba2c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java @@ -29,25 +29,29 @@ * A {@link ManagedBuffer} backed by {@link ByteBuffer}. */ public class NioManagedBuffer extends ManagedBuffer { - private final ByteBuffer buf; + private final ChunkedByteBuffer buf; - public NioManagedBuffer(ByteBuffer buf) { + public NioManagedBuffer(ChunkedByteBuffer buf) { this.buf = buf; } + public NioManagedBuffer(ByteBuffer buf) { + this(new ChunkedByteBuffer(buf)); + } + @Override public long size() { - return buf.remaining(); + return buf.size(); } @Override - public ByteBuffer nioByteBuffer() throws IOException { - return buf.duplicate(); + public ChunkedByteBuffer nioByteBuffer() throws IOException { + return new ChunkedByteBuffer(buf.getChunks()); } @Override public InputStream createInputStream() throws IOException { - return new ByteBufInputStream(Unpooled.wrappedBuffer(buf)); + return buf.toInputStream(); } @Override @@ -62,7 +66,7 @@ public ManagedBuffer release() { @Override public Object convertToNetty() throws IOException { - return Unpooled.wrappedBuffer(buf); + return buf.toNetty(); } @Override diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java b/common/network-common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java index 6afc63f71bb3d..f3a649900d179 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java @@ -17,6 +17,8 @@ package org.apache.spark.network.client; +import org.apache.spark.network.buffer.ChunkedByteBuffer; + import java.nio.ByteBuffer; /** @@ -30,7 +32,7 @@ public interface RpcResponseCallback { * After `onSuccess` returns, `response` will be recycled and its content will become invalid. * Please copy the content of `response` if you want to use it after `onSuccess` returns. */ - void onSuccess(ByteBuffer response); + void onSuccess(ChunkedByteBuffer response); /** Exception either propagated from server or raised on client side. */ void onFailure(Throwable e); diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java index 64a83171e9e90..defa58a6dc257 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -34,6 +34,7 @@ import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; +import org.apache.spark.network.buffer.ChunkedByteBuffer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -214,7 +215,7 @@ public void operationComplete(ChannelFuture future) throws Exception { * @param callback Callback to handle the RPC's reply. * @return The RPC's id. */ - public long sendRpc(ByteBuffer message, final RpcResponseCallback callback) { + public long sendRpc(ChunkedByteBuffer message, final RpcResponseCallback callback) { final String serverAddr = NettyUtils.getRemoteAddress(channel); final long startTime = System.currentTimeMillis(); logger.trace("Sending RPC to {}", serverAddr); @@ -251,17 +252,12 @@ public void operationComplete(ChannelFuture future) throws Exception { * Synchronously sends an opaque message to the RpcHandler on the server-side, waiting for up to * a specified timeout for a response. */ - public ByteBuffer sendRpcSync(ByteBuffer message, long timeoutMs) { - final SettableFuture result = SettableFuture.create(); - + public ChunkedByteBuffer sendRpcSync(ChunkedByteBuffer message, long timeoutMs) { + final SettableFuture result = SettableFuture.create(); sendRpc(message, new RpcResponseCallback() { @Override - public void onSuccess(ByteBuffer response) { - ByteBuffer copy = ByteBuffer.allocate(response.remaining()); - copy.put(response); - // flip "copy" to make it readable - copy.flip(); - result.set(copy); + public void onSuccess(ChunkedByteBuffer response) { + result.set(response); } @Override @@ -285,14 +281,14 @@ public void onFailure(Throwable e) { * * @param message The message to send. */ - public void send(ByteBuffer message) { + public void send(ChunkedByteBuffer message) { channel.writeAndFlush(new OneWayMessage(new NioManagedBuffer(message))); } /** * Removes any state associated with the given RPC. * - * @param requestId The RPC id returned by {@link #sendRpc(ByteBuffer, RpcResponseCallback)}. + * @param requestId The RPC id returned by {@link #sendRpc(ChunkedByteBuffer, RpcResponseCallback)}. */ public void removeRpcRequest(long requestId) { handler.removeRpcRequest(requestId); diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java index 94c2ac9b20e43..9fa030e8e5e59 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java @@ -18,10 +18,13 @@ package org.apache.spark.network.protocol; import com.google.common.base.Objects; +import com.sun.corba.se.impl.ior.ByteBuffer; import io.netty.buffer.ByteBuf; +import org.apache.spark.network.buffer.ChunkedByteBuffer; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NettyManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; /** * Response to {@link ChunkFetchRequest} when a chunk exists and has been successfully fetched. diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java index 68381037d6891..0ad1e55618596 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java @@ -25,6 +25,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.Channel; +import org.apache.spark.network.buffer.ChunkedByteBuffer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -75,10 +76,11 @@ public void doBootstrap(TransportClient client, Channel channel) { SaslMessage msg = new SaslMessage(appId, payload); ByteBuf buf = Unpooled.buffer(msg.encodedLength() + (int) msg.body().size()); msg.encode(buf); - buf.writeBytes(msg.body().nioByteBuffer()); + buf.writeBytes(msg.body().nioByteBuffer().toArray()); - ByteBuffer response = client.sendRpcSync(buf.nioBuffer(), conf.saslRTTimeoutMs()); - payload = saslClient.response(JavaUtils.bufferToArray(response)); + ChunkedByteBuffer response = client.sendRpcSync(ChunkedByteBuffer.wrap(buf.nioBuffer()), + conf.saslRTTimeoutMs()); + payload = saslClient.response(response.toArray()); } client.setClientId(appId); diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java index c41f5b6873f6c..554d76803fb23 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -27,6 +27,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.spark.network.buffer.ChunkedByteBuffer; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.server.RpcHandler; @@ -74,14 +75,14 @@ class SaslRpcHandler extends RpcHandler { } @Override - public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) { + public void receive(TransportClient client, ChunkedByteBuffer message, RpcResponseCallback callback) { if (isComplete) { // Authentication complete, delegate to base handler. delegate.receive(client, message, callback); return; } - ByteBuf nettyBuf = Unpooled.wrappedBuffer(message); + ByteBuf nettyBuf = message.toNetty(); SaslMessage saslMessage; try { saslMessage = SaslMessage.decode(nettyBuf); @@ -98,12 +99,11 @@ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallb byte[] response; try { - response = saslServer.response(JavaUtils.bufferToArray( - saslMessage.body().nioByteBuffer())); + response = saslServer.response(saslMessage.body().nioByteBuffer().toArray()); } catch (IOException ioe) { throw new RuntimeException(ioe); } - callback.onSuccess(ByteBuffer.wrap(response)); + callback.onSuccess(ChunkedByteBuffer.wrap(response)); // Setup encryption after the SASL response is sent, otherwise the client can't parse the // response. It's ok to change the channel pipeline here since we are processing an incoming @@ -125,7 +125,7 @@ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallb } @Override - public void receive(TransportClient client, ByteBuffer message) { + public void receive(TransportClient client, ChunkedByteBuffer message) { delegate.receive(client, message); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java index 6ed61da5c7eff..7ab8019730c4d 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java @@ -19,6 +19,7 @@ import java.nio.ByteBuffer; +import org.apache.spark.network.buffer.ChunkedByteBuffer; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; @@ -31,7 +32,8 @@ public NoOpRpcHandler() { } @Override - public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) { + public void receive(TransportClient client, ChunkedByteBuffer message, + RpcResponseCallback callback) { throw new UnsupportedOperationException("Cannot handle messages"); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java index a99c3015b0e05..494f3056e2ad9 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -22,6 +22,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.spark.network.buffer.ChunkedByteBuffer; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; @@ -46,7 +47,7 @@ public abstract class RpcHandler { */ public abstract void receive( TransportClient client, - ByteBuffer message, + ChunkedByteBuffer message, RpcResponseCallback callback); /** @@ -57,14 +58,14 @@ public abstract void receive( /** * Receives an RPC message that does not expect a reply. The default implementation will - * call "{@link #receive(TransportClient, ByteBuffer, RpcResponseCallback)}" and log a warning if + * call "{@link #receive(TransportClient, ChunkedByteBuffer, RpcResponseCallback)}" and log a warning if * any of the callback methods are called. * * @param client A channel client which enables the handler to make requests back to the sender * of this RPC. This will always be the exact same object for a particular channel. * @param message The serialized bytes of the RPC. */ - public void receive(TransportClient client, ByteBuffer message) { + public void receive(TransportClient client, ChunkedByteBuffer message) { receive(client, message, ONE_WAY_CALLBACK); } @@ -86,7 +87,7 @@ private static class OneWayRpcCallback implements RpcResponseCallback { private final Logger logger = LoggerFactory.getLogger(OneWayRpcCallback.class); @Override - public void onSuccess(ByteBuffer response) { + public void onSuccess(ChunkedByteBuffer response) { logger.warn("Response provided for one-way RPC."); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index bebe88ec5d503..37eef44907800 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -26,6 +26,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.spark.network.buffer.ChunkedByteBuffer; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.client.RpcResponseCallback; @@ -157,7 +158,7 @@ private void processRpcRequest(final RpcRequest req) { try { rpcHandler.receive(reverseClient, req.body().nioByteBuffer(), new RpcResponseCallback() { @Override - public void onSuccess(ByteBuffer response) { + public void onSuccess(ChunkedByteBuffer response) { respond(new RpcResponse(req.requestId, new NioManagedBuffer(response))); } diff --git a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java index 6d62eaf35d8cc..aa4d7f3c4b532 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java @@ -32,6 +32,7 @@ import com.google.common.collect.Lists; import com.google.common.collect.Sets; import com.google.common.io.Closeables; +import org.apache.spark.network.buffer.ChunkedByteBuffer; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; @@ -107,7 +108,7 @@ public ManagedBuffer getChunk(long streamId, int chunkIndex) { @Override public void receive( TransportClient client, - ByteBuffer message, + ChunkedByteBuffer message, RpcResponseCallback callback) { throw new UnsupportedOperationException(); } @@ -230,8 +231,8 @@ private void assertBufferListsEqual(List list0, ListnewArrayList()); - TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), + server.getPort()); try { - client.sendRpcSync(ByteBuffer.allocate(13), TIMEOUT_MS); + client.sendRpcSync(ChunkedByteBuffer.wrap(ByteBuffer.allocate(13)), TIMEOUT_MS); fail("Should have failed"); } catch (Exception e) { assertTrue(e.getMessage(), e.getMessage().contains("Expected SaslMessage")); @@ -147,7 +150,8 @@ public void testNoSaslClient() throws IOException { try { // Guessing the right tag byte doesn't magically get you in... - client.sendRpcSync(ByteBuffer.wrap(new byte[] { (byte) 0xEA }), TIMEOUT_MS); + client.sendRpcSync(ChunkedByteBuffer.wrap(new byte[] { (byte) 0xEA }), + TIMEOUT_MS); fail("Should have failed"); } catch (Exception e) { assertTrue(e.getMessage(), e.getMessage().contains("java.lang.IndexOutOfBoundsException")); @@ -223,12 +227,13 @@ public void onBlockFetchFailure(String blockId, Throwable t) { new String[] { System.getProperty("java.io.tmpdir") }, 1, "org.apache.spark.shuffle.sort.SortShuffleManager"); RegisterExecutor regmsg = new RegisterExecutor("app-1", "0", executorInfo); - client1.sendRpcSync(regmsg.toByteBuffer(), TIMEOUT_MS); + client1.sendRpcSync(ChunkedByteBuffer.wrap(regmsg.toByteBuffer()), TIMEOUT_MS); // Make a successful request to fetch blocks, which creates a new stream. But do not actually // fetch any blocks, to keep the stream open. OpenBlocks openMessage = new OpenBlocks("app-1", "0", blockIds); - ByteBuffer response = client1.sendRpcSync(openMessage.toByteBuffer(), TIMEOUT_MS); + ChunkedByteBuffer response = client1.sendRpcSync(openMessage.toChunkedByteBuffer(), + TIMEOUT_MS); StreamHandle stream = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response); long streamId = stream.streamId; @@ -274,7 +279,7 @@ public void onFailure(int chunkIndex, Throwable t) { /** RPC handler which simply responds with the message it received. */ public static class TestRpcHandler extends RpcHandler { @Override - public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) { + public void receive(TransportClient client, ChunkedByteBuffer message, RpcResponseCallback callback) { callback.onSuccess(message); } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java index 86c8609e7070b..a1af34defc72f 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java @@ -36,7 +36,8 @@ public void serializeOpenShuffleBlocks() { } private void checkSerializeDeserialize(BlockTransferMessage msg) { - BlockTransferMessage msg2 = BlockTransferMessage.Decoder.fromByteBuffer(msg.toByteBuffer()); + BlockTransferMessage msg2 = BlockTransferMessage.Decoder. + fromByteBuffer(msg.toChunkedByteBuffer()); assertEquals(msg, msg2); assertEquals(msg.hashCode(), msg2.hashCode()); assertEquals(msg.toString(), msg2.toString()); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java index c036bc2e8d256..d9b48189066c7 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java @@ -30,6 +30,7 @@ import static org.mockito.Matchers.any; import static org.mockito.Mockito.*; +import org.apache.spark.network.buffer.ChunkedByteBuffer; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.client.RpcResponseCallback; @@ -62,11 +63,12 @@ public void testRegisterExecutor() { RpcResponseCallback callback = mock(RpcResponseCallback.class); ExecutorShuffleInfo config = new ExecutorShuffleInfo(new String[] {"/a", "/b"}, 16, "sort"); - ByteBuffer registerMessage = new RegisterExecutor("app0", "exec1", config).toByteBuffer(); + ChunkedByteBuffer registerMessage = new RegisterExecutor("app0", "exec1", config). + toChunkedByteBuffer(); handler.receive(client, registerMessage, callback); verify(blockResolver, times(1)).registerExecutor("app0", "exec1", config); - verify(callback, times(1)).onSuccess(any(ByteBuffer.class)); + verify(callback, times(1)).onSuccess(any(ChunkedByteBuffer.class)); verify(callback, never()).onFailure(any(Throwable.class)); // Verify register executor request latency metrics Timer registerExecutorRequestLatencyMillis = (Timer) ((ExternalShuffleBlockHandler) handler) @@ -85,13 +87,13 @@ public void testOpenShuffleBlocks() { ManagedBuffer block1Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[7])); when(blockResolver.getBlockData("app0", "exec1", "b0")).thenReturn(block0Marker); when(blockResolver.getBlockData("app0", "exec1", "b1")).thenReturn(block1Marker); - ByteBuffer openBlocks = new OpenBlocks("app0", "exec1", new String[] { "b0", "b1" }) - .toByteBuffer(); + ChunkedByteBuffer openBlocks = new OpenBlocks("app0", "exec1", new String[] { "b0", "b1" }) + .toChunkedByteBuffer(); handler.receive(client, openBlocks, callback); verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b0"); verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b1"); - ArgumentCaptor response = ArgumentCaptor.forClass(ByteBuffer.class); + ArgumentCaptor response = ArgumentCaptor.forClass(ChunkedByteBuffer.class); verify(callback, times(1)).onSuccess(response.capture()); verify(callback, never()).onFailure((Throwable) any()); @@ -126,7 +128,7 @@ public void testOpenShuffleBlocks() { public void testBadMessages() { RpcResponseCallback callback = mock(RpcResponseCallback.class); - ByteBuffer unserializableMsg = ByteBuffer.wrap(new byte[] { 0x12, 0x34, 0x56 }); + ChunkedByteBuffer unserializableMsg =ChunkedByteBuffer.wrap(new byte[] { 0x12, 0x34, 0x56 }); try { handler.receive(client, unserializableMsg, callback); fail("Should have thrown"); @@ -134,8 +136,8 @@ public void testBadMessages() { // pass } - ByteBuffer unexpectedMsg = new UploadBlock("a", "e", "b", new byte[1], - new byte[2]).toByteBuffer(); + ChunkedByteBuffer unexpectedMsg = new UploadBlock("a", "e", "b", new byte[1], + new byte[2]).toChunkedByteBuffer(); try { handler.receive(client, unexpectedMsg, callback); fail("Should have thrown"); @@ -143,7 +145,7 @@ public void testBadMessages() { // pass } - verify(callback, never()).onSuccess(any(ByteBuffer.class)); + verify(callback, never()).onSuccess(any(ChunkedByteBuffer.class)); verify(callback, never()).onFailure(any(Throwable.class)); } } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index 552b5366c5930..8dd77528790a5 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -256,8 +256,8 @@ private void assertBufferListsEqual(List list0, List list } private void assertBuffersEqual(ManagedBuffer buffer0, ManagedBuffer buffer1) throws Exception { - ByteBuffer nio0 = buffer0.nioByteBuffer(); - ByteBuffer nio1 = buffer1.nioByteBuffer(); + ByteBuffer nio0 = buffer0.nioByteBuffer().toByteBuffer(); + ByteBuffer nio1 = buffer1.nioByteBuffer().toByteBuffer(); int len = nio0.remaining(); assertEquals(nio0.remaining(), nio1.remaining()); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java index 2590b9ce4c1f1..71cf398f76428 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java @@ -39,6 +39,7 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import org.apache.spark.network.buffer.ChunkedByteBuffer; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NettyManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; @@ -135,13 +136,13 @@ private BlockFetchingListener fetchBlocks(final LinkedHashMap ByteBuffer serialize(T t, ClassTag ev1) { + public ChunkedByteBuffer serialize(T t, ClassTag ev1) { throw new UnsupportedOperationException(); } @@ -81,12 +82,12 @@ public DeserializationStream deserializeStream(InputStream s) { } @Override - public T deserialize(ByteBuffer bytes, ClassLoader loader, ClassTag ev1) { + public T deserialize(ChunkedByteBuffer bytes, ClassLoader loader, ClassTag ev1) { throw new UnsupportedOperationException(); } @Override - public T deserialize(ByteBuffer bytes, ClassTag ev1) { + public T deserialize(ChunkedByteBuffer bytes, ClassTag ev1) { throw new UnsupportedOperationException(); } } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index e8d6d587b4824..684d2d3966991 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -27,10 +27,10 @@ import scala.util.Random import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.io.CompressionCodec +import org.apache.spark.network.buffer.{Allocator, ChunkedByteBuffer, ChunkedByteBufferOutputStream} import org.apache.spark.serializer.Serializer import org.apache.spark.storage.{BlockId, BroadcastBlockId, StorageLevel} import org.apache.spark.util.{ByteBufferInputStream, Utils} -import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream} /** * A BitTorrent-like implementation of [[org.apache.spark.broadcast.Broadcast]]. @@ -107,7 +107,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec) blocks.zipWithIndex.foreach { case (block, i) => val pieceId = BroadcastBlockId(id, "piece" + i) - val bytes = new ChunkedByteBuffer(block.duplicate()) + val bytes = ChunkedByteBuffer.wrap(block.duplicate()) if (!blockManager.putBytes(pieceId, bytes, MEMORY_AND_DISK_SER, tellMaster = true)) { throw new SparkException(s"Failed to store $pieceId of $broadcastId in local BlockManager") } @@ -183,7 +183,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) case None => logInfo("Started reading broadcast variable " + id) val startTimeMs = System.currentTimeMillis() - val blocks = readBlocks().flatMap(_.getChunks()) + val blocks = readBlocks() logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs)) val obj = TorrentBroadcast.unBlockifyObject[T]( @@ -220,7 +220,6 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) } - private object TorrentBroadcast extends Logging { def blockifyObject[T: ClassTag]( @@ -228,7 +227,9 @@ private object TorrentBroadcast extends Logging { blockSize: Int, serializer: Serializer, compressionCodec: Option[CompressionCodec]): Array[ByteBuffer] = { - val cbbos = new ChunkedByteBufferOutputStream(blockSize, ByteBuffer.allocate) + val cbbos = new ChunkedByteBufferOutputStream(blockSize, new Allocator { + override def allocate(len: Int) = ByteBuffer.allocate(len) + }) val out = compressionCodec.map(c => c.compressedOutputStream(cbbos)).getOrElse(cbbos) val ser = serializer.newInstance() val serOut = ser.serializeStream(out) @@ -241,12 +242,11 @@ private object TorrentBroadcast extends Logging { } def unBlockifyObject[T: ClassTag]( - blocks: Array[ByteBuffer], + blocks: Array[ChunkedByteBuffer], serializer: Serializer, compressionCodec: Option[CompressionCodec]): T = { require(blocks.nonEmpty, "Cannot unblockify an empty array of blocks") - val is = new SequenceInputStream( - blocks.iterator.map(new ByteBufferInputStream(_)).asJavaEnumeration) + val is = ChunkedByteBuffer.wrap(blocks.flatMap(_.getChunks)).toInputStream val in: InputStream = compressionCodec.map(c => c.compressedInputStream(is)).getOrElse(is) val ser = serializer.newInstance() val serIn = ser.deserializeStream(in) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala index af850e4871e57..365afb59444c7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala @@ -28,6 +28,7 @@ import org.apache.zookeeper.CreateMode import org.apache.spark.SparkConf import org.apache.spark.deploy.SparkCuratorUtil import org.apache.spark.internal.Logging +import org.apache.spark.network.buffer.ChunkedByteBuffer import org.apache.spark.serializer.Serializer @@ -51,7 +52,7 @@ private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serializer override def read[T: ClassTag](prefix: String): Seq[T] = { zk.getChildren.forPath(WORKING_DIR).asScala - .filter(_.startsWith(prefix)).flatMap(deserializeFromFile[T]) + .filter(_.startsWith(prefix)).flatMap(t => deserializeFromFile[T](t)) } override def close() { @@ -59,7 +60,7 @@ private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serializer } private def serializeIntoFile(path: String, value: AnyRef) { - val serialized = serializer.newInstance().serialize(value) + val serialized = serializer.newInstance().serialize(value).toByteBuffer val bytes = new Array[Byte](serialized.remaining()) serialized.get(bytes) zk.create().withMode(CreateMode.PERSISTENT).forPath(path, bytes) @@ -68,7 +69,7 @@ private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serializer private def deserializeFromFile[T](filename: String)(implicit m: ClassTag[T]): Option[T] = { val fileData = zk.getData().forPath(WORKING_DIR + "/" + filename) try { - Some(serializer.newInstance().deserialize[T](ByteBuffer.wrap(fileData))) + Some(serializer.newInstance().deserialize[T](ChunkedByteBuffer.wrap(fileData))) } catch { case e: Exception => logWarning("Exception while reading persisted file, deleting", e) diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala index 6b297c4600a68..e61844bcb09ed 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala @@ -25,6 +25,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.ExternalShuffleService import org.apache.spark.internal.Logging +import org.apache.spark.network.buffer.ChunkedByteBuffer import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler import org.apache.spark.network.shuffle.protocol.BlockTransferMessage @@ -62,7 +63,7 @@ private[mesos] class MesosExternalShuffleBlockHandler( s"registered") } connectedApps.put(appId, appState) - callback.onSuccess(ByteBuffer.allocate(0)) + callback.onSuccess(ChunkedByteBuffer.wrap(ByteBuffer.allocate(0))) case Heartbeat(appId) => val address = client.getSocketAddress Option(connectedApps.get(appId)) match { diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 391b97d73e026..bb31f6705eb41 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -30,6 +30,7 @@ import org.apache.spark.TaskState.TaskState import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.worker.WorkerWatcher import org.apache.spark.internal.Logging +import org.apache.spark.network.buffer.ChunkedByteBuffer import org.apache.spark.rpc._ import org.apache.spark.scheduler.TaskDescription import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ @@ -92,7 +93,7 @@ private[spark] class CoarseGrainedExecutorBackend( if (executor == null) { exitExecutor(1, "Received LaunchTask command but executor was null") } else { - val taskDesc = ser.deserialize[TaskDescription](data.value) + val taskDesc = ser.deserialize[TaskDescription](data) logInfo("Got assigned task " + taskDesc.taskId) executor.launchTask(this, taskId = taskDesc.taskId, attemptNumber = taskDesc.attemptNumber, taskDesc.name, taskDesc.serializedTask) @@ -135,7 +136,7 @@ private[spark] class CoarseGrainedExecutorBackend( } } - override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) { + override def statusUpdate(taskId: Long, state: TaskState, data: ChunkedByteBuffer) { val msg = StatusUpdate(executorId, taskId, state, data) driver match { case Some(driverRef) => driverRef.send(msg) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index fbf2b86db1a2e..61a11e9c73947 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -33,12 +33,12 @@ import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.memory.TaskMemoryManager +import org.apache.spark.network.buffer.ChunkedByteBuffer import org.apache.spark.rpc.RpcTimeout import org.apache.spark.scheduler.{AccumulableInfo, DirectTaskResult, IndirectTaskResult, Task} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{StorageLevel, TaskResultBlockId} import org.apache.spark.util._ -import org.apache.spark.util.io.ChunkedByteBuffer /** * Spark executor, backed by a threadpool to run tasks. @@ -62,7 +62,7 @@ private[spark] class Executor( private val currentFiles: HashMap[String, Long] = new HashMap[String, Long]() private val currentJars: HashMap[String, Long] = new HashMap[String, Long]() - private val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0)) + private val EMPTY_BYTE_BUFFER = ChunkedByteBuffer.wrap(new Array[Byte](0)) private val conf = env.conf @@ -140,7 +140,7 @@ private[spark] class Executor( taskId: Long, attemptNumber: Int, taskName: String, - serializedTask: ByteBuffer): Unit = { + serializedTask: ChunkedByteBuffer): Unit = { val tr = new TaskRunner(context, taskId = taskId, attemptNumber = attemptNumber, taskName, serializedTask) runningTasks.put(taskId, tr) @@ -189,7 +189,7 @@ private[spark] class Executor( val taskId: Long, val attemptNumber: Int, taskName: String, - serializedTask: ByteBuffer) + serializedTask: ChunkedByteBuffer) extends Runnable { /** Whether this task has been killed. */ @@ -327,20 +327,21 @@ private[spark] class Executor( // TODO: do not serialize value twice val directResult = new DirectTaskResult(valueBytes, accumUpdates) val serializedDirectResult = ser.serialize(directResult) - val resultSize = serializedDirectResult.limit + val resultSize = serializedDirectResult.size().toInt // directSend = sending directly back to the driver - val serializedResult: ByteBuffer = { + val serializedResult: ChunkedByteBuffer = { if (maxResultSize > 0 && resultSize > maxResultSize) { logWarning(s"Finished $taskName (TID $taskId). Result is larger than maxResultSize " + s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " + s"dropping it.") - ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize)) + ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), + resultSize)) } else if (resultSize > maxDirectResultSize) { val blockId = TaskResultBlockId(taskId) env.blockManager.putBytes( blockId, - new ChunkedByteBuffer(serializedDirectResult.duplicate()), + serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER) logInfo( s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)") diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorBackend.scala index 7153323d01a0b..a29bc1b3f6068 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorBackend.scala @@ -20,11 +20,12 @@ package org.apache.spark.executor import java.nio.ByteBuffer import org.apache.spark.TaskState.TaskState +import org.apache.spark.network.buffer.ChunkedByteBuffer /** * A pluggable interface used by the Executor to send updates to the cluster scheduler. */ private[spark] trait ExecutorBackend { - def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer): Unit + def statusUpdate(taskId: Long, state: TaskState, data: ChunkedByteBuffer): Unit } diff --git a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala index 680cfb733e9e6..bc95d64579bc5 100644 --- a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala @@ -29,6 +29,7 @@ import org.apache.spark.{SparkConf, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging +import org.apache.spark.network.buffer.ChunkedByteBuffer import org.apache.spark.scheduler.cluster.mesos.MesosTaskLaunchData import org.apache.spark.util.Utils @@ -40,12 +41,12 @@ private[spark] class MesosExecutorBackend var executor: Executor = null var driver: ExecutorDriver = null - override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) { + override def statusUpdate(taskId: Long, state: TaskState, data: ChunkedByteBuffer) { val mesosTaskId = TaskID.newBuilder().setValue(taskId.toString).build() driver.sendStatusUpdate(MesosTaskStatus.newBuilder() .setTaskId(mesosTaskId) .setState(TaskState.toMesos(state)) - .setData(ByteString.copyFrom(data)) + .setData(ByteString.copyFrom(data.toByteBuffer)) .build()) } @@ -90,7 +91,7 @@ private[spark] class MesosExecutorBackend } else { SparkHadoopUtil.get.runAsSparkUser { () => executor.launchTask(this, taskId = taskId, attemptNumber = taskData.attemptNumber, - taskInfo.getName, taskData.serializedTask) + taskInfo.getName, ChunkedByteBuffer.wrap(taskData.serializedTask)) } } } diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index cb9d389dd7ea6..4f2bd16f2e941 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -25,10 +25,11 @@ import scala.concurrent.duration.Duration import scala.reflect.ClassTag import org.apache.spark.internal.Logging -import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.network.buffer.{ChunkedByteBufferOutputStream, ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} import org.apache.spark.storage.{BlockId, StorageLevel} -import org.apache.spark.util.ThreadUtils +import org.apache.spark.SparkException +import org.apache.spark.util.{ThreadUtils, Utils} private[spark] abstract class BlockTransferService extends ShuffleClient with Closeable with Logging { @@ -95,13 +96,22 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo result.failure(exception) } override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { - val ret = ByteBuffer.allocate(data.size.toInt) - ret.put(data.nioByteBuffer()) - ret.flip() - result.success(new NioManagedBuffer(ret)) + result.success(data.retain()) } }) - ThreadUtils.awaitResult(result.future, Duration.Inf) + val data = ThreadUtils.awaitResult(result.future, Duration.Inf) + val dataSize = data.size() + val chunkSize = math.min(data.size(), 32 * 1024).toInt + val out = new ChunkedByteBufferOutputStream(chunkSize) + try { + Utils.copyStream(data.createInputStream(), out, closeStreams = true) + if (out.size() != dataSize) { + throw new SparkException(s"buffer size ${out.size()} but expected $dataSize") + } + } finally { + data.release() + } + new NioManagedBuffer(out.toChunkedByteBuffer) } /** diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index 2ed8a00df7023..52fb7e1051bd5 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -25,7 +25,7 @@ import scala.reflect.ClassTag import org.apache.spark.internal.Logging import org.apache.spark.network.BlockDataManager -import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.network.buffer.{ChunkedByteBuffer, ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} import org.apache.spark.network.server.{OneForOneStreamManager, RpcHandler, StreamManager} import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, OpenBlocks, StreamHandle, UploadBlock} @@ -49,7 +49,7 @@ class NettyBlockRpcServer( override def receive( client: TransportClient, - rpcMessage: ByteBuffer, + rpcMessage: ChunkedByteBuffer, responseContext: RpcResponseCallback): Unit = { val message = BlockTransferMessage.Decoder.fromByteBuffer(rpcMessage) logTrace(s"Received request: $message") @@ -60,20 +60,20 @@ class NettyBlockRpcServer( openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData) val streamId = streamManager.registerStream(appId, blocks.iterator.asJava) logTrace(s"Registered streamId $streamId with ${blocks.size} buffers") - responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteBuffer) + responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toChunkedByteBuffer) case uploadBlock: UploadBlock => // StorageLevel and ClassTag are serialized as bytes using our JavaSerializer. val (level: StorageLevel, classTag: ClassTag[_]) = { serializer .newInstance() - .deserialize(ByteBuffer.wrap(uploadBlock.metadata)) + .deserialize(ChunkedByteBuffer.wrap(uploadBlock.metadata)) .asInstanceOf[(StorageLevel, ClassTag[_])] } val data = new NioManagedBuffer(ByteBuffer.wrap(uploadBlock.blockData)) val blockId = BlockId(uploadBlock.blockId) blockManager.putBlockData(blockId, data, level, classTag) - responseContext.onSuccess(ByteBuffer.allocate(0)) + responseContext.onSuccess(ChunkedByteBuffer.allocate(0)) } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 33a3219607749..1e30105c2915d 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -25,7 +25,7 @@ import scala.reflect.ClassTag import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.network._ -import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.network.buffer.{ChunkedByteBuffer, ManagedBuffer} import org.apache.spark.network.client.{RpcResponseCallback, TransportClientBootstrap, TransportClientFactory} import org.apache.spark.network.sasl.{SaslClientBootstrap, SaslServerBootstrap} import org.apache.spark.network.server._ @@ -128,14 +128,15 @@ private[spark] class NettyBlockTransferService( // StorageLevel and ClassTag are serialized as bytes using our JavaSerializer. // Everything else is encoded using our binary protocol. - val metadata = JavaUtils.bufferToArray(serializer.newInstance().serialize((level, classTag))) + val metadata = serializer.newInstance().serialize((level, classTag)).toArray // Convert or copy nio buffer into array in order to serialize it. - val array = JavaUtils.bufferToArray(blockData.nioByteBuffer()) + val array = blockData.nioByteBuffer().toArray - client.sendRpc(new UploadBlock(appId, execId, blockId.toString, metadata, array).toByteBuffer, + client.sendRpc(new UploadBlock(appId, execId, blockId.toString, metadata, array). + toChunkedByteBuffer, new RpcResponseCallback { - override def onSuccess(response: ByteBuffer): Unit = { + override def onSuccess(response: ChunkedByteBuffer): Unit = { logTrace(s"Successfully uploaded block $blockId") result.success((): Unit) } diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 7d6a8805bc016..cec978d4cb44b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -42,6 +42,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.OutputMetrics import org.apache.spark.internal.Logging +import org.apache.spark.network.buffer.ChunkedByteBuffer import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.serializer.Serializer import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -162,12 +163,10 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) def aggregateByKey[U: ClassTag](zeroValue: U, partitioner: Partitioner)(seqOp: (U, V) => U, combOp: (U, U) => U): RDD[(K, U)] = self.withScope { // Serialize the zero value to a byte array so that we can get a new clone of it on each key - val zeroBuffer = SparkEnv.get.serializer.newInstance().serialize(zeroValue) - val zeroArray = new Array[Byte](zeroBuffer.limit) - zeroBuffer.get(zeroArray) + val zeroArray = SparkEnv.get.serializer.newInstance().serialize(zeroValue).toArray lazy val cachedSerializer = SparkEnv.get.serializer.newInstance() - val createZero = () => cachedSerializer.deserialize[U](ByteBuffer.wrap(zeroArray)) + val createZero = () => cachedSerializer.deserialize[U](ChunkedByteBuffer.wrap(zeroArray)) // We will clean the combiner closure later in `combineByKey` val cleanedSeqOp = self.context.clean(seqOp) @@ -212,13 +211,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) zeroValue: V, partitioner: Partitioner)(func: (V, V) => V): RDD[(K, V)] = self.withScope { // Serialize the zero value to a byte array so that we can get a new clone of it on each key - val zeroBuffer = SparkEnv.get.serializer.newInstance().serialize(zeroValue) - val zeroArray = new Array[Byte](zeroBuffer.limit) - zeroBuffer.get(zeroArray) + val zeroArray = SparkEnv.get.serializer.newInstance().serialize(zeroValue).toArray // When deserializing, use a lazy val to create just one instance of the serializer per task lazy val cachedSerializer = SparkEnv.get.serializer.newInstance() - val createZero = () => cachedSerializer.deserialize[V](ByteBuffer.wrap(zeroArray)) + val createZero = () => cachedSerializer.deserialize[V](ChunkedByteBuffer.wrap(zeroArray)) val cleanedFunc = self.context.clean(func) combineByKeyWithClassTag[V]((v: V) => cleanedFunc(createZero(), v), diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 89d2fb9b47971..2fae168902e44 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -32,6 +32,7 @@ import scala.util.control.NonFatal import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.internal.Logging import org.apache.spark.network.TransportContext +import org.apache.spark.network.buffer.ChunkedByteBuffer import org.apache.spark.network.client._ import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.sasl.{SaslClientBootstrap, SaslServerBootstrap} @@ -249,11 +250,12 @@ private[netty] class NettyRpcEnv( promise.future.mapTo[T].recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread) } - private[netty] def serialize(content: Any): ByteBuffer = { + private[netty] def serialize(content: Any): ChunkedByteBuffer = { javaSerializerInstance.serialize(content) } - private[netty] def deserialize[T: ClassTag](client: TransportClient, bytes: ByteBuffer): T = { + private[netty] def deserialize[T: ClassTag](client: TransportClient, + bytes: ChunkedByteBuffer): T = { NettyRpcEnv.currentClient.withValue(client) { deserialize { () => javaSerializerInstance.deserialize[T](bytes) @@ -558,7 +560,7 @@ private[netty] class NettyRpcHandler( override def receive( client: TransportClient, - message: ByteBuffer, + message: ChunkedByteBuffer, callback: RpcResponseCallback): Unit = { val messageToDispatch = internalReceive(client, message) dispatcher.postRemoteMessage(messageToDispatch, callback) @@ -566,12 +568,13 @@ private[netty] class NettyRpcHandler( override def receive( client: TransportClient, - message: ByteBuffer): Unit = { + message: ChunkedByteBuffer): Unit = { val messageToDispatch = internalReceive(client, message) dispatcher.postOneWayMessage(messageToDispatch) } - private def internalReceive(client: TransportClient, message: ByteBuffer): RequestMessage = { + private def internalReceive(client: TransportClient, + message: ChunkedByteBuffer): RequestMessage = { val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] assert(addr != null) val clientAddr = RpcAddress(addr.getHostString, addr.getPort) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala index 6c090ada5ae9d..35610e3536665 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala @@ -25,6 +25,7 @@ import scala.util.control.NonFatal import org.apache.spark.SparkException import org.apache.spark.internal.Logging +import org.apache.spark.network.buffer.ChunkedByteBuffer import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} import org.apache.spark.rpc.{RpcAddress, RpcEnvStoppedException} @@ -36,7 +37,7 @@ private[netty] sealed trait OutboxMessage { } -private[netty] case class OneWayOutboxMessage(content: ByteBuffer) extends OutboxMessage +private[netty] case class OneWayOutboxMessage(content: ChunkedByteBuffer) extends OutboxMessage with Logging { override def sendWith(client: TransportClient): Unit = { @@ -53,9 +54,9 @@ private[netty] case class OneWayOutboxMessage(content: ByteBuffer) extends Outbo } private[netty] case class RpcOutboxMessage( - content: ByteBuffer, + content: ChunkedByteBuffer, _onFailure: (Throwable) => Unit, - _onSuccess: (TransportClient, ByteBuffer) => Unit) + _onSuccess: (TransportClient, ChunkedByteBuffer) => Unit) extends OutboxMessage with RpcResponseCallback { private var client: TransportClient = _ @@ -75,7 +76,7 @@ private[netty] case class RpcOutboxMessage( _onFailure(e) } - override def onSuccess(response: ByteBuffer): Unit = { + override def onSuccess(response: ChunkedByteBuffer): Unit = { _onSuccess(client, response) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 4eb7c81f9e8cc..667ebc7b4a857 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -987,10 +987,9 @@ class DAGScheduler( // For ResultTask, serialize and broadcast (rdd, func). val taskBinaryBytes: Array[Byte] = stage match { case stage: ShuffleMapStage => - JavaUtils.bufferToArray( - closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef)) + closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef).toArray case stage: ResultStage => - JavaUtils.bufferToArray(closureSerializer.serialize((stage.rdd, stage.func): AnyRef)) + closureSerializer.serialize((stage.rdd, stage.func): AnyRef).toArray } taskBinary = sc.broadcast(taskBinaryBytes) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index 75c6018e214d8..318723f5ddef7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -24,6 +24,7 @@ import java.util.Properties import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.executor.TaskMetrics +import org.apache.spark.network.buffer.ChunkedByteBuffer import org.apache.spark.rdd.RDD /** @@ -64,7 +65,7 @@ private[spark] class ResultTask[T, U]( val deserializeStartTime = System.currentTimeMillis() val ser = SparkEnv.get.closureSerializer.newInstance() val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)]( - ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader) + ChunkedByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader) _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime func(context, rdd.iterator(partition, context)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 84b3e5ba6c1f3..31690d4bbfdb7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -26,6 +26,7 @@ import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.executor.TaskMetrics import org.apache.spark.internal.Logging +import org.apache.spark.network.buffer.ChunkedByteBuffer import org.apache.spark.rdd.RDD import org.apache.spark.shuffle.ShuffleWriter @@ -69,7 +70,7 @@ private[spark] class ShuffleMapTask( val deserializeStartTime = System.currentTimeMillis() val ser = SparkEnv.get.closureSerializer.newInstance() val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])]( - ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader) + ChunkedByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader) _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime var writer: ShuffleWriter[Any, Any] = null diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 35c4dafe9c19c..93af467e016c0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -28,6 +28,7 @@ import org.apache.spark._ import org.apache.spark.executor.TaskMetrics import org.apache.spark.memory.{MemoryMode, TaskMemoryManager} import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.network.buffer.{ChunkedByteBuffer, ChunkedByteBufferOutputStream} import org.apache.spark.serializer.SerializerInstance import org.apache.spark.util.{AccumulatorV2, ByteBufferInputStream, ByteBufferOutputStream, Utils} @@ -202,9 +203,9 @@ private[spark] object Task { currentFiles: mutable.Map[String, Long], currentJars: mutable.Map[String, Long], serializer: SerializerInstance) - : ByteBuffer = { + : ChunkedByteBuffer = { - val out = new ByteBufferOutputStream(4096) + val out = new ChunkedByteBufferOutputStream(4 * 1024) val dataOut = new DataOutputStream(out) // Write currentFiles @@ -228,9 +229,9 @@ private[spark] object Task { // Write the task itself and finish dataOut.flush() - val taskBytes = serializer.serialize(task) + val taskBytes = serializer.serialize(task).toByteBuffer Utils.writeByteBuffer(taskBytes, out) - out.toByteBuffer + out.toChunkedByteBuffer } /** @@ -240,10 +241,10 @@ private[spark] object Task { * * @return (taskFiles, taskJars, taskBytes) */ - def deserializeWithDependencies(serializedTask: ByteBuffer) - : (HashMap[String, Long], HashMap[String, Long], Properties, ByteBuffer) = { + def deserializeWithDependencies(serializedTask: ChunkedByteBuffer) + : (HashMap[String, Long], HashMap[String, Long], Properties, ChunkedByteBuffer) = { - val in = new ByteBufferInputStream(serializedTask) + val in = serializedTask.toInputStream val dataIn = new DataInputStream(in) // Read task's files @@ -266,7 +267,7 @@ private[spark] object Task { val taskProps = Utils.deserialize[Properties](propBytes) // Create a sub-buffer for the rest of the data, which is the serialized Task object - val subBuffer = serializedTask.slice() // ByteBufferInputStream will have read just up to task + val subBuffer = in.toChunkedByteBuffer (taskFiles, taskJars, taskProps, subBuffer) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala index 1c7c81c488c3a..991cb9998e59e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala @@ -19,6 +19,7 @@ package org.apache.spark.scheduler import java.nio.ByteBuffer +import org.apache.spark.network.buffer.ChunkedByteBuffer import org.apache.spark.util.SerializableBuffer /** @@ -31,13 +32,8 @@ private[spark] class TaskDescription( val executorId: String, val name: String, val index: Int, // Index within this task's TaskSet - _serializedTask: ByteBuffer) + val serializedTask: ChunkedByteBuffer) extends Serializable { - // Because ByteBuffers are not serializable, wrap the task in a SerializableBuffer - private val buffer = new SerializableBuffer(_serializedTask) - - def serializedTask: ByteBuffer = buffer.value - override def toString: String = "TaskDescription(TID=%d, index=%d)".format(taskId, index) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala index 77fda6fcff959..d366ea260ea49 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala @@ -23,6 +23,7 @@ import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkEnv +import org.apache.spark.network.buffer.ChunkedByteBuffer import org.apache.spark.storage.BlockId import org.apache.spark.util.{AccumulatorV2, Utils} @@ -35,28 +36,24 @@ private[spark] case class IndirectTaskResult[T](blockId: BlockId, size: Int) /** A TaskResult that contains the task's return value and accumulator updates. */ private[spark] class DirectTaskResult[T]( - var valueBytes: ByteBuffer, + var valueBytes: ChunkedByteBuffer, var accumUpdates: Seq[AccumulatorV2[_, _]]) extends TaskResult[T] with Externalizable { private var valueObjectDeserialized = false private var valueObject: T = _ - def this() = this(null.asInstanceOf[ByteBuffer], null) + def this() = this(null.asInstanceOf[ChunkedByteBuffer], null) override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { - out.writeInt(valueBytes.remaining) - Utils.writeByteBuffer(valueBytes, out) + valueBytes.writeExternal(out) out.writeInt(accumUpdates.size) accumUpdates.foreach(out.writeObject) } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { - val blen = in.readInt() - val byteVal = new Array[Byte](blen) - in.readFully(byteVal) - valueBytes = ByteBuffer.wrap(byteVal) - + valueBytes = new ChunkedByteBuffer() + valueBytes.readExternal(in) val numUpdates = in.readInt if (numUpdates == 0) { accumUpdates = Seq() diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index 685ef55c66876..7c8041963ed8b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -26,6 +26,7 @@ import scala.util.control.NonFatal import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.internal.Logging +import org.apache.spark.network.buffer.ChunkedByteBuffer import org.apache.spark.serializer.SerializerInstance import org.apache.spark.util.{LongAccumulator, ThreadUtils, Utils} @@ -51,20 +52,20 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul def enqueueSuccessfulTask( taskSetManager: TaskSetManager, tid: Long, - serializedData: ByteBuffer): Unit = { + serializedData: ChunkedByteBuffer): Unit = { getTaskResultExecutor.execute(new Runnable { override def run(): Unit = Utils.logUncaughtExceptions { try { val (result, size) = serializer.get().deserialize[TaskResult[_]](serializedData) match { case directResult: DirectTaskResult[_] => - if (!taskSetManager.canFetchMoreResults(serializedData.limit())) { + if (!taskSetManager.canFetchMoreResults(serializedData.size())) { return } // deserialize "value" without holding any lock so that it won't block other threads. // We should call it here, so that when it's called again in // "TaskSetManager.handleSuccessfulTask", it does not need to deserialize the value. directResult.value() - (directResult, serializedData.limit()) + (directResult, serializedData.size()) case IndirectTaskResult(blockId, size) => if (!taskSetManager.canFetchMoreResults(size)) { // dropped by executor if size is larger than maxResultSize @@ -83,9 +84,9 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul return } val deserializedResult = serializer.get().deserialize[DirectTaskResult[_]]( - serializedTaskResult.get.toByteBuffer) + serializedTaskResult.get) sparkEnv.blockManager.master.removeBlock(blockId) - (deserializedResult, size) + (deserializedResult, size.toLong) } // Set the task result size in the accumulator updates received from the executors. @@ -95,7 +96,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul if (a.name == Some(InternalAccumulator.RESULT_SIZE)) { val acc = a.asInstanceOf[LongAccumulator] assert(acc.sum == 0L, "task result size should not have been set on the executors") - acc.setValue(size.toLong) + acc.setValue(size) acc } else { a @@ -117,16 +118,15 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul } def enqueueFailedTask(taskSetManager: TaskSetManager, tid: Long, taskState: TaskState, - serializedData: ByteBuffer) { + serializedData: ChunkedByteBuffer) { var reason : TaskEndReason = UnknownReason try { getTaskResultExecutor.execute(new Runnable { override def run(): Unit = Utils.logUncaughtExceptions { val loader = Utils.getContextOrSparkClassLoader try { - if (serializedData != null && serializedData.limit() > 0) { - reason = serializer.get().deserialize[TaskEndReason]( - serializedData, loader) + if (serializedData != null && serializedData.size() > 0) { + reason = serializer.get().deserialize[TaskEndReason](serializedData, loader) } } catch { case cnd: ClassNotFoundException => diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index dc05e764c3951..4cc24a0ec7e00 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -30,6 +30,7 @@ import scala.util.Random import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.internal.Logging +import org.apache.spark.network.buffer.ChunkedByteBuffer import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.scheduler.TaskLocality.TaskLocality import org.apache.spark.scheduler.local.LocalSchedulerBackend @@ -340,7 +341,7 @@ private[spark] class TaskSchedulerImpl( return tasks } - def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { + def statusUpdate(tid: Long, state: TaskState, serializedData: ChunkedByteBuffer) { var failedExecutor: Option[String] = None synchronized { try { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 2fef447b0a3c1..b6c3105e576c5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -32,6 +32,7 @@ import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.scheduler.SchedulingMode._ import org.apache.spark.TaskState.TaskState +import org.apache.spark.network.buffer.ChunkedByteBuffer import org.apache.spark.util.{AccumulatorV2, Clock, SystemClock, Utils} /** @@ -454,7 +455,7 @@ private[spark] class TaskSetManager( } // Serialize and return the task val startTime = clock.getTimeMillis() - val serializedTask: ByteBuffer = try { + val serializedTask: ChunkedByteBuffer = try { Task.serializeWithDependencies(task, sched.sc.addedFiles, sched.sc.addedJars, ser) } catch { // If the task cannot be serialized, then there's no point to re-attempt the task, @@ -465,11 +466,11 @@ private[spark] class TaskSetManager( abort(s"$msg Exception during serialization: $e") throw new TaskNotSerializableException(e) } - if (serializedTask.limit > TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024 && + if (serializedTask.size() > TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024 && !emittedTaskSizeWarning) { emittedTaskSizeWarning = true logWarning(s"Stage ${task.stageId} contains a task of very large size " + - s"(${serializedTask.limit / 1024} KB). The maximum recommended task size is " + + s"(${serializedTask.size() / 1024} KB). The maximum recommended task size is " + s"${TaskSetManager.TASK_SIZE_TO_WARN_KB} KB.") } addRunningTask(taskId) @@ -479,7 +480,7 @@ private[spark] class TaskSetManager( // val timeTaken = clock.getTime() - startTime val taskName = s"task ${info.id} in stage ${taskSet.id}" logInfo(s"Starting $taskName (TID $taskId, $host, partition ${task.partitionId}," + - s" $taskLocality, ${serializedTask.limit} bytes)") + s" $taskLocality, ${serializedTask.size()} bytes)") sched.dagScheduler.taskStarted(task, info) return Some(new TaskDescription(taskId = taskId, attemptNumber = attemptNum, execId, diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index edc8aac5d1515..0f6b436323647 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -20,6 +20,7 @@ package org.apache.spark.scheduler.cluster import java.nio.ByteBuffer import org.apache.spark.TaskState.TaskState +import org.apache.spark.network.buffer.ChunkedByteBuffer import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler.ExecutorLossReason import org.apache.spark.util.SerializableBuffer @@ -33,7 +34,7 @@ private[spark] object CoarseGrainedClusterMessages { case object RetrieveLastAllocatedExecutorId extends CoarseGrainedClusterMessage // Driver to executors - case class LaunchTask(data: SerializableBuffer) extends CoarseGrainedClusterMessage + case class LaunchTask(data: ChunkedByteBuffer) extends CoarseGrainedClusterMessage case class KillTask(taskId: Long, executor: String, interruptThread: Boolean) extends CoarseGrainedClusterMessage @@ -55,15 +56,7 @@ private[spark] object CoarseGrainedClusterMessages { extends CoarseGrainedClusterMessage case class StatusUpdate(executorId: String, taskId: Long, state: TaskState, - data: SerializableBuffer) extends CoarseGrainedClusterMessage - - object StatusUpdate { - /** Alternate factory method that takes a ByteBuffer directly for the data field */ - def apply(executorId: String, taskId: Long, state: TaskState, data: ByteBuffer) - : StatusUpdate = { - StatusUpdate(executorId, taskId, state, new SerializableBuffer(data)) - } - } + data: ChunkedByteBuffer) extends CoarseGrainedClusterMessage // Internal messages in driver case object ReviveOffers extends CoarseGrainedClusterMessage diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 8259923ce31c3..a1069cfbb48d4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -118,7 +118,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp override def receive: PartialFunction[Any, Unit] = { case StatusUpdate(executorId, taskId, state, data) => - scheduler.statusUpdate(taskId, state, data.value) + scheduler.statusUpdate(taskId, state, data) if (TaskState.isFinished(state)) { executorDataMap.get(executorId) match { case Some(executorInfo) => @@ -245,13 +245,13 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp private def launchTasks(tasks: Seq[Seq[TaskDescription]]) { for (task <- tasks.flatten) { val serializedTask = ser.serialize(task) - if (serializedTask.limit >= maxRpcMessageSize) { + if (serializedTask.size() >= maxRpcMessageSize) { scheduler.taskIdToTaskSetManager.get(task.taskId).foreach { taskSetMgr => try { var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " + "spark.rpc.message.maxSize (%d bytes). Consider increasing " + "spark.rpc.message.maxSize or using broadcast variables for large values." - msg = msg.format(task.taskId, task.index, serializedTask.limit, maxRpcMessageSize) + msg = msg.format(task.taskId, task.index, serializedTask.size(), maxRpcMessageSize) taskSetMgr.abort(msg) } catch { case e: Exception => logError("Exception in error callback", e) @@ -265,7 +265,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp logInfo(s"Launching task ${task.taskId} on executor id: ${task.executorId} hostname: " + s"${executorData.executorHost}.") - executorData.executorEndpoint.send(LaunchTask(new SerializableBuffer(serializedTask))) + executorData.executorEndpoint.send(LaunchTask(serializedTask)) } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala index f1e48fa7c52e1..eb25dd9422a4c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala @@ -28,6 +28,7 @@ import org.apache.mesos.protobuf.ByteString import org.apache.spark.{SparkContext, SparkException, TaskState} import org.apache.spark.executor.MesosExecutorBackend +import org.apache.spark.network.buffer.ChunkedByteBuffer import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.util.Utils @@ -358,7 +359,8 @@ private[spark] class MesosFineGrainedSchedulerBackend( .setExecutor(executorInfo) .setName(task.name) .addAllResources(cpuResources.asJava) - .setData(MesosTaskLaunchData(task.serializedTask, task.attemptNumber).toByteString) + .setData(MesosTaskLaunchData(task.serializedTask.toByteBuffer, + task.attemptNumber).toByteString) .build() (taskInfo, finalResources.asJava) } @@ -377,7 +379,8 @@ private[spark] class MesosFineGrainedSchedulerBackend( taskIdToSlaveId.remove(tid) } } - scheduler.statusUpdate(tid, state, status.getData.asReadOnlyByteBuffer) + scheduler.statusUpdate(tid, state, + ChunkedByteBuffer.wrap(status.getData.asReadOnlyByteBuffer)) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala index e386052814039..bf42ad0624028 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala @@ -26,13 +26,14 @@ import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.{Executor, ExecutorBackend} import org.apache.spark.internal.Logging import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle} +import org.apache.spark.network.buffer.ChunkedByteBuffer import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.ExecutorInfo private case class ReviveOffers() -private case class StatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) +private case class StatusUpdate(taskId: Long, state: TaskState, serializedData: ChunkedByteBuffer) private case class KillTask(taskId: Long, interruptThread: Boolean) @@ -148,7 +149,7 @@ private[spark] class LocalSchedulerBackend( localEndpoint.send(KillTask(taskId, interruptThread)) } - override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) { + override def statusUpdate(taskId: Long, state: TaskState, serializedData: ChunkedByteBuffer) { localEndpoint.send(StatusUpdate(taskId, state, serializedData)) } diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index 8b72da2ee01b7..0a227b5abebc1 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -24,6 +24,7 @@ import scala.reflect.ClassTag import org.apache.spark.SparkConf import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.network.buffer.{Allocator, ChunkedByteBuffer, ChunkedByteBufferOutputStream} import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Utils} private[spark] class JavaSerializationStream( @@ -94,22 +95,24 @@ private[spark] class JavaSerializerInstance( counterReset: Int, extraDebugInfo: Boolean, defaultClassLoader: ClassLoader) extends SerializerInstance { - override def serialize[T: ClassTag](t: T): ByteBuffer = { - val bos = new ByteBufferOutputStream() + override def serialize[T: ClassTag](t: T): ChunkedByteBuffer = { + val bos = new ChunkedByteBufferOutputStream(32 * 1024, new Allocator { + override def allocate(len: Int) = ByteBuffer.allocate(len) + }) val out = serializeStream(bos) out.writeObject(t) out.close() - bos.toByteBuffer + bos.toChunkedByteBuffer } - override def deserialize[T: ClassTag](bytes: ByteBuffer): T = { - val bis = new ByteBufferInputStream(bytes) + override def deserialize[T: ClassTag](bytes: ChunkedByteBuffer): T = { + val bis = bytes.toInputStream() val in = deserializeStream(bis) in.readObject() } - override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = { - val bis = new ByteBufferInputStream(bytes) + override def deserialize[T: ClassTag](bytes: ChunkedByteBuffer, loader: ClassLoader): T = { + val bis = bytes.toInputStream() val in = deserializeStream(bis, loader) in.readObject() } diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 1fba552f70501..a9e32733a6775 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -35,6 +35,7 @@ import org.roaringbitmap.RoaringBitmap import org.apache.spark._ import org.apache.spark.api.python.PythonBroadcast import org.apache.spark.internal.Logging +import org.apache.spark.network.buffer.{ChunkedByteBuffer, ChunkedByteBufferOutputStream} import org.apache.spark.network.util.ByteUnit import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus} import org.apache.spark.storage._ @@ -81,6 +82,8 @@ class KryoSerializer(conf: SparkConf) def newKryoOutput(): KryoOutput = new KryoOutput(bufferSize, math.max(bufferSize, maxBufferSize)) + def newKryoInput(): KryoInput = new KryoInput(bufferSize) + def newKryo(): Kryo = { val instantiator = new EmptyScalaKryoInstantiator val kryo = instantiator.newKryo() @@ -288,10 +291,12 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ // Make these lazy vals to avoid creating a buffer unless we use them. private lazy val output = ks.newKryoOutput() - private lazy val input = new KryoInput() + private lazy val input = ks.newKryoInput() - override def serialize[T: ClassTag](t: T): ByteBuffer = { + override def serialize[T: ClassTag](t: T): ChunkedByteBuffer = { output.clear() + val out = new ChunkedByteBufferOutputStream(32 * 1024) + output.setOutputStream(out) val kryo = borrowKryo() try { kryo.writeClassAndObject(output, t) @@ -300,29 +305,37 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ throw new SparkException(s"Kryo serialization failed: ${e.getMessage}. To avoid this, " + "increase spark.kryoserializer.buffer.max value.") } finally { + output.close() + output.setOutputStream(null) releaseKryo(kryo) } - ByteBuffer.wrap(output.toBytes) + out.toChunkedByteBuffer } - override def deserialize[T: ClassTag](bytes: ByteBuffer): T = { + override def deserialize[T: ClassTag](bytes: ChunkedByteBuffer): T = { val kryo = borrowKryo() try { - input.setBuffer(bytes.array(), bytes.arrayOffset() + bytes.position(), bytes.remaining()) + input.setInputStream(bytes.toInputStream()) + // input.setBuffer(bytes.array(), bytes.arrayOffset() + bytes.position(), bytes.remaining()) kryo.readClassAndObject(input).asInstanceOf[T] } finally { + input.close() + input.setInputStream(null) releaseKryo(kryo) } } - override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = { + override def deserialize[T: ClassTag](bytes: ChunkedByteBuffer, loader: ClassLoader): T = { val kryo = borrowKryo() val oldClassLoader = kryo.getClassLoader try { kryo.setClassLoader(loader) - input.setBuffer(bytes.array(), bytes.arrayOffset() + bytes.position(), bytes.remaining()) + input.setInputStream(bytes.toInputStream()) + // input.setBuffer(bytes.array(), bytes.arrayOffset() + bytes.position(), bytes.remaining()) kryo.readClassAndObject(input).asInstanceOf[T] } finally { + input.close() + input.setInputStream(null) kryo.setClassLoader(oldClassLoader) releaseKryo(kryo) } diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index cb95246d5b0ca..a2408d52ceef9 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -25,6 +25,7 @@ import scala.reflect.ClassTag import org.apache.spark.SparkEnv import org.apache.spark.annotation.{DeveloperApi, Private} +import org.apache.spark.network.buffer.ChunkedByteBuffer import org.apache.spark.util.NextIterator /** @@ -110,11 +111,11 @@ abstract class Serializer { @DeveloperApi @NotThreadSafe abstract class SerializerInstance { - def serialize[T: ClassTag](t: T): ByteBuffer + def serialize[T: ClassTag](t: T): ChunkedByteBuffer - def deserialize[T: ClassTag](bytes: ByteBuffer): T + def deserialize[T: ClassTag](bytes: ChunkedByteBuffer): T - def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T + def deserialize[T: ClassTag](bytes: ChunkedByteBuffer, loader: ClassLoader): T def serializeStream(s: OutputStream): SerializationStream diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala index 9dc274c9fe288..d375c2e91f668 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala @@ -24,8 +24,8 @@ import scala.reflect.ClassTag import org.apache.spark.SparkConf import org.apache.spark.io.CompressionCodec +import org.apache.spark.network.buffer.{Allocator, ChunkedByteBuffer, ChunkedByteBufferOutputStream} import org.apache.spark.storage._ -import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream} /** * Component which configures serialization and compression for various Spark components, including @@ -128,7 +128,9 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar /** Serializes into a chunked byte buffer. */ def dataSerialize[T: ClassTag](blockId: BlockId, values: Iterator[T]): ChunkedByteBuffer = { - val bbos = new ChunkedByteBufferOutputStream(1024 * 1024 * 4, ByteBuffer.allocate) + val bbos = new ChunkedByteBufferOutputStream(32 * 1024, new Allocator { + override def allocate(len: Int) = ByteBuffer.allocate(len) + }) dataSerializeStream(blockId, bbos, values) bbos.toChunkedByteBuffer } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 015e71d1260ea..ac156e3c47aa2 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -32,7 +32,7 @@ import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics} import org.apache.spark.internal.Logging import org.apache.spark.memory.{MemoryManager, MemoryMode} import org.apache.spark.network._ -import org.apache.spark.network.buffer.{ManagedBuffer, NettyManagedBuffer} +import org.apache.spark.network.buffer._ import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.ExternalShuffleClient import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo @@ -42,7 +42,6 @@ import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.storage.memory._ import org.apache.spark.unsafe.Platform import org.apache.spark.util._ -import org.apache.spark.util.io.ChunkedByteBuffer /* Class for returning a fetched block and associated metrics. */ private[spark] class BlockResult( @@ -293,7 +292,7 @@ private[spark] class BlockManager( data: ManagedBuffer, level: StorageLevel, classTag: ClassTag[_]): Boolean = { - putBytes(blockId, new ChunkedByteBuffer(data.nioByteBuffer()), level)(classTag) + putBytes(blockId, data.nioByteBuffer(), level)(classTag) } /** @@ -441,12 +440,12 @@ private[spark] class BlockManager( if (level.deserialized) { val diskValues = serializerManager.dataDeserializeStream( blockId, - diskBytes.toInputStream(dispose = true))(info.classTag) + diskBytes.toInputStream(true))(info.classTag) maybeCacheDiskValuesInMemory(info, blockId, level, diskValues) } else { val stream = maybeCacheDiskBytesInMemory(info, blockId, level, diskBytes) - .map {_.toInputStream(dispose = false)} - .getOrElse { diskBytes.toInputStream(dispose = true) } + .map {_.toInputStream(false)} + .getOrElse { diskBytes.toInputStream( true) } serializerManager.dataDeserializeStream(blockId, stream)(info.classTag) } } @@ -470,8 +469,7 @@ private[spark] class BlockManager( // TODO: This should gracefully handle case where local block is not available. Currently // downstream code will throw an exception. Option( - new ChunkedByteBuffer( - shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]).nioByteBuffer())) + shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]).nioByteBuffer()) } else { blockInfoManager.lockForReading(blockId).map { info => doGetLocalBytes(blockId, info) } } @@ -522,7 +520,7 @@ private[spark] class BlockManager( private def getRemoteValues(blockId: BlockId): Option[BlockResult] = { getRemoteBytes(blockId).map { data => val values = - serializerManager.dataDeserializeStream(blockId, data.toInputStream(dispose = true)) + serializerManager.dataDeserializeStream(blockId, data.toInputStream(true)) new BlockResult(values, DataReadMethod.Network, data.size) } } @@ -586,7 +584,7 @@ private[spark] class BlockManager( } if (data != null) { - return Some(new ChunkedByteBuffer(data)) + return Some(data) } logDebug(s"The value of block $blockId is null") } @@ -1019,7 +1017,11 @@ private[spark] class BlockManager( // If the file size is bigger than the free memory, OOM will happen. So if we // cannot put it into MemoryStore, copyForMemory should not be created. That's why // this action is put into a `() => ChunkedByteBuffer` and created lazily. - diskBytes.copy(allocator) + val out = new ChunkedByteBufferOutputStream(32 * 1024, new Allocator { + override def allocate(len: Int) = allocator(len) + }) + Utils.copyStream(diskBytes.toInputStream(), out, true) + out.toChunkedByteBuffer }) if (putSucceeded) { diskBytes.dispose() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala index f66f942798550..e49e8ac1093fa 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala @@ -17,8 +17,7 @@ package org.apache.spark.storage -import org.apache.spark.network.buffer.{ManagedBuffer, NettyManagedBuffer} -import org.apache.spark.util.io.ChunkedByteBuffer +import org.apache.spark.network.buffer.{ChunkedByteBuffer, ManagedBuffer, NettyManagedBuffer} /** * This [[ManagedBuffer]] wraps a [[ChunkedByteBuffer]] retrieved from the [[BlockManager]] diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index ca23e2391ed02..62f391791fa96 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -25,8 +25,8 @@ import com.google.common.io.Closeables import org.apache.spark.SparkConf import org.apache.spark.internal.Logging +import org.apache.spark.network.buffer.ChunkedByteBuffer import org.apache.spark.util.Utils -import org.apache.spark.util.io.ChunkedByteBuffer /** * Stores BlockManager blocks on disk. diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala index 586339a58d236..b3a09421a9284 100644 --- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala @@ -30,12 +30,12 @@ import com.google.common.io.ByteStreams import org.apache.spark.{SparkConf, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.memory.{MemoryManager, MemoryMode} +import org.apache.spark.network.buffer.{Allocator, ChunkedByteBuffer, ChunkedByteBufferOutputStream} import org.apache.spark.serializer.{SerializationStream, SerializerManager} import org.apache.spark.storage.{BlockId, BlockInfoManager, StorageLevel} import org.apache.spark.unsafe.Platform import org.apache.spark.util.{CompletionIterator, SizeEstimator, Utils} import org.apache.spark.util.collection.SizeTrackingVector -import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream} private sealed trait MemoryEntry[T] { def size: Long @@ -326,7 +326,9 @@ private[spark] class MemoryStore( var unrollMemoryUsedByThisBlock = 0L // Underlying buffer for unrolling the block val redirectableStream = new RedirectableOutputStream - val bbos = new ChunkedByteBufferOutputStream(initialMemoryThreshold.toInt, allocator) + val bbos = new ChunkedByteBufferOutputStream(initialMemoryThreshold.toInt, new Allocator { + override def allocate(len: Int) = allocator(len) + }) redirectableStream.setOutputStream(bbos) val serializationStream: SerializationStream = { val ser = serializerManager.getSerializer(classTag).newInstance() @@ -765,7 +767,7 @@ private[storage] class PartiallySerializedBlock[T]( */ def finishWritingToStream(os: OutputStream): Unit = { // `unrolled`'s underlying buffers will be freed once this input stream is fully read: - ByteStreams.copy(unrolled.toInputStream(dispose = true), os) + ByteStreams.copy(unrolled.toInputStream(true), os) memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory) redirectableOutputStream.setOutputStream(os) while (rest.hasNext) { @@ -784,7 +786,7 @@ private[storage] class PartiallySerializedBlock[T]( def valuesIterator: PartiallyUnrolledIterator[T] = { // `unrolled`'s underlying buffers will be freed once this input stream is fully read: val unrolledIter = serializerManager.dataDeserializeStream( - blockId, unrolled.toInputStream(dispose = true))(classTag) + blockId, unrolled.toInputStream(true))(classTag) new PartiallyUnrolledIterator( memoryStore, unrollMemory, diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala deleted file mode 100644 index 89b0874e3865a..0000000000000 --- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala +++ /dev/null @@ -1,219 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util.io - -import java.io.InputStream -import java.nio.ByteBuffer -import java.nio.channels.WritableByteChannel - -import com.google.common.primitives.UnsignedBytes -import io.netty.buffer.{ByteBuf, Unpooled} - -import org.apache.spark.network.util.ByteArrayWritableChannel -import org.apache.spark.storage.StorageUtils - -/** - * Read-only byte buffer which is physically stored as multiple chunks rather than a single - * contiguous array. - * - * @param chunks an array of [[ByteBuffer]]s. Each buffer in this array must have position == 0. - * Ownership of these buffers is transferred to the ChunkedByteBuffer, so if these - * buffers may also be used elsewhere then the caller is responsible for copying - * them as needed. - */ -private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { - require(chunks != null, "chunks must not be null") - require(chunks.forall(_.position() == 0), "chunks' positions must be 0") - - private[this] var disposed: Boolean = false - - /** - * This size of this buffer, in bytes. - */ - val size: Long = chunks.map(_.limit().asInstanceOf[Long]).sum - - def this(byteBuffer: ByteBuffer) = { - this(Array(byteBuffer)) - } - - /** - * Write this buffer to a channel. - */ - def writeFully(channel: WritableByteChannel): Unit = { - for (bytes <- getChunks()) { - while (bytes.remaining > 0) { - channel.write(bytes) - } - } - } - - /** - * Wrap this buffer to view it as a Netty ByteBuf. - */ - def toNetty: ByteBuf = { - Unpooled.wrappedBuffer(getChunks(): _*) - } - - /** - * Copy this buffer into a new byte array. - * - * @throws UnsupportedOperationException if this buffer's size exceeds the maximum array size. - */ - def toArray: Array[Byte] = { - if (size >= Integer.MAX_VALUE) { - throw new UnsupportedOperationException( - s"cannot call toArray because buffer size ($size bytes) exceeds maximum array size") - } - val byteChannel = new ByteArrayWritableChannel(size.toInt) - writeFully(byteChannel) - byteChannel.close() - byteChannel.getData - } - - /** - * Copy this buffer into a new ByteBuffer. - * - * @throws UnsupportedOperationException if this buffer's size exceeds the max ByteBuffer size. - */ - def toByteBuffer: ByteBuffer = { - if (chunks.length == 1) { - chunks.head.duplicate() - } else { - ByteBuffer.wrap(toArray) - } - } - - /** - * Creates an input stream to read data from this ChunkedByteBuffer. - * - * @param dispose if true, [[dispose()]] will be called at the end of the stream - * in order to close any memory-mapped files which back this buffer. - */ - def toInputStream(dispose: Boolean = false): InputStream = { - new ChunkedByteBufferInputStream(this, dispose) - } - - /** - * Get duplicates of the ByteBuffers backing this ChunkedByteBuffer. - */ - def getChunks(): Array[ByteBuffer] = { - chunks.map(_.duplicate()) - } - - /** - * Make a copy of this ChunkedByteBuffer, copying all of the backing data into new buffers. - * The new buffer will share no resources with the original buffer. - * - * @param allocator a method for allocating byte buffers - */ - def copy(allocator: Int => ByteBuffer): ChunkedByteBuffer = { - val copiedChunks = getChunks().map { chunk => - val newChunk = allocator(chunk.limit()) - newChunk.put(chunk) - newChunk.flip() - newChunk - } - new ChunkedByteBuffer(copiedChunks) - } - - /** - * Attempt to clean up a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that - * might cause errors if one attempts to read from the unmapped buffer, but it's better than - * waiting for the GC to find it because that could lead to huge numbers of open files. There's - * unfortunately no standard API to do this. - */ - def dispose(): Unit = { - if (!disposed) { - chunks.foreach(StorageUtils.dispose) - disposed = true - } - } -} - -/** - * Reads data from a ChunkedByteBuffer. - * - * @param dispose if true, [[ChunkedByteBuffer.dispose()]] will be called at the end of the stream - * in order to close any memory-mapped files which back the buffer. - */ -private class ChunkedByteBufferInputStream( - var chunkedByteBuffer: ChunkedByteBuffer, - dispose: Boolean) - extends InputStream { - - private[this] var chunks = chunkedByteBuffer.getChunks().iterator - private[this] var currentChunk: ByteBuffer = { - if (chunks.hasNext) { - chunks.next() - } else { - null - } - } - - override def read(): Int = { - if (currentChunk != null && !currentChunk.hasRemaining && chunks.hasNext) { - currentChunk = chunks.next() - } - if (currentChunk != null && currentChunk.hasRemaining) { - UnsignedBytes.toInt(currentChunk.get()) - } else { - close() - -1 - } - } - - override def read(dest: Array[Byte], offset: Int, length: Int): Int = { - if (currentChunk != null && !currentChunk.hasRemaining && chunks.hasNext) { - currentChunk = chunks.next() - } - if (currentChunk != null && currentChunk.hasRemaining) { - val amountToGet = math.min(currentChunk.remaining(), length) - currentChunk.get(dest, offset, amountToGet) - amountToGet - } else { - close() - -1 - } - } - - override def skip(bytes: Long): Long = { - if (currentChunk != null) { - val amountToSkip = math.min(bytes, currentChunk.remaining).toInt - currentChunk.position(currentChunk.position + amountToSkip) - if (currentChunk.remaining() == 0) { - if (chunks.hasNext) { - currentChunk = chunks.next() - } else { - close() - } - } - amountToSkip - } else { - 0L - } - } - - override def close(): Unit = { - if (chunkedByteBuffer != null && dispose) { - chunkedByteBuffer.dispose() - } - chunkedByteBuffer = null - chunks = null - currentChunk = null - } -} diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala deleted file mode 100644 index 67b50d1e70437..0000000000000 --- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util.io - -import java.io.OutputStream -import java.nio.ByteBuffer - -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.storage.StorageUtils - -/** - * An OutputStream that writes to fixed-size chunks of byte arrays. - * - * @param chunkSize size of each chunk, in bytes. - */ -private[spark] class ChunkedByteBufferOutputStream( - chunkSize: Int, - allocator: Int => ByteBuffer) - extends OutputStream { - - private[this] var toChunkedByteBufferWasCalled = false - - private val chunks = new ArrayBuffer[ByteBuffer] - - /** Index of the last chunk. Starting with -1 when the chunks array is empty. */ - private[this] var lastChunkIndex = -1 - - /** - * Next position to write in the last chunk. - * - * If this equals chunkSize, it means for next write we need to allocate a new chunk. - * This can also never be 0. - */ - private[this] var position = chunkSize - private[this] var _size = 0 - - def size: Long = _size - - override def write(b: Int): Unit = { - allocateNewChunkIfNeeded() - chunks(lastChunkIndex).put(b.toByte) - position += 1 - _size += 1 - } - - override def write(bytes: Array[Byte], off: Int, len: Int): Unit = { - var written = 0 - while (written < len) { - allocateNewChunkIfNeeded() - val thisBatch = math.min(chunkSize - position, len - written) - chunks(lastChunkIndex).put(bytes, written + off, thisBatch) - written += thisBatch - position += thisBatch - } - _size += len - } - - @inline - private def allocateNewChunkIfNeeded(): Unit = { - require(!toChunkedByteBufferWasCalled, "cannot write after toChunkedByteBuffer() is called") - if (position == chunkSize) { - chunks += allocator(chunkSize) - lastChunkIndex += 1 - position = 0 - } - } - - def toChunkedByteBuffer: ChunkedByteBuffer = { - require(!toChunkedByteBufferWasCalled, "toChunkedByteBuffer() can only be called once") - toChunkedByteBufferWasCalled = true - if (lastChunkIndex == -1) { - new ChunkedByteBuffer(Array.empty[ByteBuffer]) - } else { - // Copy the first n-1 chunks to the output, and then create an array that fits the last chunk. - // An alternative would have been returning an array of ByteBuffers, with the last buffer - // bounded to only the last chunk's position. However, given our use case in Spark (to put - // the chunks in block manager), only limiting the view bound of the buffer would still - // require the block manager to store the whole chunk. - val ret = new Array[ByteBuffer](chunks.size) - for (i <- 0 until chunks.size - 1) { - ret(i) = chunks(i) - ret(i).flip() - } - if (position == chunkSize) { - ret(lastChunkIndex) = chunks(lastChunkIndex) - ret(lastChunkIndex).flip() - } else { - ret(lastChunkIndex) = allocator(position) - chunks(lastChunkIndex).flip() - ret(lastChunkIndex).put(chunks(lastChunkIndex)) - ret(lastChunkIndex).flip() - StorageUtils.dispose(chunks(lastChunkIndex)) - } - new ChunkedByteBuffer(ret) - } - } -} diff --git a/core/src/test/java/org/apache/spark/serializer/TestJavaSerializerImpl.java b/core/src/test/java/org/apache/spark/serializer/TestJavaSerializerImpl.java index 8aa0636700991..9a13df4ecb006 100644 --- a/core/src/test/java/org/apache/spark/serializer/TestJavaSerializerImpl.java +++ b/core/src/test/java/org/apache/spark/serializer/TestJavaSerializerImpl.java @@ -21,6 +21,7 @@ import java.io.OutputStream; import java.nio.ByteBuffer; +import org.apache.spark.network.buffer.ChunkedByteBuffer; import scala.reflect.ClassTag; @@ -35,18 +36,19 @@ public SerializerInstance newInstance() { } static class SerializerInstanceImpl extends SerializerInstance { - @Override - public ByteBuffer serialize(T t, ClassTag evidence$1) { - return null; - } - @Override - public T deserialize(ByteBuffer bytes, ClassLoader loader, ClassTag evidence$1) { + @Override + public ChunkedByteBuffer serialize(T t, ClassTag evidence$1) { + return null; + } + + @Override + public T deserialize(ChunkedByteBuffer bytes, ClassLoader loader, ClassTag evidence$1) { return null; } @Override - public T deserialize(ByteBuffer bytes, ClassTag evidence$1) { + public T deserialize(ChunkedByteBuffer bytes, ClassTag evidence$1) { return null; } diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 6beae842b04d1..1d515ff6f6bb4 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -21,8 +21,8 @@ import org.scalatest.concurrent.Timeouts._ import org.scalatest.Matchers import org.scalatest.time.{Millis, Span} +import org.apache.spark.network.buffer.ChunkedByteBuffer import org.apache.spark.storage.{RDDBlockId, StorageLevel} -import org.apache.spark.util.io.ChunkedByteBuffer class NotSerializableClass class NotSerializableExn(val notSer: NotSerializableClass) extends Throwable() {} @@ -216,7 +216,7 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex val bytes = blockTransfer.fetchBlockSync(cmId.host, cmId.port, cmId.executorId, blockId.toString) val deserialized = serializerManager.dataDeserializeStream[Int](blockId, - new ChunkedByteBuffer(bytes.nioByteBuffer()).toInputStream()).toList + bytes.nioByteBuffer().toInputStream()).toList assert(deserialized === (1 to 100).toList) } } diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index 973676398ae54..32e93b785ca81 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -23,6 +23,7 @@ import org.scalatest.Assertions import org.apache.spark._ import org.apache.spark.io.SnappyCompressionCodec +import org.apache.spark.network.buffer.ChunkedByteBuffer import org.apache.spark.rdd.RDD import org.apache.spark.serializer.JavaSerializer import org.apache.spark.storage._ @@ -85,7 +86,9 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { val size = 1 + rand.nextInt(1024 * 10) val data: Array[Byte] = new Array[Byte](size) rand.nextBytes(data) - val blocks = blockifyObject(data, blockSize, serializer, compressionCodec) + val blocks = blockifyObject(data, blockSize, serializer, compressionCodec).map { block => + ChunkedByteBuffer.wrap(block) + } val unblockified = unBlockifyObject[Array[Byte]](blocks, serializer, compressionCodec) assert(unblockified === data) } diff --git a/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala b/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala index 4b86da536768c..ab5496ac2aac0 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala @@ -26,6 +26,7 @@ import scala.reflect.ClassTag import org.apache.spark.SparkConf import org.apache.spark.deploy.master._ +import org.apache.spark.network.buffer.ChunkedByteBuffer import org.apache.spark.serializer.Serializer class CustomRecoveryModeFactory( @@ -65,7 +66,7 @@ class CustomPersistenceEngine(serializer: Serializer) extends PersistenceEngine */ override def persist(name: String, obj: Object): Unit = { CustomPersistenceEngine.persistAttempts += 1 - val serialized = serializer.newInstance().serialize(obj) + val serialized = serializer.newInstance().serialize(obj).toByteBuffer val bytes = new Array[Byte](serialized.remaining()) serialized.get(bytes) data += name -> bytes @@ -86,7 +87,7 @@ class CustomPersistenceEngine(serializer: Serializer) extends PersistenceEngine override def read[T: ClassTag](prefix: String): Seq[T] = { CustomPersistenceEngine.readAttempts += 1 val results = for ((name, bytes) <- data; if name.startsWith(prefix)) - yield serializer.newInstance().deserialize[T](ByteBuffer.wrap(bytes)) + yield serializer.newInstance().deserialize[T](ChunkedByteBuffer.wrap(bytes)) results.toSeq } } diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index 683eeeeb6d661..ef2e238a5fb91 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -31,6 +31,7 @@ import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.memory.MemoryManager import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.network.buffer.ChunkedByteBuffer import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.{FakeTask, Task} import org.apache.spark.serializer.JavaSerializer @@ -93,7 +94,7 @@ class ExecutorSuite extends SparkFunSuite { // save the returned `taskState` and `testFailedReason` into `executorSuiteHelper` val taskState = invocationOnMock.getArguments()(1).asInstanceOf[TaskState] executorSuiteHelper.taskState = taskState - val taskEndReason = invocationOnMock.getArguments()(2).asInstanceOf[ByteBuffer] + val taskEndReason = invocationOnMock.getArguments()(2).asInstanceOf[ChunkedByteBuffer] executorSuiteHelper.testFailedReason = serializer.newInstance().deserialize(taskEndReason) // let the main test thread check `taskState` and `testFailedReason` diff --git a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala index 38b48a4c9e654..c942a3767ba2a 100644 --- a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala +++ b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala @@ -22,8 +22,8 @@ import java.nio.ByteBuffer import com.google.common.io.ByteStreams import org.apache.spark.SparkFunSuite +import org.apache.spark.network.buffer.{Allocator, ChunkedByteBuffer} import org.apache.spark.network.util.ByteArrayWritableChannel -import org.apache.spark.util.io.ChunkedByteBuffer class ChunkedByteBufferSuite extends SparkFunSuite { @@ -34,8 +34,8 @@ class ChunkedByteBufferSuite extends SparkFunSuite { assert(emptyChunkedByteBuffer.toArray === Array.empty) assert(emptyChunkedByteBuffer.toByteBuffer.capacity() === 0) assert(emptyChunkedByteBuffer.toNetty.capacity() === 0) - emptyChunkedByteBuffer.toInputStream(dispose = false).close() - emptyChunkedByteBuffer.toInputStream(dispose = true).close() + emptyChunkedByteBuffer.toInputStream(false).close() + emptyChunkedByteBuffer.toInputStream(true).close() } test("getChunks() duplicates chunks") { @@ -46,7 +46,9 @@ class ChunkedByteBufferSuite extends SparkFunSuite { test("copy() does not affect original buffer's position") { val chunkedByteBuffer = new ChunkedByteBuffer(Array(ByteBuffer.allocate(8))) - chunkedByteBuffer.copy(ByteBuffer.allocate) + chunkedByteBuffer.copy(new Allocator { + override def allocate(len: Int): ByteBuffer = ByteBuffer.allocate(len) + }) assert(chunkedByteBuffer.getChunks().head.position() === 0) } @@ -80,7 +82,7 @@ class ChunkedByteBufferSuite extends SparkFunSuite { val chunkedByteBuffer = new ChunkedByteBuffer(Array(empty, bytes1, bytes2)) assert(chunkedByteBuffer.size === bytes1.limit() + bytes2.limit()) - val inputStream = chunkedByteBuffer.toInputStream(dispose = false) + val inputStream = chunkedByteBuffer.toInputStream(false) val bytesFromStream = new Array[Byte](chunkedByteBuffer.size.toInt) ByteStreams.readFully(inputStream, bytesFromStream) assert(bytesFromStream === bytes1.array() ++ bytes2.array()) diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index ad56715656c85..fe987be8198d3 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -188,8 +188,8 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { val ser = SparkEnv.get.closureSerializer.newInstance() val union = rdd1.union(rdd2) // The UnionRDD itself should be large, but each individual partition should be small. - assert(ser.serialize(union).limit() > 2000) - assert(ser.serialize(union.partitions.head).limit() < 2000) + assert(ser.serialize(union).toByteBuffer.limit() > 2000) + assert(ser.serialize(union.partitions.head).toByteBuffer.limit() < 2000) } test("aggregate") { diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala index 0c156fef0ae0f..e344b50b616bc 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala @@ -25,6 +25,7 @@ import org.mockito.Matchers._ import org.mockito.Mockito._ import org.apache.spark.SparkFunSuite +import org.apache.spark.network.buffer.ChunkedByteBuffer import org.apache.spark.network.client.{TransportClient, TransportResponseHandler} import org.apache.spark.network.server.StreamManager import org.apache.spark.rpc._ @@ -33,7 +34,7 @@ class NettyRpcHandlerSuite extends SparkFunSuite { val env = mock(classOf[NettyRpcEnv]) val sm = mock(classOf[StreamManager]) - when(env.deserialize(any(classOf[TransportClient]), any(classOf[ByteBuffer]))(any())) + when(env.deserialize(any(classOf[TransportClient]), any(classOf[ChunkedByteBuffer]))(any())) .thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null)) test("receive") { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 9eda79ace18d0..c6fb0924db0ee 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -60,7 +60,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark } val closureSerializer = SparkEnv.get.closureSerializer.newInstance() val func = (c: TaskContext, i: Iterator[String]) => i.next() - val taskBinary = sc.broadcast(JavaUtils.bufferToArray(closureSerializer.serialize((rdd, func)))) + val taskBinary = sc.broadcast(closureSerializer.serialize((rdd, func)).toArray) val task = new ResultTask[String, String]( 0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, new Properties, new TaskMetrics) intercept[RuntimeException] { @@ -81,7 +81,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark } val closureSerializer = SparkEnv.get.closureSerializer.newInstance() val func = (c: TaskContext, i: Iterator[String]) => i.next() - val taskBinary = sc.broadcast(JavaUtils.bufferToArray(closureSerializer.serialize((rdd, func)))) + val taskBinary = sc.broadcast(closureSerializer.serialize((rdd, func)).toArray) val task = new ResultTask[String, String]( 0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, new Properties, new TaskMetrics) intercept[RuntimeException] { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala index 9e472f900b655..0f64d509ae442 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala @@ -36,9 +36,9 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark._ import org.apache.spark.storage.TaskResultBlockId import org.apache.spark.TestUtils.JavaSourceFromString +import org.apache.spark.network.buffer.ChunkedByteBuffer import org.apache.spark.util.{MutableURLClassLoader, RpcUtils, Utils} - /** * Removes the TaskResult from the BlockManager before delegating to a normal TaskResultGetter. * @@ -52,11 +52,11 @@ private class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: Task @volatile var removeBlockSuccessfully = false override def enqueueSuccessfulTask( - taskSetManager: TaskSetManager, tid: Long, serializedData: ByteBuffer) { + taskSetManager: TaskSetManager, tid: Long, serializedData: ChunkedByteBuffer) { if (!removedResult) { // Only remove the result once, since we'd like to test the case where the task eventually // succeeds. - serializer.get().deserialize[TaskResult[_]](serializedData) match { + serializer.get().deserialize[TaskResult[_]](serializedData.duplicate()) match { case IndirectTaskResult(blockId, size) => sparkEnv.blockManager.master.removeBlock(blockId) // removeBlock is asynchronous. Need to wait it's removed successfully @@ -71,7 +71,6 @@ private class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: Task case directResult: DirectTaskResult[_] => taskSetManager.abort("Internal error: expect only indirect results") } - serializedData.rewind() removedResult = true } super.enqueueSuccessfulTask(taskSetManager, tid, serializedData) @@ -94,7 +93,8 @@ private class MyTaskResultGetter(env: SparkEnv, scheduler: TaskSchedulerImpl) def taskResults: Seq[DirectTaskResult[_]] = _taskResults - override def enqueueSuccessfulTask(tsm: TaskSetManager, tid: Long, data: ByteBuffer): Unit = { + override def enqueueSuccessfulTask(tsm: TaskSetManager, tid: Long, + data: ChunkedByteBuffer): Unit = { // work on a copy since the super class still needs to use the buffer val newBuffer = data.duplicate() _taskResults += env.closureSerializer.newInstance().deserialize[DirectTaskResult[_]](newBuffer) diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala index fcf39f63915f7..dbcdc34aef894 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala @@ -36,8 +36,8 @@ import org.scalatest.mock.MockitoSugar import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.executor.MesosExecutorBackend -import org.apache.spark.scheduler.{LiveListenerBus, SparkListenerExecutorAdded, - TaskDescription, TaskSchedulerImpl, WorkerOffer} +import org.apache.spark.network.buffer.ChunkedByteBuffer +import org.apache.spark.scheduler.{LiveListenerBus, SparkListenerExecutorAdded, TaskDescription, TaskSchedulerImpl, WorkerOffer} import org.apache.spark.scheduler.cluster.ExecutorInfo class MesosFineGrainedSchedulerBackendSuite @@ -246,7 +246,8 @@ class MesosFineGrainedSchedulerBackendSuite mesosOffers.get(2).getHostname, (minCpu - backend.mesosExecutorCores).toInt )) - val taskDesc = new TaskDescription(1L, 0, "s1", "n1", 0, ByteBuffer.wrap(new Array[Byte](0))) + val taskDesc = new TaskDescription(1L, 0, "s1", "n1", 0, + ChunkedByteBuffer.wrap(new Array[Byte](0))) when(taskScheduler.resourceOffers(expectedWorkerOffers)).thenReturn(Seq(Seq(taskDesc))) when(taskScheduler.CPUS_PER_TASK).thenReturn(2) @@ -345,7 +346,8 @@ class MesosFineGrainedSchedulerBackendSuite 2 // Deducting 1 for executor )) - val taskDesc = new TaskDescription(1L, 0, "s1", "n1", 0, ByteBuffer.wrap(new Array[Byte](0))) + val taskDesc = new TaskDescription(1L, 0, "s1", "n1", 0, + ChunkedByteBuffer.wrap(new Array[Byte](0))) when(taskScheduler.resourceOffers(expectedWorkerOffers)).thenReturn(Seq(Seq(taskDesc))) when(taskScheduler.CPUS_PER_TASK).thenReturn(1) diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 57a82312008e9..27cdd08309434 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -188,7 +188,7 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { def check[T: ClassTag](t: T) { assert(ser.deserialize[T](ser.serialize(t)) === t) // Check that very long ranges don't get written one element at a time - assert(ser.serialize(t).limit < 100) + assert(ser.serialize(t).toByteBuffer.limit < 100) } check(1 to 1000000) check(1 to 1000000 by 2) diff --git a/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala b/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala index 17037870f7a15..c86492b887580 100644 --- a/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala +++ b/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala @@ -22,6 +22,8 @@ import java.nio.ByteBuffer import scala.reflect.ClassTag +import org.apache.spark.network.buffer.ChunkedByteBuffer + /** * A serializer implementation that always returns two elements in a deserialization stream. */ @@ -31,7 +33,8 @@ class TestSerializer extends Serializer { class TestSerializerInstance extends SerializerInstance { - override def serialize[T: ClassTag](t: T): ByteBuffer = throw new UnsupportedOperationException + override def serialize[T: ClassTag](t: T): ChunkedByteBuffer = + throw new UnsupportedOperationException override def serializeStream(s: OutputStream): SerializationStream = throw new UnsupportedOperationException @@ -39,10 +42,10 @@ class TestSerializerInstance extends SerializerInstance { override def deserializeStream(s: InputStream): TestDeserializationStream = new TestDeserializationStream - override def deserialize[T: ClassTag](bytes: ByteBuffer): T = + override def deserialize[T: ClassTag](bytes: ChunkedByteBuffer): T = throw new UnsupportedOperationException - override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = + override def deserialize[T: ClassTag](bytes: ChunkedByteBuffer, loader: ClassLoader): T = throw new UnsupportedOperationException } diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index dba1172d5fdbd..f4ddc31c22681 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -23,7 +23,7 @@ import java.nio.ByteBuffer import org.mockito.Mockito.{mock, when} import org.apache.spark._ -import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.network.buffer.{ChunkedByteBuffer, ManagedBuffer, NioManagedBuffer} import org.apache.spark.serializer.{JavaSerializer, SerializerManager} import org.apache.spark.storage.{BlockManager, BlockManagerId, ShuffleBlockId} @@ -38,7 +38,7 @@ class RecordingManagedBuffer(underlyingBuffer: NioManagedBuffer) extends Managed var callsToRelease = 0 override def size(): Long = underlyingBuffer.size() - override def nioByteBuffer(): ByteBuffer = underlyingBuffer.nioByteBuffer() + override def nioByteBuffer(): ChunkedByteBuffer = underlyingBuffer.nioByteBuffer() override def createInputStream(): InputStream = underlyingBuffer.createInputStream() override def convertToNetty(): AnyRef = underlyingBuffer.convertToNetty() diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 87c8628ce97e9..d722b52e230c9 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -37,7 +37,7 @@ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.executor.DataReadMethod import org.apache.spark.memory.UnifiedMemoryManager import org.apache.spark.network.{BlockDataManager, BlockTransferService} -import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.network.buffer.{ChunkedByteBuffer, ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.network.shuffle.BlockFetchingListener import org.apache.spark.rpc.RpcEnv @@ -46,7 +46,6 @@ import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerMa import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat import org.apache.spark.util._ -import org.apache.spark.util.io.ChunkedByteBuffer class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach with PrivateMethodTester with LocalSparkContext with ResetSystemProperties { diff --git a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala index 9ed5016510d56..e5242710989fa 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala @@ -21,7 +21,7 @@ import java.nio.{ByteBuffer, MappedByteBuffer} import java.util.Arrays import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.util.io.ChunkedByteBuffer +import org.apache.spark.network.buffer.ChunkedByteBuffer class DiskStoreSuite extends SparkFunSuite { diff --git a/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala b/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala index c11de826677e0..0312d6ef0a685 100644 --- a/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala @@ -27,10 +27,10 @@ import org.scalatest._ import org.apache.spark._ import org.apache.spark.memory.{MemoryMode, StaticMemoryManager} +import org.apache.spark.network.buffer.ChunkedByteBuffer import org.apache.spark.serializer.{KryoSerializer, SerializerManager} import org.apache.spark.storage.memory.{BlockEvictionHandler, MemoryStore, PartiallySerializedBlock, PartiallyUnrolledIterator} import org.apache.spark.util._ -import org.apache.spark.util.io.ChunkedByteBuffer class MemoryStoreSuite extends SparkFunSuite @@ -404,7 +404,7 @@ class MemoryStoreSuite val blockId = BlockId("rdd_3_10") var bytes: ChunkedByteBuffer = null memoryStore.putBytes(blockId, 10000, MemoryMode.ON_HEAP, () => { - bytes = new ChunkedByteBuffer(ByteBuffer.allocate(10000)) + bytes = ChunkedByteBuffer.wrap(ByteBuffer.allocate(10000)) bytes }) assert(memoryStore.getSize(blockId) === 10000) diff --git a/core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala b/core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala index 226622075a6cc..4fa52ee801046 100644 --- a/core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala @@ -17,22 +17,21 @@ package org.apache.spark.util.io -import java.nio.ByteBuffer - import scala.util.Random import org.apache.spark.SparkFunSuite +import org.apache.spark.network.buffer.ChunkedByteBufferOutputStream class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite { test("empty output") { - val o = new ChunkedByteBufferOutputStream(1024, ByteBuffer.allocate) + val o = new ChunkedByteBufferOutputStream(1024) assert(o.toChunkedByteBuffer.size === 0) } test("write a single byte") { - val o = new ChunkedByteBufferOutputStream(1024, ByteBuffer.allocate) + val o = new ChunkedByteBufferOutputStream(1024) o.write(10) val chunkedByteBuffer = o.toChunkedByteBuffer assert(chunkedByteBuffer.getChunks().length === 1) @@ -40,7 +39,7 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite { } test("write a single near boundary") { - val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate) + val o = new ChunkedByteBufferOutputStream(10) o.write(new Array[Byte](9)) o.write(99) val chunkedByteBuffer = o.toChunkedByteBuffer @@ -49,7 +48,7 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite { } test("write a single at boundary") { - val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate) + val o = new ChunkedByteBufferOutputStream(10) o.write(new Array[Byte](10)) o.write(99) val arrays = o.toChunkedByteBuffer.getChunks().map(_.array()) @@ -61,7 +60,7 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite { test("single chunk output") { val ref = new Array[Byte](8) Random.nextBytes(ref) - val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate) + val o = new ChunkedByteBufferOutputStream(10) o.write(ref) val arrays = o.toChunkedByteBuffer.getChunks().map(_.array()) assert(arrays.length === 1) @@ -72,7 +71,7 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite { test("single chunk output at boundary size") { val ref = new Array[Byte](10) Random.nextBytes(ref) - val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate) + val o = new ChunkedByteBufferOutputStream(10) o.write(ref) val arrays = o.toChunkedByteBuffer.getChunks().map(_.array()) assert(arrays.length === 1) @@ -83,7 +82,7 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite { test("multiple chunk output") { val ref = new Array[Byte](26) Random.nextBytes(ref) - val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate) + val o = new ChunkedByteBufferOutputStream(10) o.write(ref) val arrays = o.toChunkedByteBuffer.getChunks().map(_.array()) assert(arrays.length === 3) @@ -99,7 +98,7 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite { test("multiple chunk output at boundary size") { val ref = new Array[Byte](30) Random.nextBytes(ref) - val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate) + val o = new ChunkedByteBufferOutputStream(10) o.write(ref) val arrays = o.toChunkedByteBuffer.getChunks().map(_.array()) assert(arrays.length === 3) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala index 8ab553369de6d..42f0f2e506383 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala @@ -24,6 +24,7 @@ import scala.reflect.ClassTag import com.google.common.io.ByteStreams +import org.apache.spark.network.buffer.ChunkedByteBuffer import org.apache.spark.serializer.{DeserializationStream, SerializationStream, Serializer, SerializerInstance} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.metric.SQLMetric @@ -176,9 +177,10 @@ private class UnsafeRowSerializerInstance( } // These methods are never called by shuffle code. - override def serialize[T: ClassTag](t: T): ByteBuffer = throw new UnsupportedOperationException - override def deserialize[T: ClassTag](bytes: ByteBuffer): T = + override def serialize[T: ClassTag](t: T): ChunkedByteBuffer = throw new UnsupportedOperationException - override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = + override def deserialize[T: ClassTag](bytes: ChunkedByteBuffer): T = + throw new UnsupportedOperationException + override def deserialize[T: ClassTag](bytes: ChunkedByteBuffer, loader: ClassLoader): T = throw new UnsupportedOperationException } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala index 698f07b0a187f..09d7b33776ad7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -29,6 +29,7 @@ import org.apache.hadoop.fs._ import org.apache.hadoop.fs.permission.FsPermission import org.apache.spark.internal.Logging +import org.apache.spark.network.buffer.ChunkedByteBuffer import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.JavaSerializer import org.apache.spark.sql.SparkSession @@ -85,11 +86,11 @@ class HDFSMetadataLog[T: ClassTag](sparkSession: SparkSession, path: String) } protected def serialize(metadata: T): Array[Byte] = { - JavaUtils.bufferToArray(serializer.serialize(metadata)) + serializer.serialize(metadata).toArray } protected def deserialize(bytes: Array[Byte]): T = { - serializer.deserialize[T](ByteBuffer.wrap(bytes)) + serializer.deserialize[T](ChunkedByteBuffer.wrap(bytes)) } /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala index 53fccd8d5e6ed..a9cabb726fa9d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala @@ -24,11 +24,11 @@ import scala.reflect.ClassTag import scala.util.control.NonFatal import org.apache.spark._ +import org.apache.spark.network.buffer.ChunkedByteBuffer import org.apache.spark.rdd.BlockRDD import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.streaming.util._ import org.apache.spark.util.SerializableConfiguration -import org.apache.spark.util.io.ChunkedByteBuffer /** * Partition class for [[org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD]]. @@ -158,12 +158,12 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( logInfo(s"Read partition data of $this from write ahead log, record handle " + partition.walRecordHandle) if (storeInBlockManager) { - blockManager.putBytes(blockId, new ChunkedByteBuffer(dataRead.duplicate()), storageLevel) + blockManager.putBytes(blockId, ChunkedByteBuffer.wrap(dataRead.duplicate()), storageLevel) logDebug(s"Stored partition data of $this into block manager with level $storageLevel") dataRead.rewind() } serializerManager - .dataDeserializeStream(blockId, new ChunkedByteBuffer(dataRead).toInputStream()) + .dataDeserializeStream(blockId, ChunkedByteBuffer.wrap(dataRead).toInputStream()) .asInstanceOf[Iterator[T]] } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala index 80c07958b41f2..ff179a27012fe 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala @@ -26,12 +26,12 @@ import org.apache.hadoop.fs.Path import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.internal.Logging +import org.apache.spark.network.buffer.ChunkedByteBuffer import org.apache.spark.serializer.SerializerManager import org.apache.spark.storage._ import org.apache.spark.streaming.receiver.WriteAheadLogBasedBlockHandler._ import org.apache.spark.streaming.util.{WriteAheadLogRecordHandle, WriteAheadLogUtils} import org.apache.spark.util.{Clock, SystemClock, ThreadUtils} -import org.apache.spark.util.io.ChunkedByteBuffer /** Trait that represents the metadata related to storage of blocks */ private[streaming] trait ReceivedBlockStoreResult { @@ -87,7 +87,7 @@ private[streaming] class BlockManagerBasedBlockHandler( putResult case ByteBufferBlock(byteBuffer) => blockManager.putBytes( - blockId, new ChunkedByteBuffer(byteBuffer.duplicate()), storageLevel, tellMaster = true) + blockId, ChunkedByteBuffer.wrap(byteBuffer.duplicate()), storageLevel, tellMaster = true) case o => throw new SparkException( s"Could not store $blockId to block manager, unexpected block type ${o.getClass.getName}") @@ -182,7 +182,7 @@ private[streaming] class WriteAheadLogBasedBlockHandler( numRecords = countIterator.count serializedBlock case ByteBufferBlock(byteBuffer) => - new ChunkedByteBuffer(byteBuffer.duplicate()) + ChunkedByteBuffer.wrap(byteBuffer.duplicate()) case _ => throw new Exception(s"Could not push $blockId to block manager, unexpected block type") } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index feb5c30c6aa14..00a3a1c3e07b0 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -32,6 +32,7 @@ import org.apache.spark._ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.internal.Logging import org.apache.spark.memory.StaticMemoryManager +import org.apache.spark.network.buffer.ChunkedByteBuffer import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus @@ -41,7 +42,6 @@ import org.apache.spark.storage._ import org.apache.spark.streaming.receiver._ import org.apache.spark.streaming.util._ import org.apache.spark.util.{ManualClock, Utils} -import org.apache.spark.util.io.ChunkedByteBuffer class ReceivedBlockHandlerSuite extends SparkFunSuite @@ -163,7 +163,7 @@ class ReceivedBlockHandlerSuite val bytes = reader.read(fileSegment) reader.close() serializerManager.dataDeserializeStream( - generateBlockId(), new ChunkedByteBuffer(bytes).toInputStream()).toList + generateBlockId(), ChunkedByteBuffer.wrap(bytes).toInputStream()).toList } loggedData shouldEqual data }