diff --git a/shell/agents/Microsoft.Azure.Agent/AzureCopilotReceiver.cs b/shell/agents/Microsoft.Azure.Agent/AzureCopilotReceiver.cs index 3ad26873..b2f958fa 100644 --- a/shell/agents/Microsoft.Azure.Agent/AzureCopilotReceiver.cs +++ b/shell/agents/Microsoft.Azure.Agent/AzureCopilotReceiver.cs @@ -1,5 +1,6 @@ using System.Collections.Concurrent; using System.Net.WebSockets; +using System.Runtime.ExceptionServices; using System.Text.Json; namespace Microsoft.Azure.Agent; @@ -26,7 +27,6 @@ private AzureCopilotReceiver(ClientWebSocket webSocket) } internal int Watermark { get; private set; } - internal BlockingCollection ActivityQueue => _activityQueue; internal static async Task CreateAsync(string streamUrl) { @@ -52,6 +52,7 @@ private async Task ProcessActivities() if (result.MessageType is WebSocketMessageType.Close) { closingMessage = "Close message received"; + _activityQueue.Add(new CopilotActivity { Error = new ConnectionDroppedException("The server websocket is closing. Connection dropped.") }); } } catch (OperationCanceledException) @@ -65,6 +66,7 @@ private async Task ProcessActivities() { // TODO: log the closing request. await _webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, closingMessage, CancellationToken.None); + _activityQueue.CompleteAdding(); break; } @@ -98,8 +100,20 @@ private async Task ProcessActivities() } } - // TODO: log the current state of the web socket - // TODO: handle error state, such as 'aborted' + // TODO: log the current state of the web socket. + _activityQueue.Add(new CopilotActivity { Error = new ConnectionDroppedException($"The websocket got in '{_webSocket.State}' state. Connection dropped.") }); + _activityQueue.CompleteAdding(); + } + + internal CopilotActivity Take(CancellationToken cancellationToken) + { + CopilotActivity activity = _activityQueue.Take(cancellationToken); + if (activity.Error is not null) + { + ExceptionDispatchInfo.Capture(activity.Error).Throw(); + } + + return activity; } public void Dispose() diff --git a/shell/agents/Microsoft.Azure.Agent/ChatSession.cs b/shell/agents/Microsoft.Azure.Agent/ChatSession.cs index 13d9ef77..6e45637a 100644 --- a/shell/agents/Microsoft.Azure.Agent/ChatSession.cs +++ b/shell/agents/Microsoft.Azure.Agent/ChatSession.cs @@ -142,7 +142,7 @@ private async Task StartConversationAsync(IHost host, CancellationToken cancella while (true) { - CopilotActivity activity = _copilotReceiver.ActivityQueue.Take(cancellationToken); + CopilotActivity activity = _copilotReceiver.Take(cancellationToken); if (activity.IsMessage && activity.IsFromCopilot && _copilotReceiver.Watermark is 0) { activity.ExtractMetadata(out _, out ConversationState conversationState); @@ -259,7 +259,7 @@ internal async Task GetChatResponseAsync(string input, IStatusC while (true) { - CopilotActivity activity = _copilotReceiver.ActivityQueue.Take(cancellationToken); + CopilotActivity activity = _copilotReceiver.Take(cancellationToken); if (activity.ReplyToId != activityId) { diff --git a/shell/agents/Microsoft.Azure.Agent/Schema.cs b/shell/agents/Microsoft.Azure.Agent/Schema.cs index 248f7ae0..acc3f9a0 100644 --- a/shell/agents/Microsoft.Azure.Agent/Schema.cs +++ b/shell/agents/Microsoft.Azure.Agent/Schema.cs @@ -159,7 +159,7 @@ internal CopilotActivity ReadChunk(CancellationToken cancellationToken) return null; } - CopilotActivity activity = _receiver.ActivityQueue.Take(cancellationToken); + CopilotActivity activity = _receiver.Take(cancellationToken); if (!activity.IsMessageUpdate) {