From 929bc0719460f7eb91f64b602e830c33967ac4b2 Mon Sep 17 00:00:00 2001 From: Hammad Tariq Date: Mon, 21 Jul 2025 11:25:24 +0500 Subject: [PATCH] Add mid-stream quota log and expand tail buffer --- middleware/quota.py | 411 ++++++++++++++++++++++++++++++++------------ 1 file changed, 300 insertions(+), 111 deletions(-) diff --git a/middleware/quota.py b/middleware/quota.py index d5026d3..ba18179 100644 --- a/middleware/quota.py +++ b/middleware/quota.py @@ -1,41 +1,53 @@ -"""Token quota middleware for Attach Gateway. - -Enforces per-user token limits using a sliding 1-minute window. The -quota is applied to both the request body and the response body. -""" - from __future__ import annotations +"""Token quota middleware enforcing a sliding window budget.""" + +import json import logging import os -import sys import time from collections import deque -from typing import Deque, Dict, Optional, Protocol, Tuple +from typing import ( + AsyncIterator, + Awaitable, + Callable, + Deque, + Dict, + Iterable, + Protocol, + Tuple, +) from uuid import uuid4 from fastapi import Request from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import JSONResponse, StreamingResponse +try: + import tiktoken # type: ignore +except Exception: # pragma: no cover + tiktoken = None + from usage.backends import NullUsageBackend logger = logging.getLogger(__name__) class AbstractMeterStore(Protocol): - """Interface for token accounting backends.""" + """Token accounting backend.""" async def increment(self, user: str, tokens: int) -> Tuple[int, float]: - """Return the running total and timestamp of the oldest entry.""" + """Increment ``user``'s counter and return ``(total, oldest_ts)``.""" + async def adjust(self, user: str, delta: int) -> Tuple[int, float]: + """Atomically adjust ``user``'s total by ``delta``.""" -class InMemoryMeterStore: - """Simple in-memory sliding window counter. + async def peek_total(self, user: str) -> int: + """Return ``user``'s current total without mutating state.""" - Not safe for multi-process deployments; each process keeps its own - counters. Use :class:`RedisMeterStore` in production. - """ + +class InMemoryMeterStore: + """Simple in-memory meter for tests and development.""" def __init__(self, window: int = 60) -> None: self.window = window @@ -52,11 +64,21 @@ async def increment(self, user: str, tokens: int) -> Tuple[int, float]: oldest = dq[0][0] if dq else now return total, oldest + async def adjust(self, user: str, delta: int) -> Tuple[int, float]: + return await self.increment(user, delta) + + async def peek_total(self, user: str) -> int: + total, _ = await self.increment(user, 0) + dq = self._data.get(user) + if total and dq: + dq.pop() + return total + class RedisMeterStore: - """Redis backed sliding window counter.""" + """Redis backed sliding window meter.""" - def __init__(self, url: str = "redis://localhost:6379", window: int = 60) -> None: + def __init__(self, url: str, window: int = 60) -> None: import redis.asyncio as redis # type: ignore self.window = window @@ -74,73 +96,200 @@ async def increment(self, user: str, tokens: int) -> Tuple[int, float]: entries = results[-1] total = 0 oldest = now - for member, ts in entries: + for m, ts in entries: try: - _, tok = member.split(":", 1) + _, tok = m.split(":", 1) total += int(tok) except Exception: pass oldest = min(oldest, ts) return total, oldest + async def adjust(self, user: str, delta: int) -> Tuple[int, float]: + now = time.time() + key = f"attach:quota:{user}" + member = f"{now}:{delta}" + async with self.redis.pipeline(transaction=True) as pipe: + await pipe.zadd(key, {member: now}) + await pipe.zremrangebyscore(key, 0, now - self.window) + await pipe.zrange(key, 0, -1, withscores=True) + results = await pipe.execute() + entries = results[-1] + total = 0 + oldest = now + for m, ts in entries: + try: + _, tok = m.split(":", 1) + total += int(tok) + except Exception: + pass + oldest = min(oldest, ts) + return total, oldest -class TokenQuotaMiddleware(BaseHTTPMiddleware): - """FastAPI middleware enforcing per-user token quotas.""" + async def peek_total(self, user: str) -> int: + now = time.time() + key = f"attach:quota:{user}" + member = f"{now}:0:{uuid4().hex}" + async with self.redis.pipeline(transaction=True) as pipe: + await pipe.zadd(key, {member: now}) + await pipe.zremrangebyscore(key, 0, now - self.window) + await pipe.zrange(key, 0, -1, withscores=True) + pipe.zrem(key, member) + results = await pipe.execute() + entries = results[2] + total = 0 + for m, _ in entries: + try: + _, tok = m.split(":", 1) + total += int(tok) + except Exception: + pass + return total - def __init__(self, app, store: Optional[AbstractMeterStore] = None) -> None: - """Create middleware. - Must be added **after** :func:`session_mw` so the ``X-Attach-User`` - header or client IP is available for quota tracking. - """ - super().__init__(app) - self.store = store or InMemoryMeterStore() - self.window = 60 - self.max_tokens = int(os.getenv("MAX_TOKENS_PER_MIN", "60000")) - enc_name = os.getenv("QUOTA_ENCODING", "cl100k_base") - try: - import tiktoken +def _is_textual(mime: str) -> bool: + mime = (mime or "").lower() + if not mime or mime == "*/*": + return False + return mime.startswith("text/") or "json" in mime + + +def _encoder_for_model(model: str): + """Return a tiktoken encoder, falling back to byte count.""" + if tiktoken is None: # pragma: no cover - fallback + logger.warning( + "tiktoken missing – using byte-count fallback; token metrics may be inflated" + ) - self.encoder = tiktoken.get_encoding(enc_name) - except Exception: # pragma: no cover - fallback - if "tiktoken" not in sys.modules: - logger.warning( - "tiktoken missing – using byte-count fallback; token metrics may be inflated" - ) + class _Simple: + def encode(self, text: str) -> list[int]: + return list(text.encode()) + + return _Simple() + + try: + return tiktoken.encoding_for_model(model) + except Exception: + try: + return tiktoken.get_encoding("cl100k_base") + except Exception: class _Simple: def encode(self, text: str) -> list[int]: return list(text.encode()) - self.encoder = _Simple() + return _Simple() + + +def _num_tokens(text: str, model: str = "cl100k_base") -> int: + return len(_encoder_for_model(model).encode(text)) - @staticmethod - def _is_textual(mime: str) -> bool: - return mime.startswith("text/") or "json" in mime - def _num_tokens(self, text: str) -> int: - return len(self.encoder.encode(text)) +def num_tokens_from_messages(messages: Iterable[dict], model: str) -> int: + enc = _encoder_for_model(model) + total = 3 + for msg in messages: + total += 4 + for k, v in msg.items(): + total += len(enc.encode(str(v))) + if k == "name": + total -= 1 + return total + + +async def async_iter(data: Iterable[bytes]) -> AsyncIterator[bytes]: + for chunk in data: + yield chunk + + +_SKIP_PATHS = { + "/metrics", + "/mem/events", + "/auth/config", + "/health", + "/docs", + "/redoc", + "/openapi.json", +} + + +class TokenQuotaMiddleware(BaseHTTPMiddleware): + """Apply per-user LLM token quotas.""" + + def __init__(self, app, store: AbstractMeterStore | None = None) -> None: + super().__init__(app) + self.window = int(os.getenv("WINDOW", "60")) + self.max_tokens = int(os.getenv("MAX_TOKENS_PER_MIN", "60000")) + if store is not None: + self.store = store + else: + redis_url = os.getenv("REDIS_URL") + self.store = ( + RedisMeterStore(redis_url, self.window) + if redis_url + else InMemoryMeterStore(self.window) + ) - @staticmethod - def _is_monitoring_endpoint(path: str) -> bool: - """Monitoring endpoints that should skip token-based rate limiting.""" - monitoring_paths = [ - "/metrics", # Prometheus metrics - "/mem/events", # Memory events - "/auth/config", # Auth configuration - ] - return any(path.startswith(p) for p in monitoring_paths) + class _Streamer: + def __init__( + self, + iterator: AsyncIterator[bytes], + *, + user: str, + store: AbstractMeterStore, + max_tokens: int, + is_textual: bool, + ) -> None: + self.iterator = iterator + self.tail = bytearray() + self.on_complete: Callable[[], Awaitable[None]] | None = None + self.user = user + self.store = store + self.limit = max_tokens + self.is_textual = is_textual + self.quota_exceeded = False + + def __aiter__(self) -> AsyncIterator[bytes]: + return self._gen() + + async def _gen(self) -> AsyncIterator[bytes]: + try: + async for chunk in self.iterator: + self.tail.extend(chunk) + if len(self.tail) > 8192: + del self.tail[:-8192] + if self.limit: + chunk_tokens = ( + _num_tokens(chunk.decode("utf-8", "ignore")) + if self.is_textual + else 0 + ) + if chunk_tokens: + total, _ = await self.store.adjust(self.user, chunk_tokens) + if total > self.limit: + self.quota_exceeded = True + logger.warning( + "User %s quota breached mid-stream", self.user + ) + return + yield chunk + finally: + if self.on_complete: + await self.on_complete() + + def get_tail(self) -> bytes: + return bytes(self.tail) async def dispatch(self, request: Request, call_next): - # Skip token-based rate limiting for monitoring endpoints - if self._is_monitoring_endpoint(request.url.path): + if any(request.url.path.startswith(p) for p in _SKIP_PATHS): return await call_next(request) - + if not hasattr(request.app.state, "usage"): request.app.state.usage = NullUsageBackend() - user = request.headers.get("x-attach-user") - if not user: - user = request.client.host if request.client else "unknown" + + user = request.headers.get("x-attach-user") or ( + request.client.host if request.client else "unknown" + ) usage = { "user": user, "project": request.headers.get("x-attach-project", "default"), @@ -149,16 +298,27 @@ async def dispatch(self, request: Request, call_next): "model": "unknown", "request_id": request.headers.get("x-request-id") or str(uuid4()), } + req_is_text = _is_textual(request.headers.get("content-type", "")) + + tokens_in = 0 + if req_is_text: + raw = await request.body() + try: + payload = json.loads(raw.decode()) + except Exception: + payload = None + if isinstance(payload, dict) and "messages" in payload: + model = payload.get("model", "cl100k_base") + usage["model"] = model + tokens_in = num_tokens_from_messages(payload.get("messages", []), model) + else: + tokens_in = _num_tokens(raw.decode("utf-8", "ignore")) + + usage["tokens_in"] = tokens_in - body = await request.body() - content_type = request.headers.get("content-type", "") - tokens = 0 - if self._is_textual(content_type): - tokens = self._num_tokens(body.decode("utf-8", "ignore")) - usage["tokens_in"] = tokens - total, oldest = await self.store.increment(user, tokens) - request.state._usage = usage + total, oldest = await self.store.adjust(user, tokens_in) if total > self.max_tokens: + await self.store.adjust(user, -tokens_in) retry_after = max(0, int(self.window - (time.time() - oldest))) usage["ts"] = time.time() await request.app.state.usage.record(**usage) @@ -167,53 +327,82 @@ async def dispatch(self, request: Request, call_next): status_code=429, ) - response = await call_next(request) - usage["model"] = response.headers.get("x-llm-model", "unknown") + resp = await call_next(request) - first_chunk = None - try: - first_chunk = await response.body_iterator.__anext__() - except StopAsyncIteration: - pass - - media = response.media_type or response.headers.get("content-type", "") - if first_chunk is not None and self._is_textual(media): - tokens_chunk = self._num_tokens(first_chunk.decode("utf-8", "ignore")) - total, oldest = await self.store.increment(user, tokens_chunk) - if total > self.max_tokens: - retry_after = max(0, int(self.window - (time.time() - oldest))) - usage["tokens_out"] += tokens_chunk + media = resp.media_type or resp.headers.get("content-type", "") + resp_is_text = _is_textual(media) + streamer = self._Streamer( + resp.body_iterator, + user=user, + store=self.store, + max_tokens=self.max_tokens, + is_textual=resp_is_text, + ) + + headers = dict(resp.headers) + headers.pop("content-length", None) + response = StreamingResponse( + content=streamer, + status_code=resp.status_code, + headers=headers, + media_type=resp.media_type, + ) + + async def finalize() -> None: + nonlocal tokens_in + if getattr(streamer, "quota_exceeded", False): + usage["detail"] = "token quota exceeded mid-stream" usage["ts"] = time.time() await request.app.state.usage.record(**usage) - return JSONResponse( - {"detail": "token quota exceeded", "retry_after": retry_after}, - status_code=429, - ) - usage["tokens_out"] += tokens_chunk - - async def stream_with_quota(): - nonlocal total, oldest - if first_chunk is not None: - yield first_chunk - async for chunk in response.body_iterator: - tokens_chunk = 0 - media = response.media_type or response.headers.get("content-type", "") - if self._is_textual(media): - tokens_chunk = self._num_tokens(chunk.decode("utf-8", "ignore")) - next_total, oldest = await self.store.increment(user, tokens_chunk) - if next_total > self.max_tokens: - break - total = next_total - if tokens_chunk: - usage["tokens_out"] += tokens_chunk - yield chunk + return + tokens_out = 0 + model = usage.get("model", "unknown") + parsed: dict | None = None + if resp_is_text: + text_tail = streamer.get_tail().decode("utf-8", "ignore") + idx = max(text_tail.rfind("{"), text_tail.rfind('{"')) + if idx != -1: + try: + parsed = json.loads(text_tail[idx:]) + except Exception: + parsed = None + if isinstance(parsed, dict): + model = parsed.get("model", model) + if "usage" in parsed: + u = parsed.get("usage") or {} + tokens_out = int(u.get("completion_tokens", 0)) + prompt_tokens = int(u.get("prompt_tokens", tokens_in)) + delta_prompt = prompt_tokens - tokens_in + tokens_in = prompt_tokens + usage["tokens_in"] = tokens_in + await self.store.adjust(user, delta_prompt) + elif "choices" in parsed: + msgs = [ + (c.get("message") or {}).get("content", "") + for c in parsed.get("choices", []) + ] + tokens_out = num_tokens_from_messages( + [{"content": m} for m in msgs], model + ) + if tokens_out == 0: + tokens_out = _num_tokens(text_tail) + + usage["tokens_out"] = tokens_out + usage["model"] = model + + await self.store.adjust(user, tokens_out) + + total = await self.store.peek_total(user) + if total > self.max_tokens: + usage["detail"] = "token quota exceeded post-stream" + + response.headers["x-llm-model"] = model + response.headers["x-tokens-in"] = str(tokens_in) + response.headers["x-tokens-out"] = str(tokens_out) usage["ts"] = time.time() await request.app.state.usage.record(**usage) + logger.info(json.dumps(usage)) - return StreamingResponse( - stream_with_quota(), - status_code=response.status_code, - headers=dict(response.headers), - media_type=response.media_type, - ) + streamer.on_complete = finalize + return response