Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RntbdHealthCheckImprovement-3 #33566

Merged
merged 3 commits into from
Feb 17, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,7 @@ public void channelPingCompleted() {

public void channelReadCompleted() {
lastReadUpdater.set(this, Instant.now());
this.resetTransitTimeout(); // we have got a successful read, so reset the transitTimeout count.
}

public void channelWriteAttempted() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,12 @@ private void startRntbdContextRequest(final ChannelHandlerContext context) throw
final RntbdContextRequest request = new RntbdContextRequest(Utils.randomUUID(), this.userAgent);
final CompletableFuture<RntbdContextRequest> contextRequestFuture = this.manager.rntbdContextRequestFuture();

this.manager.getTimestamps().channelWriteAttempted();
super.write(context, request, channel.newPromise().addListener((ChannelFutureListener)future -> {

if (future.isSuccess()) {
contextRequestFuture.complete(request);
this.manager.getTimestamps().channelWriteCompleted();
return;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ public void channelInactive(final ChannelHandlerContext context) {
public void channelRead(final ChannelHandlerContext context, final Object message) {

this.traceOperation(context, "channelRead");
this.timestamps.resetTransitTimeout(); // we have got a successful read, so reset the transitTimeout count.
this.timestamps.channelReadCompleted();

try {
if (message.getClass() == RntbdResponse.class) {
Expand Down Expand Up @@ -231,7 +231,6 @@ public void channelRead(final ChannelHandlerContext context, final Object messag
@Override
public void channelReadComplete(final ChannelHandlerContext context) {
this.traceOperation(context, "channelReadComplete");
this.timestamps.channelReadCompleted();
context.fireChannelReadComplete();
}

Expand Down Expand Up @@ -378,6 +377,10 @@ public void userEventTriggered(final ChannelHandlerContext context, final Object
if (event instanceof RntbdContext) {
this.contextFuture.complete((RntbdContext) event);
this.removeContextNegotiatorAndFlushPendingWrites(context);

// Important: currently the RntbdContext negotiation response will not be captured during channelRead
// need to mark the timestamp here
this.timestamps.channelReadCompleted();
return;
}
if (event instanceof RntbdContextException) {
Expand Down Expand Up @@ -578,13 +581,12 @@ public void write(final ChannelHandlerContext context, final Object message, fin
if (message instanceof RntbdRequestRecord) {

final RntbdRequestRecord record = (RntbdRequestRecord) message;

this.timestamps.channelWriteAttempted();
record.setTimestamps(this.timestamps);

record.setSendingRequestHasStarted();

if (!record.isCancelled()) {
record.setSendingRequestHasStarted();
this.timestamps.channelWriteAttempted();

context.write(this.addPendingRequestRecord(context, record), promise).addListener(completed -> {
record.stage(RntbdRequestRecord.Stage.SENT);
if (completed.isSuccess()) {
Expand Down Expand Up @@ -670,6 +672,10 @@ void pendWrite(final ByteBuf out, final ChannelPromise promise) {
this.pendingWrites.add(out, promise);
}

public Timestamps getTimestamps() {
return this.timestamps;
}

Timestamps snapshotTimestamps() {
return new Timestamps(this.timestamps);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -416,8 +416,4 @@ public static AtomicReference<Uri.HealthStatus> getHealthStatus(Uri uri) {
public static Set<Uri.HealthStatus> getReplicaValidationScopes(GatewayAddressCache gatewayAddressCache) {
return get(Set.class, gatewayAddressCache, "replicaValidationScopes");
}

public static RntbdClientChannelHealthChecker.Timestamps getTimestamps(RntbdRequestManager rntbdRequestManager) {
return get(RntbdClientChannelHealthChecker.Timestamps.class, rntbdRequestManager, "timestamps");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,16 @@
import com.azure.cosmos.implementation.OperationType;
import com.azure.cosmos.implementation.ResourceType;
import com.azure.cosmos.implementation.RxDocumentServiceRequest;
import com.azure.cosmos.implementation.directconnectivity.ReflectionUtils;
import com.azure.cosmos.implementation.directconnectivity.RntbdTransportClient;
import com.azure.cosmos.implementation.directconnectivity.Uri;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultChannelPromise;
import io.netty.channel.DefaultEventLoop;
import io.netty.channel.SingleThreadEventLoop;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.ssl.SslContext;
import org.mockito.Mockito;
Expand All @@ -21,13 +25,15 @@
import java.net.URI;
import java.net.URISyntaxException;
import java.time.Duration;
import java.time.Instant;

import static com.azure.cosmos.implementation.TestUtils.mockDiagnosticsClientContext;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
import static org.mockito.Mockito.doNothing;

public class RntbdRequestManagerTests {

@Test
@Test(groups = { "unit" })
public void transitTimeoutTimestampTests() throws URISyntaxException {
SslContext sslContextMock = Mockito.mock(SslContext.class);
RntbdEndpoint.Config config = new RntbdEndpoint.Config(
Expand All @@ -43,12 +49,18 @@ public void transitTimeoutTimestampTests() throws URISyntaxException {
30,
connectionStateListener,
Duration.ofSeconds(1).toNanos());
RntbdClientChannelHealthChecker.Timestamps timestamps = ReflectionUtils.getTimestamps(rntbdRequestManager);
RntbdClientChannelHealthChecker.Timestamps timestamps = rntbdRequestManager.getTimestamps();

ChannelHandlerContext channelHandlerContext = Mockito.mock(ChannelHandlerContext.class);
ChannelFuture channelFuture = Mockito.mock(ChannelFuture.class);
Mockito.when(channelHandlerContext.write(Mockito.any(), Mockito.any())).thenReturn(channelFuture);
Mockito.when(channelFuture.addListener(Mockito.any())).thenReturn(channelFuture);

Channel channelMock = Mockito.mock(Channel.class);
SingleThreadEventLoop eventLoopMock = new DefaultEventLoop();
Mockito.when(channelMock.eventLoop()).thenReturn(eventLoopMock);

ChannelPromise defaultChannelPromise = new DefaultChannelPromise(channelMock);
defaultChannelPromise.setSuccess();

Mockito.when(channelHandlerContext.write(Mockito.any(), Mockito.any())).thenReturn(defaultChannelPromise);

RntbdRequestArgs requestArgs = new RntbdRequestArgs(
RxDocumentServiceRequest.create(mockDiagnosticsClientContext(), OperationType.Read, ResourceType.Document),
Expand All @@ -61,18 +73,79 @@ public void transitTimeoutTimestampTests() throws URISyntaxException {
ChannelPromise promise = Mockito.mock(ChannelPromise.class);

// Test transitTimeout is 0 at start point
Instant previousLastWriteTime = timestamps.lastChannelWriteTime();
Instant previousLastWriteAttemptTime = timestamps.lastChannelWriteAttemptTime();
Instant previousLastChannelReadTime = timestamps.lastChannelReadTime();

rntbdRequestManager.write(channelHandlerContext, rntbdRequestRecord, promise);

assertThat(timestamps.transitTimeoutCount()).isZero();
assertThat(timestamps.transitTimeoutStartingTime()).isNull();
assertThat(timestamps.tansitTimeoutWriteCount()).isZero();
assertThat(timestamps.lastChannelWriteTime()).isAfter(previousLastWriteTime);
assertThat(timestamps.lastChannelWriteAttemptTime()).isAfter(previousLastWriteAttemptTime);
assertThat(timestamps.lastChannelReadTime()).isEqualTo(previousLastChannelReadTime);

// Test when a transit timeout happens, the transitTimeoutCount is increased
rntbdRequestRecord.expire();
assertThat(timestamps.transitTimeoutCount()).isOne();
assertThat(timestamps.tansitTimeoutWriteCount()).isZero();
assertThat(timestamps.transitTimeoutStartingTime()).isNotNull();
assertThat(Duration.between(timestamps.transitTimeoutStartingTime(), Instant.now())).isLessThan(Duration.ofSeconds(5));

// Test when there is channelRead, transitTimeout is cleared out
previousLastWriteTime = timestamps.lastChannelWriteTime();
previousLastWriteAttemptTime = timestamps.lastChannelWriteAttemptTime();
previousLastChannelReadTime = timestamps.lastChannelReadTime();

Mockito.when(channelHandlerContext.flush()).thenReturn(channelHandlerContext);

ChannelFuture closeChannelFuture = Mockito.mock(ChannelFuture.class);
Mockito.when(channelHandlerContext.close()).thenReturn(closeChannelFuture);
rntbdRequestManager.channelRead(channelHandlerContext, rntbdRequestRecord);
assertThat(timestamps.transitTimeoutCount()).isZero();
assertThat(timestamps.transitTimeoutStartingTime()).isNull();
assertThat(timestamps.lastChannelReadTime()).isAfter(previousLastChannelReadTime);
assertThat(timestamps.lastChannelWriteAttemptTime()).isEqualTo(previousLastWriteAttemptTime);
assertThat(timestamps.lastChannelWriteTime()).isEqualTo(previousLastWriteTime);
}

@Test(groups = { "unit" })
public void rntbdContextResponseTests() {
// Test when getting rntbdContext response, the lastReadTimestamp should be marked
SslContext sslContextMock = Mockito.mock(SslContext.class);
RntbdEndpoint.Config config = new RntbdEndpoint.Config(
new RntbdTransportClient.Options.Builder(ConnectionPolicy.getDefaultPolicy()).build(),
sslContextMock,
LogLevel.INFO);
RntbdClientChannelHealthChecker healthChecker = new RntbdClientChannelHealthChecker(config);

RntbdConnectionStateListener connectionStateListener = Mockito.mock(RntbdConnectionStateListener.class);

RntbdRequestManager rntbdRequestManager = new RntbdRequestManager(
healthChecker,
30,
connectionStateListener,
Duration.ofSeconds(1).toNanos());

RntbdClientChannelHealthChecker.Timestamps timestamps = rntbdRequestManager.getTimestamps();

ChannelHandlerContext channelHandlerContextMock = Mockito.mock(ChannelHandlerContext.class);
ChannelPipeline channelPipelineMock = Mockito.mock(ChannelPipeline.class);
Mockito.when(channelPipelineMock.fireChannelRegistered()).thenReturn(channelPipelineMock);
Mockito.when(channelHandlerContextMock.channel()).thenReturn(Mockito.mock(Channel.class));

RntbdContextNegotiator rntbdContextNegotiatorMock = Mockito.mock(RntbdContextNegotiator.class);
doNothing().when(rntbdContextNegotiatorMock).removeInboundHandler();
doNothing().when(rntbdContextNegotiatorMock).removeOutboundHandler();
Mockito.when(channelPipelineMock.get(RntbdContextNegotiator.class)).thenReturn(rntbdContextNegotiatorMock);
Mockito.when(channelHandlerContextMock.pipeline()).thenReturn(channelPipelineMock);

rntbdRequestManager.channelRegistered(channelHandlerContextMock);

Instant lastReadTimestamp = timestamps.lastChannelReadTime();
RntbdContext rntbdContextMock = Mockito.mock(RntbdContext.class);
rntbdRequestManager.userEventTriggered(channelHandlerContextMock, rntbdContextMock);
assertThat(timestamps.lastChannelReadTime()).isAfter(lastReadTimestamp);
}
}