Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -237,14 +237,15 @@ curl -X POST /v1/logs \
## Usage hooks

Emit token usage metrics for every request. Choose a backend via
`USAGE_BACKEND`:
`USAGE_METERING` (alias `USAGE_BACKEND`):

```bash
export USAGE_BACKEND=prometheus # or openmeter/null
export USAGE_METERING=prometheus # or null
```

A Prometheus counter `attach_usage_tokens_total{user,direction,model}` is
exposed for Grafana dashboards.
Set `USAGE_METERING=null` (the default) to disable metering entirely.

> **⚠️ Usage hooks depend on the quota middleware.**
> Make sure `MAX_TOKENS_PER_MIN` is set (any positive number) so the
Expand All @@ -254,9 +255,23 @@ exposed for Grafana dashboards.
```bash
# Enable usage tracking (set any reasonable limit)
export MAX_TOKENS_PER_MIN=60000
export USAGE_BACKEND=prometheus
export USAGE_METERING=prometheus
```

#### OpenMeter (Stripe / ClickHouse)

```bash
pip install "attach-gateway[usage]"
export MAX_TOKENS_PER_MIN=60000
export USAGE_METERING=openmeter
export OPENMETER_API_KEY=...
export OPENMETER_URL=http://localhost:8888 # optional self-host, defaults to https://openmeter.cloud
```

Events land in the tokens meter of OpenMeter and can sync to Stripe.

The gateway runs fine without these vars; metering activates only when both USAGE_METERING=openmeter and OPENMETER_API_KEY are set.

### Scraping metrics

