Skip to content

Commit

Permalink
RntbdHealthCheckImprovement-3 (#33566)
Browse files Browse the repository at this point in the history
* update

* isAfter -isAfterOrEqualTo

* Update RntbdRequestManagerTests.java

---------

Co-authored-by: annie-mac <annie-mac@annie-macs-MacBook-Pro.local>
Co-authored-by: Fabian Meiswinkel <fabian@meiswinkel.com>
  • Loading branch information
3 people committed Feb 17, 2023
1 parent 3fabf97 commit 329beb9
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,7 @@ public void channelPingCompleted() {

public void channelReadCompleted() {
lastReadUpdater.set(this, Instant.now());
this.resetTransitTimeout(); // we have got a successful read, so reset the transitTimeout count.
}

public void channelWriteAttempted() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,12 @@ private void startRntbdContextRequest(final ChannelHandlerContext context) throw
final RntbdContextRequest request = new RntbdContextRequest(Utils.randomUUID(), this.userAgent);
final CompletableFuture<RntbdContextRequest> contextRequestFuture = this.manager.rntbdContextRequestFuture();

this.manager.getTimestamps().channelWriteAttempted();
super.write(context, request, channel.newPromise().addListener((ChannelFutureListener)future -> {

if (future.isSuccess()) {
contextRequestFuture.complete(request);
this.manager.getTimestamps().channelWriteCompleted();
return;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ public void channelInactive(final ChannelHandlerContext context) {
public void channelRead(final ChannelHandlerContext context, final Object message) {

this.traceOperation(context, "channelRead");
this.timestamps.resetTransitTimeout(); // we have got a successful read, so reset the transitTimeout count.
this.timestamps.channelReadCompleted();

try {
if (message.getClass() == RntbdResponse.class) {
Expand Down Expand Up @@ -231,7 +231,6 @@ public void channelRead(final ChannelHandlerContext context, final Object messag
@Override
public void channelReadComplete(final ChannelHandlerContext context) {
this.traceOperation(context, "channelReadComplete");
this.timestamps.channelReadCompleted();
context.fireChannelReadComplete();
}

Expand Down Expand Up @@ -378,6 +377,10 @@ public void userEventTriggered(final ChannelHandlerContext context, final Object
if (event instanceof RntbdContext) {
this.contextFuture.complete((RntbdContext) event);
this.removeContextNegotiatorAndFlushPendingWrites(context);

// Important: currently the RntbdContext negotiation response will not be captured during channelRead
// need to mark the timestamp here
this.timestamps.channelReadCompleted();
return;
}
if (event instanceof RntbdContextException) {
Expand Down Expand Up @@ -578,13 +581,12 @@ public void write(final ChannelHandlerContext context, final Object message, fin
if (message instanceof RntbdRequestRecord) {

final RntbdRequestRecord record = (RntbdRequestRecord) message;

this.timestamps.channelWriteAttempted();
record.setTimestamps(this.timestamps);

record.setSendingRequestHasStarted();

if (!record.isCancelled()) {
record.setSendingRequestHasStarted();
this.timestamps.channelWriteAttempted();

context.write(this.addPendingRequestRecord(context, record), promise).addListener(completed -> {
record.stage(RntbdRequestRecord.Stage.SENT);
if (completed.isSuccess()) {
Expand Down Expand Up @@ -670,6 +672,10 @@ void pendWrite(final ByteBuf out, final ChannelPromise promise) {
this.pendingWrites.add(out, promise);
}

public Timestamps getTimestamps() {
return this.timestamps;
}

Timestamps snapshotTimestamps() {
return new Timestamps(this.timestamps);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -416,8 +416,4 @@ public static AtomicReference<Uri.HealthStatus> getHealthStatus(Uri uri) {
public static Set<Uri.HealthStatus> getReplicaValidationScopes(GatewayAddressCache gatewayAddressCache) {
return get(Set.class, gatewayAddressCache, "replicaValidationScopes");
}

public static RntbdClientChannelHealthChecker.Timestamps getTimestamps(RntbdRequestManager rntbdRequestManager) {
return get(RntbdClientChannelHealthChecker.Timestamps.class, rntbdRequestManager, "timestamps");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,16 @@
import com.azure.cosmos.implementation.OperationType;
import com.azure.cosmos.implementation.ResourceType;
import com.azure.cosmos.implementation.RxDocumentServiceRequest;
import com.azure.cosmos.implementation.directconnectivity.ReflectionUtils;
import com.azure.cosmos.implementation.directconnectivity.RntbdTransportClient;
import com.azure.cosmos.implementation.directconnectivity.Uri;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultChannelPromise;
import io.netty.channel.DefaultEventLoop;
import io.netty.channel.SingleThreadEventLoop;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.ssl.SslContext;
import org.mockito.Mockito;
Expand All @@ -21,13 +25,15 @@
import java.net.URI;
import java.net.URISyntaxException;
import java.time.Duration;
import java.time.Instant;

import static com.azure.cosmos.implementation.TestUtils.mockDiagnosticsClientContext;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
import static org.mockito.Mockito.doNothing;

public class RntbdRequestManagerTests {

@Test
@Test(groups = { "unit" })
public void transitTimeoutTimestampTests() throws URISyntaxException {
SslContext sslContextMock = Mockito.mock(SslContext.class);
RntbdEndpoint.Config config = new RntbdEndpoint.Config(
Expand All @@ -43,12 +49,18 @@ public void transitTimeoutTimestampTests() throws URISyntaxException {
30,
connectionStateListener,
Duration.ofSeconds(1).toNanos());
RntbdClientChannelHealthChecker.Timestamps timestamps = ReflectionUtils.getTimestamps(rntbdRequestManager);
RntbdClientChannelHealthChecker.Timestamps timestamps = rntbdRequestManager.getTimestamps();

ChannelHandlerContext channelHandlerContext = Mockito.mock(ChannelHandlerContext.class);
ChannelFuture channelFuture = Mockito.mock(ChannelFuture.class);
Mockito.when(channelHandlerContext.write(Mockito.any(), Mockito.any())).thenReturn(channelFuture);
Mockito.when(channelFuture.addListener(Mockito.any())).thenReturn(channelFuture);

Channel channelMock = Mockito.mock(Channel.class);
SingleThreadEventLoop eventLoopMock = new DefaultEventLoop();
Mockito.when(channelMock.eventLoop()).thenReturn(eventLoopMock);

ChannelPromise defaultChannelPromise = new DefaultChannelPromise(channelMock);
defaultChannelPromise.setSuccess();

Mockito.when(channelHandlerContext.write(Mockito.any(), Mockito.any())).thenReturn(defaultChannelPromise);

RntbdRequestArgs requestArgs = new RntbdRequestArgs(
RxDocumentServiceRequest.create(mockDiagnosticsClientContext(), OperationType.Read, ResourceType.Document),
Expand All @@ -61,18 +73,79 @@ public void transitTimeoutTimestampTests() throws URISyntaxException {
ChannelPromise promise = Mockito.mock(ChannelPromise.class);

// Test transitTimeout is 0 at start point
Instant previousLastWriteTime = timestamps.lastChannelWriteTime();
Instant previousLastWriteAttemptTime = timestamps.lastChannelWriteAttemptTime();
Instant previousLastChannelReadTime = timestamps.lastChannelReadTime();

rntbdRequestManager.write(channelHandlerContext, rntbdRequestRecord, promise);

assertThat(timestamps.transitTimeoutCount()).isZero();
assertThat(timestamps.transitTimeoutStartingTime()).isNull();
assertThat(timestamps.tansitTimeoutWriteCount()).isZero();
assertThat(timestamps.lastChannelWriteTime()).isAfterOrEqualTo(previousLastWriteTime);
assertThat(timestamps.lastChannelWriteAttemptTime()).isAfterOrEqualTo(previousLastWriteAttemptTime);
assertThat(timestamps.lastChannelReadTime()).isEqualTo(previousLastChannelReadTime);

// Test when a transit timeout happens, the transitTimeoutCount is increased
rntbdRequestRecord.expire();
assertThat(timestamps.transitTimeoutCount()).isOne();
assertThat(timestamps.tansitTimeoutWriteCount()).isZero();
assertThat(timestamps.transitTimeoutStartingTime()).isNotNull();
assertThat(Duration.between(timestamps.transitTimeoutStartingTime(), Instant.now())).isLessThan(Duration.ofSeconds(5));

// Test when there is channelRead, transitTimeout is cleared out
previousLastWriteTime = timestamps.lastChannelWriteTime();
previousLastWriteAttemptTime = timestamps.lastChannelWriteAttemptTime();
previousLastChannelReadTime = timestamps.lastChannelReadTime();

Mockito.when(channelHandlerContext.flush()).thenReturn(channelHandlerContext);

ChannelFuture closeChannelFuture = Mockito.mock(ChannelFuture.class);
Mockito.when(channelHandlerContext.close()).thenReturn(closeChannelFuture);
rntbdRequestManager.channelRead(channelHandlerContext, rntbdRequestRecord);
assertThat(timestamps.transitTimeoutCount()).isZero();
assertThat(timestamps.transitTimeoutStartingTime()).isNull();
assertThat(timestamps.lastChannelReadTime()).isAfterOrEqualTo(previousLastChannelReadTime);
assertThat(timestamps.lastChannelWriteAttemptTime()).isEqualTo(previousLastWriteAttemptTime);
assertThat(timestamps.lastChannelWriteTime()).isAfterOrEqualTo(previousLastWriteTime);
}

@Test(groups = { "unit" })
public void rntbdContextResponseTests() {
// Test when getting rntbdContext response, the lastReadTimestamp should be marked
SslContext sslContextMock = Mockito.mock(SslContext.class);
RntbdEndpoint.Config config = new RntbdEndpoint.Config(
new RntbdTransportClient.Options.Builder(ConnectionPolicy.getDefaultPolicy()).build(),
sslContextMock,
LogLevel.INFO);
RntbdClientChannelHealthChecker healthChecker = new RntbdClientChannelHealthChecker(config);

RntbdConnectionStateListener connectionStateListener = Mockito.mock(RntbdConnectionStateListener.class);

RntbdRequestManager rntbdRequestManager = new RntbdRequestManager(
healthChecker,
30,
connectionStateListener,
Duration.ofSeconds(1).toNanos());

RntbdClientChannelHealthChecker.Timestamps timestamps = rntbdRequestManager.getTimestamps();

ChannelHandlerContext channelHandlerContextMock = Mockito.mock(ChannelHandlerContext.class);
ChannelPipeline channelPipelineMock = Mockito.mock(ChannelPipeline.class);
Mockito.when(channelPipelineMock.fireChannelRegistered()).thenReturn(channelPipelineMock);
Mockito.when(channelHandlerContextMock.channel()).thenReturn(Mockito.mock(Channel.class));

RntbdContextNegotiator rntbdContextNegotiatorMock = Mockito.mock(RntbdContextNegotiator.class);
doNothing().when(rntbdContextNegotiatorMock).removeInboundHandler();
doNothing().when(rntbdContextNegotiatorMock).removeOutboundHandler();
Mockito.when(channelPipelineMock.get(RntbdContextNegotiator.class)).thenReturn(rntbdContextNegotiatorMock);
Mockito.when(channelHandlerContextMock.pipeline()).thenReturn(channelPipelineMock);

rntbdRequestManager.channelRegistered(channelHandlerContextMock);

Instant lastReadTimestamp = timestamps.lastChannelReadTime();
RntbdContext rntbdContextMock = Mockito.mock(RntbdContext.class);
rntbdRequestManager.userEventTriggered(channelHandlerContextMock, rntbdContextMock);
assertThat(timestamps.lastChannelReadTime()).isAfterOrEqualTo(lastReadTimestamp);
}
}

0 comments on commit 329beb9

Please sign in to comment.