diff --git a/desktop/Desktop/Sources/PendingSaveCounter.swift b/desktop/Desktop/Sources/PendingSaveCounter.swift new file mode 100644 index 00000000000..48aa9800081 --- /dev/null +++ b/desktop/Desktop/Sources/PendingSaveCounter.swift @@ -0,0 +1,76 @@ +import Foundation + +/// Counter-based multi-holder gate for tracking in-flight persistence +/// operations. Companion to `ReentrancyGate` — that type is single-entry +/// (one holder at a time); this one allows multiple concurrent holders +/// and reports "is anything in flight right now?" via `isActive`. +/// +/// Used by `ChatProvider` to prevent the cross-platform message poll +/// from running while any `saveMessage(...)` is mid-flight. The poll +/// reads backend state to detect messages sent from other devices; if +/// it fires between a local save's request and its response, it can +/// observe the just-saved message and treat it as new. The existing +/// 200-char text-prefix merge at `pollForNewMessages` catches most of +/// these, but a counter-based suppression is defense-in-depth — +/// eliminates the race window entirely instead of relying on text +/// heuristics that fail on short common replies ("Yes", "Got it"). +/// +/// Caller contract: +/// ```swift +/// counter.begin() +/// Task { +/// do { +/// let response = try await APIClient.shared.saveMessage(...) +/// await MainActor.run { +/// // … sync state update … +/// self.counter.end() +/// } +/// } catch { +/// await MainActor.run { self.counter.end() } +/// logError(...) +/// } +/// } +/// ``` +/// +/// Both success and failure paths MUST call `end()`. Missing an `end()` +/// causes the counter to leak upward and permanently suppresses the +/// poll. `end()` is no-op when the counter is already at 0, so an +/// extra (defensive) `end()` is safe but masks bugs — prefer matched +/// pairs. +/// +/// Tested in `PendingSaveCounterTests`. +@MainActor +final class PendingSaveCounter { + private var count: Int = 0 + + /// Invoked each time the count returns to 0 (the last in-flight save + /// completed). Lets the owner re-run any work that was suppressed + /// while saves were active — e.g. a `pollForNewMessages` cycle that + /// was deferred so it wouldn't observe a half-saved message. + var onDrained: (() -> Void)? + + /// True when at least one save is in flight. + var isActive: Bool { count > 0 } + + /// Visible only for tests. Production code should compare against + /// `isActive` rather than reading the raw count. + var currentCount: Int { count } + + /// Increment the count. Call before launching a save Task (or + /// before `await`ing the inline save). + func begin() { + count += 1 + } + + /// Decrement the count. Bounded at 0 — stray calls cannot drive + /// the counter negative, which would otherwise permanently + /// indicate "no saves in flight" even when there are. The `assert` + /// surfaces an unbalanced `begin()`/`end()` pair in debug builds + /// (zero cost in release) instead of failing silently. + func end() { + assert(count > 0, "PendingSaveCounter: unbalanced end() — no matching begin()") + guard count > 0 else { return } + count -= 1 + if count == 0 { onDrained?() } + } +} diff --git a/desktop/Desktop/Sources/Providers/ChatProvider.swift b/desktop/Desktop/Sources/Providers/ChatProvider.swift index a26434f5ffd..593b6e634bf 100644 --- a/desktop/Desktop/Sources/Providers/ChatProvider.swift +++ b/desktop/Desktop/Sources/Providers/ChatProvider.swift @@ -686,6 +686,22 @@ BROWSER TABS: when you use the browser (Playwright), on your FIRST browser actio init() { log("ChatProvider initialized, will start Claude bridge on first use") + // When the last in-flight save completes, re-run any poll cycle + // that was deferred while saves were active. Keeps suppression + // from permanently dropping a fetch of other-platform messages. + // + // The flag is intentionally NOT cleared here — only + // `pollForNewMessages` clears it, and only once it actually gets + // past its guards and commits to a fetch. Otherwise a retry that + // bails again (e.g. on `isSending` because the next turn is mid- + // stream) would drop the deferral permanently; leaving the flag + // set lets the next drain (e.g. the AI-response save) retry once + // sending has finished. + pendingSaves.onDrained = { [weak self] in + guard let self, self.pollDeferredDuringSave else { return } + Task { [weak self] in await self?.pollForNewMessages() } + } + // Migrate legacy "agentSDK" persisted mode to the new default "piMono". // Pre-6594 installs may have the old agentSDK tag saved; the settings // picker no longer offers it, so leaving it stored would leave the UI @@ -2114,6 +2130,27 @@ BROWSER TABS: when you use the browser (Playwright), on your FIRST browser actio /// Prevents overlapping fetches when activation + Cmd+R fire back-to-back. private let pollGate = ReentrancyGate() + /// Defense-in-depth against the saveMessage / pollForNewMessages + /// race. `isSending` is released *before* the AI message save + /// completes (intentional — to unblock the next query), which opens a + /// window where the poll can observe the just-saved AI message and + /// treat it as new-from-another-platform. The existing 200-char + /// text-prefix merge at `pollForNewMessages` catches most of these, + /// but a counter-based suppression eliminates the race window + /// entirely instead of relying on text heuristics that fail on short + /// common replies ("Yes", "Got it"). Every saveMessage call site + /// begins/ends the counter; the poll skips when the counter is + /// active. Sites are documented inline at each `saveMessage(...)` call. + private let pendingSaves = PendingSaveCounter() + + /// Set when a `pollForNewMessages` cycle bailed *because* a save was + /// in flight. `pollForNewMessages` is only triggered by activation / + /// Cmd+R (there is no periodic poll), so a dropped cycle would leave + /// messages from other platforms unfetched until the next activation. + /// `pendingSaves.onDrained` re-runs the poll once saves finish, but + /// only when this flag says one was actually deferred. + private var pollDeferredDuringSave = false + /// Fetch new messages from other platforms (e.g. mobile). /// Merges new messages into the existing array without disrupting the UI. private func pollForNewMessages() async { @@ -2127,12 +2164,29 @@ BROWSER TABS: when you use the browser (Playwright), on your FIRST browser actio // Skip if we're actively sending. Note: isSending is released *before* the AI // message is saved to the backend (to unblock the next query). This means the // poll can run while saveMessage() is still in-flight — see the race note below. + // + // `pendingSaves.isActive` closes the same race window from the save side + // — any in-flight saveMessage (user msg, AI msg, follow-up, partial-on-error, + // proactive notification) keeps the poll suppressed until it lands. This is + // defense-in-depth over the 200-char text-prefix merge below at lines ~2192. guard !isSending, !isLoading, !isLoadingSessions else { return } + // A save in flight means a local message hasn't reconciled its + // server ID yet — defer rather than risk observing it as new. + // Mark the cycle deferred so `pendingSaves.onDrained` re-runs it. + guard !pendingSaves.isActive else { pollDeferredDuringSave = true; return } // Skip if messages haven't been loaded yet (initial load not done) guard !messages.isEmpty || sessionsLoadError != nil else { return } // Skip if there's an active streaming message guard !messages.contains(where: { $0.isStreaming }) else { return } + // Past all the deferral-relevant guards — this cycle is actually + // going to fetch, so any pending deferral is now being honored. + // Cleared HERE (not in onDrained) so a retry that bailed earlier + // on `isSending`/streaming keeps the flag set and gets retried by + // the next drain. The post-fetch recheck below re-sets it if a + // save sneaks in during getMessages. + pollDeferredDuringSave = false + do { let persistedMessages: [ChatMessageDB] @@ -2150,6 +2204,19 @@ BROWSER TABS: when you use the browser (Playwright), on your FIRST browser actio ) } + // A save may have begun *while* getMessages was awaiting — e.g. + // a proactive assistant message appended via appendAssistantMessage + // (FloatingControlBarWindow) after this poll already passed the + // pendingSaves guard above. That message can be in the batch we + // just fetched, carrying a server ID the local copy hasn't adopted + // yet. Re-check here and bail this cycle; the next poll after the + // save lands reconciles it by ID. Without this, the post-guard + // window stays open for the proactive paths. Mark the cycle + // deferred so the drain handler re-runs it — otherwise the + // just-fetched batch (including any genuine new messages from + // other platforms) would be dropped until the next activation. + guard !pendingSaves.isActive else { pollDeferredDuringSave = true; return } + // Build a lookup of existing IDs for fast O(1) checks. let existingIds = Set(messages.map(\.id)) @@ -2248,10 +2315,15 @@ BROWSER TABS: when you use the browser (Playwright), on your FIRST browser actio ) messages.append(userMessage) - // Persist to backend and sync server ID back to prevent poll duplicates + // Persist to backend and sync server ID back to prevent poll duplicates. + // + // saveMessage site 1 of 5: user follow-up message sent + // mid-query. Fire-and-forget Task. `pendingSaves` guards the + // poll for the lifetime of this save. let capturedSessionId = isInDefaultChat ? nil : currentSessionId let capturedAppId = overrideAppId ?? selectedAppId let localId = userMessage.id + pendingSaves.begin() Task { [weak self] in do { let response = try await APIClient.shared.saveMessage( @@ -2265,9 +2337,11 @@ BROWSER TABS: when you use the browser (Playwright), on your FIRST browser actio self?.messages[index].id = response.id self?.messages[index].isSynced = true } + self?.pendingSaves.end() } log("Saved follow-up message to backend: \(response.id)") } catch { + await MainActor.run { self?.pendingSaves.end() } logError("Failed to persist follow-up message", error: error) } } @@ -2292,6 +2366,10 @@ BROWSER TABS: when you use the browser (Playwright), on your FIRST browser actio messages.append(aiMessage) + // saveMessage site 2 of 5: AI message synthesized from a + // proactive notification (no bridge query, no streaming). + // Fire-and-forget Task. + pendingSaves.begin() Task { [weak self] in do { let response = try await APIClient.shared.saveMessage( @@ -2305,9 +2383,11 @@ BROWSER TABS: when you use the browser (Playwright), on your FIRST browser actio self?.messages[index].id = response.id self?.messages[index].isSynced = true } + self?.pendingSaves.end() } log("Saved assistant message to backend: \(response.id)") } catch { + await MainActor.run { self?.pendingSaves.end() } logError("Failed to persist assistant message", error: error) } } @@ -2546,6 +2626,13 @@ BROWSER TABS: when you use the browser (Playwright), on your FIRST browser actio let capturedSessionId = sessionId let capturedAppId = overrideAppId ?? selectedAppId if !isFollowUp { + // saveMessage site 3 of 5: user message at turn start. + // Fire-and-forget Task launched before the bridge query so + // it doesn't block streaming. `isSending` already gates the + // poll until the AI response lands, but `pendingSaves` + // provides defense-in-depth in case the save outlives the + // bridge query (slow backend, retry, etc.). + pendingSaves.begin() Task { [weak self] in do { let response = try await APIClient.shared.saveMessage( @@ -2562,9 +2649,11 @@ BROWSER TABS: when you use the browser (Playwright), on your FIRST browser actio self?.messages[index].id = response.id self?.messages[index].isSynced = true } + self?.pendingSaves.end() } log("Saved user message to backend: \(response.id)") } catch { + await MainActor.run { self?.pendingSaves.end() } logError("Failed to persist user message", error: error) // Non-critical - continue with chat } @@ -2816,6 +2905,23 @@ BROWSER TABS: when you use the browser (Playwright), on your FIRST browser actio // before this update runs. let textToSave = queryResult.text.isEmpty ? messageText : queryResult.text if !textToSave.isEmpty { + // saveMessage site 4 of 5 (THE CRITICAL ONE): AI + // response on the success path. `isSending=false` was + // already released a few lines above to unblock the + // next query, so the poll could fire DURING this await + // and observe the just-saved AI message before the + // local UUID has been updated to the server ID below. + // The counter closes that window — `pendingSaves` + // stays active until the save lands AND the in-memory + // ID has been synced. The pre-existing 200-char + // text-prefix merge at `pollForNewMessages` stays as + // a secondary safety net. + // `defer` guarantees the counter is released on every exit + // path — success, throw, or any future early return added + // inside this block — so a missed `end()` can't permanently + // suppress the poll. + pendingSaves.begin() + defer { pendingSaves.end() } do { let toolMetadata = serializeToolCallMetadata(messageId: aiMessageId) let response = try await APIClient.shared.saveMessage( @@ -2928,9 +3034,14 @@ BROWSER TABS: when you use the browser (Playwright), on your FIRST browser actio messages[index].isStreaming = false completeRemainingToolCalls(messageId: aiMessageId) log("Bridge error after partial response — keeping \(messages[index].text.count) chars of streamed text") - // Still try to persist the partial response + // Still try to persist the partial response. + // + // saveMessage site 5 of 5: partial AI + // response after a bridge error. Fire-and-forget + // Task; same counter pattern as the other sites. let partialText = messages[index].text let partialToolMetadata = self.serializeToolCallMetadata(messageId: aiMessageId) + pendingSaves.begin() Task { [weak self] in do { let response = try await APIClient.shared.saveMessage( @@ -2945,9 +3056,11 @@ BROWSER TABS: when you use the browser (Playwright), on your FIRST browser actio self?.messages[syncIndex].id = response.id self?.messages[syncIndex].isSynced = true } + self?.pendingSaves.end() } log("Saved partial AI response to backend: \(response.id)") } catch { + await MainActor.run { self?.pendingSaves.end() } logError("Failed to persist partial AI response", error: error) } } diff --git a/desktop/Desktop/Tests/PendingSaveCounterTests.swift b/desktop/Desktop/Tests/PendingSaveCounterTests.swift new file mode 100644 index 00000000000..32b553ec2d8 --- /dev/null +++ b/desktop/Desktop/Tests/PendingSaveCounterTests.swift @@ -0,0 +1,115 @@ +import XCTest + +@testable import Omi_Computer + +/// Contract tests for `PendingSaveCounter` — the synchronization +/// primitive that gates `pollForNewMessages` against in-flight +/// `saveMessage(...)` calls. +@MainActor +final class PendingSaveCounterTests: XCTestCase { + + func testFreshCounterIsInactive() { + let counter = PendingSaveCounter() + XCTAssertFalse(counter.isActive) + XCTAssertEqual(counter.currentCount, 0) + } + + func testBeginActivatesCounter() { + let counter = PendingSaveCounter() + counter.begin() + XCTAssertTrue(counter.isActive) + XCTAssertEqual(counter.currentCount, 1) + } + + func testEndDecrementsCounter() { + let counter = PendingSaveCounter() + counter.begin() + counter.end() + XCTAssertFalse(counter.isActive) + XCTAssertEqual(counter.currentCount, 0) + } + + /// Multiple sites can hold the counter simultaneously. This mirrors + /// production: `sendMessage` saves both the user message and the AI + /// response, plus a partial-save path on error. Concurrent saves + /// must all suppress the poll until the last one completes. + func testMultipleHoldersStack() { + let counter = PendingSaveCounter() + counter.begin() + counter.begin() + counter.begin() + XCTAssertEqual(counter.currentCount, 3) + XCTAssertTrue(counter.isActive) + + counter.end() + XCTAssertEqual(counter.currentCount, 2) + XCTAssertTrue(counter.isActive, "still active until all holders release") + + counter.end() + counter.end() + XCTAssertEqual(counter.currentCount, 0) + XCTAssertFalse(counter.isActive) + } + + /// The counter must not get "stuck" or go negative across balanced + /// use — after equal begins and ends it returns cleanly to zero and + /// a fresh round still activates it. (Calling `end()` without a + /// matching `begin()` is a programmer error caught by an `assert` in + /// debug builds; the release-build guard still bounds it at zero.) + func testCounterReturnsToZeroAndStaysUsableAcrossRounds() { + let counter = PendingSaveCounter() + counter.begin() + counter.begin() + counter.end() + counter.end() + XCTAssertEqual(counter.currentCount, 0) + XCTAssertFalse(counter.isActive) + + // A subsequent round must still activate — no stuck state. + counter.begin() + XCTAssertTrue(counter.isActive) + counter.end() + XCTAssertFalse(counter.isActive) + } + + /// `onDrained` fires exactly when the last holder releases (count + /// returns to 0), not on intermediate `end()` calls. This is what + /// lets the owner re-run a poll cycle that was deferred while saves + /// were in flight. + func testOnDrainedFiresOnlyWhenCountReturnsToZero() { + let counter = PendingSaveCounter() + var drains = 0 + counter.onDrained = { drains += 1 } + + counter.begin() + counter.begin() + counter.end() + XCTAssertEqual(drains, 0, "still one holder — must not fire yet") + counter.end() + XCTAssertEqual(drains, 1, "last holder released — fires once") + + // A fresh round fires again. + counter.begin() + counter.end() + XCTAssertEqual(drains, 2) + } + + /// Production usage interleaves begin/end across overlapping save + /// Tasks. Verify the counter behaves correctly when begins and + /// ends arrive out of original order. + func testInterleavedBeginAndEnd() { + let counter = PendingSaveCounter() + // Site A starts + counter.begin() + XCTAssertTrue(counter.isActive) + // Site B starts before A finishes + counter.begin() + XCTAssertEqual(counter.currentCount, 2) + // Site A finishes first + counter.end() + XCTAssertTrue(counter.isActive, "B is still in flight") + // Site B finishes + counter.end() + XCTAssertFalse(counter.isActive) + } +}