diff --git a/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java b/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java index 50035fba5702c..2f9bd7e81fab1 100644 --- a/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java @@ -26,9 +26,11 @@ import io.netty.channel.FileRegion; import io.netty.util.AbstractReferenceCounted; import org.junit.Test; +import org.mockito.Mockito; import static org.junit.Assert.*; +import org.apache.spark.network.TestManagedBuffer; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NettyManagedBuffer; import org.apache.spark.network.util.ByteArrayWritableChannel; @@ -48,7 +50,7 @@ public void testShortWrite() throws Exception { @Test public void testByteBufBody() throws Exception { ByteBuf header = Unpooled.copyLong(42); - ByteBuf body = Unpooled.copyLong(84); + ByteBuf body = Unpooled.copyLong(84).retain(); ManagedBuffer managedBuf = new NettyManagedBuffer(body); MessageWithHeader msg = new MessageWithHeader(managedBuf, header, body, body.readableBytes()); @@ -56,6 +58,17 @@ public void testByteBufBody() throws Exception { assertEquals(msg.count(), result.readableBytes()); assertEquals(42, result.readLong()); assertEquals(84, result.readLong()); + msg.deallocate(); + } + + @Test + public void testDeallocateReleasesManagedBuffer() throws Exception { + ByteBuf header = Unpooled.copyLong(42); + ManagedBuffer managedBuf = Mockito.spy(new TestManagedBuffer(84)); + ByteBuf body = ((ByteBuf) managedBuf.convertToNetty()).retain(); + MessageWithHeader msg = new MessageWithHeader(managedBuf, header, body, body.readableBytes()); + msg.deallocate(); + Mockito.verify(managedBuf, Mockito.times(1)).release(); } private void testFileRegionBody(int totalWrites, int writesPerCall) throws Exception { @@ -70,6 +83,7 @@ private void testFileRegionBody(int totalWrites, int writesPerCall) throws Excep for (long i = 0; i < 8; i++) { assertEquals(i, result.readLong()); } + msg.deallocate(); } private ByteBuf doWrite(MessageWithHeader msg, int minExpectedWrites) throws Exception { diff --git a/network/common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java b/network/common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java index 6356ac6c24f80..c647525d8f1bd 100644 --- a/network/common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java @@ -23,7 +23,6 @@ import io.netty.channel.Channel; import org.junit.Test; import org.mockito.Mockito; -import static org.mockito.Mockito.*; import org.apache.spark.network.TestManagedBuffer; import org.apache.spark.network.buffer.ManagedBuffer; @@ -40,12 +39,12 @@ public void managedBuffersAreFeedWhenConnectionIsClosed() throws Exception { buffers.add(buffer2); long streamId = manager.registerStream("appId", buffers.iterator()); - Channel dummyChannel = Mockito.mock(Channel.class); + Channel dummyChannel = Mockito.mock(Channel.class, Mockito.RETURNS_SMART_NULLS); manager.registerChannel(dummyChannel, streamId); manager.connectionTerminated(dummyChannel); - Mockito.verify(buffer1, times(1)).release(); - Mockito.verify(buffer2, times(1)).release(); + Mockito.verify(buffer1, Mockito.times(1)).release(); + Mockito.verify(buffer2, Mockito.times(1)).release(); } }