Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,16 @@

import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.RestOptions;
import org.apache.flink.core.testutils.FlinkAssertions;
import org.apache.flink.core.testutils.OneShotLatch;
import org.apache.flink.runtime.rest.messages.EmptyMessageParameters;
import org.apache.flink.runtime.rest.messages.EmptyRequestBody;
import org.apache.flink.runtime.rest.messages.EmptyResponseBody;
import org.apache.flink.runtime.rest.messages.RuntimeMessageHeaders;
import org.apache.flink.runtime.rest.versioning.RuntimeRestAPIVersion;
import org.apache.flink.testutils.TestingUtils;
import org.apache.flink.testutils.executor.TestExecutorResource;
import org.apache.flink.util.ExceptionUtils;
import org.apache.flink.testutils.executor.TestExecutorExtension;
import org.apache.flink.util.NetUtils;
import org.apache.flink.util.TestLogger;
import org.apache.flink.util.concurrent.Executors;
import org.apache.flink.util.function.CheckedSupplier;

Expand All @@ -41,10 +40,9 @@
import org.apache.flink.shaded.netty4.io.netty.channel.SelectStrategyFactory;
import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpResponseStatus;

import org.junit.Assert;
import org.junit.ClassRule;
import org.junit.Test;
import org.junit.function.ThrowingRunnable;
import org.assertj.core.api.InstanceOfAssertFactories;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import java.io.IOException;
import java.net.ServerSocket;
Expand All @@ -56,26 +54,24 @@
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertThrows;
import static org.assertj.core.api.Assertions.as;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.assertj.core.api.InstanceOfAssertFactories.THROWABLE;

