Skip to content

Commit

Permalink
[SPARK-26674][CORE] Consolidate CompositeByteBuf when reading large f…
Browse files Browse the repository at this point in the history
…rame

## What changes were proposed in this pull request?

Currently, TransportFrameDecoder will not consolidate the buffers read from network which may cause memory waste. Actually, bytebuf's writtenIndex is far less than it's capacity  in most cases, so we can optimize it by doing consolidation.

This PR will do this optimization.

Related codes:
https://github.com/apache/spark/blob/9a30e23211e165a44acc0dbe19693950f7a7cc73/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java#L143

## How was this patch tested?

UT

Please review http://spark.apache.org/contributing.html before opening a pull request.

Closes #23602 from liupc/Reduce-memory-consumption-in-TransportFrameDecoder.

Lead-authored-by: liupengcheng <liupengcheng@xiaomi.com>
Co-authored-by: Liupengcheng <liupengcheng@xiaomi.com>
Signed-off-by: Marcelo Vanzin <vanzin@cloudera.com>
  • Loading branch information
liupengcheng authored and Marcelo Vanzin committed Feb 26, 2019
1 parent 4baa2d4 commit 52a180f
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 17 deletions.
Expand Up @@ -19,6 +19,7 @@

import java.util.LinkedList;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.CompositeByteBuf;
Expand Down Expand Up @@ -48,14 +49,30 @@ public class TransportFrameDecoder extends ChannelInboundHandlerAdapter {
private static final int LENGTH_SIZE = 8;
private static final int MAX_FRAME_SIZE = Integer.MAX_VALUE;
private static final int UNKNOWN_FRAME_SIZE = -1;
private static final long CONSOLIDATE_THRESHOLD = 20 * 1024 * 1024;

private final LinkedList<ByteBuf> buffers = new LinkedList<>();
private final ByteBuf frameLenBuf = Unpooled.buffer(LENGTH_SIZE, LENGTH_SIZE);
private final long consolidateThreshold;

private CompositeByteBuf frameBuf = null;
private long consolidatedFrameBufSize = 0;
private int consolidatedNumComponents = 0;

private long totalSize = 0;
private long nextFrameSize = UNKNOWN_FRAME_SIZE;
private int frameRemainingBytes = UNKNOWN_FRAME_SIZE;
private volatile Interceptor interceptor;

public TransportFrameDecoder() {
this(CONSOLIDATE_THRESHOLD);
}

@VisibleForTesting
TransportFrameDecoder(long consolidateThreshold) {
this.consolidateThreshold = consolidateThreshold;
}

@Override
public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception {
ByteBuf in = (ByteBuf) data;
Expand Down Expand Up @@ -123,30 +140,56 @@ private long decodeFrameSize() {

private ByteBuf decodeNext() {
long frameSize = decodeFrameSize();
if (frameSize == UNKNOWN_FRAME_SIZE || totalSize < frameSize) {
if (frameSize == UNKNOWN_FRAME_SIZE) {
return null;
}

// Reset size for next frame.
nextFrameSize = UNKNOWN_FRAME_SIZE;

Preconditions.checkArgument(frameSize < MAX_FRAME_SIZE, "Too large frame: %s", frameSize);
Preconditions.checkArgument(frameSize > 0, "Frame length should be positive: %s", frameSize);
if (frameBuf == null) {
Preconditions.checkArgument(frameSize < MAX_FRAME_SIZE,
"Too large frame: %s", frameSize);
Preconditions.checkArgument(frameSize > 0,
"Frame length should be positive: %s", frameSize);
frameRemainingBytes = (int) frameSize;

// If the first buffer holds the entire frame, return it.
int remaining = (int) frameSize;
if (buffers.getFirst().readableBytes() >= remaining) {
return nextBufferForFrame(remaining);
// If buffers is empty, then return immediately for more input data.
if (buffers.isEmpty()) {
return null;
}
// Otherwise, if the first buffer holds the entire frame, we attempt to
// build frame with it and return.
if (buffers.getFirst().readableBytes() >= frameRemainingBytes) {
// Reset buf and size for next frame.
frameBuf = null;
nextFrameSize = UNKNOWN_FRAME_SIZE;
return nextBufferForFrame(frameRemainingBytes);
}
// Other cases, create a composite buffer to manage all the buffers.
frameBuf = buffers.getFirst().alloc().compositeBuffer(Integer.MAX_VALUE);
}

// Otherwise, create a composite buffer.
CompositeByteBuf frame = buffers.getFirst().alloc().compositeBuffer(Integer.MAX_VALUE);
while (remaining > 0) {
ByteBuf next = nextBufferForFrame(remaining);
remaining -= next.readableBytes();
frame.addComponent(next).writerIndex(frame.writerIndex() + next.readableBytes());
while (frameRemainingBytes > 0 && !buffers.isEmpty()) {
ByteBuf next = nextBufferForFrame(frameRemainingBytes);
frameRemainingBytes -= next.readableBytes();
frameBuf.addComponent(true, next);
}
assert remaining == 0;
// If the delta size of frameBuf exceeds the threshold, then we do consolidation
// to reduce memory consumption.
if (frameBuf.capacity() - consolidatedFrameBufSize > consolidateThreshold) {
int newNumComponents = frameBuf.numComponents() - consolidatedNumComponents;
frameBuf.consolidate(consolidatedNumComponents, newNumComponents);
consolidatedFrameBufSize = frameBuf.capacity();
consolidatedNumComponents = frameBuf.numComponents();
}
if (frameRemainingBytes > 0) {
return null;
}

// Reset buf and size for next frame.
ByteBuf frame = frameBuf;
frameBuf = null;
consolidatedFrameBufSize = 0;
consolidatedNumComponents = 0;
nextFrameSize = UNKNOWN_FRAME_SIZE;
return frame;
}

Expand Down
Expand Up @@ -27,11 +27,15 @@
import io.netty.channel.ChannelHandlerContext;
import org.junit.AfterClass;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import static org.junit.Assert.*;
import static org.mockito.Mockito.*;

public class TransportFrameDecoderSuite {

private static final Logger logger = LoggerFactory.getLogger(TransportFrameDecoderSuite.class);
private static Random RND = new Random();

@AfterClass
Expand All @@ -47,6 +51,69 @@ public void testFrameDecoding() throws Exception {
verifyAndCloseDecoder(decoder, ctx, data);
}

@Test
public void testConsolidationPerf() throws Exception {
long[] testingConsolidateThresholds = new long[] {
ByteUnit.MiB.toBytes(1),
ByteUnit.MiB.toBytes(5),
ByteUnit.MiB.toBytes(10),
ByteUnit.MiB.toBytes(20),
ByteUnit.MiB.toBytes(30),
ByteUnit.MiB.toBytes(50),
ByteUnit.MiB.toBytes(80),
ByteUnit.MiB.toBytes(100),
ByteUnit.MiB.toBytes(300),
ByteUnit.MiB.toBytes(500),
Long.MAX_VALUE };
for (long threshold : testingConsolidateThresholds) {
TransportFrameDecoder decoder = new TransportFrameDecoder(threshold);
ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
List<ByteBuf> retained = new ArrayList<>();
when(ctx.fireChannelRead(any())).thenAnswer(in -> {
ByteBuf buf = (ByteBuf) in.getArguments()[0];
retained.add(buf);
return null;
});

// Testing multiple messages
int numMessages = 3;
long targetBytes = ByteUnit.MiB.toBytes(300);
int pieceBytes = (int) ByteUnit.KiB.toBytes(32);
for (int i = 0; i < numMessages; i++) {
try {
long writtenBytes = 0;
long totalTime = 0;
ByteBuf buf = Unpooled.buffer(8);
buf.writeLong(8 + targetBytes);
decoder.channelRead(ctx, buf);
while (writtenBytes < targetBytes) {
buf = Unpooled.buffer(pieceBytes * 2);
ByteBuf writtenBuf = Unpooled.buffer(pieceBytes).writerIndex(pieceBytes);
buf.writeBytes(writtenBuf);
writtenBuf.release();
long start = System.currentTimeMillis();
decoder.channelRead(ctx, buf);
long elapsedTime = System.currentTimeMillis() - start;
totalTime += elapsedTime;
writtenBytes += pieceBytes;
}
logger.info("Writing 300MiB frame buf with consolidation of threshold " + threshold
+ " took " + totalTime + " milis");
} finally {
for (ByteBuf buf : retained) {
release(buf);
}
}
}
long totalBytesGot = 0;
for (ByteBuf buf : retained) {
totalBytesGot += buf.capacity();
}
assertEquals(numMessages, retained.size());
assertEquals(targetBytes * numMessages, totalBytesGot);
}
}

@Test
public void testInterception() throws Exception {
int interceptedReads = 3;
Expand Down

0 comments on commit 52a180f

Please sign in to comment.