diff --git a/desktop/Backend-Rust/src/routes/proxy.rs b/desktop/Backend-Rust/src/routes/proxy.rs index 04a66d5cb15..1620db0c9ec 100644 --- a/desktop/Backend-Rust/src/routes/proxy.rs +++ b/desktop/Backend-Rust/src/routes/proxy.rs @@ -185,7 +185,17 @@ async fn deepgram_ws_proxy( })) } +/// Which side of the proxy terminated first +#[derive(Debug)] +enum ProxyCloseOrigin { + ClientClosed, + UpstreamClosed, + ClientError, + UpstreamError, +} + /// Bidirectional WebSocket proxy between client (axum) and upstream (tokio-tungstenite). +/// When one side closes or errors, a close frame is forwarded to the other side before teardown. async fn proxy_ws_bidirectional( client_socket: axum::extract::ws::WebSocket, upstream_url: &str, @@ -210,47 +220,71 @@ async fn proxy_ws_bidirectional( // Client → Upstream let client_to_upstream = async { - while let Some(Ok(msg)) = client_stream.next().await { - let tung_msg = match msg { - AxumMsg::Text(t) => TungMsg::Text(t), - AxumMsg::Binary(b) => TungMsg::Binary(b), - AxumMsg::Ping(p) => TungMsg::Ping(p), - AxumMsg::Pong(p) => TungMsg::Pong(p), - AxumMsg::Close(_) => { - let _ = upstream_sink.close().await; - return; + while let Some(result) = client_stream.next().await { + match result { + Ok(msg) => { + let tung_msg = match msg { + AxumMsg::Text(t) => TungMsg::Text(t), + AxumMsg::Binary(b) => TungMsg::Binary(b), + AxumMsg::Ping(p) => TungMsg::Ping(p), + AxumMsg::Pong(p) => TungMsg::Pong(p), + AxumMsg::Close(_) => { + let _ = upstream_sink.close().await; + return ProxyCloseOrigin::ClientClosed; + } + }; + if upstream_sink.send(tung_msg).await.is_err() { + return ProxyCloseOrigin::UpstreamError; + } } - }; - if upstream_sink.send(tung_msg).await.is_err() { - return; + Err(_) => return ProxyCloseOrigin::ClientError, } } + ProxyCloseOrigin::ClientClosed }; // Upstream → Client let upstream_to_client = async { - while let Some(Ok(msg)) = upstream_stream.next().await { - let axum_msg = match msg { - TungMsg::Text(t) => AxumMsg::Text(t), - TungMsg::Binary(b) => AxumMsg::Binary(b), - TungMsg::Ping(p) => AxumMsg::Ping(p), - TungMsg::Pong(p) => AxumMsg::Pong(p), - TungMsg::Close(_) => { - let _ = client_sink.close().await; - return; + while let Some(result) = upstream_stream.next().await { + match result { + Ok(msg) => { + let axum_msg = match msg { + TungMsg::Text(t) => AxumMsg::Text(t), + TungMsg::Binary(b) => AxumMsg::Binary(b), + TungMsg::Ping(p) => AxumMsg::Ping(p), + TungMsg::Pong(p) => AxumMsg::Pong(p), + TungMsg::Close(_) => { + let _ = client_sink.close().await; + return ProxyCloseOrigin::UpstreamClosed; + } + TungMsg::Frame(_) => continue, + }; + if client_sink.send(axum_msg).await.is_err() { + return ProxyCloseOrigin::ClientError; + } } - TungMsg::Frame(_) => continue, - }; - if client_sink.send(axum_msg).await.is_err() { - return; + Err(_) => return ProxyCloseOrigin::UpstreamError, } } + ProxyCloseOrigin::UpstreamClosed + }; + + // Run both directions concurrently; when either ends, gracefully close the other side + let origin = tokio::select! { + origin = client_to_upstream => origin, + origin = upstream_to_client => origin, }; - // Run both directions concurrently; when either ends, drop both - tokio::select! { - _ = client_to_upstream => {}, - _ = upstream_to_client => {}, + // Forward close frame to the surviving side with a timeout to prevent hanging + let close_timeout = std::time::Duration::from_secs(5); + tracing::debug!("deepgram_ws_proxy: proxy ended ({:?})", origin); + match origin { + ProxyCloseOrigin::UpstreamClosed | ProxyCloseOrigin::UpstreamError => { + let _ = tokio::time::timeout(close_timeout, client_sink.close()).await; + } + ProxyCloseOrigin::ClientClosed | ProxyCloseOrigin::ClientError => { + let _ = tokio::time::timeout(close_timeout, upstream_sink.close()).await; + } } Ok(()) diff --git a/desktop/CHANGELOG.json b/desktop/CHANGELOG.json index e26c3ce5ac2..f57821852ba 100644 --- a/desktop/CHANGELOG.json +++ b/desktop/CHANGELOG.json @@ -1,5 +1,7 @@ { - "unreleased": [], + "unreleased": [ + "Fixed WebSocket transcription disconnects: proper handshake detection, audio buffering during reconnection, unlimited retry with backoff, and thread-safe connection state" + ], "releases": [ { "version": "0.11.186", diff --git a/desktop/Desktop/Sources/TranscriptionService.swift b/desktop/Desktop/Sources/TranscriptionService.swift index a00779c4b94..58300e1a8a9 100644 --- a/desktop/Desktop/Sources/TranscriptionService.swift +++ b/desktop/Desktop/Sources/TranscriptionService.swift @@ -2,10 +2,18 @@ import Foundation /// Service for real-time speech-to-text transcription using DeepGram /// Streams audio over WebSocket and receives transcript segments -class TranscriptionService { +class TranscriptionService: NSObject, URLSessionWebSocketDelegate { // MARK: - Types + /// Connection lifecycle state (thread-safe via stateQueue) + enum ConnectionState { + case disconnected + case connecting + case connected + case reconnecting + } + /// Transcript segment from DeepGram struct TranscriptSegment { let text: String @@ -50,13 +58,27 @@ class TranscriptionService { } } + // MARK: - Thread-safe state + + /// Serial queue protecting all mutable connection state + private let stateQueue = DispatchQueue(label: "com.omi.transcription.state") + private var _connectionState: ConnectionState = .disconnected + private var _webSocketTask: URLSessionWebSocketTask? + private var _urlSession: URLSession? + private var _shouldReconnect = false + private var _reconnectAttempts = 0 + private var _connectionGeneration: UInt64 = 0 // Monotonic ID to discard stale delegate callbacks + private var _lastDataReceivedAt: Date? + private var _lastKeepaliveSuccessAt: Date? + + /// Execute a block on the state queue and return its result + private func withState(_ body: () -> T) -> T { + stateQueue.sync { body() } + } + // MARK: - Properties private let apiKey: String - private var webSocketTask: URLSessionWebSocketTask? - private var urlSession: URLSession? - private var isConnected = false - private var shouldReconnect = false // Callbacks private var onTranscript: TranscriptHandler? @@ -90,10 +112,10 @@ class TranscriptionService { }() private let channels: Int // 2 = stereo (mic + system), 1 = mono (mic only for PTT) - // Reconnection - private var reconnectAttempts = 0 - private let maxReconnectAttempts = 10 + // Reconnection — no hard cap; backoff with jitter, retry while shouldReconnect is true private var reconnectTask: Task? + private let maxBackoff: TimeInterval = 60.0 + private let backoffJitterRange: ClosedRange = 0.5...1.5 // Keepalive private var keepaliveTask: Task? @@ -101,16 +123,18 @@ class TranscriptionService { // Watchdog: detect stale connections where WebSocket dies silently private var watchdogTask: Task? - private var lastDataReceivedAt: Date? - private var lastKeepaliveSuccessAt: Date? private let watchdogInterval: TimeInterval = 30.0 // Check every 30 seconds private let staleThreshold: TimeInterval = 60.0 // Reconnect if no data for 60 seconds - // Audio buffering + // Audio buffering (outbound send coalescing) private var audioBuffer = Data() private let audioBufferSize = 3200 // ~100ms of 16kHz 16-bit audio (16000 * 2 * 0.1) private let audioBufferLock = NSLock() + // Reconnect audio ring buffer: holds audio produced while disconnected/reconnecting + // 30s of stereo 16kHz 16-bit = ~1.92MB; cap at 960KB (~15s) to stay conservative + private var reconnectBuffer = ReconnectAudioRingBuffer(ttl: 30, maxBytes: 960_000) + // MARK: - Initialization /// Whether this instance uses the backend proxy (no direct Deepgram access) @@ -136,6 +160,7 @@ class TranscriptionService { self.language = language self.vocabulary = vocabulary self.channels = channels + super.init() log("TranscriptionService: Initialized with language=\(language), vocabulary=\(self.vocabulary.count) terms, channels=\(channels), proxy=\(self.useProxy)") } @@ -152,15 +177,17 @@ class TranscriptionService { self.onError = onError self.onConnected = onConnected self.onDisconnected = onDisconnected - self.shouldReconnect = true - self.reconnectAttempts = 0 + withState { + _shouldReconnect = true + _reconnectAttempts = 0 + } connect() } /// Stop the transcription service func stop() { - shouldReconnect = false + withState { _shouldReconnect = false } reconnectTask?.cancel() reconnectTask = nil keepaliveTask?.cancel() @@ -177,7 +204,7 @@ class TranscriptionService { /// Signal Deepgram that no more audio will be sent, but keep connection open /// to receive final transcription results. Call stop() later to fully disconnect. func finishStream() { - shouldReconnect = false + withState { _shouldReconnect = false } reconnectTask?.cancel() reconnectTask = nil keepaliveTask?.cancel() @@ -187,10 +214,14 @@ class TranscriptionService { flushAudioBuffer() - guard isConnected, let webSocketTask = webSocketTask else { return } + let task: URLSessionWebSocketTask? = withState { + guard _connectionState == .connected else { return nil } + return _webSocketTask + } + guard let task = task else { return } let closeMsg = "{\"type\": \"CloseStream\"}" - webSocketTask.send(.string(closeMsg)) { error in + task.send(.string(closeMsg)) { error in if let error = error { logError("TranscriptionService: CloseStream send error", error: error) } @@ -198,9 +229,28 @@ class TranscriptionService { log("TranscriptionService: CloseStream sent, waiting for final results") } - /// Send audio data to DeepGram (buffered for efficiency) + /// Send audio data to DeepGram (buffered for efficiency). + /// When disconnected/reconnecting, audio is queued in a ring buffer and replayed on reconnect. func sendAudio(_ data: Data) { - guard isConnected else { return } + guard !data.isEmpty else { return } + + let shouldSendNow: Bool = withState { + reconnectBuffer.prune() + switch _connectionState { + case .connected: + return true + case .connecting, .reconnecting: + reconnectBuffer.append(data) + return false + case .disconnected: + if _shouldReconnect { + reconnectBuffer.append(data) + } + return false + } + } + + guard shouldSendNow else { return } audioBufferLock.lock() audioBuffer.append(data) @@ -230,10 +280,14 @@ class TranscriptionService { /// Actually send an audio chunk to DeepGram private func sendAudioChunk(_ data: Data) { - guard isConnected, let webSocketTask = webSocketTask else { return } + let task: URLSessionWebSocketTask? = withState { + guard _connectionState == .connected else { return nil } + return _webSocketTask + } + guard let task = task else { return } let message = URLSessionWebSocketTask.Message.data(data) - webSocketTask.send(message) { [weak self] error in + task.send(message) { [weak self] error in if let error = error { logError("TranscriptionService: Send error", error: error) self?.handleDisconnection() @@ -241,11 +295,51 @@ class TranscriptionService { } } + /// Replay audio buffered during reconnection. + /// Sends chunks sequentially so only the first failure controls the retry path, + /// preventing duplicate rebuffering from concurrent failure callbacks. + private func replayBufferedAudio() { + let (task, chunks): (URLSessionWebSocketTask?, [Data]) = withState { + guard _connectionState == .connected else { return (nil, []) } + return (_webSocketTask, reconnectBuffer.drain()) + } + guard let task = task, !chunks.isEmpty else { return } + + log("TranscriptionService: Replaying \(chunks.count) buffered audio chunks") + replayChunksSequentially(task: task, chunks: chunks, index: 0) + } + + /// Send chunks one at a time; on first failure, re-buffer the rest and reconnect. + private func replayChunksSequentially(task: URLSessionWebSocketTask, chunks: [Data], index: Int) { + guard index < chunks.count else { return } + + task.send(.data(chunks[index])) { [weak self] error in + if let error = error { + logError("TranscriptionService: Replay send error at chunk \(index)", error: error) + if let self = self { + self.withState { + for remaining in chunks[index...] { + self.reconnectBuffer.append(remaining) + } + } + self.handleDisconnection() + } + return + } + // Success — send next chunk + self?.replayChunksSequentially(task: task, chunks: chunks, index: index + 1) + } + } + /// Send Deepgram Finalize message to flush pending transcripts func sendFinalize() { - guard isConnected, let webSocketTask = webSocketTask else { return } + let task: URLSessionWebSocketTask? = withState { + guard _connectionState == .connected else { return nil } + return _webSocketTask + } + guard let task = task else { return } let msg = "{\"type\": \"Finalize\"}" - webSocketTask.send(.string(msg)) { error in + task.send(.string(msg)) { error in if let error = error { logError("TranscriptionService: Finalize error", error: error) } @@ -259,12 +353,62 @@ class TranscriptionService { /// Check if connected var connected: Bool { - return isConnected + return withState { _connectionState == .connected } + } + + // MARK: - Test accessors (internal for @testable import) + + /// Current connection state (read-only, for testing) + var testConnectionState: ConnectionState { + withState { _connectionState } + } + + /// Current generation token (read-only, for testing) + var testConnectionGeneration: UInt64 { + withState { _connectionGeneration } + } + + /// Directly set state for testing state machine behavior. + /// Only callable from tests via @testable import. + func testSetState(_ state: ConnectionState) { + withState { _connectionState = state } + } + + /// Expose handleDisconnection for idempotency testing + func testHandleDisconnection() { + handleDisconnection() + } + + /// Expose shouldReconnect setter for testing + func testSetShouldReconnect(_ value: Bool) { + withState { _shouldReconnect = value } + } + + /// Compute reconnect delay: exponential backoff capped at maxBackoff, then jittered. + /// Exposed as static for testability. + static func reconnectDelay( + attempt: Int, + maxBackoff: TimeInterval = 60.0, + jitterRange: ClosedRange = 0.5...1.5 + ) -> TimeInterval { + let baseDelay = min(pow(2.0, Double(attempt)), maxBackoff) + let jitter = Double.random(in: jitterRange) + return baseDelay * jitter } // MARK: - Private Methods private func connect() { + let generation: UInt64 = withState { + guard _connectionState == .disconnected || _connectionState == .reconnecting else { + return 0 // 0 = signal not to proceed + } + _connectionState = .connecting + _connectionGeneration += 1 + return _connectionGeneration + } + guard generation > 0 else { return } + if useProxy { // Proxy mode: get Firebase auth token async, then connect Task { [weak self] in @@ -272,19 +416,28 @@ class TranscriptionService { do { let authService = await MainActor.run { AuthService.shared } let authHeader = try await authService.getAuthHeader() - self.connectWithAuth(authHeader: authHeader) + // Re-check: stop() may have been called while fetching auth token + let stillValid = self.withState { + self._connectionGeneration == generation && self._shouldReconnect && self._connectionState == .connecting + } + guard stillValid else { + log("TranscriptionService: Auth fetched but connection no longer wanted (gen \(generation))") + return + } + self.connectWithAuth(authHeader: authHeader, generation: generation) } catch { logError("TranscriptionService: Failed to get auth token for proxy", error: error) self.onError?(TranscriptionError.connectionFailed(error)) + self.handleDisconnection() } } } else { // Direct Deepgram mode (legacy/developer override) - connectWithAuth(authHeader: "Token \(apiKey)") + connectWithAuth(authHeader: "Token \(apiKey)", generation: generation) } } - private func connectWithAuth(authHeader: String) { + private func connectWithAuth(authHeader: String, generation: UInt64) { // Build WebSocket URL with parameters let wsBase: String if useProxy { @@ -300,6 +453,7 @@ class TranscriptionService { guard var components = URLComponents(string: "\(wsBase)\(listenPath)") else { log("TranscriptionService: Invalid URL base: \(wsBase)") onError?(TranscriptionError.connectionFailed(NSError(domain: "Invalid URL", code: -1))) + handleDisconnection() return } var queryItems = [ @@ -328,49 +482,111 @@ class TranscriptionService { guard let url = components.url else { onError?(TranscriptionError.connectionFailed(NSError(domain: "Invalid URL", code: -1))) + handleDisconnection() return } - log("TranscriptionService: Connecting to \(url.absoluteString)") + // Verify this generation is still current before creating network resources + let stillValid = withState { _connectionGeneration == generation && _connectionState == .connecting } + guard stillValid else { + log("TranscriptionService: Connection no longer wanted (gen \(generation))") + return + } + log("TranscriptionService: Connecting to \(url.host ?? "?") (gen \(generation))") // Create URL request with authorization header var request = URLRequest(url: url) request.setValue(authHeader, forHTTPHeaderField: "Authorization") - // Create URLSession and WebSocket task + // Create URLSession with self as delegate to receive WebSocket lifecycle callbacks let configuration = URLSessionConfiguration.default configuration.timeoutIntervalForRequest = 30 configuration.timeoutIntervalForResource = 0 // No resource timeout for long-lived WebSocket - urlSession = URLSession(configuration: configuration) - webSocketTask = urlSession?.webSocketTask(with: request) + let session = URLSession(configuration: configuration, delegate: self, delegateQueue: nil) + let task = session.webSocketTask(with: request) - // Start the connection - webSocketTask?.resume() + withState { + _urlSession = session + _webSocketTask = task + } + + // Start the connection — didOpenWithProtocol delegate will confirm handshake + task.resume() - // Start receiving messages - receiveMessage() + // Start receiving messages immediately (queued until handshake completes) + receiveMessage(generation: generation) - // Mark as connected (DeepGram doesn't send a connect confirmation) - DispatchQueue.main.asyncAfter(deadline: .now() + 0.5) { [weak self] in - guard let self = self, self.webSocketTask?.state == .running else { return } - self.isConnected = true - self.reconnectAttempts = 0 - self.lastDataReceivedAt = Date() - self.lastKeepaliveSuccessAt = Date() - log("TranscriptionService: Connected") - self.startKeepalive() - self.startWatchdog() - self.onConnected?() + // Connect timeout: if handshake hasn't completed in 10s, treat as failure + Task { [weak self] in + try? await Task.sleep(nanoseconds: 10_000_000_000) + guard let self = self else { return } + let shouldTimeout: Bool = self.withState { + self._connectionGeneration == generation && self._connectionState == .connecting + } + if shouldTimeout { + log("TranscriptionService: Connect timeout (gen \(generation))") + self.handleDisconnection() + } } } + // MARK: - URLSessionWebSocketDelegate + + /// Called when WebSocket handshake completes successfully + func urlSession( + _ session: URLSession, + webSocketTask: URLSessionWebSocketTask, + didOpenWithProtocol protocol: String? + ) { + let (isValid, generation): (Bool, UInt64) = withState { + // Only accept if this is the current session AND we still want to be connected + guard _urlSession === session && _shouldReconnect && _connectionState == .connecting else { + return (false, _connectionGeneration) + } + _connectionState = .connected + _reconnectAttempts = 0 + _lastDataReceivedAt = Date() + _lastKeepaliveSuccessAt = Date() + return (true, _connectionGeneration) + } + guard isValid else { + log("TranscriptionService: Ignoring stale didOpen (gen \(generation))") + // Clean up the unwanted session + session.invalidateAndCancel() + return + } + + log("TranscriptionService: Connected (gen \(generation), protocol=\(`protocol` ?? "none"))") + startKeepalive() + startWatchdog() + replayBufferedAudio() + onConnected?() + } + + /// Called when WebSocket receives a close frame from server + func urlSession( + _ session: URLSession, + webSocketTask: URLSessionWebSocketTask, + didCloseWith closeCode: URLSessionWebSocketTask.CloseCode, + reason: Data? + ) { + let isCurrentSession: Bool = withState { _urlSession === session } + guard isCurrentSession else { return } + + let reasonText = reason.flatMap { String(data: $0, encoding: .utf8) } ?? "none" + log("TranscriptionService: Server closed connection (code=\(closeCode.rawValue), reason=\(reasonText))") + handleDisconnection() + } + /// Start keepalive ping task to prevent connection timeout private func startKeepalive() { keepaliveTask?.cancel() keepaliveTask = Task { [weak self] in while !Task.isCancelled { try? await Task.sleep(nanoseconds: UInt64(self?.keepaliveInterval ?? 8.0) * 1_000_000_000) - guard !Task.isCancelled, let self = self, self.isConnected else { break } + guard !Task.isCancelled, let self = self else { break } + let isConn = self.withState { self._connectionState == .connected } + guard isConn else { break } self.sendKeepalive() } } @@ -378,17 +594,20 @@ class TranscriptionService { /// Send a keepalive ping to DeepGram private func sendKeepalive() { - guard isConnected, let webSocketTask = webSocketTask else { return } + let task: URLSessionWebSocketTask? = withState { + guard _connectionState == .connected else { return nil } + return _webSocketTask + } + guard let task = task else { return } - // Send a small JSON keepalive message let keepalive = "{\"type\": \"KeepAlive\"}" let message = URLSessionWebSocketTask.Message.string(keepalive) - webSocketTask.send(message) { [weak self] error in + task.send(message) { [weak self] error in if let error = error { logError("TranscriptionService: Keepalive error", error: error) self?.handleDisconnection() } else { - self?.lastKeepaliveSuccessAt = Date() + self?.withState { self?._lastKeepaliveSuccessAt = Date() } } } } @@ -399,16 +618,19 @@ class TranscriptionService { watchdogTask = Task { [weak self] in while !Task.isCancelled { try? await Task.sleep(nanoseconds: UInt64(self?.watchdogInterval ?? 30.0) * 1_000_000_000) - guard !Task.isCancelled, let self = self, self.isConnected else { break } + guard !Task.isCancelled, let self = self else { break } + + let (isConn, lastData, lastKeepalive) = self.withState { + (self._connectionState == .connected, self._lastDataReceivedAt, self._lastKeepaliveSuccessAt) + } + guard isConn else { break } - if let lastData = self.lastDataReceivedAt, + if let lastData = lastData, Date().timeIntervalSince(lastData) > self.staleThreshold { // Check if keepalives are still succeeding — if so, the connection // is alive and Deepgram just has nothing to return (silent room). - // Only force reconnect when keepalives have also gone stale. - if let lastKeepalive = self.lastKeepaliveSuccessAt, + if let lastKeepalive = lastKeepalive, Date().timeIntervalSince(lastKeepalive) < self.staleThreshold { - // Keepalives working — connection is alive, just no speech to transcribe continue } log("TranscriptionService: Watchdog detected stale connection (no data for \(String(format: "%.0f", Date().timeIntervalSince(lastData)))s, keepalives also failing) - forcing reconnect") @@ -419,61 +641,93 @@ class TranscriptionService { } private func disconnect() { - isConnected = false + let oldSession: URLSession? = withState { + _connectionState = .disconnected + _connectionGeneration += 1 // Invalidate any in-flight receive callbacks + let s = _urlSession + _webSocketTask?.cancel(with: .normalClosure, reason: nil) + _webSocketTask = nil + _urlSession = nil + return s + } keepaliveTask?.cancel() keepaliveTask = nil watchdogTask?.cancel() watchdogTask = nil - webSocketTask?.cancel(with: .normalClosure, reason: nil) - webSocketTask = nil - urlSession?.invalidateAndCancel() - urlSession = nil + oldSession?.invalidateAndCancel() log("TranscriptionService: Disconnected") onDisconnected?() } private func handleDisconnection() { - guard isConnected else { return } + let (shouldAttemptReconnect, attempt): (Bool, Int) = withState { + // Idempotent: if already reconnecting or disconnected, this is a duplicate callback + guard _connectionState == .connected || _connectionState == .connecting else { return (false, 0) } + + _connectionGeneration += 1 // Invalidate any in-flight receive/keepalive callbacks + let oldSession = _urlSession + _connectionState = .reconnecting + _webSocketTask = nil + _urlSession = nil + oldSession?.invalidateAndCancel() + + guard _shouldReconnect else { + _connectionState = .disconnected + return (false, 0) + } + _reconnectAttempts += 1 + return (true, _reconnectAttempts) + } - isConnected = false keepaliveTask?.cancel() keepaliveTask = nil watchdogTask?.cancel() watchdogTask = nil - webSocketTask = nil - urlSession?.invalidateAndCancel() - urlSession = nil + + // Salvage any partial audio in the coalescing buffer into the reconnect buffer + audioBufferLock.lock() + let partialAudio = audioBuffer + audioBuffer = Data() + audioBufferLock.unlock() + if !partialAudio.isEmpty { + withState { reconnectBuffer.append(partialAudio) } + } + onDisconnected?() - // Attempt reconnection if enabled - if shouldReconnect && reconnectAttempts < maxReconnectAttempts { - reconnectAttempts += 1 - let delay = min(pow(2.0, Double(reconnectAttempts)), 32.0) // Exponential backoff, max 32s - log("TranscriptionService: Reconnecting in \(delay)s (attempt \(reconnectAttempts))") + guard shouldAttemptReconnect else { return } - reconnectTask = Task { - try? await Task.sleep(nanoseconds: UInt64(delay * 1_000_000_000)) - guard !Task.isCancelled, self.shouldReconnect else { return } - self.connect() - } - } else if reconnectAttempts >= maxReconnectAttempts { - log("TranscriptionService: Max reconnect attempts reached") - onError?(TranscriptionError.webSocketError("Max reconnect attempts reached")) + // Exponential backoff with jitter, no hard cap on attempts + let delay = Self.reconnectDelay(attempt: attempt, maxBackoff: maxBackoff, jitterRange: backoffJitterRange) + log("TranscriptionService: Reconnecting in \(String(format: "%.1f", delay))s (attempt \(attempt))") + + reconnectTask = Task { [weak self] in + try? await Task.sleep(nanoseconds: UInt64(delay * 1_000_000_000)) + guard !Task.isCancelled else { return } + guard let self = self else { return } + let shouldReconnect = self.withState { self._shouldReconnect } + guard shouldReconnect else { return } + self.connect() } } - private func receiveMessage() { - webSocketTask?.receive { [weak self] result in + private func receiveMessage(generation: UInt64) { + let task: URLSessionWebSocketTask? = withState { _webSocketTask } + task?.receive { [weak self] result in guard let self = self else { return } + // Discard callbacks from stale connections + let currentGen = self.withState { self._connectionGeneration } + guard currentGen == generation else { return } + switch result { case .success(let message): self.handleMessage(message) - // Continue receiving - self.receiveMessage() + self.receiveMessage(generation: generation) case .failure(let error): - guard self.isConnected else { return } + let isActive = self.withState { self._connectionState == .connected || self._connectionState == .connecting } + guard isActive else { return } logError("TranscriptionService: Receive error", error: error) self.handleDisconnection() } @@ -482,7 +736,7 @@ class TranscriptionService { private func handleMessage(_ message: URLSessionWebSocketTask.Message) { // Track that we received data (for watchdog stale detection) - lastDataReceivedAt = Date() + withState { _lastDataReceivedAt = Date() } switch message { case .string(let text): @@ -737,6 +991,69 @@ extension TranscriptionService { } } +// MARK: - Reconnect Audio Ring Buffer + +/// Bounded ring buffer that holds audio chunks produced while WebSocket is reconnecting. +/// Chunks older than `ttl` or exceeding `maxBytes` are evicted automatically. +/// Internal access level for testability via @testable import. +struct ReconnectAudioRingBuffer { + struct Chunk { + let data: Data + let createdAt: Date + } + + let ttl: TimeInterval + let maxBytes: Int + private(set) var chunks: [Chunk] = [] + private(set) var totalBytes = 0 + + init(ttl: TimeInterval = 30, maxBytes: Int = 960_000) { + self.ttl = ttl + self.maxBytes = maxBytes + } + + mutating func append(_ data: Data, now: Date = Date()) { + guard !data.isEmpty else { return } + evictExpired(now: now) + + if data.count >= maxBytes { + let truncated = Data(data.suffix(maxBytes)) + chunks = [Chunk(data: truncated, createdAt: now)] + totalBytes = truncated.count + return + } + + chunks.append(Chunk(data: data, createdAt: now)) + totalBytes += data.count + evictOverflow() + } + + mutating func drain(now: Date = Date()) -> [Data] { + evictExpired(now: now) + let drained = chunks.map(\.data) + chunks.removeAll(keepingCapacity: true) + totalBytes = 0 + return drained + } + + mutating func prune(now: Date = Date()) { + evictExpired(now: now) + } + + private mutating func evictExpired(now: Date) { + while let first = chunks.first, now.timeIntervalSince(first.createdAt) > ttl { + totalBytes -= first.data.count + chunks.removeFirst() + } + } + + private mutating func evictOverflow() { + while totalBytes > maxBytes, !chunks.isEmpty { + totalBytes -= chunks.removeFirst().data.count + } + } +} + /// Response model for Deepgram pre-recorded API private struct BatchResponse: Decodable { let results: BatchResults? diff --git a/desktop/Desktop/Tests/OnboardingFlowTests.swift b/desktop/Desktop/Tests/OnboardingFlowTests.swift index 818369b40ef..7678281fd20 100644 --- a/desktop/Desktop/Tests/OnboardingFlowTests.swift +++ b/desktop/Desktop/Tests/OnboardingFlowTests.swift @@ -3,10 +3,16 @@ import XCTest @testable import Omi_Computer final class OnboardingFlowTests: XCTestCase { - func testMergedFlowUsesFiveSteps() { + func testMergedFlowUsesSeventeenSteps() { XCTAssertEqual( - OnboardingFlow.steps, ["Chat", "Notifications", "FloatingBar", "VoiceShortcut", "Tasks"]) - XCTAssertEqual(OnboardingFlow.lastStepIndex, 4) + OnboardingFlow.steps, + [ + "Trust", "Name", "Language", "ScreenRecording", "FullDiskAccess", + "FileScan", "Microphone", "Notifications", "Accessibility", "Automation", + "FloatingBarShortcut", "FloatingBar", "VoiceShortcut", "VoiceDemo", + "Research", "Goal", "Tasks", + ]) + XCTAssertEqual(OnboardingFlow.lastStepIndex, 16) } func testMigrationMovesLegacyVoiceInputToMergedVoiceShortcutStep() { @@ -14,7 +20,10 @@ final class OnboardingFlowTests: XCTestCase { currentStep: 4, hasMigratedVideoStep: true, hasInsertedVoiceShortcutStep: true, - hasMergedVoiceInputStep: false + hasMergedVoiceInputStep: false, + hasRemovedNotificationStep: true, + hasInsertedFloatingBarShortcutStep: true, + hasMigratedPagedIntro: true ) XCTAssertEqual(migrated, 3) @@ -22,10 +31,13 @@ final class OnboardingFlowTests: XCTestCase { func testMigrationClampsOverflowToTasksStep() { let migrated = OnboardingFlow.migratedStep( - currentStep: 9, + currentStep: 99, hasMigratedVideoStep: true, hasInsertedVoiceShortcutStep: true, - hasMergedVoiceInputStep: true + hasMergedVoiceInputStep: true, + hasRemovedNotificationStep: true, + hasInsertedFloatingBarShortcutStep: true, + hasMigratedPagedIntro: true ) XCTAssertEqual(migrated, OnboardingFlow.lastStepIndex) diff --git a/desktop/Desktop/Tests/TranscriptionServiceTests.swift b/desktop/Desktop/Tests/TranscriptionServiceTests.swift new file mode 100644 index 00000000000..ae5e8c57f8d --- /dev/null +++ b/desktop/Desktop/Tests/TranscriptionServiceTests.swift @@ -0,0 +1,274 @@ +import XCTest + +@testable import Omi_Computer + +final class ReconnectAudioRingBufferTests: XCTestCase { + + // MARK: - Basic append and drain + + func testAppendAndDrain() { + var buffer = ReconnectAudioRingBuffer(ttl: 30, maxBytes: 960_000) + let chunk1 = Data(repeating: 0x01, count: 100) + let chunk2 = Data(repeating: 0x02, count: 200) + + buffer.append(chunk1) + buffer.append(chunk2) + + let drained = buffer.drain() + XCTAssertEqual(drained.count, 2) + XCTAssertEqual(drained[0], chunk1) + XCTAssertEqual(drained[1], chunk2) + XCTAssertEqual(buffer.totalBytes, 0) + } + + func testDrainClearsBuffer() { + var buffer = ReconnectAudioRingBuffer(ttl: 30, maxBytes: 960_000) + buffer.append(Data(repeating: 0xAA, count: 500)) + _ = buffer.drain() + + let secondDrain = buffer.drain() + XCTAssertTrue(secondDrain.isEmpty) + } + + func testEmptyDataIgnored() { + var buffer = ReconnectAudioRingBuffer(ttl: 30, maxBytes: 960_000) + buffer.append(Data()) + XCTAssertEqual(buffer.totalBytes, 0) + XCTAssertTrue(buffer.drain().isEmpty) + } + + // MARK: - TTL eviction + + func testTTLEviction() { + var buffer = ReconnectAudioRingBuffer(ttl: 5, maxBytes: 960_000) + let now = Date() + + // Add a chunk "5.1 seconds ago" + buffer.append(Data(repeating: 0x01, count: 100), now: now.addingTimeInterval(-5.1)) + // Add a recent chunk + buffer.append(Data(repeating: 0x02, count: 200), now: now) + + let drained = buffer.drain(now: now) + XCTAssertEqual(drained.count, 1, "Old chunk should be evicted by TTL") + XCTAssertEqual(drained[0], Data(repeating: 0x02, count: 200)) + } + + func testPruneEvictsExpired() { + var buffer = ReconnectAudioRingBuffer(ttl: 2, maxBytes: 960_000) + let now = Date() + + buffer.append(Data(repeating: 0x01, count: 100), now: now.addingTimeInterval(-3)) + buffer.append(Data(repeating: 0x02, count: 200), now: now) + + buffer.prune(now: now) + XCTAssertEqual(buffer.totalBytes, 200) + } + + // MARK: - Byte cap eviction + + func testByteCapEviction() { + var buffer = ReconnectAudioRingBuffer(ttl: 30, maxBytes: 500) + + buffer.append(Data(repeating: 0x01, count: 300)) + buffer.append(Data(repeating: 0x02, count: 300)) + + // Total would be 600 > 500, so oldest chunk should be evicted + XCTAssertEqual(buffer.totalBytes, 300) + let drained = buffer.drain() + XCTAssertEqual(drained.count, 1) + XCTAssertEqual(drained[0], Data(repeating: 0x02, count: 300)) + } + + func testMultipleChunksEvictedForByteCap() { + var buffer = ReconnectAudioRingBuffer(ttl: 30, maxBytes: 200) + + buffer.append(Data(repeating: 0x01, count: 80)) + buffer.append(Data(repeating: 0x02, count: 80)) + buffer.append(Data(repeating: 0x03, count: 80)) + // 240 > 200, evict oldest until <= 200 + buffer.append(Data(repeating: 0x04, count: 80)) + // 320 > 200, evict more + + XCTAssertTrue(buffer.totalBytes <= 200) + } + + // MARK: - Oversize chunk truncation + + func testOversizeChunkTruncation() { + var buffer = ReconnectAudioRingBuffer(ttl: 30, maxBytes: 100) + + // Append a chunk larger than maxBytes + let oversized = Data(repeating: 0xFF, count: 500) + buffer.append(oversized) + + XCTAssertEqual(buffer.totalBytes, 100, "Should be truncated to maxBytes") + let drained = buffer.drain() + XCTAssertEqual(drained.count, 1) + XCTAssertEqual(drained[0].count, 100, "Chunk should be truncated to maxBytes") + // Should keep the suffix (last 100 bytes) + XCTAssertEqual(drained[0], Data(repeating: 0xFF, count: 100)) + } + + func testOversizeReplacesExistingChunks() { + var buffer = ReconnectAudioRingBuffer(ttl: 30, maxBytes: 100) + + buffer.append(Data(repeating: 0x01, count: 50)) + buffer.append(Data(repeating: 0xFF, count: 200)) + + // Oversize replaces everything + XCTAssertEqual(buffer.totalBytes, 100) + let drained = buffer.drain() + XCTAssertEqual(drained.count, 1) + } +} + +// MARK: - State machine and idempotency tests + +final class TranscriptionServiceStateTests: XCTestCase { + + /// Create a service in proxy mode (no API key needed, just needs OMI_API_URL set) + private func makeService() -> TranscriptionService? { + // Set env so proxy mode is available — static let already captured, + // so we create with try? and accept it may throw if env isn't set + return try? TranscriptionService(apiKey: "test-key", channels: 1) + } + + func testInitialStateIsDisconnected() { + guard let service = makeService() else { + // Can't create without valid env — skip gracefully + return + } + XCTAssertEqual(service.testConnectionState, .disconnected) + XCTAssertEqual(service.testConnectionGeneration, 0) + } + + func testStopFromDisconnectedRemainsDisconnected() { + guard let service = makeService() else { return } + service.stop() + XCTAssertEqual(service.testConnectionState, .disconnected) + } + + func testHandleDisconnectionFromDisconnectedIsNoOp() { + guard let service = makeService() else { return } + let genBefore = service.testConnectionGeneration + service.testHandleDisconnection() + // Should be a no-op: state stays disconnected, generation unchanged + XCTAssertEqual(service.testConnectionState, .disconnected) + XCTAssertEqual(service.testConnectionGeneration, genBefore) + } + + func testHandleDisconnectionFromConnectedBumpsGeneration() { + guard let service = makeService() else { return } + service.testSetState(.connected) + service.testSetShouldReconnect(false) + let genBefore = service.testConnectionGeneration + service.testHandleDisconnection() + // Should bump generation and transition to disconnected (shouldReconnect=false) + XCTAssertEqual(service.testConnectionState, .disconnected) + XCTAssertGreaterThan(service.testConnectionGeneration, genBefore) + } + + func testHandleDisconnectionIdempotent() { + guard let service = makeService() else { return } + service.testSetState(.connected) + service.testSetShouldReconnect(false) + // First call + service.testHandleDisconnection() + let genAfterFirst = service.testConnectionGeneration + let stateAfterFirst = service.testConnectionState + // Second call (should be no-op since we're already disconnected) + service.testHandleDisconnection() + XCTAssertEqual(service.testConnectionState, stateAfterFirst) + XCTAssertEqual(service.testConnectionGeneration, genAfterFirst, + "Second handleDisconnection should not bump generation again") + } + + func testHandleDisconnectionFromReconnectingIsNoOp() { + guard let service = makeService() else { return } + service.testSetState(.reconnecting) + let genBefore = service.testConnectionGeneration + service.testHandleDisconnection() + // .reconnecting is guarded out — no state change + XCTAssertEqual(service.testConnectionState, .reconnecting) + XCTAssertEqual(service.testConnectionGeneration, genBefore) + } + + func testHandleDisconnectionFromConnectingBumpsGeneration() { + guard let service = makeService() else { return } + service.testSetState(.connecting) + service.testSetShouldReconnect(false) + let genBefore = service.testConnectionGeneration + service.testHandleDisconnection() + XCTAssertEqual(service.testConnectionState, .disconnected) + XCTAssertGreaterThan(service.testConnectionGeneration, genBefore) + } +} + +// MARK: - Invalid URL construction tests + +final class URLConstructionTests: XCTestCase { + + func testEmptyBaseProducesNilComponents() { + // Simulates what connectWithAuth does with empty base + let wsBase = "" + let listenPath = "/v1/proxy/deepgram/ws/v1/listen" + let components = URLComponents(string: "\(wsBase)\(listenPath)") + // Empty base + path should still produce valid components (path-only URL) + // but verify the behavior is defined + XCTAssertNotNil(components, "Path-only URL should parse") + } + + func testMalformedBaseProducesNilComponents() { + // A truly malformed URL that URLComponents rejects + let wsBase = "wss://[invalid" + let listenPath = "/v1/listen" + let components = URLComponents(string: "\(wsBase)\(listenPath)") + XCTAssertNil(components, "Malformed URL base should produce nil URLComponents") + } + + func testValidBaseProducesValidURL() { + let wsBase = "wss://api.omi.me" + let listenPath = "/v1/proxy/deepgram/ws/v1/listen" + let components = URLComponents(string: "\(wsBase)\(listenPath)") + XCTAssertNotNil(components) + XCTAssertNotNil(components?.url) + } +} + +final class ReconnectDelayTests: XCTestCase { + + func testExponentialGrowth() { + // With jitter range 1.0...1.0 (no jitter), delays should be exact powers of 2 + let d1 = TranscriptionService.reconnectDelay(attempt: 1, maxBackoff: 60, jitterRange: 1.0...1.0) + let d2 = TranscriptionService.reconnectDelay(attempt: 2, maxBackoff: 60, jitterRange: 1.0...1.0) + let d3 = TranscriptionService.reconnectDelay(attempt: 3, maxBackoff: 60, jitterRange: 1.0...1.0) + let d5 = TranscriptionService.reconnectDelay(attempt: 5, maxBackoff: 60, jitterRange: 1.0...1.0) + + XCTAssertEqual(d1, 2.0, accuracy: 0.001) + XCTAssertEqual(d2, 4.0, accuracy: 0.001) + XCTAssertEqual(d3, 8.0, accuracy: 0.001) + XCTAssertEqual(d5, 32.0, accuracy: 0.001) + } + + func testMaxBackoffCap() { + // Attempt 100 should still be capped at maxBackoff + let delay = TranscriptionService.reconnectDelay(attempt: 100, maxBackoff: 60, jitterRange: 1.0...1.0) + XCTAssertEqual(delay, 60.0, accuracy: 0.001) + } + + func testJitterBounds() { + // Run many iterations to verify jitter stays within range + for _ in 0..<100 { + let delay = TranscriptionService.reconnectDelay(attempt: 3, maxBackoff: 60, jitterRange: 0.5...1.5) + // Base = 8.0, so range is [4.0, 12.0] + XCTAssertGreaterThanOrEqual(delay, 4.0) + XCTAssertLessThanOrEqual(delay, 12.0) + } + } + + func testAttemptZero() { + // 2^0 = 1.0 + let delay = TranscriptionService.reconnectDelay(attempt: 0, maxBackoff: 60, jitterRange: 1.0...1.0) + XCTAssertEqual(delay, 1.0, accuracy: 0.001) + } +}