diff --git a/flink-core/src/main/java/org/apache/flink/core/memory/HybridMemorySegment.java b/flink-core/src/main/java/org/apache/flink/core/memory/HybridMemorySegment.java index 23ea740829e40..1bfb50eed15b3 100644 --- a/flink-core/src/main/java/org/apache/flink/core/memory/HybridMemorySegment.java +++ b/flink-core/src/main/java/org/apache/flink/core/memory/HybridMemorySegment.java @@ -60,6 +60,8 @@ public final class HybridMemorySegment extends MemorySegment { */ @Nullable private ByteBuffer offHeapBuffer; + @Nullable private final Runnable cleaner; + /** * Wrapping is not allowed when the underlying memory is unsafe. Unsafe memory can be actively * released, without reference counting. Therefore, access from wrapped buffers, which may not @@ -80,7 +82,7 @@ public final class HybridMemorySegment extends MemorySegment { * @throws IllegalArgumentException Thrown, if the given ByteBuffer is not direct. */ HybridMemorySegment(@Nonnull ByteBuffer buffer, @Nullable Object owner) { - this(buffer, owner, true); + this(buffer, owner, true, null); } /** @@ -94,12 +96,18 @@ public final class HybridMemorySegment extends MemorySegment { * @param buffer The byte buffer whose memory is represented by this memory segment. * @param owner The owner references by this memory segment. * @param allowWrap Whether wrapping {@link ByteBuffer}s from the segment is allowed. + * @param cleaner The cleaner to be called on free segment. * @throws IllegalArgumentException Thrown, if the given ByteBuffer is not direct. */ - HybridMemorySegment(@Nonnull ByteBuffer buffer, @Nullable Object owner, boolean allowWrap) { + HybridMemorySegment( + @Nonnull ByteBuffer buffer, + @Nullable Object owner, + boolean allowWrap, + @Nullable Runnable cleaner) { super(getByteBufferAddress(buffer), buffer.capacity(), owner); this.offHeapBuffer = buffer; this.allowWrap = allowWrap; + this.cleaner = cleaner; } /** @@ -114,6 +122,7 @@ public final class HybridMemorySegment extends MemorySegment { super(buffer, owner); this.offHeapBuffer = null; this.allowWrap = true; + this.cleaner = null; } // ------------------------------------------------------------------------- @@ -123,6 +132,9 @@ public final class HybridMemorySegment extends MemorySegment { @Override public void free() { super.free(); + if (cleaner != null) { + cleaner.run(); + } offHeapBuffer = null; // to enable GC of unsafe memory } diff --git a/flink-core/src/main/java/org/apache/flink/core/memory/MemorySegmentFactory.java b/flink-core/src/main/java/org/apache/flink/core/memory/MemorySegmentFactory.java index 09874962b107d..112d6de9fdfde 100644 --- a/flink-core/src/main/java/org/apache/flink/core/memory/MemorySegmentFactory.java +++ b/flink-core/src/main/java/org/apache/flink/core/memory/MemorySegmentFactory.java @@ -175,8 +175,9 @@ public static MemorySegment allocateOffHeapUnsafeMemory( int size, Object owner, Runnable customCleanupAction) { long address = MemoryUtils.allocateUnsafe(size); ByteBuffer offHeapBuffer = MemoryUtils.wrapUnsafeMemoryWithByteBuffer(address, size); - MemoryUtils.createMemoryGcCleaner(offHeapBuffer, address, customCleanupAction); - return new HybridMemorySegment(offHeapBuffer, owner, false); + Runnable cleaner = + MemoryUtils.createMemoryGcCleaner(offHeapBuffer, address, customCleanupAction); + return new HybridMemorySegment(offHeapBuffer, owner, false, cleaner); } /** diff --git a/flink-core/src/test/java/org/apache/flink/core/memory/HybridOffHeapUnsafeMemorySegmentTest.java b/flink-core/src/test/java/org/apache/flink/core/memory/HybridOffHeapUnsafeMemorySegmentTest.java index d6663578beeb5..2f77153597e65 100644 --- a/flink-core/src/test/java/org/apache/flink/core/memory/HybridOffHeapUnsafeMemorySegmentTest.java +++ b/flink-core/src/test/java/org/apache/flink/core/memory/HybridOffHeapUnsafeMemorySegmentTest.java @@ -22,6 +22,10 @@ import org.junit.runner.RunWith; import org.junit.runners.Parameterized; +import java.util.concurrent.CompletableFuture; + +import static org.junit.Assert.assertTrue; + /** Tests for the {@link HybridMemorySegment} in off-heap mode using unsafe memory. */ @RunWith(Parameterized.class) public class HybridOffHeapUnsafeMemorySegmentTest extends MemorySegmentTestBase { @@ -45,4 +49,13 @@ MemorySegment createSegment(int size, Object owner) { public void testByteBufferWrapping() { createSegment(10).wrap(1, 2); } + + @Test + public void testCallCleanerOnFree() { + final CompletableFuture cleanerFuture = new CompletableFuture<>(); + MemorySegmentFactory.allocateOffHeapUnsafeMemory( + 10, null, () -> cleanerFuture.complete(null)) + .free(); + assertTrue(cleanerFuture.isDone()); + } }