diff --git a/apps/code/src/main/services/usage-monitor/schemas.ts b/apps/code/src/main/services/usage-monitor/schemas.ts index 2923b1cc4..f3d3f7e45 100644 --- a/apps/code/src/main/services/usage-monitor/schemas.ts +++ b/apps/code/src/main/services/usage-monitor/schemas.ts @@ -1,3 +1,5 @@ +import type { UsageOutput } from "@main/services/llm-gateway/schemas"; +import { usageOutput } from "@main/services/llm-gateway/schemas"; import { z } from "zod"; export const USAGE_THRESHOLDS = [50, 75, 90, 100] as const; @@ -19,10 +21,15 @@ export const thresholdCrossedEvent = z.object({ export type ThresholdCrossedEvent = z.infer; +export const usageSnapshotOutput = usageOutput.nullable(); +export type UsageSnapshot = UsageOutput | null; + export const UsageMonitorEvent = { ThresholdCrossed: "threshold-crossed", + UsageUpdated: "usage-updated", } as const; export interface UsageMonitorEvents { [UsageMonitorEvent.ThresholdCrossed]: ThresholdCrossedEvent; + [UsageMonitorEvent.UsageUpdated]: UsageOutput; } diff --git a/apps/code/src/main/services/usage-monitor/service.test.ts b/apps/code/src/main/services/usage-monitor/service.test.ts index 0e9fcfcbd..9079e28c0 100644 --- a/apps/code/src/main/services/usage-monitor/service.test.ts +++ b/apps/code/src/main/services/usage-monitor/service.test.ts @@ -23,7 +23,7 @@ vi.mock("../../utils/logger.js", () => ({ }, })); -import { LlmGatewayService } from "../llm-gateway/service"; +import type { LlmGatewayService } from "../llm-gateway/service"; import { UsageMonitorService } from "./service"; function makeUsage(overrides?: { @@ -179,4 +179,41 @@ describe("UsageMonitorService", () => { await expect(service.pollOnce()).resolves.toBeNull(); expect(events).toHaveLength(0); }); + + it("emits UsageUpdated and caches the snapshot on every successful poll", async () => { + const updates: UsageOutput[] = []; + const gateway = mockGateway(makeUsage({ burstPercent: 20 })); + service = new UsageMonitorService(gateway); + service.on(UsageMonitorEvent.UsageUpdated, (u) => updates.push(u)); + + expect(service.getLatest()).toBeNull(); + await service.pollOnce(); + expect(updates).toHaveLength(1); + expect(service.getLatest()?.burst.used_percent).toBe(20); + + await service.pollOnce(); + expect(updates).toHaveLength(2); + }); + + it("does not emit UsageUpdated when the gateway throws", async () => { + const updates: UsageOutput[] = []; + const gateway = { + fetchUsage: vi.fn().mockRejectedValue(new Error("offline")), + } as unknown as LlmGatewayService; + service = new UsageMonitorService(gateway); + service.on(UsageMonitorEvent.UsageUpdated, (u) => updates.push(u)); + + await service.pollOnce(); + expect(updates).toHaveLength(0); + expect(service.getLatest()).toBeNull(); + }); + + it("refreshNow triggers a fresh poll and returns the snapshot", async () => { + const gateway = mockGateway(makeUsage({ burstPercent: 42 })); + service = new UsageMonitorService(gateway); + + const result = await service.refreshNow(); + expect(result?.burst.used_percent).toBe(42); + expect(service.getLatest()?.burst.used_percent).toBe(42); + }); }); diff --git a/apps/code/src/main/services/usage-monitor/service.ts b/apps/code/src/main/services/usage-monitor/service.ts index de6611851..90c57da1e 100644 --- a/apps/code/src/main/services/usage-monitor/service.ts +++ b/apps/code/src/main/services/usage-monitor/service.ts @@ -2,8 +2,8 @@ import { inject, injectable, postConstruct, preDestroy } from "inversify"; import { MAIN_TOKENS } from "../../di/tokens"; import { logger } from "../../utils/logger"; import { TypedEventEmitter } from "../../utils/typed-event-emitter"; -import { LlmGatewayService } from "../llm-gateway/service"; import type { UsageBucket, UsageOutput } from "../llm-gateway/schemas"; +import type { LlmGatewayService } from "../llm-gateway/service"; import { USAGE_THRESHOLDS, UsageMonitorEvent, @@ -25,6 +25,7 @@ export class UsageMonitorService extends TypedEventEmitter { // Snapshot of the most recent thresholdsSeen map so we hit electron-store // only when we actually persist a new threshold. private thresholdsSeen: Record; + private latestUsage: UsageOutput | null = null; constructor( @inject(MAIN_TOKENS.LlmGatewayService) @@ -34,6 +35,16 @@ export class UsageMonitorService extends TypedEventEmitter { this.thresholdsSeen = { ...usageMonitorStore.get("thresholdsSeen", {}) }; } + /** Last successful usage snapshot; null until the first poll succeeds. */ + getLatest(): UsageOutput | null { + return this.latestUsage; + } + + /** Trigger an immediate refresh, returning the resulting snapshot. */ + async refreshNow(): Promise { + return this.pollOnce(); + } + @postConstruct() init(): void { this.pruneStaleEntries(); @@ -54,7 +65,11 @@ export class UsageMonitorService extends TypedEventEmitter { this.isPolling = true; try { const usage = await this.fetchUsageQuietly(); - if (usage) this.processUsage(usage); + if (usage) { + this.latestUsage = usage; + this.emit(UsageMonitorEvent.UsageUpdated, usage); + this.processUsage(usage); + } return usage; } finally { this.isPolling = false; @@ -90,14 +105,7 @@ export class UsageMonitorService extends TypedEventEmitter { const isPro = !!usage.billing_period_end; this.maybeEmit(usage, "burst", usage.burst, userId, product, isPro); - this.maybeEmit( - usage, - "sustained", - usage.sustained, - userId, - product, - isPro, - ); + this.maybeEmit(usage, "sustained", usage.sustained, userId, product, isPro); } private maybeEmit( @@ -145,11 +153,7 @@ export class UsageMonitorService extends TypedEventEmitter { usage: UsageOutput, ): string | null { if (bucket === "sustained") { - return ( - usage.billing_period_end ?? - sustainedFreeAnchor(status) ?? - null - ); + return usage.billing_period_end ?? sustainedFreeAnchor(status) ?? null; } return burstAnchor(status); } @@ -212,4 +216,3 @@ function makeKey( ): string { return `${userId}:${product}:${bucket}:${anchor}:${threshold}`; } - diff --git a/apps/code/src/main/trpc/routers/llm-gateway.ts b/apps/code/src/main/trpc/routers/llm-gateway.ts index fe9862e1a..2c0017dde 100644 --- a/apps/code/src/main/trpc/routers/llm-gateway.ts +++ b/apps/code/src/main/trpc/routers/llm-gateway.ts @@ -1,10 +1,6 @@ import { container } from "../../di/container"; import { MAIN_TOKENS } from "../../di/tokens"; -import { - promptInput, - promptOutput, - usageOutput, -} from "../../services/llm-gateway/schemas"; +import { promptInput, promptOutput } from "../../services/llm-gateway/schemas"; import type { LlmGatewayService } from "../../services/llm-gateway/service"; import { publicProcedure, router } from "../trpc"; @@ -23,10 +19,6 @@ export const llmGatewayRouter = router({ }), ), - usage: publicProcedure - .output(usageOutput) - .query(() => getService().fetchUsage()), - invalidatePlanCache: publicProcedure.mutation(() => getService().invalidatePlanCache(), ), diff --git a/apps/code/src/main/trpc/routers/usage-monitor.ts b/apps/code/src/main/trpc/routers/usage-monitor.ts index d103612db..ef3c63ab9 100644 --- a/apps/code/src/main/trpc/routers/usage-monitor.ts +++ b/apps/code/src/main/trpc/routers/usage-monitor.ts @@ -3,6 +3,7 @@ import { MAIN_TOKENS } from "../../di/tokens"; import { UsageMonitorEvent, type UsageMonitorEvents, + usageSnapshotOutput, } from "../../services/usage-monitor/schemas"; import type { UsageMonitorService } from "../../services/usage-monitor/service"; import { publicProcedure, router } from "../trpc"; @@ -22,4 +23,14 @@ function subscribe(event: K) { export const usageMonitorRouter = router({ onThresholdCrossed: subscribe(UsageMonitorEvent.ThresholdCrossed), + // Stream of full usage snapshots — replaces the renderer's 30s poll. + onUsageUpdated: subscribe(UsageMonitorEvent.UsageUpdated), + // Cached snapshot for the renderer to bootstrap before the first event + // arrives. Null until the first poll completes. + getLatest: publicProcedure + .output(usageSnapshotOutput) + .query(() => getService().getLatest()), + refresh: publicProcedure + .output(usageSnapshotOutput) + .mutation(() => getService().refreshNow()), }); diff --git a/apps/code/src/renderer/features/billing/hooks/useUsage.ts b/apps/code/src/renderer/features/billing/hooks/useUsage.ts index a48e426af..82b48274c 100644 --- a/apps/code/src/renderer/features/billing/hooks/useUsage.ts +++ b/apps/code/src/renderer/features/billing/hooks/useUsage.ts @@ -1,21 +1,50 @@ import { useTRPC } from "@renderer/trpc"; -import { useRendererWindowFocusStore } from "@stores/rendererWindowFocusStore"; -import { useQuery } from "@tanstack/react-query"; - -const USAGE_REFETCH_INTERVAL_MS = 30_000; +import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; +import { useSubscription } from "@trpc/tanstack-react-query"; +import { useCallback } from "react"; +/** + * Subscribe to usage snapshots pushed by the main-process `UsageMonitorService`. + * Avoids the renderer doing its own gateway polling — the service is the single + * source of truth and we just consume what it broadcasts every ~30s. + */ export function useUsage({ enabled = true }: { enabled?: boolean } = {}) { const trpc = useTRPC(); - const focused = useRendererWindowFocusStore((s) => s.focused); - const { - data: usage, - isLoading, - refetch, - } = useQuery({ - ...trpc.llmGateway.usage.queryOptions(), + const queryClient = useQueryClient(); + const query = useQuery({ + ...trpc.usageMonitor.getLatest.queryOptions(), enabled, - refetchInterval: focused && enabled ? USAGE_REFETCH_INTERVAL_MS : false, - refetchIntervalInBackground: false, }); - return { usage: usage ?? null, isLoading, refetch }; + const refreshMutation = useMutation( + trpc.usageMonitor.refresh.mutationOptions(), + ); + + useSubscription( + trpc.usageMonitor.onUsageUpdated.subscriptionOptions(undefined, { + enabled, + onData: (data) => { + queryClient.setQueryData( + trpc.usageMonitor.getLatest.queryKey(), + data, + ); + }, + }), + ); + + const refetch = useCallback(async () => { + const fresh = await refreshMutation.mutateAsync(); + if (fresh) { + queryClient.setQueryData( + trpc.usageMonitor.getLatest.queryKey(), + fresh, + ); + } + return fresh; + }, [refreshMutation, queryClient, trpc.usageMonitor.getLatest]); + + return { + usage: query.data ?? null, + isLoading: query.isLoading, + refetch, + }; }