diff --git a/rlib-common/build.gradle b/rlib-common/build.gradle index 1499fe59..8c7df445 100644 --- a/rlib-common/build.gradle +++ b/rlib-common/build.gradle @@ -6,4 +6,5 @@ plugins { dependencies { api projects.rlibLoggerApi api projects.rlibFunctions + testFixturesImplementation libs.lombok } diff --git a/rlib-common/src/test/java/javasabr/rlib/common/util/AwaitUtilsTest.java b/rlib-common/src/test/java/javasabr/rlib/common/util/AwaitUtilsTest.java new file mode 100644 index 00000000..92958d7b --- /dev/null +++ b/rlib-common/src/test/java/javasabr/rlib/common/util/AwaitUtilsTest.java @@ -0,0 +1,48 @@ +package javasabr.rlib.common.util; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import org.junit.jupiter.api.Test; + +/** + * Tests of {@link AwaitUtils} methods. + * + * @author crazyrokr + */ +public class AwaitUtilsTest { + + @Test + void shouldAwaitCondition() throws InterruptedException { + // given + var condition = new AtomicBoolean(false); + var thread = new Thread(() -> { + try { + Thread.sleep(100); + condition.set(true); + } catch (InterruptedException e) { + // ignore + } + }); + + // when + thread.start(); + boolean result = AwaitUtils.await(500, TimeUnit.MILLISECONDS, condition::get); + + // then + assertThat(result).isTrue(); + } + + @Test + void shouldTimeoutIfConditionNotMet() throws InterruptedException { + // given + var condition = new AtomicBoolean(false); + + // when + boolean result = AwaitUtils.await(100, TimeUnit.MILLISECONDS, condition::get); + + // then + assertThat(result).isFalse(); + } +} diff --git a/rlib-common/src/testFixtures/java/javasabr/rlib/common/util/AwaitUtils.java b/rlib-common/src/testFixtures/java/javasabr/rlib/common/util/AwaitUtils.java new file mode 100644 index 00000000..869b3d6e --- /dev/null +++ b/rlib-common/src/testFixtures/java/javasabr/rlib/common/util/AwaitUtils.java @@ -0,0 +1,38 @@ +package javasabr.rlib.common.util; + +import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; +import lombok.experimental.UtilityClass; + +/** + * The utility class to await some conditions. + * + * @author crazyrokr + */ +@UtilityClass +public final class AwaitUtils { + + /** + * Await for the condition during the amount of time units. + * + * @param amount the amount of time units. + * @param unit the time unit. + * @param condition the condition. + * @return true if the condition was met. + * @throws InterruptedException if the current thread was interrupted. + */ + public static boolean await(long amount, TimeUnit unit, Supplier condition) throws InterruptedException { + if (condition.get()) { + return true; + } + var timeoutMillis = unit.toMillis(amount); + var endTime = System.currentTimeMillis() + timeoutMillis; + while (System.currentTimeMillis() < endTime) { + if (condition.get()) { + return true; + } + Thread.sleep(Math.clamp(endTime - System.currentTimeMillis(), 1, 10)); + } + return condition.get(); + } +} diff --git a/rlib-network/build.gradle b/rlib-network/build.gradle index 8ddfd5d0..c8763542 100644 --- a/rlib-network/build.gradle +++ b/rlib-network/build.gradle @@ -11,4 +11,5 @@ dependencies { api libs.project.reactor.core testRuntimeOnly projects.rlibLoggerImpl loadTestRuntimeOnly projects.rlibLoggerImpl + testImplementation testFixtures(projects.rlibCommon) } diff --git a/rlib-network/src/main/java/javasabr/rlib/network/exception/ConnectionClosedException.java b/rlib-network/src/main/java/javasabr/rlib/network/exception/ConnectionClosedException.java new file mode 100644 index 00000000..6964489e --- /dev/null +++ b/rlib-network/src/main/java/javasabr/rlib/network/exception/ConnectionClosedException.java @@ -0,0 +1,28 @@ +package javasabr.rlib.network.exception; + +/** + * Thrown when a network connection has been closed + * + * @since 10.0.0 + */ +public class ConnectionClosedException extends NetworkException { + + /** + * Creates a new exception for a closed connection + * + * @param remoteAddress the remote address + */ + public ConnectionClosedException(String remoteAddress) { + super("Connection closed: %s".formatted(remoteAddress)); + } + + /** + * Creates a new exception for a closed connection with a cause + * + * @param remoteAddress the remote address + * @param cause the cause + */ + public ConnectionClosedException(String remoteAddress, Throwable cause) { + super("Connection closed: %s".formatted(remoteAddress), cause); + } +} diff --git a/rlib-network/src/main/java/javasabr/rlib/network/impl/AbstractConnection.java b/rlib-network/src/main/java/javasabr/rlib/network/impl/AbstractConnection.java index 7e8e04c6..fb12d4c5 100644 --- a/rlib-network/src/main/java/javasabr/rlib/network/impl/AbstractConnection.java +++ b/rlib-network/src/main/java/javasabr/rlib/network/impl/AbstractConnection.java @@ -4,18 +4,23 @@ import java.nio.channels.AsynchronousChannel; import java.nio.channels.AsynchronousSocketChannel; +import java.util.Collection; import java.util.Deque; import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.locks.StampedLock; import java.util.function.BiConsumer; +import javasabr.rlib.collections.array.Array; import javasabr.rlib.collections.array.ArrayFactory; +import javasabr.rlib.collections.array.LockableArray; import javasabr.rlib.collections.array.MutableArray; import javasabr.rlib.collections.deque.DequeFactory; +import javasabr.rlib.collections.operation.LockableOperations; import javasabr.rlib.network.BufferAllocator; import javasabr.rlib.network.Connection; import javasabr.rlib.network.Network; import javasabr.rlib.network.UnsafeConnection; +import javasabr.rlib.network.exception.ConnectionClosedException; import javasabr.rlib.network.packet.NetworkPacketReader; import javasabr.rlib.network.packet.NetworkPacketWriter; import javasabr.rlib.network.packet.ReadableNetworkPacket; @@ -64,6 +69,8 @@ public WritablePacketWithFeedback(CompletableFuture attachment, Writabl final MutableArray>> validPacketSubscribers; final MutableArray>> invalidPacketSubscribers; + final LockableArray> activeSinks; + final LockableOperations>> activeSinksOperations; final int maxPacketsByRead; @@ -84,6 +91,8 @@ public AbstractConnection( this.closed = new AtomicBoolean(false); this.validPacketSubscribers = ArrayFactory.copyOnModifyArray(BiConsumer.class); this.invalidPacketSubscribers = ArrayFactory.copyOnModifyArray(BiConsumer.class); + this.activeSinks = ArrayFactory.stampedLockBasedArray(FluxSink.class); + this.activeSinksOperations = activeSinks.operations(); this.remoteAddress = String.valueOf(NetworkUtils.getRemoteAddress(channel)); } @@ -134,10 +143,12 @@ protected void registerFluxOnReceivedEvents( validPacketSubscribers.add(validListener); invalidPacketSubscribers.add(invalidListener); + activeSinksOperations.inWriteLock(sink, Collection::add); sink.onDispose(() -> { validPacketSubscribers.remove(validListener); - validPacketSubscribers.remove(invalidListener); + invalidPacketSubscribers.remove(invalidListener); + activeSinksOperations.inWriteLock(sink, Collection::remove); }); network.inNetworkThread(() -> packetReader().startRead()); @@ -146,14 +157,22 @@ protected void registerFluxOnReceivedEvents( protected void registerFluxOnReceivedValidPackets(FluxSink> sink) { BiConsumer> listener = (connection, packet) -> sink.next(packet); validPacketSubscribers.add(listener); - sink.onDispose(() -> validPacketSubscribers.remove(listener)); + activeSinksOperations.inWriteLock(sink, Collection::add); + sink.onDispose(() -> { + validPacketSubscribers.remove(listener); + activeSinksOperations.inWriteLock(sink, Collection::remove); + }); network.inNetworkThread(() -> packetReader().startRead()); } protected void registerFluxOnReceivedInvalidPackets(FluxSink> sink) { BiConsumer> listener = (connection, packet) -> sink.next(packet); invalidPacketSubscribers.add(listener); - sink.onDispose(() -> invalidPacketSubscribers.remove(listener)); + activeSinksOperations.inWriteLock(sink, Collection::add); + sink.onDispose(() -> { + invalidPacketSubscribers.remove(listener); + activeSinksOperations.inWriteLock(sink, Collection::remove); + }); network.inNetworkThread(() -> packetReader().startRead()); } @@ -184,6 +203,27 @@ protected void doClose() { clearWaitPackets(); packetReader().close(); packetWriter().close(); + notifyActiveSinks(); + } + + protected void notifyActiveSinks() { + Boolean noActiveSinks = activeSinksOperations.getInReadLock(Array::isEmpty); + if (noActiveSinks) { + return; + } + notifySinksWithError(new ConnectionClosedException(remoteAddress)); + activeSinksOperations.inWriteLock(Collection::clear); + } + + protected void notifySinksWithError(Throwable error) { + Array> localActiveSinks = activeSinksOperations.getInReadLock(Array::copyOf); + for (FluxSink sink : localActiveSinks) { + try { + sink.error(error); + } catch (RuntimeException e) { + log.error(e.getMessage(), "Failed to notify sink of connection closure: "::formatted); + } + } } /** diff --git a/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/AbstractNetworkPacketReader.java b/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/AbstractNetworkPacketReader.java index de7e5b8a..ee683870 100644 --- a/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/AbstractNetworkPacketReader.java +++ b/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/AbstractNetworkPacketReader.java @@ -461,10 +461,14 @@ protected void handleFailedReceiving(Throwable exception, ByteBuffer readingBuff retryReadLater(); } } - case AsynchronousCloseException ex -> - log.info(remoteAddress(), "[%s] Connection was closed"::formatted); - case ClosedChannelException ex -> - log.info(remoteAddress(), "[%s] Connection was closed"::formatted); + case AsynchronousCloseException ex -> { + log.info(remoteAddress(), "[%s] Connection was closed"::formatted); + connection.close(); + } + case ClosedChannelException ex -> { + log.info(remoteAddress(), "[%s] Connection was closed"::formatted); + connection.close(); + } default -> { log.error(exception); connection.close(); diff --git a/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/AbstractSslNetworkPacketReader.java b/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/AbstractSslNetworkPacketReader.java index 6ab75309..5b7033d8 100644 --- a/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/AbstractSslNetworkPacketReader.java +++ b/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/AbstractSslNetworkPacketReader.java @@ -76,6 +76,7 @@ protected AbstractSslNetworkPacketReader( protected void handleReceivedData(int receivedBytes, ByteBuffer readingBuffer) { if (receivedBytes == -1) { doHandshake(sslNetworkBuffer(), -1); + handleEmptyReadFromChannel(); return; } super.handleReceivedData(receivedBytes, readingBuffer); diff --git a/rlib-network/src/test/java/javasabr/rlib/network/ConnectionCloseTest.java b/rlib-network/src/test/java/javasabr/rlib/network/ConnectionCloseTest.java new file mode 100644 index 00000000..e831d302 --- /dev/null +++ b/rlib-network/src/test/java/javasabr/rlib/network/ConnectionCloseTest.java @@ -0,0 +1,93 @@ +package javasabr.rlib.network; + +import static javasabr.rlib.network.util.NetworkUtils.createAllTrustedClientSslContext; +import static javasabr.rlib.network.util.NetworkUtils.createSslContext; +import static org.assertj.core.api.Assertions.assertThat; + +import java.net.InetSocketAddress; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import javasabr.rlib.common.util.AwaitUtils; +import javasabr.rlib.network.exception.ConnectionClosedException; +import javasabr.rlib.network.impl.AbstractConnection; +import javasabr.rlib.network.impl.DefaultConnection; +import javasabr.rlib.network.packet.impl.DefaultReadableNetworkPacket; +import javasabr.rlib.network.packet.impl.StringWritableNetworkPacket; +import javasabr.rlib.network.packet.registry.ReadableNetworkPacketRegistry; +import lombok.SneakyThrows; +import org.junit.jupiter.api.Test; + +/** + * Checking that the connections are closed correctly + * + * @author crazyrokr + */ +public class ConnectionCloseTest extends BaseNetworkTest { + + @Test + void shouldPropagateConnectionCloseToClient() throws InterruptedException { + // given + var packetRegistry = ReadableNetworkPacketRegistry.of( + DefaultReadableNetworkPacket.class, + DefaultConnection.class, + DefaultNetworkTest.ServerPackets.RequestEchoMessage.class, + DefaultNetworkTest.ServerPackets.RequestServerTime.class); + var serverNetwork = NetworkFactory.defaultServerNetwork(packetRegistry); + InetSocketAddress serverAddress = serverNetwork.start(); + serverNetwork.onAccept(AbstractConnection::close); + var clientNetwork = NetworkFactory.defaultClientNetwork(packetRegistry); + CountDownLatch closeLatch = new CountDownLatch(1); + + // when + try { + clientNetwork + .connectReactive(serverAddress) + .flatMapMany(AbstractConnection::receivedEvents) + .doOnError(e -> { + if (e instanceof ConnectionClosedException) { + closeLatch.countDown(); + } + }) + .subscribe(); + + // then + assertThat(closeLatch.await(5000, TimeUnit.MILLISECONDS)) + .as("Client should be notified that connection is closed") + .isTrue(); + } finally { + // cleanup + clientNetwork.shutdown(); + serverNetwork.shutdown(); + } + } + + @Test + @SneakyThrows + void shouldCloseServerConnectionWhenClientClosesTcpChannelAbruptly() { + // given + try (var keystoreFile = ConnectionCloseTest.class.getResourceAsStream("/ssl/rlib_test_cert.p12"); + var testNetwork = buildStringSSLNetwork( + createSslContext(keystoreFile, "test"), + createAllTrustedClientSslContext())) { + var serverConnection = testNetwork.serverToClient; + var clientConnection = testNetwork.clientToServer; + CountDownLatch dataReceivedLatch = new CountDownLatch(1); + serverConnection.onReceiveValidPacket((conn, packet) -> dataReceivedLatch.countDown()); + clientConnection.sendInBackground(new StringWritableNetworkPacket<>("handshake")); + assertThat(dataReceivedLatch.await(5, TimeUnit.SECONDS)) + .as("Client connection should be closed prior server side verification") + .isTrue(); + + // when + clientConnection.channel().close(); + assertThat(AwaitUtils.await(5, TimeUnit.SECONDS, clientConnection::closed)) + .as("Client connection should be closed prior server side verification") + .isTrue(); + + // then + assertThat(AwaitUtils.await(5, TimeUnit.SECONDS, serverConnection::closed)) + .as("Server connection should be closed after receiving EOF from abruptly closed client channel") + .isTrue(); + } + } +}