Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@

import org.apache.flink.core.memory.MemorySegment;

import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.Timeout;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;

import java.util.ArrayList;
import java.util.List;
Expand All @@ -30,104 +29,105 @@
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

/** Tests for {@link BatchShuffleReadBufferPool}. */
public class BatchShuffleReadBufferPoolTest {
@Timeout(value = 60, unit = TimeUnit.SECONDS)
class BatchShuffleReadBufferPoolTest {

@Rule public Timeout timeout = new Timeout(60, TimeUnit.SECONDS);

@Test(expected = IllegalArgumentException.class)
public void testIllegalTotalBytes() {
createBufferPool(0, 1024);
@Test
void testIllegalTotalBytes() {
assertThatThrownBy(() -> createBufferPool(0, 1024))
.isInstanceOf(IllegalArgumentException.class);
}

@Test(expected = IllegalArgumentException.class)
public void testIllegalBufferSize() {
createBufferPool(32 * 1024 * 1024, 0);
@Test
void testIllegalBufferSize() {
assertThatThrownBy(() -> createBufferPool(32 * 1024 * 1024, 0))
.isInstanceOf(IllegalArgumentException.class);
}

@Test
public void testLargeTotalBytes() {
void testLargeTotalBytes() {
BatchShuffleReadBufferPool bufferPool = createBufferPool(Long.MAX_VALUE, 1024);
assertEquals(Integer.MAX_VALUE, bufferPool.getNumTotalBuffers());
assertThat(bufferPool.getNumTotalBuffers()).isEqualTo(Integer.MAX_VALUE);
bufferPool.destroy();
}

@Test(expected = IllegalArgumentException.class)
public void testTotalBytesSmallerThanBufferSize() {
createBufferPool(4096, 32 * 1024);
@Test
void testTotalBytesSmallerThanBufferSize() {
assertThatThrownBy(() -> createBufferPool(4096, 32 * 1024))
.isInstanceOf(IllegalArgumentException.class);
}

@Test
public void testBufferCalculation() {
void testBufferCalculation() {
long totalBytes = 32 * 1024 * 1024;
for (int bufferSize = 4 * 1024; bufferSize <= totalBytes; bufferSize += 1024) {
BatchShuffleReadBufferPool bufferPool = createBufferPool(totalBytes, bufferSize);

assertEquals(totalBytes, bufferPool.getTotalBytes());
assertEquals(totalBytes / bufferSize, bufferPool.getNumTotalBuffers());
assertTrue(bufferPool.getNumBuffersPerRequest() <= bufferPool.getNumTotalBuffers());
assertTrue(bufferPool.getNumBuffersPerRequest() > 0);
assertThat(bufferPool.getTotalBytes()).isEqualTo(totalBytes);
assertThat(bufferPool.getNumTotalBuffers()).isEqualTo(totalBytes / bufferSize);
assertThat(bufferPool.getNumBuffersPerRequest())
.isLessThanOrEqualTo(bufferPool.getNumTotalBuffers());
assertThat(bufferPool.getNumBuffersPerRequest()).isGreaterThan(0);
}
}

@Test
public void testRequestBuffers() throws Exception {
void testRequestBuffers() throws Exception {
BatchShuffleReadBufferPool bufferPool = createBufferPool();
List<MemorySegment> buffers = new ArrayList<>();

try {
buffers.addAll(bufferPool.requestBuffers());
assertEquals(bufferPool.getNumBuffersPerRequest(), buffers.size());
assertThat(buffers).hasSize(bufferPool.getNumBuffersPerRequest());
} finally {
bufferPool.recycle(buffers);
bufferPool.destroy();
}
}

@Test
public void testRecycle() throws Exception {
void testRecycle() throws Exception {
BatchShuffleReadBufferPool bufferPool = createBufferPool();
List<MemorySegment> buffers = bufferPool.requestBuffers();

bufferPool.recycle(buffers);
assertEquals(bufferPool.getNumTotalBuffers(), bufferPool.getAvailableBuffers());
assertThat(bufferPool.getAvailableBuffers()).isEqualTo(bufferPool.getNumTotalBuffers());
}

@Test
public void testBufferOperationTimestampUpdated() throws Exception {
void testBufferOperationTimestampUpdated() throws Exception {
BatchShuffleReadBufferPool bufferPool = new BatchShuffleReadBufferPool(1024, 1024);
long oldTimestamp = bufferPool.getLastBufferOperationTimestamp();
Thread.sleep(100);
List<MemorySegment> buffers = bufferPool.requestBuffers();
assertEquals(1, buffers.size());
assertThat(buffers).hasSize(1);
// The timestamp is updated when requesting buffers successfully
assertTrue(bufferPool.getLastBufferOperationTimestamp() > oldTimestamp);
assertThat(bufferPool.getLastBufferOperationTimestamp()).isGreaterThan(oldTimestamp);

oldTimestamp = bufferPool.getLastBufferOperationTimestamp();
Thread.sleep(100);
bufferPool.recycle(buffers);
// The timestamp is updated when recycling buffers
assertTrue(bufferPool.getLastBufferOperationTimestamp() > oldTimestamp);
assertThat(bufferPool.getLastBufferOperationTimestamp()).isGreaterThan(oldTimestamp);

buffers = bufferPool.requestBuffers();

oldTimestamp = bufferPool.getLastBufferOperationTimestamp();
Thread.sleep(100);
assertEquals(0, bufferPool.requestBuffers().size());
assertThat(bufferPool.requestBuffers()).isEmpty();
// The timestamp is not updated when requesting buffers is failed
assertEquals(oldTimestamp, bufferPool.getLastBufferOperationTimestamp());
assertThat(bufferPool.getLastBufferOperationTimestamp()).isEqualTo(oldTimestamp);

bufferPool.recycle(buffers);
bufferPool.destroy();
}

@Test
public void testBufferFulfilledByRecycledBuffers() throws Exception {
void testBufferFulfilledByRecycledBuffers() throws Exception {
int numRequestThreads = 2;
AtomicReference<Throwable> exception = new AtomicReference<>();
BatchShuffleReadBufferPool bufferPool = createBufferPool();
Expand All @@ -139,7 +139,7 @@ public void testBufferFulfilledByRecycledBuffers() throws Exception {
owners[i] = new Object();
buffers.put(owners[i], bufferPool.requestBuffers());
}
assertEquals(0, bufferPool.getAvailableBuffers());
assertThat(bufferPool.getAvailableBuffers()).isZero();

Thread[] requestThreads = new Thread[numRequestThreads];
for (int i = 0; i < numRequestThreads; ++i) {
Expand Down Expand Up @@ -172,20 +172,20 @@ public void testBufferFulfilledByRecycledBuffers() throws Exception {
requestThread.join();
}

assertNull(exception.get());
assertEquals(0, bufferPool.getAvailableBuffers());
assertEquals(8, buffers.size());
assertThat(exception.get()).isNull();
assertThat(bufferPool.getAvailableBuffers()).isZero();
assertThat(buffers).hasSize(8);
} finally {
for (Object owner : buffers.keySet()) {
bufferPool.recycle(buffers.remove(owner));
}
assertEquals(bufferPool.getNumTotalBuffers(), bufferPool.getAvailableBuffers());
assertThat(bufferPool.getAvailableBuffers()).isEqualTo(bufferPool.getNumTotalBuffers());
bufferPool.destroy();
}
}

@Test
public void testMultipleThreadRequestAndRecycle() throws Exception {
void testMultipleThreadRequestAndRecycle() throws Exception {
int numRequestThreads = 10;
AtomicReference<Throwable> exception = new AtomicReference<>();
BatchShuffleReadBufferPool bufferPool = createBufferPool();
Expand Down Expand Up @@ -220,52 +220,52 @@ public void testMultipleThreadRequestAndRecycle() throws Exception {
requestThread.join();
}

assertNull(exception.get());
assertEquals(bufferPool.getNumTotalBuffers(), bufferPool.getAvailableBuffers());
assertThat(exception.get()).isNull();
assertThat(bufferPool.getAvailableBuffers()).isEqualTo(bufferPool.getNumTotalBuffers());
} finally {
bufferPool.destroy();
}
}

@Test
public void testDestroy() throws Exception {
void testDestroy() throws Exception {
BatchShuffleReadBufferPool bufferPool = createBufferPool();
List<MemorySegment> buffers = bufferPool.requestBuffers();
bufferPool.recycle(buffers);

assertFalse(bufferPool.isDestroyed());
assertEquals(bufferPool.getNumTotalBuffers(), bufferPool.getAvailableBuffers());
assertThat(bufferPool.isDestroyed()).isFalse();
assertThat(bufferPool.getAvailableBuffers()).isEqualTo(bufferPool.getNumTotalBuffers());

buffers = bufferPool.requestBuffers();
assertEquals(
bufferPool.getNumTotalBuffers() - buffers.size(), bufferPool.getAvailableBuffers());
assertThat(bufferPool.getAvailableBuffers())
.isEqualTo(bufferPool.getNumTotalBuffers() - buffers.size());

bufferPool.destroy();
assertTrue(bufferPool.isDestroyed());
assertEquals(0, bufferPool.getAvailableBuffers());
assertThat(bufferPool.isDestroyed()).isTrue();
assertThat(bufferPool.getAvailableBuffers()).isZero();
}

@Test(expected = IllegalStateException.class)
public void testRequestBuffersAfterDestroyed() throws Exception {
@Test
void testRequestBuffersAfterDestroyed() throws Exception {
BatchShuffleReadBufferPool bufferPool = createBufferPool();
bufferPool.requestBuffers();

bufferPool.destroy();
bufferPool.requestBuffers();
assertThatThrownBy(bufferPool::requestBuffers).isInstanceOf(IllegalStateException.class);
}

@Test
public void testRecycleAfterDestroyed() throws Exception {
void testRecycleAfterDestroyed() throws Exception {
BatchShuffleReadBufferPool bufferPool = createBufferPool();
List<MemorySegment> buffers = bufferPool.requestBuffers();
bufferPool.destroy();

bufferPool.recycle(buffers);
assertEquals(0, bufferPool.getAvailableBuffers());
assertThat(bufferPool.getAvailableBuffers()).isZero();
}

@Test
public void testDestroyWhileBlockingRequest() throws Exception {
void testDestroyWhileBlockingRequest() throws Exception {
AtomicReference<Throwable> exception = new AtomicReference<>();
BatchShuffleReadBufferPool bufferPool = createBufferPool();

Expand All @@ -286,7 +286,7 @@ public void testDestroyWhileBlockingRequest() throws Exception {
bufferPool.destroy();
requestThread.join();

assertTrue(exception.get() instanceof IllegalStateException);
assertThat(exception.get()).isInstanceOf(IllegalStateException.class);
}

private BatchShuffleReadBufferPool createBufferPool(long totalBytes, int bufferSize) {
Expand Down
Loading