Skip to content
1 change: 1 addition & 0 deletions rlib-common/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ plugins {
dependencies {
api projects.rlibLoggerApi
api projects.rlibFunctions
testFixturesImplementation libs.lombok
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package javasabr.rlib.common.util;

import static org.assertj.core.api.Assertions.assertThat;

import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import org.junit.jupiter.api.Test;

/**
* Tests of {@link AwaitUtils} methods.
*
* @author crazyrokr
*/
public class AwaitUtilsTest {

@Test
void shouldAwaitCondition() throws InterruptedException {
// given
var condition = new AtomicBoolean(false);
var thread = new Thread(() -> {
try {
Thread.sleep(100);
condition.set(true);
} catch (InterruptedException e) {
// ignore
}
});

// when
thread.start();
boolean result = AwaitUtils.await(500, TimeUnit.MILLISECONDS, condition::get);

// then
assertThat(result).isTrue();
}

@Test
void shouldTimeoutIfConditionNotMet() throws InterruptedException {
// given
var condition = new AtomicBoolean(false);

// when
boolean result = AwaitUtils.await(100, TimeUnit.MILLISECONDS, condition::get);

// then
assertThat(result).isFalse();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package javasabr.rlib.common.util;

import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;
import lombok.experimental.UtilityClass;

/**
* The utility class to await some conditions.
*
* @author crazyrokr
*/
@UtilityClass
public final class AwaitUtils {

Comment thread
crazyrokr marked this conversation as resolved.
/**
* Await for the condition during the amount of time units.
*
* @param amount the amount of time units.
* @param unit the time unit.
* @param condition the condition.
* @return true if the condition was met.
* @throws InterruptedException if the current thread was interrupted.
*/
public static boolean await(long amount, TimeUnit unit, Supplier<Boolean> condition) throws InterruptedException {
if (condition.get()) {
return true;
}
var timeoutMillis = unit.toMillis(amount);
var endTime = System.currentTimeMillis() + timeoutMillis;
while (System.currentTimeMillis() < endTime) {
if (condition.get()) {
return true;
}
Thread.sleep(Math.clamp(endTime - System.currentTimeMillis(), 1, 10));
}
return condition.get();
}
}
1 change: 1 addition & 0 deletions rlib-network/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ dependencies {
api libs.project.reactor.core
testRuntimeOnly projects.rlibLoggerImpl
loadTestRuntimeOnly projects.rlibLoggerImpl
testImplementation testFixtures(projects.rlibCommon)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package javasabr.rlib.network.exception;

/**
* Thrown when a network connection has been closed
*
* @since 10.0.0
*/
public class ConnectionClosedException extends NetworkException {

/**
* Creates a new exception for a closed connection
*
* @param remoteAddress the remote address
*/
public ConnectionClosedException(String remoteAddress) {
super("Connection closed: %s".formatted(remoteAddress));
}

Comment thread
crazyrokr marked this conversation as resolved.
/**
* Creates a new exception for a closed connection with a cause
*
* @param remoteAddress the remote address
* @param cause the cause
*/
public ConnectionClosedException(String remoteAddress, Throwable cause) {
super("Connection closed: %s".formatted(remoteAddress), cause);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,23 @@

import java.nio.channels.AsynchronousChannel;
import java.nio.channels.AsynchronousSocketChannel;
import java.util.Collection;
import java.util.Deque;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.StampedLock;
import java.util.function.BiConsumer;
import javasabr.rlib.collections.array.Array;
import javasabr.rlib.collections.array.ArrayFactory;
import javasabr.rlib.collections.array.LockableArray;
import javasabr.rlib.collections.array.MutableArray;
import javasabr.rlib.collections.deque.DequeFactory;
import javasabr.rlib.collections.operation.LockableOperations;
import javasabr.rlib.network.BufferAllocator;
import javasabr.rlib.network.Connection;
import javasabr.rlib.network.Network;
import javasabr.rlib.network.UnsafeConnection;
import javasabr.rlib.network.exception.ConnectionClosedException;
import javasabr.rlib.network.packet.NetworkPacketReader;
import javasabr.rlib.network.packet.NetworkPacketWriter;
import javasabr.rlib.network.packet.ReadableNetworkPacket;
Expand Down Expand Up @@ -64,6 +69,8 @@ public WritablePacketWithFeedback(CompletableFuture<Boolean> attachment, Writabl

final MutableArray<BiConsumer<C, ? super ReadableNetworkPacket<C>>> validPacketSubscribers;
final MutableArray<BiConsumer<C, ? super ReadableNetworkPacket<C>>> invalidPacketSubscribers;
final LockableArray<FluxSink<?>> activeSinks;
final LockableOperations<LockableArray<FluxSink<?>>> activeSinksOperations;

final int maxPacketsByRead;

Expand All @@ -84,6 +91,8 @@ public AbstractConnection(
this.closed = new AtomicBoolean(false);
this.validPacketSubscribers = ArrayFactory.copyOnModifyArray(BiConsumer.class);
this.invalidPacketSubscribers = ArrayFactory.copyOnModifyArray(BiConsumer.class);
this.activeSinks = ArrayFactory.stampedLockBasedArray(FluxSink.class);
this.activeSinksOperations = activeSinks.operations();
this.remoteAddress = String.valueOf(NetworkUtils.getRemoteAddress(channel));
}

Expand Down Expand Up @@ -134,10 +143,12 @@ protected void registerFluxOnReceivedEvents(

validPacketSubscribers.add(validListener);
invalidPacketSubscribers.add(invalidListener);
activeSinksOperations.inWriteLock(sink, Collection::add);

sink.onDispose(() -> {
validPacketSubscribers.remove(validListener);
validPacketSubscribers.remove(invalidListener);
invalidPacketSubscribers.remove(invalidListener);
activeSinksOperations.inWriteLock(sink, Collection::remove);
});

network.inNetworkThread(() -> packetReader().startRead());
Expand All @@ -146,14 +157,22 @@ protected void registerFluxOnReceivedEvents(
protected void registerFluxOnReceivedValidPackets(FluxSink<? super ReadableNetworkPacket<C>> sink) {
BiConsumer<C, ReadableNetworkPacket<C>> listener = (connection, packet) -> sink.next(packet);
validPacketSubscribers.add(listener);
sink.onDispose(() -> validPacketSubscribers.remove(listener));
activeSinksOperations.inWriteLock(sink, Collection::add);
sink.onDispose(() -> {
validPacketSubscribers.remove(listener);
activeSinksOperations.inWriteLock(sink, Collection::remove);
});
network.inNetworkThread(() -> packetReader().startRead());
}

protected void registerFluxOnReceivedInvalidPackets(FluxSink<? super ReadableNetworkPacket<C>> sink) {
BiConsumer<C, ReadableNetworkPacket<C>> listener = (connection, packet) -> sink.next(packet);
invalidPacketSubscribers.add(listener);
sink.onDispose(() -> invalidPacketSubscribers.remove(listener));
activeSinksOperations.inWriteLock(sink, Collection::add);
sink.onDispose(() -> {
invalidPacketSubscribers.remove(listener);
activeSinksOperations.inWriteLock(sink, Collection::remove);
});
network.inNetworkThread(() -> packetReader().startRead());
}

Expand Down Expand Up @@ -184,6 +203,27 @@ protected void doClose() {
clearWaitPackets();
packetReader().close();
packetWriter().close();
notifyActiveSinks();
}

protected void notifyActiveSinks() {
Boolean noActiveSinks = activeSinksOperations.getInReadLock(Array::isEmpty);
if (noActiveSinks) {
return;
}
notifySinksWithError(new ConnectionClosedException(remoteAddress));
activeSinksOperations.inWriteLock(Collection::clear);
}

protected void notifySinksWithError(Throwable error) {
Array<FluxSink<?>> localActiveSinks = activeSinksOperations.getInReadLock(Array::copyOf);
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you really want to allocate a full array for such reading?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so because sink.error(exc) in line 223 activates sink.onDispose() callback making write lock, so this situation causes dead lock. Any better ideas are welcome

https://github.com/crazyrokr/RLib/blob/35939cee7afbfc724d229ea08b733a7005393bdc/rlib-network/src/main/java/javasabr/rlib/network/impl/AbstractConnection.java#L148-L152

for (FluxSink<?> sink : localActiveSinks) {
try {
sink.error(error);
} catch (RuntimeException e) {
log.error(e.getMessage(), "Failed to notify sink of connection closure: "::formatted);
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -461,10 +461,14 @@ protected void handleFailedReceiving(Throwable exception, ByteBuffer readingBuff
retryReadLater();
}
}
case AsynchronousCloseException ex ->
log.info(remoteAddress(), "[%s] Connection was closed"::formatted);
case ClosedChannelException ex ->
log.info(remoteAddress(), "[%s] Connection was closed"::formatted);
case AsynchronousCloseException ex -> {
log.info(remoteAddress(), "[%s] Connection was closed"::formatted);
connection.close();
}
case ClosedChannelException ex -> {
log.info(remoteAddress(), "[%s] Connection was closed"::formatted);
connection.close();
}
default -> {
log.error(exception);
connection.close();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ protected AbstractSslNetworkPacketReader(
protected void handleReceivedData(int receivedBytes, ByteBuffer readingBuffer) {
if (receivedBytes == -1) {
doHandshake(sslNetworkBuffer(), -1);
handleEmptyReadFromChannel();
return;
}
super.handleReceivedData(receivedBytes, readingBuffer);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package javasabr.rlib.network;

import static javasabr.rlib.network.util.NetworkUtils.createAllTrustedClientSslContext;
import static javasabr.rlib.network.util.NetworkUtils.createSslContext;
import static org.assertj.core.api.Assertions.assertThat;

import java.net.InetSocketAddress;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import javasabr.rlib.common.util.AwaitUtils;
import javasabr.rlib.network.exception.ConnectionClosedException;
import javasabr.rlib.network.impl.AbstractConnection;
import javasabr.rlib.network.impl.DefaultConnection;
import javasabr.rlib.network.packet.impl.DefaultReadableNetworkPacket;
import javasabr.rlib.network.packet.impl.StringWritableNetworkPacket;
import javasabr.rlib.network.packet.registry.ReadableNetworkPacketRegistry;
import lombok.SneakyThrows;
import org.junit.jupiter.api.Test;

/**
* Checking that the connections are closed correctly
*
* @author crazyrokr
*/
public class ConnectionCloseTest extends BaseNetworkTest {

@Test
void shouldPropagateConnectionCloseToClient() throws InterruptedException {
// given
var packetRegistry = ReadableNetworkPacketRegistry.of(
DefaultReadableNetworkPacket.class,
DefaultConnection.class,
DefaultNetworkTest.ServerPackets.RequestEchoMessage.class,
DefaultNetworkTest.ServerPackets.RequestServerTime.class);
var serverNetwork = NetworkFactory.defaultServerNetwork(packetRegistry);
InetSocketAddress serverAddress = serverNetwork.start();
serverNetwork.onAccept(AbstractConnection::close);
var clientNetwork = NetworkFactory.defaultClientNetwork(packetRegistry);
CountDownLatch closeLatch = new CountDownLatch(1);

// when
try {
clientNetwork
.connectReactive(serverAddress)
.flatMapMany(AbstractConnection::receivedEvents)
.doOnError(e -> {
if (e instanceof ConnectionClosedException) {
closeLatch.countDown();
}
})
.subscribe();

// then
assertThat(closeLatch.await(5000, TimeUnit.MILLISECONDS))
.as("Client should be notified that connection is closed")
.isTrue();
} finally {
// cleanup
clientNetwork.shutdown();
serverNetwork.shutdown();
}
}

@Test
@SneakyThrows
void shouldCloseServerConnectionWhenClientClosesTcpChannelAbruptly() {
// given
try (var keystoreFile = ConnectionCloseTest.class.getResourceAsStream("/ssl/rlib_test_cert.p12");
var testNetwork = buildStringSSLNetwork(
createSslContext(keystoreFile, "test"),
createAllTrustedClientSslContext())) {
var serverConnection = testNetwork.serverToClient;
var clientConnection = testNetwork.clientToServer;
CountDownLatch dataReceivedLatch = new CountDownLatch(1);
serverConnection.onReceiveValidPacket((conn, packet) -> dataReceivedLatch.countDown());
clientConnection.sendInBackground(new StringWritableNetworkPacket<>("handshake"));
assertThat(dataReceivedLatch.await(5, TimeUnit.SECONDS))
.as("Client connection should be closed prior server side verification")
.isTrue();

// when
clientConnection.channel().close();
assertThat(AwaitUtils.await(5, TimeUnit.SECONDS, clientConnection::closed))
.as("Client connection should be closed prior server side verification")
.isTrue();

// then
assertThat(AwaitUtils.await(5, TimeUnit.SECONDS, serverConnection::closed))
.as("Server connection should be closed after receiving EOF from abruptly closed client channel")
.isTrue();
}
}
}
Loading