diff --git a/README.md b/README.md index 9d94594..48e8155 100644 --- a/README.md +++ b/README.md @@ -237,14 +237,15 @@ curl -X POST /v1/logs \ ## Usage hooks Emit token usage metrics for every request. Choose a backend via -`USAGE_BACKEND`: +`USAGE_METERING` (alias `USAGE_BACKEND`): ```bash -export USAGE_BACKEND=prometheus # or openmeter/null +export USAGE_METERING=prometheus # or null ``` A Prometheus counter `attach_usage_tokens_total{user,direction,model}` is exposed for Grafana dashboards. +Set `USAGE_METERING=null` (the default) to disable metering entirely. > **⚠️ Usage hooks depend on the quota middleware.** > Make sure `MAX_TOKENS_PER_MIN` is set (any positive number) so the @@ -254,9 +255,23 @@ exposed for Grafana dashboards. ```bash # Enable usage tracking (set any reasonable limit) export MAX_TOKENS_PER_MIN=60000 -export USAGE_BACKEND=prometheus +export USAGE_METERING=prometheus ``` +#### OpenMeter (Stripe / ClickHouse) + +```bash +pip install "attach-gateway[usage]" +export MAX_TOKENS_PER_MIN=60000 +export USAGE_METERING=openmeter +export OPENMETER_API_KEY=... +export OPENMETER_URL=http://localhost:8888 # optional self-host, defaults to https://openmeter.cloud +``` + +Events land in the tokens meter of OpenMeter and can sync to Stripe. + +The gateway runs fine without these vars; metering activates only when both USAGE_METERING=openmeter and OPENMETER_API_KEY are set. + ### Scraping metrics ```bash diff --git a/attach/gateway.py b/attach/gateway.py index d44d41d..2683e41 100644 --- a/attach/gateway.py +++ b/attach/gateway.py @@ -21,7 +21,7 @@ from middleware.quota import TokenQuotaMiddleware from middleware.session import session_mw from proxy.engine import router as proxy_router -from usage.factory import get_usage_backend +from usage.factory import _select_backend, get_usage_backend from usage.metrics import mount_metrics # Import version from parent package @@ -147,6 +147,7 @@ def create_app(config: Optional[AttachConfig] = None) -> FastAPI: memory_backend = get_memory_backend(config.mem_backend, config) app.state.memory = memory_backend app.state.config = config - app.state.usage = get_usage_backend(os.getenv("USAGE_BACKEND", "null")) + backend_selector = _select_backend() + app.state.usage = get_usage_backend(backend_selector) return app diff --git a/main.py b/main.py index e4f4d81..04236c1 100644 --- a/main.py +++ b/main.py @@ -13,7 +13,7 @@ from middleware.auth import jwt_auth_mw # ← your auth middleware from middleware.session import session_mw # ← generates session-id header from proxy.engine import router as proxy_router -from usage.factory import get_usage_backend +from usage.factory import _select_backend, get_usage_backend from usage.metrics import mount_metrics # At the top, make the import conditional @@ -132,7 +132,8 @@ async def get_memory_events(request: Request, limit: int = 10): # Create app without middleware first app = FastAPI(title="attach-gateway", middleware=middlewares) -app.state.usage = get_usage_backend(os.getenv("USAGE_BACKEND", "null")) +backend_selector = _select_backend() +app.state.usage = get_usage_backend(backend_selector) mount_metrics(app) diff --git a/pyproject.toml b/pyproject.toml index 52dc93a..f74f709 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,8 @@ memory = ["weaviate-client>=3.26.7,<4.0.0"] temporal = ["temporalio>=1.5.0"] quota = ["tiktoken>=0.5.0"] usage = [ - "prometheus_client>=0.20.0" + "prometheus_client>=0.20.0", + "openmeter>=1.0.0b188,<2", # openmeter>=1.0.0b188 requires Python >= 3.9 ] dev = [ "pytest>=8.0.0", diff --git a/tests/test_prometheus_fallback.py b/tests/test_prometheus_fallback.py index f1772aa..0536d0c 100644 --- a/tests/test_prometheus_fallback.py +++ b/tests/test_prometheus_fallback.py @@ -1,18 +1,19 @@ -import os import pytest -from usage.factory import get_usage_backend + +from usage.factory import get_usage_backend pytest.importorskip("prometheus_client") -def test_prometheus_backend_falls_back_to_null_when_unavailable(): - os.environ["USAGE_BACKEND"] = "prometheus" - + +def test_prometheus_backend_falls_back_to_null_when_unavailable(monkeypatch): + monkeypatch.setenv("USAGE_METERING", "prometheus") + # This should return NullUsageBackend when prometheus_client unavailable backend = get_usage_backend("prometheus") - - if hasattr(backend, 'counter'): + + if hasattr(backend, "counter"): # prometheus_client available - got PrometheusUsageBackend - assert backend.__class__.__name__ == 'PrometheusUsageBackend' + assert backend.__class__.__name__ == "PrometheusUsageBackend" else: # prometheus_client unavailable - got NullUsageBackend fallback - assert backend.__class__.__name__ == 'NullUsageBackend' \ No newline at end of file + assert backend.__class__.__name__ == "NullUsageBackend" diff --git a/tests/test_usage_openmeter.py b/tests/test_usage_openmeter.py new file mode 100644 index 0000000..171dd54 --- /dev/null +++ b/tests/test_usage_openmeter.py @@ -0,0 +1,66 @@ +import os +import sys +import types + +import pytest + + +class DummyEvents: + def __init__(self): + self.called = None + + async def create(self, **event): + self.called = event + + +class DummyClient: + def __init__(self, api_key: str, base_url: str = "https://openmeter.cloud") -> None: + self.api_key = api_key + self.base_url = base_url + self.events = DummyEvents() + + async def aclose(self) -> None: + pass + + +dummy_module = types.SimpleNamespace(Client=DummyClient) + + +@pytest.mark.asyncio +async def test_openmeter_backend_create(monkeypatch): + monkeypatch.setitem(sys.modules, "openmeter", dummy_module) + if "usage.backends" in sys.modules: + del sys.modules["usage.backends"] + if "usage.factory" in sys.modules: + del sys.modules["usage.factory"] + from usage.factory import get_usage_backend + + monkeypatch.setenv("OPENMETER_API_KEY", "k") + monkeypatch.setenv("OPENMETER_URL", "https://example.com") + + backend = get_usage_backend("openmeter") + assert backend.__class__.__name__ == "OpenMeterBackend" + await backend.record(user="bob", tokens_in=1, tokens_out=2, model="m") + await backend.aclose() + + called = backend.client.events.called + assert called["type"] == "tokens" + assert called["subject"] == "bob" + assert called["project"] is None + assert called["data"] == {"tokens_in": 1, "tokens_out": 2, "model": "m"} + assert "time" in called + + +def test_openmeter_backend_missing_key(monkeypatch): + monkeypatch.setitem(sys.modules, "openmeter", dummy_module) + if "usage.backends" in sys.modules: + del sys.modules["usage.backends"] + if "usage.factory" in sys.modules: + del sys.modules["usage.factory"] + from usage.factory import get_usage_backend + + monkeypatch.setenv("USAGE_METERING", "openmeter") + monkeypatch.delenv("OPENMETER_API_KEY", raising=False) + + backend = get_usage_backend("openmeter") + assert backend.__class__.__name__ == "NullUsageBackend" diff --git a/tests/test_usage_prometheus.py b/tests/test_usage_prometheus.py index ba1573a..d009a0f 100644 --- a/tests/test_usage_prometheus.py +++ b/tests/test_usage_prometheus.py @@ -1,14 +1,15 @@ import os import sys + import pytest from fastapi import FastAPI, Request from httpx import ASGITransport, AsyncClient # Force reload to pick up prometheus_client if just installed -if 'usage.backends' in sys.modules: - del sys.modules['usage.backends'] -if 'usage.factory' in sys.modules: - del sys.modules['usage.factory'] +if "usage.backends" in sys.modules: + del sys.modules["usage.backends"] +if "usage.factory" in sys.modules: + del sys.modules["usage.factory"] from middleware.quota import TokenQuotaMiddleware from usage.factory import get_usage_backend @@ -21,11 +22,11 @@ async def test_prometheus_backend_counts_tokens(monkeypatch): # Verify we actually get PrometheusUsageBackend backend = get_usage_backend("prometheus") - if not hasattr(backend, 'counter'): + if not hasattr(backend, "counter"): pytest.skip("prometheus_client not available, got NullUsageBackend") - - os.environ["USAGE_BACKEND"] = "prometheus" - os.environ["MAX_TOKENS_PER_MIN"] = "1000" + + monkeypatch.setenv("USAGE_METERING", "prometheus") + monkeypatch.setenv("MAX_TOKENS_PER_MIN", "1000") app = FastAPI() app.add_middleware(TokenQuotaMiddleware) app.state.usage = backend # Use the backend we verified diff --git a/usage/backends.py b/usage/backends.py index 4731da9..734c503 100644 --- a/usage/backends.py +++ b/usage/backends.py @@ -2,8 +2,14 @@ """Usage accounting backends for Attach Gateway.""" +import inspect +import logging +import os +from datetime import datetime, timezone from typing import Protocol +logger = logging.getLogger(__name__) + try: from prometheus_client import Counter except Exception: # pragma: no cover - optional dep @@ -80,7 +86,50 @@ async def record(self, **evt) -> None: class OpenMeterBackend: - """Stub for future OpenMeter integration.""" + """Send token usage events to OpenMeter.""" + + def __init__(self) -> None: + try: + from openmeter import Client # type: ignore + except Exception as exc: # pragma: no cover - optional dep + raise ImportError("openmeter package is required") from exc + + api_key = os.getenv("OPENMETER_API_KEY") + if not api_key: + raise ImportError("OPENMETER_API_KEY is required for OpenMeter") - async def record(self, **evt) -> None: # pragma: no cover - not implemented - raise NotImplementedError + url = os.getenv("OPENMETER_URL", "https://openmeter.cloud") + self.client = Client(api_key=api_key, base_url=url) + + async def aclose(self) -> None: + """Close the underlying OpenMeter client.""" + try: + await self.client.aclose() # type: ignore[call-arg] + except Exception: # pragma: no cover - optional + pass + + async def record(self, **evt) -> None: + event = { + "type": "tokens", + "subject": evt.get("user"), + "project": evt.get("project"), + "time": datetime.now(timezone.utc) + .isoformat(timespec="milliseconds") + .replace("+00:00", "Z"), + "data": { + "tokens_in": int(evt.get("tokens_in", 0) or 0), + "tokens_out": int(evt.get("tokens_out", 0) or 0), + "model": evt.get("model"), + }, + } + + create_fn = self.client.events.create + import anyio + + try: + if inspect.iscoroutinefunction(create_fn): + await create_fn(**event) # type: ignore[arg-type] + else: + await anyio.to_thread.run_sync(create_fn, **event) + except Exception as exc: # pragma: no cover - network errors + logger.warning("OpenMeter create failed: %s", exc) diff --git a/usage/factory.py b/usage/factory.py index b6dbb4c..3663db5 100644 --- a/usage/factory.py +++ b/usage/factory.py @@ -2,6 +2,9 @@ """Factory for usage backends.""" +import os +import warnings + from .backends import ( AbstractUsageBackend, NullUsageBackend, @@ -10,6 +13,19 @@ ) +def _select_backend() -> str: + """Return backend name from env vars with deprecation warning.""" + if "USAGE_METERING" in os.environ: + return os.getenv("USAGE_METERING", "null") + if "USAGE_BACKEND" in os.environ: + warnings.warn( + "USAGE_BACKEND is deprecated; rename to USAGE_METERING", + UserWarning, + stacklevel=2, + ) + return os.getenv("USAGE_BACKEND", "null") + + def get_usage_backend(kind: str) -> AbstractUsageBackend: """Return an instance of the requested usage backend.""" kind = (kind or "null").lower() @@ -19,5 +35,8 @@ def get_usage_backend(kind: str) -> AbstractUsageBackend: except Exception: return NullUsageBackend() if kind == "openmeter": - return OpenMeterBackend() + try: + return OpenMeterBackend() + except Exception: + return NullUsageBackend() return NullUsageBackend()