diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdClientChannelHealthChecker.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdClientChannelHealthChecker.java index b593de39ed189..6939eebe09da7 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdClientChannelHealthChecker.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdClientChannelHealthChecker.java @@ -462,6 +462,7 @@ public void channelPingCompleted() { public void channelReadCompleted() { lastReadUpdater.set(this, Instant.now()); + this.resetTransitTimeout(); // we have got a successful read, so reset the transitTimeout count. } public void channelWriteAttempted() { diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdContextNegotiator.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdContextNegotiator.java index b6818a1e85cdc..ef629979a21b0 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdContextNegotiator.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdContextNegotiator.java @@ -77,10 +77,12 @@ private void startRntbdContextRequest(final ChannelHandlerContext context) throw final RntbdContextRequest request = new RntbdContextRequest(Utils.randomUUID(), this.userAgent); final CompletableFuture contextRequestFuture = this.manager.rntbdContextRequestFuture(); + this.manager.getTimestamps().channelWriteAttempted(); super.write(context, request, channel.newPromise().addListener((ChannelFutureListener)future -> { if (future.isSuccess()) { contextRequestFuture.complete(request); + this.manager.getTimestamps().channelWriteCompleted(); return; } diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdRequestManager.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdRequestManager.java index 7b8b7baea6cd7..1639da82fb013 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdRequestManager.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdRequestManager.java @@ -187,7 +187,7 @@ public void channelInactive(final ChannelHandlerContext context) { public void channelRead(final ChannelHandlerContext context, final Object message) { this.traceOperation(context, "channelRead"); - this.timestamps.resetTransitTimeout(); // we have got a successful read, so reset the transitTimeout count. + this.timestamps.channelReadCompleted(); try { if (message.getClass() == RntbdResponse.class) { @@ -231,7 +231,6 @@ public void channelRead(final ChannelHandlerContext context, final Object messag @Override public void channelReadComplete(final ChannelHandlerContext context) { this.traceOperation(context, "channelReadComplete"); - this.timestamps.channelReadCompleted(); context.fireChannelReadComplete(); } @@ -378,6 +377,10 @@ public void userEventTriggered(final ChannelHandlerContext context, final Object if (event instanceof RntbdContext) { this.contextFuture.complete((RntbdContext) event); this.removeContextNegotiatorAndFlushPendingWrites(context); + + // Important: currently the RntbdContext negotiation response will not be captured during channelRead + // need to mark the timestamp here + this.timestamps.channelReadCompleted(); return; } if (event instanceof RntbdContextException) { @@ -578,13 +581,12 @@ public void write(final ChannelHandlerContext context, final Object message, fin if (message instanceof RntbdRequestRecord) { final RntbdRequestRecord record = (RntbdRequestRecord) message; - - this.timestamps.channelWriteAttempted(); record.setTimestamps(this.timestamps); - record.setSendingRequestHasStarted(); - if (!record.isCancelled()) { + record.setSendingRequestHasStarted(); + this.timestamps.channelWriteAttempted(); + context.write(this.addPendingRequestRecord(context, record), promise).addListener(completed -> { record.stage(RntbdRequestRecord.Stage.SENT); if (completed.isSuccess()) { @@ -670,6 +672,10 @@ void pendWrite(final ByteBuf out, final ChannelPromise promise) { this.pendingWrites.add(out, promise); } + public Timestamps getTimestamps() { + return this.timestamps; + } + Timestamps snapshotTimestamps() { return new Timestamps(this.timestamps); } diff --git a/sdk/cosmos/azure-cosmos/src/test/java/com/azure/cosmos/implementation/directconnectivity/ReflectionUtils.java b/sdk/cosmos/azure-cosmos/src/test/java/com/azure/cosmos/implementation/directconnectivity/ReflectionUtils.java index f55be7bb79708..37d69232973ca 100644 --- a/sdk/cosmos/azure-cosmos/src/test/java/com/azure/cosmos/implementation/directconnectivity/ReflectionUtils.java +++ b/sdk/cosmos/azure-cosmos/src/test/java/com/azure/cosmos/implementation/directconnectivity/ReflectionUtils.java @@ -416,8 +416,4 @@ public static AtomicReference getHealthStatus(Uri uri) { public static Set getReplicaValidationScopes(GatewayAddressCache gatewayAddressCache) { return get(Set.class, gatewayAddressCache, "replicaValidationScopes"); } - - public static RntbdClientChannelHealthChecker.Timestamps getTimestamps(RntbdRequestManager rntbdRequestManager) { - return get(RntbdClientChannelHealthChecker.Timestamps.class, rntbdRequestManager, "timestamps"); - } } diff --git a/sdk/cosmos/azure-cosmos/src/test/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdRequestManagerTests.java b/sdk/cosmos/azure-cosmos/src/test/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdRequestManagerTests.java index fc722620d0e80..114e8ddd8b198 100644 --- a/sdk/cosmos/azure-cosmos/src/test/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdRequestManagerTests.java +++ b/sdk/cosmos/azure-cosmos/src/test/java/com/azure/cosmos/implementation/directconnectivity/rntbd/RntbdRequestManagerTests.java @@ -7,12 +7,16 @@ import com.azure.cosmos.implementation.OperationType; import com.azure.cosmos.implementation.ResourceType; import com.azure.cosmos.implementation.RxDocumentServiceRequest; -import com.azure.cosmos.implementation.directconnectivity.ReflectionUtils; import com.azure.cosmos.implementation.directconnectivity.RntbdTransportClient; import com.azure.cosmos.implementation.directconnectivity.Uri; +import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPipeline; import io.netty.channel.ChannelPromise; +import io.netty.channel.DefaultChannelPromise; +import io.netty.channel.DefaultEventLoop; +import io.netty.channel.SingleThreadEventLoop; import io.netty.handler.logging.LogLevel; import io.netty.handler.ssl.SslContext; import org.mockito.Mockito; @@ -21,13 +25,15 @@ import java.net.URI; import java.net.URISyntaxException; import java.time.Duration; +import java.time.Instant; import static com.azure.cosmos.implementation.TestUtils.mockDiagnosticsClientContext; import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.mockito.Mockito.doNothing; public class RntbdRequestManagerTests { - @Test + @Test(groups = { "unit" }) public void transitTimeoutTimestampTests() throws URISyntaxException { SslContext sslContextMock = Mockito.mock(SslContext.class); RntbdEndpoint.Config config = new RntbdEndpoint.Config( @@ -43,12 +49,18 @@ public void transitTimeoutTimestampTests() throws URISyntaxException { 30, connectionStateListener, Duration.ofSeconds(1).toNanos()); - RntbdClientChannelHealthChecker.Timestamps timestamps = ReflectionUtils.getTimestamps(rntbdRequestManager); + RntbdClientChannelHealthChecker.Timestamps timestamps = rntbdRequestManager.getTimestamps(); ChannelHandlerContext channelHandlerContext = Mockito.mock(ChannelHandlerContext.class); - ChannelFuture channelFuture = Mockito.mock(ChannelFuture.class); - Mockito.when(channelHandlerContext.write(Mockito.any(), Mockito.any())).thenReturn(channelFuture); - Mockito.when(channelFuture.addListener(Mockito.any())).thenReturn(channelFuture); + + Channel channelMock = Mockito.mock(Channel.class); + SingleThreadEventLoop eventLoopMock = new DefaultEventLoop(); + Mockito.when(channelMock.eventLoop()).thenReturn(eventLoopMock); + + ChannelPromise defaultChannelPromise = new DefaultChannelPromise(channelMock); + defaultChannelPromise.setSuccess(); + + Mockito.when(channelHandlerContext.write(Mockito.any(), Mockito.any())).thenReturn(defaultChannelPromise); RntbdRequestArgs requestArgs = new RntbdRequestArgs( RxDocumentServiceRequest.create(mockDiagnosticsClientContext(), OperationType.Read, ResourceType.Document), @@ -61,18 +73,79 @@ public void transitTimeoutTimestampTests() throws URISyntaxException { ChannelPromise promise = Mockito.mock(ChannelPromise.class); // Test transitTimeout is 0 at start point + Instant previousLastWriteTime = timestamps.lastChannelWriteTime(); + Instant previousLastWriteAttemptTime = timestamps.lastChannelWriteAttemptTime(); + Instant previousLastChannelReadTime = timestamps.lastChannelReadTime(); + rntbdRequestManager.write(channelHandlerContext, rntbdRequestRecord, promise); + assertThat(timestamps.transitTimeoutCount()).isZero(); + assertThat(timestamps.transitTimeoutStartingTime()).isNull(); + assertThat(timestamps.tansitTimeoutWriteCount()).isZero(); + assertThat(timestamps.lastChannelWriteTime()).isAfterOrEqualTo(previousLastWriteTime); + assertThat(timestamps.lastChannelWriteAttemptTime()).isAfterOrEqualTo(previousLastWriteAttemptTime); + assertThat(timestamps.lastChannelReadTime()).isEqualTo(previousLastChannelReadTime); // Test when a transit timeout happens, the transitTimeoutCount is increased rntbdRequestRecord.expire(); assertThat(timestamps.transitTimeoutCount()).isOne(); + assertThat(timestamps.tansitTimeoutWriteCount()).isZero(); + assertThat(timestamps.transitTimeoutStartingTime()).isNotNull(); + assertThat(Duration.between(timestamps.transitTimeoutStartingTime(), Instant.now())).isLessThan(Duration.ofSeconds(5)); // Test when there is channelRead, transitTimeout is cleared out + previousLastWriteTime = timestamps.lastChannelWriteTime(); + previousLastWriteAttemptTime = timestamps.lastChannelWriteAttemptTime(); + previousLastChannelReadTime = timestamps.lastChannelReadTime(); + Mockito.when(channelHandlerContext.flush()).thenReturn(channelHandlerContext); + ChannelFuture closeChannelFuture = Mockito.mock(ChannelFuture.class); Mockito.when(channelHandlerContext.close()).thenReturn(closeChannelFuture); rntbdRequestManager.channelRead(channelHandlerContext, rntbdRequestRecord); assertThat(timestamps.transitTimeoutCount()).isZero(); + assertThat(timestamps.transitTimeoutStartingTime()).isNull(); + assertThat(timestamps.lastChannelReadTime()).isAfterOrEqualTo(previousLastChannelReadTime); + assertThat(timestamps.lastChannelWriteAttemptTime()).isEqualTo(previousLastWriteAttemptTime); + assertThat(timestamps.lastChannelWriteTime()).isAfterOrEqualTo(previousLastWriteTime); + } + + @Test(groups = { "unit" }) + public void rntbdContextResponseTests() { + // Test when getting rntbdContext response, the lastReadTimestamp should be marked + SslContext sslContextMock = Mockito.mock(SslContext.class); + RntbdEndpoint.Config config = new RntbdEndpoint.Config( + new RntbdTransportClient.Options.Builder(ConnectionPolicy.getDefaultPolicy()).build(), + sslContextMock, + LogLevel.INFO); + RntbdClientChannelHealthChecker healthChecker = new RntbdClientChannelHealthChecker(config); + + RntbdConnectionStateListener connectionStateListener = Mockito.mock(RntbdConnectionStateListener.class); + + RntbdRequestManager rntbdRequestManager = new RntbdRequestManager( + healthChecker, + 30, + connectionStateListener, + Duration.ofSeconds(1).toNanos()); + + RntbdClientChannelHealthChecker.Timestamps timestamps = rntbdRequestManager.getTimestamps(); + + ChannelHandlerContext channelHandlerContextMock = Mockito.mock(ChannelHandlerContext.class); + ChannelPipeline channelPipelineMock = Mockito.mock(ChannelPipeline.class); + Mockito.when(channelPipelineMock.fireChannelRegistered()).thenReturn(channelPipelineMock); + Mockito.when(channelHandlerContextMock.channel()).thenReturn(Mockito.mock(Channel.class)); + + RntbdContextNegotiator rntbdContextNegotiatorMock = Mockito.mock(RntbdContextNegotiator.class); + doNothing().when(rntbdContextNegotiatorMock).removeInboundHandler(); + doNothing().when(rntbdContextNegotiatorMock).removeOutboundHandler(); + Mockito.when(channelPipelineMock.get(RntbdContextNegotiator.class)).thenReturn(rntbdContextNegotiatorMock); + Mockito.when(channelHandlerContextMock.pipeline()).thenReturn(channelPipelineMock); + + rntbdRequestManager.channelRegistered(channelHandlerContextMock); + + Instant lastReadTimestamp = timestamps.lastChannelReadTime(); + RntbdContext rntbdContextMock = Mockito.mock(RntbdContext.class); + rntbdRequestManager.userEventTriggered(channelHandlerContextMock, rntbdContextMock); + assertThat(timestamps.lastChannelReadTime()).isAfterOrEqualTo(lastReadTimestamp); } }