diff --git a/protocols/raft/src/main/java/io/atomix/protocols/raft/proxy/impl/RaftProxyConnection.java b/protocols/raft/src/main/java/io/atomix/protocols/raft/proxy/impl/RaftProxyConnection.java index b982bcbe84..06323fb3b2 100644 --- a/protocols/raft/src/main/java/io/atomix/protocols/raft/proxy/impl/RaftProxyConnection.java +++ b/protocols/raft/src/main/java/io/atomix/protocols/raft/proxy/impl/RaftProxyConnection.java @@ -67,6 +67,7 @@ public class RaftProxyConnection { private final MemberSelector selector; private final ThreadContext context; private NodeId currentNode; + private int selectionId; public RaftProxyConnection(RaftClientProtocol protocol, MemberSelector selector, ThreadContext context, LoggerContext loggerContext) { this.protocol = checkNotNull(protocol, "protocol cannot be null"); @@ -119,9 +120,9 @@ public Collection members() { public CompletableFuture openSession(OpenSessionRequest request) { CompletableFuture future = new CompletableFuture<>(); if (context.isCurrentContext()) { - sendRequest(request, protocol::openSession, next(), future); + sendRequest(request, protocol::openSession, future); } else { - context.execute(() -> sendRequest(request, protocol::openSession, next(), future)); + context.execute(() -> sendRequest(request, protocol::openSession, future)); } return future; } @@ -135,9 +136,9 @@ public CompletableFuture openSession(OpenSessionRequest req public CompletableFuture closeSession(CloseSessionRequest request) { CompletableFuture future = new CompletableFuture<>(); if (context.isCurrentContext()) { - sendRequest(request, protocol::closeSession, next(), future); + sendRequest(request, protocol::closeSession, future); } else { - context.execute(() -> sendRequest(request, protocol::closeSession, next(), future)); + context.execute(() -> sendRequest(request, protocol::closeSession, future)); } return future; } @@ -151,9 +152,9 @@ public CompletableFuture closeSession(CloseSessionRequest public CompletableFuture keepAlive(KeepAliveRequest request) { CompletableFuture future = new CompletableFuture<>(); if (context.isCurrentContext()) { - sendRequest(request, protocol::keepAlive, next(), future); + sendRequest(request, protocol::keepAlive, future); } else { - context.execute(() -> sendRequest(request, protocol::keepAlive, next(), future)); + context.execute(() -> sendRequest(request, protocol::keepAlive, future)); } return future; } @@ -167,9 +168,9 @@ public CompletableFuture keepAlive(KeepAliveRequest request) public CompletableFuture query(QueryRequest request) { CompletableFuture future = new CompletableFuture<>(); if (context.isCurrentContext()) { - sendRequest(request, protocol::query, next(), future); + sendRequest(request, protocol::query, future); } else { - context.execute(() -> sendRequest(request, protocol::query, next(), future)); + context.execute(() -> sendRequest(request, protocol::query, future)); } return future; } @@ -183,9 +184,9 @@ public CompletableFuture query(QueryRequest request) { public CompletableFuture command(CommandRequest request) { CompletableFuture future = new CompletableFuture<>(); if (context.isCurrentContext()) { - sendRequest(request, protocol::command, next(), future); + sendRequest(request, protocol::command, future); } else { - context.execute(() -> sendRequest(request, protocol::command, next(), future)); + context.execute(() -> sendRequest(request, protocol::command, future)); } return future; } @@ -199,9 +200,9 @@ public CompletableFuture command(CommandRequest request) { public CompletableFuture metadata(MetadataRequest request) { CompletableFuture future = new CompletableFuture<>(); if (context.isCurrentContext()) { - sendRequest(request, protocol::metadata, next(), future); + sendRequest(request, protocol::metadata, future); } else { - context.execute(() -> sendRequest(request, protocol::metadata, next(), future)); + context.execute(() -> sendRequest(request, protocol::metadata, future)); } return future; } @@ -209,12 +210,21 @@ public CompletableFuture metadata(MetadataRequest request) { /** * Sends the given request attempt to the cluster. */ - protected void sendRequest(T request, BiFunction> sender, NodeId member, CompletableFuture future) { - if (member != null) { - log.trace("Sending {} to {}", request, member); - sender.apply(member, request).whenCompleteAsync((r, e) -> { + protected void sendRequest(T request, BiFunction> sender, CompletableFuture future) { + sendRequest(request, sender, 0, future); + } + + /** + * Sends the given request attempt to the cluster. + */ + protected void sendRequest(T request, BiFunction> sender, int count, CompletableFuture future) { + NodeId node = next(); + if (node != null) { + log.trace("Sending {} to {}", request, node); + int selectionId = this.selectionId; + sender.apply(node, request).whenCompleteAsync((r, e) -> { if (e != null || r != null) { - handleResponse(request, sender, member, r, e, future); + handleResponse(request, sender, count, selectionId, node, r, e, future); } else { future.complete(null); } @@ -228,29 +238,29 @@ protected void sendRequest(T req * Resends a request due to a request failure, resetting the connection if necessary. */ @SuppressWarnings("unchecked") - protected void retryRequest(Throwable cause, T request, BiFunction sender, NodeId member, CompletableFuture future) { + protected void retryRequest(Throwable cause, T request, BiFunction sender, int count, int selectionId, CompletableFuture future) { // If the connection has not changed, reset it and connect to the next server. - if (this.currentNode == member) { + if (this.selectionId == selectionId) { log.trace("Resetting connection. Reason: {}", cause.getMessage()); this.currentNode = null; } // Attempt to send the request again. - sendRequest(request, sender, next(), future); + sendRequest(request, sender, count, future); } /** * Handles a response from the cluster. */ @SuppressWarnings("unchecked") - protected void handleResponse(T request, BiFunction sender, NodeId member, RaftResponse response, Throwable error, CompletableFuture future) { + protected void handleResponse(T request, BiFunction sender, int count, int selectionId, NodeId node, RaftResponse response, Throwable error, CompletableFuture future) { if (error == null) { - log.trace("Received {} from {}", response, member); + log.trace("Received {} from {}", response, node); if (COMPLETE_PREDICATE.test(response)) { future.complete(response); selector.reset(); } else { - retryRequest(response.error().createException(), request, sender, member, future); + retryRequest(response.error().createException(), request, sender, count + 1, selectionId, future); } } else { if (error instanceof CompletionException) { @@ -258,7 +268,11 @@ protected void handleResponse(T request, BiFunction send } log.debug("{} failed! Reason: {}", request, error); if (error instanceof ConnectException || error instanceof TimeoutException || error instanceof ClosedChannelException) { - retryRequest(error, request, sender, member, future); + if (count < selector.members().size() + 1) { + retryRequest(error, request, sender, count + 1, selectionId, future); + } else { + future.completeExceptionally(error); + } } else { future.completeExceptionally(error); } @@ -278,6 +292,7 @@ protected NodeId next() { if (selector.leader() != null) { selector.reset(null, selector.members()); this.currentNode = selector.next(); + this.selectionId++; return currentNode; } else { log.debug("Failed to connect to the cluster"); @@ -286,6 +301,7 @@ protected NodeId next() { } } else { this.currentNode = selector.next(); + this.selectionId++; return currentNode; } }