diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java index 469b8fe8a143e..756598a8390f2 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java @@ -164,4 +164,83 @@ public void cleanerCreateMethodIsDefined() { // path to be hit in normal usage. Assertions.assertTrue(Platform.cleanerCreateMethodIsDefined()); } + + @Test + public void reallocateMemoryGrow() { + long address = Platform.allocateMemory(1024); + try { + for (int i = 0; i < 1024; i++) { + Platform.putByte(null, address + i, (byte) i); + } + + address = Platform.reallocateMemory(address, 1024, 2048); + for (int i = 0; i < 1024; i++) { + Assertions.assertEquals((byte) i, Platform.getByte(null, address + i)); + } + for (int i = 1024; i < 2048; i++) { + Platform.putByte(null, address + i, (byte) i); + } + for (int i = 1024; i < 2048; i++) { + Assertions.assertEquals((byte) i, Platform.getByte(null, address + i)); + } + } finally { + Platform.freeMemory(address); + } + } + + @Test + public void reallocateMemoryShrinkDoesNotOverflow() { + long oldSize = 1024L; + long newSize = 512L; + long sentinelSize = oldSize - newSize; + byte sentinelValue = (byte) 0xAB; + byte sourceValue = (byte) 0xCD; + int pairs = 1000; + + long[] sources = new long[pairs]; + long[] sentinels = new long[pairs]; + + try { + // Allocate all pairs + for (int i = 0; i < pairs; i++) { + sources[i] = Platform.allocateMemory(oldSize); + sentinels[i] = Platform.allocateMemory(sentinelSize); + + Platform.setMemory(sources[i], sourceValue, oldSize); + Platform.setMemory(sentinels[i], sentinelValue, sentinelSize); + } + + // Reallocate source blocks + for (int i = 0; i < pairs; i++) { + long oldAddr = sources[i]; + long newAddr = Platform.reallocateMemory(oldAddr, oldSize, newSize); + sources[i] = newAddr; + } + + // Verify dest content + for (int i = 0; i < pairs; i++) { + for (long j = 0; j < newSize; j++) { + Assertions.assertEquals(sourceValue, Platform.getByte(null, sources[i] + j), + "dest block content corrupted at pair " + i); + } + } + + // Verify sentinel content + for (int i = 0; i < pairs; i++) { + for (long j = 0; j < sentinelSize; j++) { + Assertions.assertEquals(sentinelValue, Platform.getByte(null, sentinels[i] + j), + "sentinel corrupted – overflow write detected at pair " + i); + } + } + + } finally { + // Free all memory + for (long addr : sources) { + if (addr != 0) Platform.freeMemory(addr); + } + for (long addr : sentinels) { + if (addr != 0) Platform.freeMemory(addr); + } + } + } }