diff --git a/src/StrawberryShake/Client/src/Transport.WebSockets/ISocketClient.cs b/src/StrawberryShake/Client/src/Transport.WebSockets/ISocketClient.cs
index 09e43878a37..db419b07629 100644
--- a/src/StrawberryShake/Client/src/Transport.WebSockets/ISocketClient.cs
+++ b/src/StrawberryShake/Client/src/Transport.WebSockets/ISocketClient.cs
@@ -11,6 +11,11 @@ namespace StrawberryShake.Transport.WebSockets;
///
public interface ISocketClient : IAsyncDisposable
{
+ ///
+ /// An event that is called when the message receiving cycle stoped
+ ///
+ event EventHandler ReceiveFinished;
+
///
/// The URI where the socket should connect to
///
@@ -70,4 +75,4 @@ public interface ISocketClient : IAsyncDisposable
string message,
SocketCloseStatus closeStatus,
CancellationToken cancellationToken = default);
-}
\ No newline at end of file
+}
diff --git a/src/StrawberryShake/Client/src/Transport.WebSockets/ISocketOperation.cs b/src/StrawberryShake/Client/src/Transport.WebSockets/ISocketOperation.cs
index 1ac7ef97534..efa69dbeb06 100644
--- a/src/StrawberryShake/Client/src/Transport.WebSockets/ISocketOperation.cs
+++ b/src/StrawberryShake/Client/src/Transport.WebSockets/ISocketOperation.cs
@@ -1,5 +1,7 @@
using System;
using System.Collections.Generic;
+using System.Threading;
+using System.Threading.Tasks;
using StrawberryShake.Transport.WebSockets.Messages;
namespace StrawberryShake.Transport.WebSockets;
@@ -18,4 +20,13 @@ public interface ISocketOperation : IAsyncDisposable
/// CReate an operation message stream.
///
IAsyncEnumerable ReadAsync();
+
+ ///
+ /// Complete the operation
+ ///
+ ///
+ /// A to cancel the completion
+ ///
+ /// A task that is completed once the operation is completed
+ ValueTask CompleteAsync(CancellationToken cancellationToken);
}
diff --git a/src/StrawberryShake/Client/src/Transport.WebSockets/ISocketProtocol.cs b/src/StrawberryShake/Client/src/Transport.WebSockets/ISocketProtocol.cs
index a3d82d092fa..68e702a1d42 100644
--- a/src/StrawberryShake/Client/src/Transport.WebSockets/ISocketProtocol.cs
+++ b/src/StrawberryShake/Client/src/Transport.WebSockets/ISocketProtocol.cs
@@ -88,4 +88,16 @@ public interface ISocketProtocol : IAsyncDisposable
///
///
void Unsubscribe(OnReceiveAsync listener);
-}
\ No newline at end of file
+
+ ///
+ /// Notify the protocol to complete
+ ///
+ /// The id of the operation to stop
+ ///
+ /// A to cancel the notification
+ ///
+ /// A task that is completed once the notification is completed
+ ValueTask NotifyCompletion(
+ string operationId,
+ CancellationToken cancellationToken);
+}
diff --git a/src/StrawberryShake/Client/src/Transport.WebSockets/Session.cs b/src/StrawberryShake/Client/src/Transport.WebSockets/Session.cs
index c974665ec43..7e49e3b58d7 100644
--- a/src/StrawberryShake/Client/src/Transport.WebSockets/Session.cs
+++ b/src/StrawberryShake/Client/src/Transport.WebSockets/Session.cs
@@ -24,6 +24,8 @@ public Session(ISocketClient socketClient)
{
_socketClient = socketClient ??
throw new ArgumentNullException(nameof(socketClient));
+
+ _socketClient.ReceiveFinished += ReceiveFinishHandler;
}
///
@@ -93,6 +95,21 @@ await socketProtocol
}
}
+ ///
+ private async ValueTask CompleteOperation(CancellationToken cancellationToken)
+ {
+ foreach (var operation in _operations)
+ {
+ await operation.Value.CompleteAsync(cancellationToken);
+ }
+ }
+
+ ///
+ private void ReceiveFinishHandler(object? sender, EventArgs args)
+ {
+ _ = CompleteOperation(default);
+ }
+
///
/// Opens a session over the socket
///
@@ -166,6 +183,7 @@ public async ValueTask DisposeAsync()
_operations.Clear();
}
+ _socketClient.ReceiveFinished -= ReceiveFinishHandler;
_socketProtocol?.Unsubscribe(ReceiveMessage);
await _socketClient.DisposeAsync();
}
diff --git a/src/StrawberryShake/Client/src/Transport.WebSockets/SocketOperation.cs b/src/StrawberryShake/Client/src/Transport.WebSockets/SocketOperation.cs
index a966b78be73..c78e39ab0e3 100644
--- a/src/StrawberryShake/Client/src/Transport.WebSockets/SocketOperation.cs
+++ b/src/StrawberryShake/Client/src/Transport.WebSockets/SocketOperation.cs
@@ -54,6 +54,22 @@ public SocketOperation(ISession manager)
public IAsyncEnumerable ReadAsync()
=> new MessageStream(this, _channel);
+ ///
+ public async ValueTask CompleteAsync(CancellationToken cancellationToken)
+ {
+ if (!_disposed)
+ {
+ try
+ {
+ await _channel.Writer.WriteAsync(CompleteOperationMessage.Default, cancellationToken).ConfigureAwait(false);
+ }
+ catch (ChannelClosedException)
+ {
+ // if the channel is closed we will move on.
+ }
+ }
+ }
+
private sealed class MessageStream : IAsyncEnumerable
{
private readonly SocketOperation _operation;
@@ -113,4 +129,4 @@ public async ValueTask DisposeAsync()
_disposed = true;
}
}
-}
\ No newline at end of file
+}
diff --git a/src/StrawberryShake/Client/src/Transport.WebSockets/SocketProtocolBase.cs b/src/StrawberryShake/Client/src/Transport.WebSockets/SocketProtocolBase.cs
index 92409153874..25e7c4dfe4c 100644
--- a/src/StrawberryShake/Client/src/Transport.WebSockets/SocketProtocolBase.cs
+++ b/src/StrawberryShake/Client/src/Transport.WebSockets/SocketProtocolBase.cs
@@ -65,6 +65,14 @@ public void Unsubscribe(OnReceiveAsync listener)
}
}
+ ///
+ public async ValueTask NotifyCompletion(
+ string operationId,
+ CancellationToken cancellationToken)
+ {
+ await Notify(operationId, CompleteOperationMessage.Default, cancellationToken).ConfigureAwait(false);
+ }
+
///
public virtual ValueTask DisposeAsync()
{
@@ -78,4 +86,4 @@ public virtual ValueTask DisposeAsync()
_disposed = true;
return default;
}
-}
\ No newline at end of file
+}
diff --git a/src/StrawberryShake/Client/src/Transport.WebSockets/WebSocketClient.cs b/src/StrawberryShake/Client/src/Transport.WebSockets/WebSocketClient.cs
index 7f43c704003..977580d0072 100644
--- a/src/StrawberryShake/Client/src/Transport.WebSockets/WebSocketClient.cs
+++ b/src/StrawberryShake/Client/src/Transport.WebSockets/WebSocketClient.cs
@@ -19,8 +19,12 @@ public sealed class WebSocketClient : IWebSocketClient
private readonly IReadOnlyList _protocolFactories;
private readonly ClientWebSocket _socket;
private ISocketProtocol? _activeProtocol;
+ private bool _receiveFinishEventTriggered = false;
private bool _disposed;
+ ///
+ public event EventHandler ReceiveFinished = default!;
+
///
/// Creates a new instance of
///
@@ -52,10 +56,22 @@ public sealed class WebSocketClient : IWebSocketClient
public string Name { get; }
///
- public bool IsClosed =>
- _disposed
- || _socket.CloseStatus.HasValue
- || _socket.State == WebSocketState.Aborted;
+ public bool IsClosed
+ {
+ get
+ {
+ var closed = _disposed
+ || _socket.CloseStatus.HasValue
+ || _socket.State == WebSocketState.Aborted;
+
+ if (closed && !_receiveFinishEventTriggered)
+ {
+ _receiveFinishEventTriggered = true;
+ ReceiveFinished?.Invoke(this, EventArgs.Empty);
+ }
+ return closed;
+ }
+ }
///
public WebSocket Socket => _socket;
diff --git a/src/StrawberryShake/Client/test/Transport.WebSocket.Tests/TestHelper/SocketClientStub.cs b/src/StrawberryShake/Client/test/Transport.WebSocket.Tests/TestHelper/SocketClientStub.cs
index 0eb8b56fe2d..e479058a913 100644
--- a/src/StrawberryShake/Client/test/Transport.WebSocket.Tests/TestHelper/SocketClientStub.cs
+++ b/src/StrawberryShake/Client/test/Transport.WebSocket.Tests/TestHelper/SocketClientStub.cs
@@ -17,6 +17,8 @@ public sealed class SocketClientStub : ISocketClient
new(TaskCreationOptions.None);
private bool _isClosed = true;
+ public event EventHandler ReceiveFinished = default!;
+
public SemaphoreSlim Blocker { get; } = new(0);
public Uri? Uri { get; set; }
diff --git a/src/StrawberryShake/CodeGeneration/test/CodeGeneration.CSharp.Tests/Integration/StarWarsOnReviewSubCompletionTest.cs b/src/StrawberryShake/CodeGeneration/test/CodeGeneration.CSharp.Tests/Integration/StarWarsOnReviewSubCompletionTest.cs
index a57596b140c..6c046d172a5 100644
--- a/src/StrawberryShake/CodeGeneration/test/CodeGeneration.CSharp.Tests/Integration/StarWarsOnReviewSubCompletionTest.cs
+++ b/src/StrawberryShake/CodeGeneration/test/CodeGeneration.CSharp.Tests/Integration/StarWarsOnReviewSubCompletionTest.cs
@@ -1,4 +1,6 @@
using System;
+using System.Net.WebSockets;
+using System.Reflection;
using System.Threading.Tasks;
using HotChocolate.AspNetCore.Tests.Utilities;
using HotChocolate.StarWars.Models;
@@ -6,6 +8,7 @@
using Microsoft.AspNetCore.Hosting;
using Microsoft.Extensions.DependencyInjection;
using StrawberryShake.Transport.WebSockets;
+using StrawberryShake.Transport.WebSockets.Protocols;
using Xunit;
namespace StrawberryShake.CodeGeneration.CSharp.Integration.StarWarsOnReviewSubCompletion
@@ -63,11 +66,128 @@ public async Task Watch_StarWarsOnReviewSubCompletion_Test()
{
await Task.Delay(1_000);
}
+
+ // assert
+ Assert.True(commentary is not null && completionTriggered);
session.Dispose();
+ }
+
+ [Fact]
+ public async Task Watch_StarWarsOnReviewSubCompletionPassively_Test()
+ {
+ // arrange
+ using IWebHost host = TestServerHelper.CreateServer(
+ _ => { },
+ out var port);
+ var topicEventSender = host.Services.GetRequiredService();
+
+ var serviceCollection = new ServiceCollection();
+ serviceCollection.AddStarWarsOnReviewSubCompletionClient(
+ profile: StarWarsOnReviewSubCompletionClientProfileKind.Default)
+ .ConfigureHttpClient(
+ c => c.BaseAddress = new Uri("http://localhost:" + port + "/graphql"))
+ .ConfigureWebSocketClient(
+ c => c.Uri = new Uri("ws://localhost:" + port + "/graphql"));
+
+ serviceCollection.AddSingleton();
+
+ // act
+ IServiceProvider services = serviceCollection.BuildServiceProvider();
+ IStarWarsOnReviewSubCompletionClient client = services.GetRequiredService();
+
+ string? commentary = null;
+ bool completionTriggered = false;
+
+ var sub = client.OnReviewSub.Watch();
+ var session = sub.Subscribe(
+ result => commentary = result.Data?.OnReview?.Commentary,
+ () => completionTriggered = true);
+
+ var topic = Episode.NewHope;
+
+ // try to send message 10 times
+ // make sure the subscription connection is successful
+ for (int times = 0; commentary is null && times < 10; times++)
+ {
+ await topicEventSender.SendAsync(topic, new Review { Stars = 1, Commentary = "Commentary" });
+ await Task.Delay(1_000);
+ }
+
+ // simulate network error
+ var monitor = services.GetRequiredService();
+ monitor.AbortSocket();
+
+ //await host.StopAsync();
+
+ // waiting for completion message sent
+ for (int times = 0; !completionTriggered && times < 10; times++)
+ {
+ await Task.Delay(1_000);
+ }
// assert
Assert.True(commentary is not null && completionTriggered);
+
+ session.Dispose();
+ }
+ }
+
+ public class SubscriptionSocketStateMonitor
+ {
+ private const BindingFlags _bindingFlags = BindingFlags.NonPublic | BindingFlags.Instance;
+
+ private readonly ISessionPool _sessionPool;
+ private readonly Type _sessionPoolType;
+ private readonly FieldInfo _sessionsField;
+
+ private readonly FieldInfo _socketOperationsDictionaryField = typeof(Session).GetField("_operations", _bindingFlags)!;
+ private readonly FieldInfo _socketOperationManagerField = typeof(SocketOperation).GetField("_manager", _bindingFlags)!;
+ private readonly FieldInfo _socketProtocolField = typeof(Session)!.GetField("_socketProtocol", _bindingFlags)!;
+ private readonly FieldInfo _protocolReceiverField = typeof(GraphQLWebSocketProtocol).GetField("_receiver", _bindingFlags)!;
+
+ private Type? _sessionInfoType;
+ private PropertyInfo? _sessionProperty;
+ private Type? _receiverType;
+ private FieldInfo? _receiverClientField;
+
+ public SubscriptionSocketStateMonitor(ISessionPool sessionPool)
+ {
+ _sessionPool = sessionPool;
+ _sessionPoolType = _sessionPool.GetType();
+ _sessionsField = _sessionPoolType.GetField("_sessions", _bindingFlags)!;
+ }
+
+ public void AbortSocket()
+ {
+ var sessionInfos = (_sessionsField!.GetValue(_sessionPool) as System.Collections.IDictionary)!.Values;
+
+ foreach (var sessionInfo in sessionInfos)
+ {
+ _sessionInfoType ??= sessionInfo.GetType();
+ _sessionProperty ??= _sessionInfoType.GetProperty("Session")!;
+ var session = _sessionProperty.GetValue(sessionInfo) as Session;
+ var socketOperations = _socketOperationsDictionaryField
+ .GetValue(session) as System.Collections.Concurrent.ConcurrentDictionary;
+
+ foreach (var operation in socketOperations!)
+ {
+ var operationsession = _socketOperationManagerField.GetValue(operation.Value) as Session;
+ var protocol = _socketProtocolField.GetValue(operationsession) as GraphQLWebSocketProtocol;
+
+ var receiver = _protocolReceiverField.GetValue(protocol)!;
+
+ _receiverType ??= receiver.GetType();
+ _receiverClientField ??= _receiverType.GetField("_client", _bindingFlags)!;
+ var client = _receiverClientField.GetValue(receiver) as ISocketClient;
+
+ if (client!.IsClosed is false && client is WebSocketClient webSocketClient)
+ {
+ var socket = typeof(WebSocketClient).GetField("_socket", _bindingFlags)!.GetValue(webSocketClient) as ClientWebSocket;
+ socket!.Abort();
+ }
+ }
+ }
}
}
}