diff --git a/middleware/quota.py b/middleware/quota.py index ba18179..0037368 100644 --- a/middleware/quota.py +++ b/middleware/quota.py @@ -359,11 +359,21 @@ async def finalize() -> None: model = usage.get("model", "unknown") parsed: dict | None = None if resp_is_text: + # -- Robustly extract the last JSON object ----------------- text_tail = streamer.get_tail().decode("utf-8", "ignore") - idx = max(text_tail.rfind("{"), text_tail.rfind('{"')) - if idx != -1: + # 1. Split SSE frames if present: keep only the part after the final "data: " + if "data:" in text_tail: + *_, last_frame = text_tail.strip().split("data:") + text_tail = last_frame.strip() + # 2. Strip the trailing '[DONE]' token if it exists + if text_tail.endswith("[DONE]"): + text_tail = text_tail[: text_tail.rfind("[DONE]")].rstrip() + # 3. Find the first '{' from the *left* (because the frame has been cleaned) + brace = text_tail.find("{") + parsed = None + if brace != -1: try: - parsed = json.loads(text_tail[idx:]) + parsed = json.loads(text_tail[brace:]) except Exception: parsed = None if isinstance(parsed, dict): @@ -372,9 +382,10 @@ async def finalize() -> None: u = parsed.get("usage") or {} tokens_out = int(u.get("completion_tokens", 0)) prompt_tokens = int(u.get("prompt_tokens", tokens_in)) - delta_prompt = prompt_tokens - tokens_in - tokens_in = prompt_tokens + delta_prompt = prompt_tokens - tokens_in # adjust quota window + tokens_in = prompt_tokens # ← canonical value usage["tokens_in"] = tokens_in + # Replace provisional count with canonical one in the meter await self.store.adjust(user, delta_prompt) elif "choices" in parsed: msgs = [ @@ -390,15 +401,20 @@ async def finalize() -> None: usage["tokens_out"] = tokens_out usage["model"] = model + # All *out* tokens are new – add them once. await self.store.adjust(user, tokens_out) total = await self.store.peek_total(user) if total > self.max_tokens: usage["detail"] = "token quota exceeded post-stream" - response.headers["x-llm-model"] = model - response.headers["x-tokens-in"] = str(tokens_in) - response.headers["x-tokens-out"] = str(tokens_out) + response.headers.update( + { + "x-llm-model": model, + "x-tokens-in": str(tokens_in), + "x-tokens-out": str(tokens_out), + } + ) usage["ts"] = time.time() await request.app.state.usage.record(**usage)