/** Tests for {@link RestClient}. */
public class RestClientTest extends TestLogger {
@ClassRule
public static final TestExecutorResource<ScheduledExecutorService> EXECUTOR_RESOURCE =
TestingUtils.defaultExecutorResource();
class RestClientTest {

@RegisterExtension
static final TestExecutorExtension<ScheduledExecutorService> EXECUTOR_EXTENSION =
TestingUtils.defaultExecutorExtension();

private static final String unroutableIp = "240.0.0.0";

private static final long TIMEOUT = 10L;

@Test
public void testConnectionTimeout() throws Exception {
void testConnectionTimeout() throws Exception {
final Configuration config = new Configuration();
config.setLong(RestOptions.CONNECTION_TIMEOUT, 1);
try (final RestClient restClient = new RestClient(config, Executors.directExecutor())) {
Expand All @@ -87,28 +83,30 @@ public void testConnectionTimeout() throws Exception {
EmptyMessageParameters.getInstance(),
EmptyRequestBody.getInstance());

final Throwable cause = assertThrows(ExecutionException.class, future::get).getCause();
assertThat(cause, instanceOf(ConnectTimeoutException.class));
assertThat(cause.getMessage(), containsString(unroutableIp));
FlinkAssertions.assertThatFuture(future)
.eventuallyFailsWith(ExecutionException.class)
.withCauseInstanceOf(ConnectTimeoutException.class)
.extracting(Throwable::getCause, as(InstanceOfAssertFactories.THROWABLE))
.hasMessageContaining(unroutableIp);
}
}

@Test
public void testInvalidVersionRejection() throws Exception {
try (final RestClient restClient =
new RestClient(new Configuration(), Executors.directExecutor())) {
CompletableFuture<EmptyResponseBody> invalidVersionResponse =
restClient.sendRequest(
unroutableIp,
80,
new TestMessageHeaders(),
EmptyMessageParameters.getInstance(),
EmptyRequestBody.getInstance(),
Collections.emptyList(),
RuntimeRestAPIVersion.V0);
Assert.fail("The request should have been rejected due to a version mismatch.");
} catch (IllegalArgumentException e) {
// expected
assertThatThrownBy(
() ->
restClient.sendRequest(
unroutableIp,
80,
new TestMessageHeaders(),
EmptyMessageParameters.getInstance(),
EmptyRequestBody.getInstance(),
Collections.emptyList(),
RuntimeRestAPIVersion.V0))
.as("The request should have been rejected due to a version mismatch.")
.isInstanceOf(IllegalArgumentException.class);
}
}

Expand All @@ -119,7 +117,7 @@ public void testConnectionClosedHandling() throws Exception {
config.setLong(RestOptions.IDLENESS_TIMEOUT, 5000L);
try (final ServerSocket serverSocket = new ServerSocket(0);
final RestClient restClient =
new RestClient(config, EXECUTOR_RESOURCE.getExecutor())) {
new RestClient(config, EXECUTOR_EXTENSION.getExecutor())) {

final String targetAddress = "localhost";
final int targetPort = serverSocket.getLocalPort();
Expand Down Expand Up @@ -153,13 +151,9 @@ public void testConnectionClosedHandling() throws Exception {
connectionSocket.close();
}

try {
responseFuture.get();
} catch (ExecutionException ee) {
if (!ExceptionUtils.findThrowable(ee, IOException.class).isPresent()) {
throw ee;
}
}
FlinkAssertions.assertThatFuture(responseFuture)
.eventuallyFailsWith(ExecutionException.class)
.withCauseInstanceOf(IOException.class);
}
}

Expand All @@ -173,7 +167,7 @@ public void testRestClientClosedHandling() throws Exception {

try (final ServerSocket serverSocket = new ServerSocket(0);
final RestClient restClient =
new RestClient(config, EXECUTOR_RESOURCE.getExecutor())) {
new RestClient(config, EXECUTOR_EXTENSION.getExecutor())) {

final String targetAddress = "localhost";
final int targetPort = serverSocket.getLocalPort();
Expand Down Expand Up @@ -202,13 +196,9 @@ public void testRestClientClosedHandling() throws Exception {

restClient.close();

try {
responseFuture.get();
} catch (ExecutionException ee) {
if (!ExceptionUtils.findThrowable(ee, IOException.class).isPresent()) {
throw ee;
}
}
FlinkAssertions.assertThatFuture(responseFuture)
.eventuallyFailsWith(ExecutionException.class)
.withCauseInstanceOf(IOException.class);
} finally {
if (connectionSocket != null) {
connectionSocket.close();
Expand Down Expand Up @@ -236,14 +226,11 @@ public void testCloseClientBeforeRequest() throws Exception {
EmptyMessageParameters.getInstance(),
EmptyRequestBody.getInstance());

// Call get() on the future with a timeout of 0s so we can test that the exception
// thrown is not a TimeoutException, which is what would be thrown if restClient were
// not already closed
final ThrowingRunnable getFuture = () -> future.get(0, TimeUnit.SECONDS);

final Throwable cause = assertThrows(ExecutionException.class, getFuture).getCause();
assertThat(cause, instanceOf(IllegalStateException.class));
assertThat(cause.getMessage(), equalTo("RestClient is already closed"));
FlinkAssertions.assertThatFuture(future)
.eventuallyFailsWith(ExecutionException.class)
.withCauseInstanceOf(IllegalStateException.class)
.extracting(Throwable::getCause, as(THROWABLE))
.hasMessage("RestClient is already closed");
}
}

Expand All @@ -270,7 +257,7 @@ public void testCloseClientWhileProcessingRequest() throws Exception {
new Configuration(), Executors.directExecutor(), selectStrategyFactory)) {
// Check that client's internal collection of pending response futures is empty prior to
// the request
assertThat(restClient.getResponseChannelFutures(), empty());
assertThat(restClient.getResponseChannelFutures()).isEmpty();

final CompletableFuture<?> requestFuture =
restClient.sendRequest(
Expand All @@ -282,24 +269,23 @@ public void testCloseClientWhileProcessingRequest() throws Exception {

// Check that client's internal collection of pending response futures now has one
// entry, presumably due to the call to sendRequest
assertThat(restClient.getResponseChannelFutures(), hasSize(1));
assertThat(restClient.getResponseChannelFutures()).hasSize(1);

// Wait for Netty to start connecting, then while it's paused in the SelectStrategy,
// close the client before unpausing Netty
connectTriggered.await();
final CompletableFuture<Void> closeFuture = restClient.closeAsync();
closeTriggered.trigger();

// Close should complete successfully
closeFuture.get();
FlinkAssertions.assertThatFuture(closeFuture)
.as("Close should have had completed.")
.eventuallySucceeds();

final Throwable cause =
assertThrows(
ExecutionException.class,
() -> requestFuture.get(0, TimeUnit.SECONDS))
.getCause();
assertThat(cause, instanceOf(IllegalStateException.class));
assertThat(cause.getMessage(), equalTo("executor not accepting a task"));
FlinkAssertions.assertThatFuture(requestFuture)
.eventuallyFailsWith(ExecutionException.class)
.withCauseInstanceOf(IllegalStateException.class)
.extracting(Throwable::getCause, as(THROWABLE))
.hasMessage("executor not accepting a task");
}
}

Expand All @@ -318,15 +304,13 @@ public void testResponseChannelFuturesResolvedExceptionallyOnClose() throws Exce

// Ensure the client's internal collection of pending response futures was cleared after
// close
assertThat(restClient.getResponseChannelFutures(), empty());

final Throwable cause =
assertThrows(
ExecutionException.class,
() -> responseChannelFuture.get(0, TimeUnit.SECONDS))
.getCause();
assertThat(cause, instanceOf(IllegalStateException.class));
assertThat(cause.getMessage(), equalTo("RestClient closed before request completed"));
assertThat(restClient.getResponseChannelFutures()).isEmpty();

FlinkAssertions.assertThatFuture(responseChannelFuture)
.eventuallyFailsWith(ExecutionException.class)
.withCauseInstanceOf(IllegalStateException.class)
.extracting(Throwable::getCause, as(THROWABLE))
.hasMessage("RestClient closed before request completed");
}
}

Expand Down
Loading