diff --git a/src/StrawberryShake/Client/src/Transport.WebSockets/ISocketClient.cs b/src/StrawberryShake/Client/src/Transport.WebSockets/ISocketClient.cs index 09e43878a37..db419b07629 100644 --- a/src/StrawberryShake/Client/src/Transport.WebSockets/ISocketClient.cs +++ b/src/StrawberryShake/Client/src/Transport.WebSockets/ISocketClient.cs @@ -11,6 +11,11 @@ namespace StrawberryShake.Transport.WebSockets; /// public interface ISocketClient : IAsyncDisposable { + /// + /// An event that is called when the message receiving cycle stoped + /// + event EventHandler ReceiveFinished; + /// /// The URI where the socket should connect to /// @@ -70,4 +75,4 @@ public interface ISocketClient : IAsyncDisposable string message, SocketCloseStatus closeStatus, CancellationToken cancellationToken = default); -} \ No newline at end of file +} diff --git a/src/StrawberryShake/Client/src/Transport.WebSockets/ISocketOperation.cs b/src/StrawberryShake/Client/src/Transport.WebSockets/ISocketOperation.cs index 1ac7ef97534..efa69dbeb06 100644 --- a/src/StrawberryShake/Client/src/Transport.WebSockets/ISocketOperation.cs +++ b/src/StrawberryShake/Client/src/Transport.WebSockets/ISocketOperation.cs @@ -1,5 +1,7 @@ using System; using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; using StrawberryShake.Transport.WebSockets.Messages; namespace StrawberryShake.Transport.WebSockets; @@ -18,4 +20,13 @@ public interface ISocketOperation : IAsyncDisposable /// CReate an operation message stream. /// IAsyncEnumerable ReadAsync(); + + /// + /// Complete the operation + /// + /// + /// A to cancel the completion + /// + /// A task that is completed once the operation is completed + ValueTask CompleteAsync(CancellationToken cancellationToken); } diff --git a/src/StrawberryShake/Client/src/Transport.WebSockets/ISocketProtocol.cs b/src/StrawberryShake/Client/src/Transport.WebSockets/ISocketProtocol.cs index a3d82d092fa..68e702a1d42 100644 --- a/src/StrawberryShake/Client/src/Transport.WebSockets/ISocketProtocol.cs +++ b/src/StrawberryShake/Client/src/Transport.WebSockets/ISocketProtocol.cs @@ -88,4 +88,16 @@ public interface ISocketProtocol : IAsyncDisposable /// /// void Unsubscribe(OnReceiveAsync listener); -} \ No newline at end of file + + /// + /// Notify the protocol to complete + /// + /// The id of the operation to stop + /// + /// A to cancel the notification + /// + /// A task that is completed once the notification is completed + ValueTask NotifyCompletion( + string operationId, + CancellationToken cancellationToken); +} diff --git a/src/StrawberryShake/Client/src/Transport.WebSockets/Session.cs b/src/StrawberryShake/Client/src/Transport.WebSockets/Session.cs index c974665ec43..7e49e3b58d7 100644 --- a/src/StrawberryShake/Client/src/Transport.WebSockets/Session.cs +++ b/src/StrawberryShake/Client/src/Transport.WebSockets/Session.cs @@ -24,6 +24,8 @@ public Session(ISocketClient socketClient) { _socketClient = socketClient ?? throw new ArgumentNullException(nameof(socketClient)); + + _socketClient.ReceiveFinished += ReceiveFinishHandler; } /// @@ -93,6 +95,21 @@ await socketProtocol } } + /// + private async ValueTask CompleteOperation(CancellationToken cancellationToken) + { + foreach (var operation in _operations) + { + await operation.Value.CompleteAsync(cancellationToken); + } + } + + /// + private void ReceiveFinishHandler(object? sender, EventArgs args) + { + _ = CompleteOperation(default); + } + /// /// Opens a session over the socket /// @@ -166,6 +183,7 @@ public async ValueTask DisposeAsync() _operations.Clear(); } + _socketClient.ReceiveFinished -= ReceiveFinishHandler; _socketProtocol?.Unsubscribe(ReceiveMessage); await _socketClient.DisposeAsync(); } diff --git a/src/StrawberryShake/Client/src/Transport.WebSockets/SocketOperation.cs b/src/StrawberryShake/Client/src/Transport.WebSockets/SocketOperation.cs index a966b78be73..c78e39ab0e3 100644 --- a/src/StrawberryShake/Client/src/Transport.WebSockets/SocketOperation.cs +++ b/src/StrawberryShake/Client/src/Transport.WebSockets/SocketOperation.cs @@ -54,6 +54,22 @@ public SocketOperation(ISession manager) public IAsyncEnumerable ReadAsync() => new MessageStream(this, _channel); + /// + public async ValueTask CompleteAsync(CancellationToken cancellationToken) + { + if (!_disposed) + { + try + { + await _channel.Writer.WriteAsync(CompleteOperationMessage.Default, cancellationToken).ConfigureAwait(false); + } + catch (ChannelClosedException) + { + // if the channel is closed we will move on. + } + } + } + private sealed class MessageStream : IAsyncEnumerable { private readonly SocketOperation _operation; @@ -113,4 +129,4 @@ public async ValueTask DisposeAsync() _disposed = true; } } -} \ No newline at end of file +} diff --git a/src/StrawberryShake/Client/src/Transport.WebSockets/SocketProtocolBase.cs b/src/StrawberryShake/Client/src/Transport.WebSockets/SocketProtocolBase.cs index 92409153874..25e7c4dfe4c 100644 --- a/src/StrawberryShake/Client/src/Transport.WebSockets/SocketProtocolBase.cs +++ b/src/StrawberryShake/Client/src/Transport.WebSockets/SocketProtocolBase.cs @@ -65,6 +65,14 @@ public void Unsubscribe(OnReceiveAsync listener) } } + /// + public async ValueTask NotifyCompletion( + string operationId, + CancellationToken cancellationToken) + { + await Notify(operationId, CompleteOperationMessage.Default, cancellationToken).ConfigureAwait(false); + } + /// public virtual ValueTask DisposeAsync() { @@ -78,4 +86,4 @@ public virtual ValueTask DisposeAsync() _disposed = true; return default; } -} \ No newline at end of file +} diff --git a/src/StrawberryShake/Client/src/Transport.WebSockets/WebSocketClient.cs b/src/StrawberryShake/Client/src/Transport.WebSockets/WebSocketClient.cs index 7f43c704003..977580d0072 100644 --- a/src/StrawberryShake/Client/src/Transport.WebSockets/WebSocketClient.cs +++ b/src/StrawberryShake/Client/src/Transport.WebSockets/WebSocketClient.cs @@ -19,8 +19,12 @@ public sealed class WebSocketClient : IWebSocketClient private readonly IReadOnlyList _protocolFactories; private readonly ClientWebSocket _socket; private ISocketProtocol? _activeProtocol; + private bool _receiveFinishEventTriggered = false; private bool _disposed; + /// + public event EventHandler ReceiveFinished = default!; + /// /// Creates a new instance of /// @@ -52,10 +56,22 @@ public sealed class WebSocketClient : IWebSocketClient public string Name { get; } /// - public bool IsClosed => - _disposed - || _socket.CloseStatus.HasValue - || _socket.State == WebSocketState.Aborted; + public bool IsClosed + { + get + { + var closed = _disposed + || _socket.CloseStatus.HasValue + || _socket.State == WebSocketState.Aborted; + + if (closed && !_receiveFinishEventTriggered) + { + _receiveFinishEventTriggered = true; + ReceiveFinished?.Invoke(this, EventArgs.Empty); + } + return closed; + } + } /// public WebSocket Socket => _socket; diff --git a/src/StrawberryShake/Client/test/Transport.WebSocket.Tests/TestHelper/SocketClientStub.cs b/src/StrawberryShake/Client/test/Transport.WebSocket.Tests/TestHelper/SocketClientStub.cs index 0eb8b56fe2d..e479058a913 100644 --- a/src/StrawberryShake/Client/test/Transport.WebSocket.Tests/TestHelper/SocketClientStub.cs +++ b/src/StrawberryShake/Client/test/Transport.WebSocket.Tests/TestHelper/SocketClientStub.cs @@ -17,6 +17,8 @@ public sealed class SocketClientStub : ISocketClient new(TaskCreationOptions.None); private bool _isClosed = true; + public event EventHandler ReceiveFinished = default!; + public SemaphoreSlim Blocker { get; } = new(0); public Uri? Uri { get; set; } diff --git a/src/StrawberryShake/CodeGeneration/test/CodeGeneration.CSharp.Tests/Integration/StarWarsOnReviewSubCompletionTest.cs b/src/StrawberryShake/CodeGeneration/test/CodeGeneration.CSharp.Tests/Integration/StarWarsOnReviewSubCompletionTest.cs index a57596b140c..6c046d172a5 100644 --- a/src/StrawberryShake/CodeGeneration/test/CodeGeneration.CSharp.Tests/Integration/StarWarsOnReviewSubCompletionTest.cs +++ b/src/StrawberryShake/CodeGeneration/test/CodeGeneration.CSharp.Tests/Integration/StarWarsOnReviewSubCompletionTest.cs @@ -1,4 +1,6 @@ using System; +using System.Net.WebSockets; +using System.Reflection; using System.Threading.Tasks; using HotChocolate.AspNetCore.Tests.Utilities; using HotChocolate.StarWars.Models; @@ -6,6 +8,7 @@ using Microsoft.AspNetCore.Hosting; using Microsoft.Extensions.DependencyInjection; using StrawberryShake.Transport.WebSockets; +using StrawberryShake.Transport.WebSockets.Protocols; using Xunit; namespace StrawberryShake.CodeGeneration.CSharp.Integration.StarWarsOnReviewSubCompletion @@ -63,11 +66,128 @@ public async Task Watch_StarWarsOnReviewSubCompletion_Test() { await Task.Delay(1_000); } + + // assert + Assert.True(commentary is not null && completionTriggered); session.Dispose(); + } + + [Fact] + public async Task Watch_StarWarsOnReviewSubCompletionPassively_Test() + { + // arrange + using IWebHost host = TestServerHelper.CreateServer( + _ => { }, + out var port); + var topicEventSender = host.Services.GetRequiredService(); + + var serviceCollection = new ServiceCollection(); + serviceCollection.AddStarWarsOnReviewSubCompletionClient( + profile: StarWarsOnReviewSubCompletionClientProfileKind.Default) + .ConfigureHttpClient( + c => c.BaseAddress = new Uri("http://localhost:" + port + "/graphql")) + .ConfigureWebSocketClient( + c => c.Uri = new Uri("ws://localhost:" + port + "/graphql")); + + serviceCollection.AddSingleton(); + + // act + IServiceProvider services = serviceCollection.BuildServiceProvider(); + IStarWarsOnReviewSubCompletionClient client = services.GetRequiredService(); + + string? commentary = null; + bool completionTriggered = false; + + var sub = client.OnReviewSub.Watch(); + var session = sub.Subscribe( + result => commentary = result.Data?.OnReview?.Commentary, + () => completionTriggered = true); + + var topic = Episode.NewHope; + + // try to send message 10 times + // make sure the subscription connection is successful + for (int times = 0; commentary is null && times < 10; times++) + { + await topicEventSender.SendAsync(topic, new Review { Stars = 1, Commentary = "Commentary" }); + await Task.Delay(1_000); + } + + // simulate network error + var monitor = services.GetRequiredService(); + monitor.AbortSocket(); + + //await host.StopAsync(); + + // waiting for completion message sent + for (int times = 0; !completionTriggered && times < 10; times++) + { + await Task.Delay(1_000); + } // assert Assert.True(commentary is not null && completionTriggered); + + session.Dispose(); + } + } + + public class SubscriptionSocketStateMonitor + { + private const BindingFlags _bindingFlags = BindingFlags.NonPublic | BindingFlags.Instance; + + private readonly ISessionPool _sessionPool; + private readonly Type _sessionPoolType; + private readonly FieldInfo _sessionsField; + + private readonly FieldInfo _socketOperationsDictionaryField = typeof(Session).GetField("_operations", _bindingFlags)!; + private readonly FieldInfo _socketOperationManagerField = typeof(SocketOperation).GetField("_manager", _bindingFlags)!; + private readonly FieldInfo _socketProtocolField = typeof(Session)!.GetField("_socketProtocol", _bindingFlags)!; + private readonly FieldInfo _protocolReceiverField = typeof(GraphQLWebSocketProtocol).GetField("_receiver", _bindingFlags)!; + + private Type? _sessionInfoType; + private PropertyInfo? _sessionProperty; + private Type? _receiverType; + private FieldInfo? _receiverClientField; + + public SubscriptionSocketStateMonitor(ISessionPool sessionPool) + { + _sessionPool = sessionPool; + _sessionPoolType = _sessionPool.GetType(); + _sessionsField = _sessionPoolType.GetField("_sessions", _bindingFlags)!; + } + + public void AbortSocket() + { + var sessionInfos = (_sessionsField!.GetValue(_sessionPool) as System.Collections.IDictionary)!.Values; + + foreach (var sessionInfo in sessionInfos) + { + _sessionInfoType ??= sessionInfo.GetType(); + _sessionProperty ??= _sessionInfoType.GetProperty("Session")!; + var session = _sessionProperty.GetValue(sessionInfo) as Session; + var socketOperations = _socketOperationsDictionaryField + .GetValue(session) as System.Collections.Concurrent.ConcurrentDictionary; + + foreach (var operation in socketOperations!) + { + var operationsession = _socketOperationManagerField.GetValue(operation.Value) as Session; + var protocol = _socketProtocolField.GetValue(operationsession) as GraphQLWebSocketProtocol; + + var receiver = _protocolReceiverField.GetValue(protocol)!; + + _receiverType ??= receiver.GetType(); + _receiverClientField ??= _receiverType.GetField("_client", _bindingFlags)!; + var client = _receiverClientField.GetValue(receiver) as ISocketClient; + + if (client!.IsClosed is false && client is WebSocketClient webSocketClient) + { + var socket = typeof(WebSocketClient).GetField("_socket", _bindingFlags)!.GetValue(webSocketClient) as ClientWebSocket; + socket!.Abort(); + } + } + } } } }