From d924be6c86283908dccf3c37956b069a2e37ee95 Mon Sep 17 00:00:00 2001 From: jkolobok Date: Mon, 25 Apr 2022 02:46:56 +0300 Subject: [PATCH] ConnectionListener#onDisconnected_is_called_twice_for_every_disconnect, Fixes #164 & Fixes #219 - moved exception handler - Add tests - Don't reset HeartbeatHandler timeout on message send --- .../VanillaClusteredNetworkContext.java | 6 +- .../cluster/handlers/HeartbeatHandler.java | 38 ++-- .../network/cluster/handlers/UberHandler.java | 19 +- .../network/ConnectionListenerTest.java | 214 ++++++++++++++++++ .../chronicle/network/UberHandlerTest.java | 133 ++--------- .../chronicle/network/test/TestCluster.java | 14 ++ .../network/test/TestClusterContext.java | 95 ++++++++ .../test/TestClusteredNetworkContext.java | 91 ++++++++ 8 files changed, 467 insertions(+), 143 deletions(-) create mode 100644 src/test/java/net/openhft/chronicle/network/ConnectionListenerTest.java create mode 100644 src/test/java/net/openhft/chronicle/network/test/TestCluster.java create mode 100644 src/test/java/net/openhft/chronicle/network/test/TestClusterContext.java create mode 100644 src/test/java/net/openhft/chronicle/network/test/TestClusteredNetworkContext.java diff --git a/src/main/java/net/openhft/chronicle/network/cluster/VanillaClusteredNetworkContext.java b/src/main/java/net/openhft/chronicle/network/cluster/VanillaClusteredNetworkContext.java index eaac7d6c52d..a51e0669e3d 100644 --- a/src/main/java/net/openhft/chronicle/network/cluster/VanillaClusteredNetworkContext.java +++ b/src/main/java/net/openhft/chronicle/network/cluster/VanillaClusteredNetworkContext.java @@ -18,15 +18,13 @@ package net.openhft.chronicle.network.cluster; +import net.openhft.chronicle.core.Jvm; import net.openhft.chronicle.core.threads.EventLoop; import net.openhft.chronicle.network.VanillaNetworkContext; import org.jetbrains.annotations.NotNull; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; public class VanillaClusteredNetworkContext, C extends ClusterContext> extends VanillaNetworkContext implements ClusteredNetworkContext { - private static final Logger LOGGER = LoggerFactory.getLogger(VanillaClusteredNetworkContext.class); @NotNull private final EventLoop eventLoop; @@ -57,7 +55,7 @@ public C clusterContext() { } private boolean logMissedHeartbeat() { - LOGGER.warn("Missed heartbeat on network context " + socketChannel()); + Jvm.warn().on(VanillaClusteredNetworkContext.class, "Missed heartbeat on network context " + socketChannel()); return false; } } diff --git a/src/main/java/net/openhft/chronicle/network/cluster/handlers/HeartbeatHandler.java b/src/main/java/net/openhft/chronicle/network/cluster/handlers/HeartbeatHandler.java index 3066f687f61..a139308369c 100644 --- a/src/main/java/net/openhft/chronicle/network/cluster/handlers/HeartbeatHandler.java +++ b/src/main/java/net/openhft/chronicle/network/cluster/handlers/HeartbeatHandler.java @@ -46,9 +46,10 @@ public final class HeartbeatHandler> extend private final long heartbeatIntervalMs; private final long heartbeatTimeoutMs; private final AtomicBoolean hasHeartbeats = new AtomicBoolean(); + private final AtomicBoolean closed; private volatile long lastTimeMessageReceived; @Nullable - private ConnectionListener connectionMonitor; + private ConnectionListener connectionListener; @Nullable private Timer timer; @@ -60,6 +61,7 @@ public HeartbeatHandler(@NotNull WireIn w) { } private HeartbeatHandler(long heartbeatTimeoutMs, long heartbeatIntervalMs) { + closed = new AtomicBoolean(false); this.heartbeatTimeoutMs = heartbeatTimeoutMs; this.heartbeatIntervalMs = heartbeatIntervalMs; validateHeartbeatParameters(this.heartbeatTimeoutMs, this.heartbeatIntervalMs); @@ -100,7 +102,7 @@ public void onInitialize(@NotNull WireOut outWire) { @NotNull final WriteMarshallable heartbeatMessage = new HeartbeatMessage(); - connectionMonitor = nc().acquireConnectionListener(); + connectionListener = nc().acquireConnectionListener(); timer = new Timer(nc().eventLoop()); startPeriodicHeartbeatCheck(); startPeriodicallySendingHeartbeats(heartbeatMessage); @@ -134,12 +136,19 @@ public void onRead(@NotNull WireIn inWire, @NotNull WireOut outWire) { @Override public void close() { - if (connectionMonitor != null) - connectionMonitor.onDisconnected(localIdentifier(), remoteIdentifier(), nc().isAcceptor()); - lastTimeMessageReceived = Long.MAX_VALUE; - Closeable closable = closable(); - if (closable != null && !closable.isClosed()) { - Closeable.closeQuietly(closable); + if (closed.compareAndSet(false, true)) { + if (connectionListener != null) { + try { + connectionListener.onDisconnected(localIdentifier(), remoteIdentifier(), nc().isAcceptor()); + } catch (Exception e) { + Jvm.error().on(getClass(), "Exception thrown by ConnectionListener#onDisconnected", e); + } + } + lastTimeMessageReceived = Long.MAX_VALUE; + Closeable closable = closable(); + if (closable != null && !closable.isClosed()) { + Closeable.closeQuietly(closable); + } } } @@ -224,9 +233,6 @@ public boolean action() throws InvalidEventHandlerException { if (hasHeartbeats != prev) { if (!hasHeartbeats) { - connectionMonitor.onDisconnected(HeartbeatHandler.this.localIdentifier(), - HeartbeatHandler.this.remoteIdentifier(), HeartbeatHandler.this.nc().isAcceptor()); - final Runnable socketReconnector = HeartbeatHandler.this.nc().socketReconnector(); if (socketReconnector == null) Jvm.warn().on(getClass(), "socketReconnector == null"); @@ -237,8 +243,12 @@ public boolean action() throws InvalidEventHandlerException { throw newClosedInvalidEventHandlerException(); } else - connectionMonitor.onConnected(HeartbeatHandler.this.localIdentifier(), - HeartbeatHandler.this.remoteIdentifier(), HeartbeatHandler.this.nc().isAcceptor()); + try { + connectionListener.onConnected(HeartbeatHandler.this.localIdentifier(), + HeartbeatHandler.this.remoteIdentifier(), HeartbeatHandler.this.nc().isAcceptor()); + } catch (RuntimeException e) { + Jvm.error().on(HeartbeatCheckHandler.class, "Exception thrown by ConnectionListener#onConnected", e); + } } return true; @@ -291,4 +301,4 @@ InvalidEventHandlerException newClosedInvalidEventHandlerException() { return new InvalidEventHandlerException("closed"); } -} \ No newline at end of file +} diff --git a/src/main/java/net/openhft/chronicle/network/cluster/handlers/UberHandler.java b/src/main/java/net/openhft/chronicle/network/cluster/handlers/UberHandler.java index 234dca58259..e6775413289 100644 --- a/src/main/java/net/openhft/chronicle/network/cluster/handlers/UberHandler.java +++ b/src/main/java/net/openhft/chronicle/network/cluster/handlers/UberHandler.java @@ -22,7 +22,6 @@ import net.openhft.chronicle.core.io.Closeable; import net.openhft.chronicle.core.io.ClosedIllegalStateException; import net.openhft.chronicle.core.threads.EventLoop; -import net.openhft.chronicle.network.ConnectionListener; import net.openhft.chronicle.network.api.session.SubHandler; import net.openhft.chronicle.network.api.session.WritableSubHandler; import net.openhft.chronicle.network.cluster.ClusteredNetworkContext; @@ -149,17 +148,6 @@ protected void performClose() { if (connectionChangedNotifier != null) { eventEmitterToken = connectionChangedNotifier.onConnectionChanged(false, nc, eventEmitterToken); } - - try { - if (nc != null) { - final ConnectionListener listener = nc.acquireConnectionListener(); - if (listener != null) - listener.onDisconnected(localIdentifier, remoteIdentifier(), nc.isAcceptor()); - } - } catch (Exception e) { - Jvm.error().on(getClass(), "close:", e); - throw Jvm.rethrow(e); - } Closeable.closeQuietly(writers); writers.clear(); super.performClose(); @@ -187,7 +175,7 @@ protected void onRead(@NotNull final DocumentContext dc, @NotNull final WireOut } } - onMessageReceivedOrWritten(); + onMessageReceived(); final Wire inWire = dc.wire(); if (dc.isMetaData()) { @@ -250,7 +238,6 @@ public void performIdleWork() { @Override protected void onBytesWritten() { - onMessageReceivedOrWritten(); } /** @@ -286,7 +273,7 @@ protected void onWrite(@NotNull final WireOut outWire) { } } - private void onMessageReceivedOrWritten() { + private void onMessageReceived() { final HeartbeatEventHandler heartbeatEventHandler = heartbeatEventHandler(); if (heartbeatEventHandler != null) heartbeatEventHandler.onMessageReceived(); @@ -299,4 +286,4 @@ public String toString() { ", localIdentifier=" + localIdentifier + '}'; } -} \ No newline at end of file +} diff --git a/src/test/java/net/openhft/chronicle/network/ConnectionListenerTest.java b/src/test/java/net/openhft/chronicle/network/ConnectionListenerTest.java new file mode 100644 index 00000000000..7e0185d2403 --- /dev/null +++ b/src/test/java/net/openhft/chronicle/network/ConnectionListenerTest.java @@ -0,0 +1,214 @@ +package net.openhft.chronicle.network; + +import net.openhft.chronicle.core.Jvm; +import net.openhft.chronicle.core.threads.InvalidEventHandlerException; +import net.openhft.chronicle.network.cluster.HostDetails; +import net.openhft.chronicle.network.test.TestClusterContext; +import net.openhft.chronicle.testframework.Waiters; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.io.IOException; + +import static net.openhft.chronicle.network.TCPRegistry.createServerSocketChannelFor; +import static net.openhft.chronicle.network.test.TestClusterContext.forHosts; +import static org.junit.jupiter.api.Assertions.assertEquals; + +class ConnectionListenerTest extends NetworkTestCommon { + + private HostDetails initiatorHost; + private HostDetails acceptorHost; + private CountingConnectionListener initiatorCounter; + private CountingConnectionListener acceptorCounter; + + @BeforeEach + void setUp() throws IOException { + createServerSocketChannelFor("initiator", "acceptor"); + initiatorHost = new HostDetails().hostId(2).connectUri("initiator"); + acceptorHost = new HostDetails().hostId(1).connectUri("acceptor"); + acceptorCounter = new CountingConnectionListener(); + initiatorCounter = new CountingConnectionListener(); + } + + @Test + void onConnectAndOnDisconnectAreCalledOnce_OnOrderlyConnectionAndDisconnection() { + try (TestClusterContext acceptorCtx = forHosts(acceptorHost, initiatorHost); + TestClusterContext initiatorCtx = forHosts(initiatorHost, acceptorHost)) { + + acceptorCtx.addConnectionListener(acceptorCounter); + initiatorCtx.addConnectionListener(initiatorCounter); + + acceptorCtx.cluster().start(acceptorHost.hostId()); + initiatorCtx.cluster().start(initiatorHost.hostId()); + + Waiters.waitForCondition("acceptor and initiator to connect", + () -> acceptorCounter.onConnectedCalls > 0 && initiatorCounter.onConnectedCalls > 0, + 5_000); + } + assertEquals(1, acceptorCounter.onConnectedCalls); + assertEquals(1, acceptorCounter.onDisconnectedCalls); + assertEquals(1, initiatorCounter.onConnectedCalls); + assertEquals(1, initiatorCounter.onDisconnectedCalls); + } + + @Test + void onConnectAndOnDisconnectAreCalledOnce_WhenConnectionTimesOut_InUberHandler() { + expectException("missed heartbeat, lastTimeMessageReceived="); + try (TestClusterContext acceptorCtx = forHosts(acceptorHost, initiatorHost); + TestClusterContext initiatorCtx = forHosts(initiatorHost, acceptorHost)) { + initiatorCtx.overrideNetworkContextTimeout(5_000); // we want the heartbeat handler to timeout + initiatorCtx.disableReconnect(); + acceptorCtx.overrideNetworkContextTimeout(5_000); // we want the heartbeat handler to timeout + acceptorCtx.disableReconnect(); + + initiatorCtx.heartbeatTimeoutMs(1_000); // set to minimum, initiator dictates + + acceptorCtx.addConnectionListener(acceptorCounter); + initiatorCtx.addConnectionListener(initiatorCounter); + + acceptorCtx.cluster().start(acceptorHost.hostId()); + initiatorCtx.cluster().start(initiatorHost.hostId()); + + Waiters.waitForCondition("acceptor and initiator to connect", + () -> acceptorCounter.onConnectedCalls > 0 && initiatorCounter.onConnectedCalls > 0, + 5_000); + // jam up the acceptor event loop to trigger an initiator timeout + acceptorCtx.cluster().clusterContext().eventLoop().addHandler(() -> { + Jvm.pause(3_000); + throw InvalidEventHandlerException.reusable(); + }); + Waiters.waitForCondition("initiator to timeout", + () -> initiatorCounter.onDisconnectedCalls > 0, 3_000); + } + assertEquals(1, acceptorCounter.onConnectedCalls); + assertEquals(1, acceptorCounter.onDisconnectedCalls); + assertEquals(1, initiatorCounter.onConnectedCalls); + assertEquals(1, initiatorCounter.onDisconnectedCalls); + } + + @Test + void onConnectAndOnDisconnectAreCalledOnce_WhenConnectionTimesOut_InTcpHandler() { + expectException("Missed heartbeat on network context"); + try (TestClusterContext acceptorCtx = forHosts(acceptorHost, initiatorHost); + TestClusterContext initiatorCtx = forHosts(initiatorHost, acceptorHost)) { + initiatorCtx.overrideNetworkContextTimeout(1_000); // we want the TcpEventHandler to timeout + initiatorCtx.disableReconnect(); + acceptorCtx.overrideNetworkContextTimeout(1_000); // we want the TcpEventHandler to timeout + acceptorCtx.disableReconnect(); + + initiatorCtx.heartbeatTimeoutMs(5_000); + + acceptorCtx.addConnectionListener(acceptorCounter); + initiatorCtx.addConnectionListener(initiatorCounter); + + acceptorCtx.cluster().start(acceptorHost.hostId()); + initiatorCtx.cluster().start(initiatorHost.hostId()); + + Waiters.waitForCondition("acceptor and initiator to connect", + () -> acceptorCounter.onConnectedCalls > 0 && initiatorCounter.onConnectedCalls > 0, + 5_000); + + // jam up the acceptor event loop to trigger an initiator timeout + acceptorCtx.cluster().clusterContext().eventLoop().addHandler(() -> { + Jvm.pause(3_000); + throw InvalidEventHandlerException.reusable(); + }); + Waiters.waitForCondition("initiator to timeout", + () -> initiatorCounter.onDisconnectedCalls == 1, 3_000); + } + assertEquals(1, acceptorCounter.onConnectedCalls); + assertEquals(1, acceptorCounter.onDisconnectedCalls); + assertEquals(1, initiatorCounter.onConnectedCalls); + assertEquals(1, initiatorCounter.onDisconnectedCalls); + } + + @Test + void onConnectAndOnDisconnectAreNotCalled_WhenNoConnectionIsEstablished_Initiator() { + try (TestClusterContext acceptorCtx = forHosts(acceptorHost, initiatorHost); + TestClusterContext initiatorCtx = forHosts(initiatorHost, acceptorHost)) { + + acceptorCtx.addConnectionListener(acceptorCounter); + initiatorCtx.addConnectionListener(initiatorCounter); + + // only start initiator + initiatorCtx.cluster().start(initiatorHost.hostId()); + Jvm.pause(1_000); + } + assertEquals(0, acceptorCounter.onConnectedCalls); + assertEquals(0, acceptorCounter.onDisconnectedCalls); + assertEquals(0, initiatorCounter.onConnectedCalls); + assertEquals(0, initiatorCounter.onDisconnectedCalls); + } + + @Test + void onConnectAndOnDisconnectAreNotCalled_WhenNoConnectionIsEstablished_Acceptor() { + try (TestClusterContext acceptorCtx = forHosts(acceptorHost, initiatorHost); + TestClusterContext initiatorCtx = forHosts(initiatorHost, acceptorHost)) { + + acceptorCtx.addConnectionListener(acceptorCounter); + initiatorCtx.addConnectionListener(initiatorCounter); + + // only start acceptor + acceptorCtx.cluster().start(initiatorHost.hostId()); + Jvm.pause(1_000); + } + assertEquals(0, acceptorCounter.onConnectedCalls); + assertEquals(0, acceptorCounter.onDisconnectedCalls); + assertEquals(0, initiatorCounter.onConnectedCalls); + assertEquals(0, initiatorCounter.onDisconnectedCalls); + } + + @Test + void onConnectAndOnDisconnect_WillLogWhenAnExceptionIsThrown() { + expectException("Something went wrong - onConnect"); + expectException("Something went wrong - onDisconnect"); + try (TestClusterContext acceptorCtx = forHosts(acceptorHost, initiatorHost); + TestClusterContext initiatorCtx = forHosts(initiatorHost, acceptorHost)) { + + acceptorCtx.addConnectionListener(new ThrowingConnectionListener()); + initiatorCtx.addConnectionListener(initiatorCounter); + + acceptorCtx.cluster().start(acceptorHost.hostId()); + initiatorCtx.cluster().start(initiatorHost.hostId()); + + Waiters.waitForCondition("acceptor and initiator to connect", + () -> initiatorCounter.onConnectedCalls > 0, + 5_000); + + // this shouldn't trigger a disconnect + Jvm.pause(1_000); + assertEquals(0, initiatorCounter.onDisconnectedCalls); + } + assertEquals(1, initiatorCounter.onConnectedCalls); + assertEquals(1, initiatorCounter.onDisconnectedCalls); + } + + private static class ThrowingConnectionListener implements ConnectionListener { + + @Override + public void onConnected(int localIdentifier, int remoteIdentifier, boolean isAcceptor) { + throw new RuntimeException("Something went wrong - onConnect"); + } + + @Override + public void onDisconnected(int localIdentifier, int remoteIdentifier, boolean isAcceptor) { + throw new RuntimeException("Something went wrong - onDisconnect"); + } + } + + private static class CountingConnectionListener implements ConnectionListener { + + private int onConnectedCalls = 0; + private int onDisconnectedCalls = 0; + + @Override + public void onConnected(int localIdentifier, int remoteIdentifier, boolean isAcceptor) { + onConnectedCalls++; + } + + @Override + public void onDisconnected(int localIdentifier, int remoteIdentifier, boolean isAcceptor) { + onDisconnectedCalls++; + } + } +} \ No newline at end of file diff --git a/src/test/java/net/openhft/chronicle/network/UberHandlerTest.java b/src/test/java/net/openhft/chronicle/network/UberHandlerTest.java index 8ee0b70c978..b636fd2bebd 100644 --- a/src/test/java/net/openhft/chronicle/network/UberHandlerTest.java +++ b/src/test/java/net/openhft/chronicle/network/UberHandlerTest.java @@ -24,22 +24,19 @@ import net.openhft.chronicle.core.io.Closeable; import net.openhft.chronicle.core.io.IORuntimeException; import net.openhft.chronicle.core.threads.EventLoop; -import net.openhft.chronicle.core.util.ThrowingFunction; -import net.openhft.chronicle.network.api.TcpHandler; import net.openhft.chronicle.network.api.session.WritableSubHandler; import net.openhft.chronicle.network.cluster.AbstractSubHandler; -import net.openhft.chronicle.network.cluster.Cluster; import net.openhft.chronicle.network.cluster.HostDetails; -import net.openhft.chronicle.network.cluster.VanillaClusteredNetworkContext; import net.openhft.chronicle.network.cluster.handlers.Registerable; import net.openhft.chronicle.network.cluster.handlers.RejectedHandlerException; import net.openhft.chronicle.network.cluster.handlers.UberHandler; import net.openhft.chronicle.network.connection.CoreFields; import net.openhft.chronicle.network.connection.VanillaWireOutPublisher; +import net.openhft.chronicle.network.test.TestClusterContext; +import net.openhft.chronicle.network.test.TestClusteredNetworkContext; import net.openhft.chronicle.threads.Pauser; import net.openhft.chronicle.threads.TimingPauser; import net.openhft.chronicle.wire.*; -import org.apache.mina.util.IdentityHashSet; import org.jetbrains.annotations.NotNull; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; @@ -49,7 +46,6 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.TimeUnit; @@ -57,12 +53,12 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Function; import java.util.stream.IntStream; import static net.openhft.chronicle.network.HeaderTcpHandler.HANDLER; import static net.openhft.chronicle.network.cluster.handlers.UberHandler.uberHandler; import static net.openhft.chronicle.network.connection.CoreFields.*; +import static net.openhft.chronicle.network.test.TestClusterContext.forHosts; import static org.junit.jupiter.api.Assertions.*; class UberHandlerTest extends NetworkTestCommon { @@ -109,8 +105,8 @@ void testUberHandlerWithMultipleSubHandlersAndHeartbeats() throws IOException, T HostDetails initiatorHost = new HostDetails().hostId(2).connectUri("initiator"); HostDetails acceptorHost = new HostDetails().hostId(1).connectUri("acceptor"); - try (MyClusterContext acceptorCtx = clusterContext(acceptorHost, initiatorHost); - MyClusterContext initiatorCtx = clusterContext(initiatorHost, acceptorHost)) { + try (TestClusterContext acceptorCtx = forHosts(acceptorHost, initiatorHost); + TestClusterContext initiatorCtx = forHosts(initiatorHost, acceptorHost)) { acceptorCtx.cluster().start(acceptorHost.hostId()); initiatorCtx.cluster().start(initiatorHost.hostId()); @@ -166,8 +162,8 @@ void testHandlerWillCloseWhenHostIdsAreWrong() throws IOException { HostDetails acceptorHost = new HostDetails().hostId(1).connectUri("acceptor"); HostDetails acceptorHostWithInvalidId = new HostDetails().hostId(98).connectUri("acceptor"); - try (MyClusterContext acceptorCtx = clusterContext(acceptorHost, initiatorHost); - MyClusterContext initiatorCtx = clusterContext(initiatorHost, acceptorHostWithInvalidId)) { + try (TestClusterContext acceptorCtx = forHosts(acceptorHost, initiatorHost); + TestClusterContext initiatorCtx = forHosts(initiatorHost, acceptorHostWithInvalidId)) { acceptorCtx.cluster().start(acceptorHost.hostId()); initiatorCtx.cluster().start(initiatorHost.hostId()); @@ -189,8 +185,8 @@ void newConnectionListenersAreExecutedOnEventLoopForExistingConnections() throws HostDetails initiatorHost = new HostDetails().hostId(2).connectUri("initiator"); HostDetails acceptorHost = new HostDetails().hostId(1).connectUri("acceptor"); - try (MyClusterContext acceptorCtx = clusterContext(acceptorHost, initiatorHost); - MyClusterContext initiatorCtx = clusterContext(initiatorHost, acceptorHost)) { + try (TestClusterContext acceptorCtx = forHosts(acceptorHost, initiatorHost); + TestClusterContext initiatorCtx = forHosts(initiatorHost, acceptorHost)) { acceptorCtx.cluster().start(acceptorHost.hostId()); initiatorCtx.cluster().start(initiatorHost.hostId()); @@ -247,8 +243,8 @@ void testBusyWritingHandlersAreCalledFirstInRoundRobin() throws IOException, Tim HostDetails initiatorHost = new HostDetails().hostId(2).connectUri("initiator"); HostDetails acceptorHost = new HostDetails().hostId(1).connectUri("acceptor"); - try (MyClusterContext acceptorCtx = clusterContext(acceptorHost, initiatorHost); - MyClusterContext initiatorCtx = clusterContext(initiatorHost, acceptorHost)) { + try (TestClusterContext acceptorCtx = forHosts(acceptorHost, initiatorHost); + TestClusterContext initiatorCtx = forHosts(initiatorHost, acceptorHost)) { acceptorCtx.cluster().start(acceptorHost.hostId()); initiatorCtx.cluster().start(initiatorHost.hostId()); @@ -382,7 +378,7 @@ public void registry(Map registry) { } } - private static class WritableRejectingSubHandler extends RejectingSubHandler implements WritableSubHandler { + private static class WritableRejectingSubHandler extends RejectingSubHandler implements WritableSubHandler { @Override public void onWrite(WireOut outWire) { @@ -395,7 +391,7 @@ public void onWrite(WireOut outWire) { } } - private static class RejectingSubHandler extends AbstractSubHandler implements Marshallable { + private static class RejectingSubHandler extends AbstractSubHandler implements Marshallable { protected boolean rejected = false; @@ -437,89 +433,8 @@ private void sendHandler(WireOut wireOut, int cid, Marshallable handler) { .writeEventName(CoreFields.handler).typedMarshallable(handler); } - private void sendMessageToHandler(WireOut wireOut, int cid) { - wireOut.writeEventName(csp).text(TEST_HANDLERS_CSP) - .writeEventName(CoreFields.cid).int64(cid); - } - - @NotNull - private MyClusterContext clusterContext(HostDetails... clusterHosts) { - MyClusterContext ctx = new MyClusterContext().wireType(WireType.BINARY).localIdentifier((byte) clusterHosts[0].hostId()); - ctx.heartbeatIntervalMs(500); - MyCluster cluster = new MyCluster(ctx); - for (HostDetails details : clusterHosts) { - cluster.hostDetails.put(String.valueOf(details.hostId()), details); - } - return ctx; - } - - static class MyClusteredNetworkContext extends VanillaClusteredNetworkContext { - - public Set connectionListeners = new IdentityHashSet<>(); - - public MyClusteredNetworkContext(@NotNull MyClusterContext clusterContext) { - super(clusterContext); - } - - @Override - public void addConnectionListener(ConnectionListener connectionListener) { - connectionListeners.add(connectionListener); - } - - @Override - public void removeConnectionListener(ConnectionListener connectionListener) { - connectionListeners.remove(connectionListener); - } - } - - static class MyCluster extends Cluster { - MyCluster(MyClusterContext clusterContext) { - super(); - clusterContext(clusterContext); - clusterContext.cluster(this); - } - } - - static class MyClusterContext extends net.openhft.chronicle.network.cluster.ClusterContext { - @Override - protected String clusterNamePrefix() { - return ""; - } - - @NotNull - @Override - public ThrowingFunction, IOException> tcpEventHandlerFactory() { - return nc -> { - if (nc.isAcceptor()) { - nc.wireOutPublisher(new VanillaWireOutPublisher(wireType())); - } - final TcpEventHandler handler = new TcpEventHandler<>(nc); - final Function> factory = - unused -> new HeaderTcpHandler<>(handler, o -> (TcpHandler) o); - final WireTypeSniffingTcpHandler sniffer = new WireTypeSniffingTcpHandler<>(handler, factory); - handler.tcpHandler(sniffer); - return handler; - }; - } - - @Override - protected void defaults() { - if (this.wireType() == null) - this.wireType(WireType.BINARY); - - if (this.wireOutPublisherFactory() == null) - this.wireOutPublisherFactory(VanillaWireOutPublisher::new); - - if (serverThreadingStrategy() == null) - this.serverThreadingStrategy(ServerThreadingStrategy.SINGLE_THREADED); - - if (this.networkContextFactory() == null) - this.networkContextFactory(MyClusteredNetworkContext::new); - } - } - static class Sender extends AbstractCompleteFlaggingHandler - implements WritableSubHandler, Marshallable { + implements WritableSubHandler, Marshallable { public Sender() { } @@ -556,7 +471,7 @@ public void onWrite(WireOut outWire) { } static class Receiver extends AbstractCompleteFlaggingHandler - implements WritableSubHandler, Marshallable { + implements WritableSubHandler, Marshallable { @Override public void onRead(@NotNull WireIn inWire, @NotNull WireOut outWire) { @@ -582,7 +497,7 @@ public void onWrite(WireOut outWire) { } static class PingPongHandler extends AbstractCompleteFlaggingHandler implements - Marshallable, WritableSubHandler { + Marshallable, WritableSubHandler { private static final int LOGGING_INTERVAL = 50; @@ -691,7 +606,7 @@ private void sendPingPongCid(WireOut outWire) { /** * Just a common way of knowing when all handlers have stopped writing */ - abstract static class AbstractCompleteFlaggingHandler extends AbstractSubHandler { + abstract static class AbstractCompleteFlaggingHandler extends AbstractSubHandler { private boolean flaggedComplete = false; protected void flagComplete() { @@ -705,15 +620,15 @@ protected void flagComplete() { static class UberHandlerTestHarness extends AbstractCloseable { - private final MyClusterContext clusterContext; - private final MyClusteredNetworkContext nc; - private final UberHandler uberHandler; + private final TestClusterContext clusterContext; + private final TestClusteredNetworkContext nc; + private final UberHandler uberHandler; private final Wire inWire; private final Wire outWire; public UberHandlerTestHarness() { - clusterContext = new MyClusterContext(); - nc = new MyClusteredNetworkContext(clusterContext); + clusterContext = new TestClusterContext(); + nc = new TestClusteredNetworkContext(clusterContext); nc.wireOutPublisher(new VanillaWireOutPublisher(clusterContext.wireType())); uberHandler = createHandler(); uberHandler.nc(nc); @@ -721,7 +636,7 @@ public UberHandlerTestHarness() { outWire = WireType.BINARY.apply(Bytes.allocateElasticOnHeap()); } - private UberHandler createHandler() { + private UberHandler createHandler() { Wire wire = WireType.BINARY.apply(Bytes.allocateElasticOnHeap()); uberHandler(123, 456, WireType.BINARY).writeMarshallable(wire); try (final DocumentContext documentContext = wire.readingDocument()) { @@ -751,7 +666,7 @@ private void sendMessageToCurrentHandler() { uberHandler.process(inWire.bytes(), outWire.bytes(), nc); } - public MyClusteredNetworkContext nc() { + public TestClusteredNetworkContext nc() { return nc; } diff --git a/src/test/java/net/openhft/chronicle/network/test/TestCluster.java b/src/test/java/net/openhft/chronicle/network/test/TestCluster.java new file mode 100644 index 00000000000..effeb130a20 --- /dev/null +++ b/src/test/java/net/openhft/chronicle/network/test/TestCluster.java @@ -0,0 +1,14 @@ +package net.openhft.chronicle.network.test; + +import net.openhft.chronicle.network.cluster.Cluster; + +/** + * A very minimal {@link Cluster} implementation for use in tests + */ +public class TestCluster extends Cluster { + public TestCluster(TestClusterContext clusterContext) { + super(); + clusterContext(clusterContext); + clusterContext.cluster(this); + } +} diff --git a/src/test/java/net/openhft/chronicle/network/test/TestClusterContext.java b/src/test/java/net/openhft/chronicle/network/test/TestClusterContext.java new file mode 100644 index 00000000000..0f5be870c46 --- /dev/null +++ b/src/test/java/net/openhft/chronicle/network/test/TestClusterContext.java @@ -0,0 +1,95 @@ +package net.openhft.chronicle.network.test; + +import net.openhft.chronicle.core.util.ThrowingFunction; +import net.openhft.chronicle.network.*; +import net.openhft.chronicle.network.api.TcpHandler; +import net.openhft.chronicle.network.cluster.ClusterContext; +import net.openhft.chronicle.network.cluster.HostDetails; +import net.openhft.chronicle.network.connection.VanillaWireOutPublisher; +import net.openhft.chronicle.wire.WireType; +import org.apache.mina.util.IdentityHashSet; +import org.jetbrains.annotations.NotNull; + +import java.io.IOException; +import java.util.Set; +import java.util.function.Function; + +/** + * A very minimal {@link ClusterContext} implementation for use in tests + */ +public class TestClusterContext extends ClusterContext { + + private final Set connectionListeners = new IdentityHashSet<>(); + private Long overrideNetworkContextTimeout; + private boolean disableReconnect; + + @NotNull + public static TestClusterContext forHosts(HostDetails... clusterHosts) { + TestClusterContext ctx = new TestClusterContext().wireType(WireType.BINARY).localIdentifier((byte) clusterHosts[0].hostId()); + ctx.heartbeatIntervalMs(500); + TestCluster cluster = new TestCluster(ctx); + for (HostDetails details : clusterHosts) { + cluster.hostDetails.put(String.valueOf(details.hostId()), details); + } + return ctx; + } + + public void addConnectionListener(ConnectionListener connectionListener) { + connectionListeners.add(connectionListener); + } + + public void disableReconnect() { + disableReconnect = true; + } + + public TestClusterContext overrideNetworkContextTimeout(long overrideNetworkContextTimeout) { + this.overrideNetworkContextTimeout = overrideNetworkContextTimeout; + return this; + } + + @Override + protected String clusterNamePrefix() { + return ""; + } + + @NotNull + @Override + public ThrowingFunction, IOException> tcpEventHandlerFactory() { + return nc -> { + if (nc.isAcceptor()) { + nc.wireOutPublisher(new VanillaWireOutPublisher(wireType())); + } + final TcpEventHandler handler = new TcpEventHandler<>(nc); + final Function> factory = + unused -> new HeaderTcpHandler<>(handler, o -> (TcpHandler) o); + final WireTypeSniffingTcpHandler sniffer = new WireTypeSniffingTcpHandler<>(handler, factory); + handler.tcpHandler(sniffer); + return handler; + }; + } + + @Override + protected void defaults() { + if (this.wireType() == null) + this.wireType(WireType.BINARY); + + if (this.wireOutPublisherFactory() == null) + this.wireOutPublisherFactory(VanillaWireOutPublisher::new); + + if (serverThreadingStrategy() == null) + this.serverThreadingStrategy(ServerThreadingStrategy.SINGLE_THREADED); + + if (this.networkContextFactory() == null) + this.networkContextFactory((TestClusterContext clusterContext) -> { + final TestClusteredNetworkContext nc = new TestClusteredNetworkContext(clusterContext); + connectionListeners.forEach(nc::addConnectionListener); + if (overrideNetworkContextTimeout != null) { + nc.heartbeatTimeoutMsOverride(overrideNetworkContextTimeout); + } + if (disableReconnect) { + nc.disableReconnect(); + } + return nc; + }); + } +} diff --git a/src/test/java/net/openhft/chronicle/network/test/TestClusteredNetworkContext.java b/src/test/java/net/openhft/chronicle/network/test/TestClusteredNetworkContext.java new file mode 100644 index 00000000000..a721a8074c5 --- /dev/null +++ b/src/test/java/net/openhft/chronicle/network/test/TestClusteredNetworkContext.java @@ -0,0 +1,91 @@ +package net.openhft.chronicle.network.test; + +import net.openhft.chronicle.network.ConnectionListener; +import net.openhft.chronicle.network.cluster.VanillaClusteredNetworkContext; +import org.apache.mina.util.IdentityHashSet; +import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; + +import java.util.Collection; +import java.util.Set; + +/** + * A very minimal {@link net.openhft.chronicle.network.cluster.ClusteredNetworkContext} implementation for use in tests + */ +public class TestClusteredNetworkContext extends VanillaClusteredNetworkContext { + + public final Set connectionListeners = new IdentityHashSet<>(); + private final ConnectionListener compositeConnectionListener = new CompositeConnectionListener(connectionListeners); + private Long heartbeatTimeoutMsOverride; + private boolean disableReconnect; + + public TestClusteredNetworkContext(@NotNull TestClusterContext clusterContext) { + super(clusterContext); + } + + /** + * Disable the reconnection, calls to {@link #socketReconnector()} will return a no-op {@link Runnable} + */ + public void disableReconnect() { + disableReconnect = true; + } + + /** + * This will override the NetworkContext#heartbeatTimeoutMs that's set on this context + * + * @param heartbeatTimeoutMsOverride The desired return value for heartbeatTimeoutMs + */ + public void heartbeatTimeoutMsOverride(Long heartbeatTimeoutMsOverride) { + this.heartbeatTimeoutMsOverride = heartbeatTimeoutMsOverride; + } + + @Override + public Runnable socketReconnector() { + if (disableReconnect) { + return () -> { + }; + } + return super.socketReconnector(); + } + + @Override + public long heartbeatTimeoutMs() { + if (heartbeatTimeoutMsOverride != null) { + return heartbeatTimeoutMsOverride; + } + return super.heartbeatTimeoutMs(); + } + + @Override + public void addConnectionListener(ConnectionListener connectionListener) { + connectionListeners.add(connectionListener); + } + + @Override + public void removeConnectionListener(ConnectionListener connectionListener) { + connectionListeners.remove(connectionListener); + } + + @Override + public @Nullable ConnectionListener acquireConnectionListener() { + return compositeConnectionListener; + } + + static class CompositeConnectionListener implements ConnectionListener { + private final Collection listenerCollection; + + public CompositeConnectionListener(Collection listenerCollection) { + this.listenerCollection = listenerCollection; + } + + @Override + public void onConnected(int localIdentifier, int remoteIdentifier, boolean isAcceptor) { + listenerCollection.forEach(cl -> cl.onConnected(localIdentifier, remoteIdentifier, isAcceptor)); + } + + @Override + public void onDisconnected(int localIdentifier, int remoteIdentifier, boolean isAcceptor) { + listenerCollection.forEach(cl -> cl.onDisconnected(localIdentifier, remoteIdentifier, isAcceptor)); + } + } +}