diff --git a/core/pom.xml b/core/pom.xml index c996d639407..8e08118a57d 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -128,6 +128,17 @@ + + maven-jar-plugin + + + test-jar + + test-jar + + + + maven-shade-plugin @@ -147,6 +158,7 @@ com.datastax.oss.driver.shaded.guava + true diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/util/concurrent/CompletableFutures.java b/core/src/main/java/com/datastax/oss/driver/internal/core/util/concurrent/CompletableFutures.java index 8a1d9c89c0e..c37c56c9213 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/util/concurrent/CompletableFutures.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/util/concurrent/CompletableFutures.java @@ -75,7 +75,8 @@ public static void whenAllDone( } /** Get the result now, when we know for sure that the future is complete. */ - public static T getCompleted(CompletableFuture future) { + public static T getCompleted(CompletionStage stage) { + CompletableFuture future = stage.toCompletableFuture(); Preconditions.checkArgument(future.isDone() && !future.isCompletedExceptionally()); try { return future.get(); @@ -86,7 +87,8 @@ public static T getCompleted(CompletableFuture future) { } /** Get the error now, when we know for sure that the future is failed. */ - public static Throwable getFailed(CompletableFuture future) { + public static Throwable getFailed(CompletionStage stage) { + CompletableFuture future = stage.toCompletableFuture(); Preconditions.checkArgument(future.isCompletedExceptionally()); try { future.get(); diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/CompletionStageAssert.java b/core/src/test/java/com/datastax/oss/driver/internal/core/CompletionStageAssert.java index aaf8049e909..0484386cd8f 100644 --- a/core/src/test/java/com/datastax/oss/driver/internal/core/CompletionStageAssert.java +++ b/core/src/test/java/com/datastax/oss/driver/internal/core/CompletionStageAssert.java @@ -18,6 +18,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.fail; +import java.util.concurrent.CancellationException; import java.util.concurrent.CompletionStage; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; @@ -66,8 +67,45 @@ public CompletionStageAssert isFailed() { return isFailed(f -> {}); } + public CompletionStageAssert isCancelled() { + boolean cancelled = false; + try { + actual.toCompletableFuture().get(2, TimeUnit.SECONDS); + } catch (CancellationException e) { + cancelled = true; + } catch (Exception ignored) { + } + if (!cancelled) { + fail("Expected completion stage to be cancelled"); + } + return this; + } + + public CompletionStageAssert isNotCancelled() { + boolean cancelled = false; + try { + actual.toCompletableFuture().get(2, TimeUnit.SECONDS); + } catch (CancellationException e) { + cancelled = true; + } catch (Exception ignored) { + } + if (cancelled) { + fail("Expected completion stage not to be cancelled"); + } + return this; + } + + public CompletionStageAssert isDone() { + assertThat(actual.toCompletableFuture().isDone()) + .overridingErrorMessage("Expected completion stage to be done") + .isTrue(); + return this; + } + public CompletionStageAssert isNotDone() { - assertThat(actual.toCompletableFuture().isDone()).isFalse(); + assertThat(actual.toCompletableFuture().isDone()) + .overridingErrorMessage("Expected completion stage not to be done") + .isFalse(); return this; } } diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/cql/PoolBehavior.java b/core/src/test/java/com/datastax/oss/driver/internal/core/cql/PoolBehavior.java index 9cf17f0c9ad..32a628028ab 100644 --- a/core/src/test/java/com/datastax/oss/driver/internal/core/cql/PoolBehavior.java +++ b/core/src/test/java/com/datastax/oss/driver/internal/core/cql/PoolBehavior.java @@ -25,7 +25,10 @@ import com.datastax.oss.driver.internal.core.channel.ResponseCallback; import com.datastax.oss.protocol.internal.Frame; import com.datastax.oss.protocol.internal.Message; +import io.netty.channel.ChannelConfig; import io.netty.channel.ChannelFuture; +import io.netty.channel.EventLoop; +import io.netty.channel.socket.DefaultSocketChannelConfig; import io.netty.util.concurrent.GlobalEventExecutor; import io.netty.util.concurrent.Promise; import java.util.concurrent.CompletableFuture; @@ -52,17 +55,25 @@ public PoolBehavior(Node node, boolean createChannel) { this.writePromise = null; } else { this.channel = Mockito.mock(DriverChannel.class); + EventLoop eventLoop = Mockito.mock(EventLoop.class); + ChannelConfig config = Mockito.mock(DefaultSocketChannelConfig.class); this.writePromise = GlobalEventExecutor.INSTANCE.newPromise(); Mockito.when( channel.write( any(Message.class), anyBoolean(), anyMap(), any(ResponseCallback.class))) .thenAnswer( invocation -> { - callbackFuture.complete(invocation.getArgument(3)); + ResponseCallback callback = invocation.getArgument(3); + if (callback.holdStreamId()) { + callback.onStreamIdAssigned(1); + } + callbackFuture.complete(callback); return writePromise; }); ChannelFuture closeFuture = Mockito.mock(ChannelFuture.class); Mockito.when(channel.closeFuture()).thenReturn(closeFuture); + Mockito.when(channel.eventLoop()).thenReturn(eventLoop); + Mockito.when(channel.config()).thenReturn(config); } } @@ -92,6 +103,14 @@ public void setResponseFailure(Throwable cause) { callbackFuture.thenAccept(callback -> callback.onFailure(cause)); } + public Node getNode() { + return node; + } + + public DriverChannel getChannel() { + return channel; + } + /** Mocks a follow-up request on the same channel. */ public void mockFollowupRequest(Class expectedMessage, Frame responseFrame) { Promise writePromise2 = GlobalEventExecutor.INSTANCE.newPromise(); diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/cql/RequestHandlerTestHarness.java b/core/src/test/java/com/datastax/oss/driver/internal/core/cql/RequestHandlerTestHarness.java index 43b8fb4aec6..9b6c89aa9f4 100644 --- a/core/src/test/java/com/datastax/oss/driver/internal/core/cql/RequestHandlerTestHarness.java +++ b/core/src/test/java/com/datastax/oss/driver/internal/core/cql/RequestHandlerTestHarness.java @@ -37,7 +37,9 @@ import com.datastax.oss.driver.internal.core.channel.DriverChannel; import com.datastax.oss.driver.internal.core.context.InternalDriverContext; import com.datastax.oss.driver.internal.core.context.NettyOptions; +import com.datastax.oss.driver.internal.core.metadata.DefaultMetadata; import com.datastax.oss.driver.internal.core.metadata.LoadBalancingPolicyWrapper; +import com.datastax.oss.driver.internal.core.metrics.SessionMetricUpdater; import com.datastax.oss.driver.internal.core.pool.ChannelPool; import com.datastax.oss.driver.internal.core.servererrors.DefaultWriteTypeRegistry; import com.datastax.oss.driver.internal.core.session.DefaultSession; @@ -70,6 +72,7 @@ public static Builder builder() { } private final ScheduledTaskCapturingEventLoop schedulingEventLoop; + private final Map pools; @Mock private InternalDriverContext context; @Mock private DefaultSession session; @@ -82,6 +85,7 @@ public static Builder builder() { @Mock private SpeculativeExecutionPolicy speculativeExecutionPolicy; @Mock private TimestampGenerator timestampGenerator; @Mock private ProtocolVersionRegistry protocolVersionRegistry; + @Mock private SessionMetricUpdater sessionMetricUpdater; private RequestHandlerTestHarness(Builder builder) { MockitoAnnotations.initMocks(this); @@ -126,7 +130,7 @@ private RequestHandlerTestHarness(Builder builder) { Mockito.when(timestampGenerator.next()).thenReturn(Long.MIN_VALUE); Mockito.when(context.timestampGenerator()).thenReturn(timestampGenerator); - Map pools = builder.buildMockPools(); + pools = builder.buildMockPools(); Mockito.when(session.getChannel(any(Node.class), anyString())) .thenAnswer( invocation -> { @@ -138,12 +142,20 @@ private RequestHandlerTestHarness(Builder builder) { Mockito.when(session.setKeyspace(any(CqlIdentifier.class))) .thenReturn(CompletableFuture.completedFuture(null)); + Mockito.when(session.getMetricUpdater()).thenReturn(sessionMetricUpdater); + + Mockito.when(session.getMetadata()).thenReturn(DefaultMetadata.EMPTY); + Mockito.when(context.protocolVersionRegistry()).thenReturn(protocolVersionRegistry); Mockito.when( protocolVersionRegistry.supports( any(ProtocolVersion.class), any(ProtocolFeature.class))) .thenReturn(true); + if (builder.protocolVersion != null) { + Mockito.when(context.protocolVersion()).thenReturn(builder.protocolVersion); + } + Mockito.when(context.consistencyLevelRegistry()) .thenReturn(new DefaultConsistencyLevelRegistry()); @@ -158,6 +170,11 @@ public InternalDriverContext getContext() { return context; } + public DriverChannel getChannel(Node node) { + ChannelPool pool = pools.get(node); + return pool.next(); + } + /** * Returns the next task that was scheduled on the request handler's admin executor. The test must * run it manually. @@ -174,6 +191,7 @@ public void close() { public static class Builder { private final List poolBehaviors = new ArrayList<>(); private boolean defaultIdempotence; + private ProtocolVersion protocolVersion; /** * Sets the given node as the next one in the query plan; an empty pool will be simulated when @@ -224,6 +242,11 @@ public Builder withDefaultIdempotence(boolean defaultIdempotence) { return this; } + public Builder withProtocolVersion(ProtocolVersion protocolVersion) { + this.protocolVersion = protocolVersion; + return this; + } + /** * Sets the given node as the next one in the query plan; the test code is responsible of * calling the methods on the returned object to complete the write and the query.