diff --git a/README.md b/README.md index 8ed1f1e..6d8096f 100644 --- a/README.md +++ b/README.md @@ -234,6 +234,18 @@ curl -X POST /v1/logs \ # => HTTP/1.1 202 Accepted ``` +## Usage hooks + +Emit token usage metrics for every request. Choose a backend via +`USAGE_BACKEND`: + +```bash +export USAGE_BACKEND=prometheus # or openmeter/null +``` + +A Prometheus counter `attach_usage_tokens_total{user,direction,model}` is +exposed for Grafana dashboards. + ## Token quotas Attach Gateway can enforce per-user token limits. Install the optional @@ -244,6 +256,9 @@ counter defaults to the `cl100k_base` encoding; override with in-memory store works in a single process and is not shared between workers—requests retried across processes may be double-counted. Use Redis for production deployments. +If `tiktoken` is missing, a byte-count fallback is used which counts about +four times more tokens than the `cl100k` tokenizer – install `tiktoken` in +production. ### Enable token quotas diff --git a/attach/gateway.py b/attach/gateway.py index 2434389..be7cfbe 100644 --- a/attach/gateway.py +++ b/attach/gateway.py @@ -21,6 +21,7 @@ from middleware.quota import TokenQuotaMiddleware from middleware.session import session_mw from proxy.engine import router as proxy_router +from usage.factory import get_usage_backend # Import version from parent package from . import __version__ @@ -144,5 +145,6 @@ def create_app(config: Optional[AttachConfig] = None) -> FastAPI: memory_backend = get_memory_backend(config.mem_backend, config) app.state.memory = memory_backend app.state.config = config + app.state.usage = get_usage_backend(os.getenv("USAGE_BACKEND", "null")) return app diff --git a/main.py b/main.py index f10e823..f5007e1 100644 --- a/main.py +++ b/main.py @@ -13,10 +13,12 @@ from middleware.auth import jwt_auth_mw # ← your auth middleware from middleware.session import session_mw # ← generates session-id header from proxy.engine import router as proxy_router +from usage.factory import get_usage_backend # At the top, make the import conditional try: from middleware.quota import TokenQuotaMiddleware + QUOTA_AVAILABLE = True except ImportError: QUOTA_AVAILABLE = False @@ -55,7 +57,7 @@ async def get_memory_events(request: Request, limit: int = 10): result = ( client.query.get( "MemoryEvent", - ["timestamp", "event", "user", "state"] + ["timestamp", "event", "user", "state"], ) .with_additional(["id"]) .with_limit(limit) @@ -110,14 +112,16 @@ async def get_memory_events(request: Request, limit: int = 10): middlewares = [ # ❶ CORS first (so it executes last and handles responses properly) - Middleware(CORSMiddleware, - allow_origins=["http://localhost:9000", "http://127.0.0.1:9000"], - allow_methods=["*"], - allow_headers=["*"], - allow_credentials=True), + Middleware( + CORSMiddleware, + allow_origins=["http://localhost:9000", "http://127.0.0.1:9000"], + allow_methods=["*"], + allow_headers=["*"], + allow_credentials=True, + ), # ❷ Auth middleware Middleware(BaseHTTPMiddleware, dispatch=jwt_auth_mw), - # ❸ Session middleware + # ❸ Session middleware Middleware(BaseHTTPMiddleware, dispatch=session_mw), ] @@ -127,6 +131,8 @@ async def get_memory_events(request: Request, limit: int = 10): # Create app without middleware first app = FastAPI(title="attach-gateway", middleware=middlewares) +app.state.usage = get_usage_backend(os.getenv("USAGE_BACKEND", "null")) + @app.get("/auth/config") async def auth_config(): @@ -136,6 +142,7 @@ async def auth_config(): "audience": os.getenv("OIDC_AUD"), } + # Add middleware after routes are defined app.include_router(a2a_router, prefix="/a2a") app.include_router(logs_router) diff --git a/middleware/quota.py b/middleware/quota.py index ecd6553..30304d9 100644 --- a/middleware/quota.py +++ b/middleware/quota.py @@ -6,15 +6,22 @@ from __future__ import annotations +import logging import os +import sys import time from collections import deque from typing import Deque, Dict, Optional, Protocol, Tuple +from uuid import uuid4 from fastapi import Request from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import JSONResponse, StreamingResponse +from usage.backends import NullUsageBackend + +logger = logging.getLogger(__name__) + class AbstractMeterStore(Protocol): """Interface for token accounting backends.""" @@ -93,11 +100,19 @@ def __init__(self, app, store: Optional[AbstractMeterStore] = None) -> None: enc_name = os.getenv("QUOTA_ENCODING", "cl100k_base") try: import tiktoken - except Exception as imp_err: # pragma: no cover - import guard - raise RuntimeError( - "tiktoken is required for TokenQuotaMiddleware; install with 'attach-gateway[quota]'" - ) from imp_err - self.encoder = tiktoken.get_encoding(enc_name) + + self.encoder = tiktoken.get_encoding(enc_name) + except Exception: # pragma: no cover - fallback + if "tiktoken" not in sys.modules: + logger.warning( + "tiktoken missing – using byte-count fallback; token metrics may be inflated" + ) + + class _Simple: + def encode(self, text: str) -> list[int]: + return list(text.encode()) + + self.encoder = _Simple() @staticmethod def _is_textual(mime: str) -> bool: @@ -107,24 +122,39 @@ def _num_tokens(self, text: str) -> int: return len(self.encoder.encode(text)) async def dispatch(self, request: Request, call_next): + if not hasattr(request.app.state, "usage"): + request.app.state.usage = NullUsageBackend() user = request.headers.get("x-attach-user") if not user: user = request.client.host if request.client else "unknown" + usage = { + "user": user, + "project": request.headers.get("x-attach-project", "default"), + "tokens_in": 0, + "tokens_out": 0, + "model": "unknown", + "request_id": request.headers.get("x-request-id") or str(uuid4()), + } body = await request.body() content_type = request.headers.get("content-type", "") tokens = 0 if self._is_textual(content_type): tokens = self._num_tokens(body.decode("utf-8", "ignore")) + usage["tokens_in"] = tokens total, oldest = await self.store.increment(user, tokens) + request.state._usage = usage if total > self.max_tokens: retry_after = max(0, int(self.window - (time.time() - oldest))) + usage["ts"] = time.time() + await request.app.state.usage.record(**usage) return JSONResponse( {"detail": "token quota exceeded", "retry_after": retry_after}, status_code=429, ) response = await call_next(request) + usage["model"] = response.headers.get("x-llm-model", "unknown") first_chunk = None try: @@ -132,15 +162,20 @@ async def dispatch(self, request: Request, call_next): except StopAsyncIteration: pass - if first_chunk is not None and self._is_textual(response.media_type or ""): + media = response.media_type or response.headers.get("content-type", "") + if first_chunk is not None and self._is_textual(media): tokens_chunk = self._num_tokens(first_chunk.decode("utf-8", "ignore")) total, oldest = await self.store.increment(user, tokens_chunk) if total > self.max_tokens: retry_after = max(0, int(self.window - (time.time() - oldest))) + usage["tokens_out"] += tokens_chunk + usage["ts"] = time.time() + await request.app.state.usage.record(**usage) return JSONResponse( {"detail": "token quota exceeded", "retry_after": retry_after}, status_code=429, ) + usage["tokens_out"] += tokens_chunk async def stream_with_quota(): nonlocal total, oldest @@ -148,14 +183,20 @@ async def stream_with_quota(): yield first_chunk async for chunk in response.body_iterator: tokens_chunk = 0 - if self._is_textual(response.media_type or ""): + media = response.media_type or response.headers.get("content-type", "") + if self._is_textual(media): tokens_chunk = self._num_tokens(chunk.decode("utf-8", "ignore")) next_total, oldest = await self.store.increment(user, tokens_chunk) if next_total > self.max_tokens: break total = next_total + if tokens_chunk: + usage["tokens_out"] += tokens_chunk yield chunk + usage["ts"] = time.time() + await request.app.state.usage.record(**usage) + return StreamingResponse( stream_with_quota(), status_code=response.status_code, diff --git a/tests/test_usage_prometheus.py b/tests/test_usage_prometheus.py new file mode 100644 index 0000000..8609845 --- /dev/null +++ b/tests/test_usage_prometheus.py @@ -0,0 +1,35 @@ +import os + +import pytest +from fastapi import FastAPI, Request +from httpx import ASGITransport, AsyncClient + +from middleware.quota import TokenQuotaMiddleware +from usage.factory import get_usage_backend + + +@pytest.mark.asyncio +async def test_prometheus_backend_counts_tokens(monkeypatch): + os.environ["USAGE_BACKEND"] = "prometheus" + os.environ["MAX_TOKENS_PER_MIN"] = "1000" + app = FastAPI() + app.add_middleware(TokenQuotaMiddleware) + app.state.usage = get_usage_backend(os.getenv("USAGE_BACKEND", "null")) + + @app.post("/echo") + async def echo(request: Request): + data = await request.json() + return {"msg": data.get("msg")} + + transport = ASGITransport(app=app) + headers = {"X-Attach-User": "bob"} + async with AsyncClient(transport=transport, base_url="http://test") as client: + await client.post("/echo", json={"msg": "hi"}, headers=headers) + await client.post("/echo", json={"msg": "there"}, headers=headers) + + c = app.state.usage.counter + in_val = c.labels(user="bob", direction="in", model="unknown")._value.get() + out_val = c.labels(user="bob", direction="out", model="unknown")._value.get() + assert in_val > 0 + assert out_val > 0 + assert in_val + out_val == sum(c.values.values()) diff --git a/usage/__init__.py b/usage/__init__.py new file mode 100644 index 0000000..de197e1 --- /dev/null +++ b/usage/__init__.py @@ -0,0 +1,8 @@ +from __future__ import annotations + +"""Public API for Attach Gateway usage accounting.""" + +from .backends import AbstractUsageBackend +from .factory import get_usage_backend + +__all__ = ["AbstractUsageBackend", "get_usage_backend"] diff --git a/usage/backends.py b/usage/backends.py new file mode 100644 index 0000000..4731da9 --- /dev/null +++ b/usage/backends.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +"""Usage accounting backends for Attach Gateway.""" + +from typing import Protocol + +try: + from prometheus_client import Counter +except Exception: # pragma: no cover - optional dep + Counter = None # type: ignore + + class Counter: # type: ignore[misc] + """Minimal in-memory Counter fallback.""" + + def __init__(self, name: str, desc: str, labelnames: list[str]): + self.labelnames = labelnames + self.values: dict[tuple[str, ...], float] = {} + + def labels(self, **labels): + key = tuple(labels.get(name, "") for name in self.labelnames) + self.values.setdefault(key, 0.0) + + class _Wrapper: + def __init__(self, parent: Counter, k: tuple[str, ...]) -> None: + self.parent = parent + self.k = k + + def inc(self, amt: float) -> None: + self.parent.values[self.k] += amt + + @property + def _value(self): + class V: + def __init__(self, parent: Counter, k: tuple[str, ...]): + self.parent = parent + self.k = k + + def get(self) -> float: + return self.parent.values[self.k] + + return V(self.parent, self.k) + + return _Wrapper(self, key) + + +class AbstractUsageBackend(Protocol): + """Interface for usage event sinks.""" + + async def record(self, **evt) -> None: + """Persist a single usage event.""" + ... + + +class NullUsageBackend: + """No-op usage backend.""" + + async def record(self, **evt) -> None: # pragma: no cover - trivial + return + + +class PrometheusUsageBackend: + """Expose a Prometheus counter for token usage.""" + + def __init__(self) -> None: + if Counter is None: # pragma: no cover - missing lib + raise RuntimeError("prometheus_client is required for this backend") + self.counter = Counter( + "attach_usage_tokens_total", + "Total tokens processed by Attach Gateway", + ["user", "direction", "model"], + ) + + async def record(self, **evt) -> None: + user = evt.get("user", "unknown") + model = evt.get("model", "unknown") + tokens_in = int(evt.get("tokens_in", 0) or 0) + tokens_out = int(evt.get("tokens_out", 0) or 0) + self.counter.labels(user=user, direction="in", model=model).inc(tokens_in) + self.counter.labels(user=user, direction="out", model=model).inc(tokens_out) + + +class OpenMeterBackend: + """Stub for future OpenMeter integration.""" + + async def record(self, **evt) -> None: # pragma: no cover - not implemented + raise NotImplementedError diff --git a/usage/factory.py b/usage/factory.py new file mode 100644 index 0000000..b6dbb4c --- /dev/null +++ b/usage/factory.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +"""Factory for usage backends.""" + +from .backends import ( + AbstractUsageBackend, + NullUsageBackend, + OpenMeterBackend, + PrometheusUsageBackend, +) + + +def get_usage_backend(kind: str) -> AbstractUsageBackend: + """Return an instance of the requested usage backend.""" + kind = (kind or "null").lower() + if kind == "prometheus": + try: + return PrometheusUsageBackend() + except Exception: + return NullUsageBackend() + if kind == "openmeter": + return OpenMeterBackend() + return NullUsageBackend()