```bash
Expand Down
5 changes: 3 additions & 2 deletions attach/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from middleware.quota import TokenQuotaMiddleware
from middleware.session import session_mw
from proxy.engine import router as proxy_router
from usage.factory import get_usage_backend
from usage.factory import _select_backend, get_usage_backend
from usage.metrics import mount_metrics

# Import version from parent package
Expand Down Expand Up @@ -147,6 +147,7 @@ def create_app(config: Optional[AttachConfig] = None) -> FastAPI:
memory_backend = get_memory_backend(config.mem_backend, config)
app.state.memory = memory_backend
app.state.config = config
app.state.usage = get_usage_backend(os.getenv("USAGE_BACKEND", "null"))
backend_selector = _select_backend()
app.state.usage = get_usage_backend(backend_selector)

return app
5 changes: 3 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from middleware.auth import jwt_auth_mw # ← your auth middleware
from middleware.session import session_mw # ← generates session-id header
from proxy.engine import router as proxy_router
from usage.factory import get_usage_backend
from usage.factory import _select_backend, get_usage_backend
from usage.metrics import mount_metrics

# At the top, make the import conditional
Expand Down Expand Up @@ -132,7 +132,8 @@ async def get_memory_events(request: Request, limit: int = 10):

# Create app without middleware first
app = FastAPI(title="attach-gateway", middleware=middlewares)
app.state.usage = get_usage_backend(os.getenv("USAGE_BACKEND", "null"))
backend_selector = _select_backend()
app.state.usage = get_usage_backend(backend_selector)
mount_metrics(app)


Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ memory = ["weaviate-client>=3.26.7,<4.0.0"]
temporal = ["temporalio>=1.5.0"]
quota = ["tiktoken>=0.5.0"]
usage = [
"prometheus_client>=0.20.0"
"prometheus_client>=0.20.0",
"openmeter>=1.0.0b188,<2", # openmeter>=1.0.0b188 requires Python >= 3.9
]
dev = [
"pytest>=8.0.0",
Expand Down
19 changes: 10 additions & 9 deletions tests/test_prometheus_fallback.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
import os
import pytest
from usage.factory import get_usage_backend

from usage.factory import get_usage_backend

pytest.importorskip("prometheus_client")

def test_prometheus_backend_falls_back_to_null_when_unavailable():
os.environ["USAGE_BACKEND"] = "prometheus"


def test_prometheus_backend_falls_back_to_null_when_unavailable(monkeypatch):
monkeypatch.setenv("USAGE_METERING", "prometheus")

# This should return NullUsageBackend when prometheus_client unavailable
backend = get_usage_backend("prometheus")
if hasattr(backend, 'counter'):

if hasattr(backend, "counter"):
# prometheus_client available - got PrometheusUsageBackend
assert backend.__class__.__name__ == 'PrometheusUsageBackend'
assert backend.__class__.__name__ == "PrometheusUsageBackend"
else:
# prometheus_client unavailable - got NullUsageBackend fallback
assert backend.__class__.__name__ == 'NullUsageBackend'
assert backend.__class__.__name__ == "NullUsageBackend"
66 changes: 66 additions & 0 deletions tests/test_usage_openmeter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import os
import sys
import types

import pytest


class DummyEvents:
def __init__(self):
self.called = None

async def create(self, **event):
self.called = event


class DummyClient:
def __init__(self, api_key: str, base_url: str = "https://openmeter.cloud") -> None:
self.api_key = api_key
self.base_url = base_url
self.events = DummyEvents()

async def aclose(self) -> None:
pass


dummy_module = types.SimpleNamespace(Client=DummyClient)


@pytest.mark.asyncio
async def test_openmeter_backend_create(monkeypatch):
monkeypatch.setitem(sys.modules, "openmeter", dummy_module)
if "usage.backends" in sys.modules:
del sys.modules["usage.backends"]
if "usage.factory" in sys.modules:
del sys.modules["usage.factory"]
from usage.factory import get_usage_backend

monkeypatch.setenv("OPENMETER_API_KEY", "k")
monkeypatch.setenv("OPENMETER_URL", "https://example.com")

backend = get_usage_backend("openmeter")
assert backend.__class__.__name__ == "OpenMeterBackend"
await backend.record(user="bob", tokens_in=1, tokens_out=2, model="m")
await backend.aclose()

called = backend.client.events.called
assert called["type"] == "tokens"
assert called["subject"] == "bob"
assert called["project"] is None
assert called["data"] == {"tokens_in": 1, "tokens_out": 2, "model": "m"}
assert "time" in called


def test_openmeter_backend_missing_key(monkeypatch):
monkeypatch.setitem(sys.modules, "openmeter", dummy_module)
if "usage.backends" in sys.modules:
del sys.modules["usage.backends"]
if "usage.factory" in sys.modules:
del sys.modules["usage.factory"]
from usage.factory import get_usage_backend

monkeypatch.setenv("USAGE_METERING", "openmeter")
monkeypatch.delenv("OPENMETER_API_KEY", raising=False)

backend = get_usage_backend("openmeter")
assert backend.__class__.__name__ == "NullUsageBackend"
17 changes: 9 additions & 8 deletions tests/test_usage_prometheus.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import os
import sys

import pytest
from fastapi import FastAPI, Request
from httpx import ASGITransport, AsyncClient

# Force reload to pick up prometheus_client if just installed
if 'usage.backends' in sys.modules:
del sys.modules['usage.backends']
if 'usage.factory' in sys.modules:
del sys.modules['usage.factory']
if "usage.backends" in sys.modules:
del sys.modules["usage.backends"]
if "usage.factory" in sys.modules:
del sys.modules["usage.factory"]

from middleware.quota import TokenQuotaMiddleware
from usage.factory import get_usage_backend
Expand All @@ -21,11 +22,11 @@
async def test_prometheus_backend_counts_tokens(monkeypatch):
# Verify we actually get PrometheusUsageBackend
backend = get_usage_backend("prometheus")
if not hasattr(backend, 'counter'):
if not hasattr(backend, "counter"):
pytest.skip("prometheus_client not available, got NullUsageBackend")
os.environ["USAGE_BACKEND"] = "prometheus"
os.environ["MAX_TOKENS_PER_MIN"] = "1000"

monkeypatch.setenv("USAGE_METERING", "prometheus")
monkeypatch.setenv("MAX_TOKENS_PER_MIN", "1000")
app = FastAPI()
app.add_middleware(TokenQuotaMiddleware)
app.state.usage = backend # Use the backend we verified
Expand Down
55 changes: 52 additions & 3 deletions usage/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,14 @@

"""Usage accounting backends for Attach Gateway."""

import inspect
import logging
import os
from datetime import datetime, timezone
from typing import Protocol

logger = logging.getLogger(__name__)

try:
from prometheus_client import Counter
except Exception: # pragma: no cover - optional dep
Expand Down Expand Up @@ -80,7 +86,50 @@ async def record(self, **evt) -> None:


class OpenMeterBackend:
"""Stub for future OpenMeter integration."""
"""Send token usage events to OpenMeter."""

def __init__(self) -> None:
try:
from openmeter import Client # type: ignore
except Exception as exc: # pragma: no cover - optional dep
raise ImportError("openmeter package is required") from exc

api_key = os.getenv("OPENMETER_API_KEY")
if not api_key:
raise ImportError("OPENMETER_API_KEY is required for OpenMeter")

async def record(self, **evt) -> None: # pragma: no cover - not implemented
raise NotImplementedError
url = os.getenv("OPENMETER_URL", "https://openmeter.cloud")
self.client = Client(api_key=api_key, base_url=url)

async def aclose(self) -> None:
"""Close the underlying OpenMeter client."""
try:
await self.client.aclose() # type: ignore[call-arg]
except Exception: # pragma: no cover - optional
pass

async def record(self, **evt) -> None:
event = {
"type": "tokens",
"subject": evt.get("user"),
"project": evt.get("project"),
"time": datetime.now(timezone.utc)
.isoformat(timespec="milliseconds")
.replace("+00:00", "Z"),
"data": {
"tokens_in": int(evt.get("tokens_in", 0) or 0),
"tokens_out": int(evt.get("tokens_out", 0) or 0),
"model": evt.get("model"),
},
}

create_fn = self.client.events.create
import anyio

try:
if inspect.iscoroutinefunction(create_fn):
await create_fn(**event) # type: ignore[arg-type]
else:
await anyio.to_thread.run_sync(create_fn, **event)
except Exception as exc: # pragma: no cover - network errors
logger.warning("OpenMeter create failed: %s", exc)
21 changes: 20 additions & 1 deletion usage/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

"""Factory for usage backends."""

import os
import warnings

from .backends import (
AbstractUsageBackend,
NullUsageBackend,
Expand All @@ -10,6 +13,19 @@
)


def _select_backend() -> str:
"""Return backend name from env vars with deprecation warning."""
if "USAGE_METERING" in os.environ:
return os.getenv("USAGE_METERING", "null")
if "USAGE_BACKEND" in os.environ:
warnings.warn(
"USAGE_BACKEND is deprecated; rename to USAGE_METERING",
UserWarning,
stacklevel=2,
)
return os.getenv("USAGE_BACKEND", "null")


def get_usage_backend(kind: str) -> AbstractUsageBackend:
"""Return an instance of the requested usage backend."""
kind = (kind or "null").lower()
Expand All @@ -19,5 +35,8 @@ def get_usage_backend(kind: str) -> AbstractUsageBackend:
except Exception:
return NullUsageBackend()
if kind == "openmeter":
return OpenMeterBackend()
try:
return OpenMeterBackend()
except Exception:
return NullUsageBackend()
return NullUsageBackend()