Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,18 @@ curl -X POST /v1/logs \
# => HTTP/1.1 202 Accepted
```

## Usage hooks

Emit token usage metrics for every request. Choose a backend via
`USAGE_BACKEND`:

```bash
export USAGE_BACKEND=prometheus # or openmeter/null
```

A Prometheus counter `attach_usage_tokens_total{user,direction,model}` is
exposed for Grafana dashboards.

## Token quotas

Attach Gateway can enforce per-user token limits. Install the optional
Expand All @@ -244,6 +256,9 @@ counter defaults to the `cl100k_base` encoding; override with
in-memory store works in a single process and is not shared between
workers—requests retried across processes may be double-counted. Use Redis
for production deployments.
If `tiktoken` is missing, a byte-count fallback is used which counts about
four times more tokens than the `cl100k` tokenizer – install `tiktoken` in
production.

### Enable token quotas

Expand Down
2 changes: 2 additions & 0 deletions attach/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from middleware.quota import TokenQuotaMiddleware
from middleware.session import session_mw
from proxy.engine import router as proxy_router
from usage.factory import get_usage_backend

# Import version from parent package
from . import __version__
Expand Down Expand Up @@ -144,5 +145,6 @@ def create_app(config: Optional[AttachConfig] = None) -> FastAPI:
memory_backend = get_memory_backend(config.mem_backend, config)
app.state.memory = memory_backend
app.state.config = config
app.state.usage = get_usage_backend(os.getenv("USAGE_BACKEND", "null"))

return app
21 changes: 14 additions & 7 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
from middleware.auth import jwt_auth_mw # ← your auth middleware
from middleware.session import session_mw # ← generates session-id header
from proxy.engine import router as proxy_router
from usage.factory import get_usage_backend

# At the top, make the import conditional
try:
from middleware.quota import TokenQuotaMiddleware

QUOTA_AVAILABLE = True
except ImportError:
QUOTA_AVAILABLE = False
Expand Down Expand Up @@ -55,7 +57,7 @@ async def get_memory_events(request: Request, limit: int = 10):
result = (
client.query.get(
"MemoryEvent",
["timestamp", "event", "user", "state"]
["timestamp", "event", "user", "state"],
)
.with_additional(["id"])
.with_limit(limit)
Expand Down Expand Up @@ -110,14 +112,16 @@ async def get_memory_events(request: Request, limit: int = 10):

middlewares = [
# ❶ CORS first (so it executes last and handles responses properly)
Middleware(CORSMiddleware,
allow_origins=["http://localhost:9000", "http://127.0.0.1:9000"],
allow_methods=["*"],
allow_headers=["*"],
allow_credentials=True),
Middleware(
CORSMiddleware,
allow_origins=["http://localhost:9000", "http://127.0.0.1:9000"],
allow_methods=["*"],
allow_headers=["*"],
allow_credentials=True,
),
# ❷ Auth middleware
Middleware(BaseHTTPMiddleware, dispatch=jwt_auth_mw),
# ❸ Session middleware
# ❸ Session middleware
Middleware(BaseHTTPMiddleware, dispatch=session_mw),
]

Expand All @@ -127,6 +131,8 @@ async def get_memory_events(request: Request, limit: int = 10):

# Create app without middleware first
app = FastAPI(title="attach-gateway", middleware=middlewares)
app.state.usage = get_usage_backend(os.getenv("USAGE_BACKEND", "null"))


@app.get("/auth/config")
async def auth_config():
Expand All @@ -136,6 +142,7 @@ async def auth_config():
"audience": os.getenv("OIDC_AUD"),
}


# Add middleware after routes are defined
app.include_router(a2a_router, prefix="/a2a")
app.include_router(logs_router)
Expand Down
55 changes: 48 additions & 7 deletions middleware/quota.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,22 @@

from __future__ import annotations

import logging
import os
import sys
import time
from collections import deque
from typing import Deque, Dict, Optional, Protocol, Tuple
from uuid import uuid4

from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse, StreamingResponse

from usage.backends import NullUsageBackend

logger = logging.getLogger(__name__)


class AbstractMeterStore(Protocol):
"""Interface for token accounting backends."""
Expand Down Expand Up @@ -93,11 +100,19 @@ def __init__(self, app, store: Optional[AbstractMeterStore] = None) -> None:
enc_name = os.getenv("QUOTA_ENCODING", "cl100k_base")
try:
import tiktoken
except Exception as imp_err: # pragma: no cover - import guard
raise RuntimeError(
"tiktoken is required for TokenQuotaMiddleware; install with 'attach-gateway[quota]'"
) from imp_err
self.encoder = tiktoken.get_encoding(enc_name)

self.encoder = tiktoken.get_encoding(enc_name)
except Exception: # pragma: no cover - fallback
if "tiktoken" not in sys.modules:
logger.warning(
"tiktoken missing – using byte-count fallback; token metrics may be inflated"
)

class _Simple:
def encode(self, text: str) -> list[int]:
return list(text.encode())

self.encoder = _Simple()

@staticmethod
def _is_textual(mime: str) -> bool:
Expand All @@ -107,55 +122,81 @@ def _num_tokens(self, text: str) -> int:
return len(self.encoder.encode(text))

async def dispatch(self, request: Request, call_next):
if not hasattr(request.app.state, "usage"):
request.app.state.usage = NullUsageBackend()
user = request.headers.get("x-attach-user")
if not user:
user = request.client.host if request.client else "unknown"
usage = {
"user": user,
"project": request.headers.get("x-attach-project", "default"),
"tokens_in": 0,
"tokens_out": 0,
"model": "unknown",
"request_id": request.headers.get("x-request-id") or str(uuid4()),
}

body = await request.body()
content_type = request.headers.get("content-type", "")
tokens = 0
if self._is_textual(content_type):
tokens = self._num_tokens(body.decode("utf-8", "ignore"))
usage["tokens_in"] = tokens
total, oldest = await self.store.increment(user, tokens)
request.state._usage = usage
if total > self.max_tokens:
retry_after = max(0, int(self.window - (time.time() - oldest)))
usage["ts"] = time.time()
await request.app.state.usage.record(**usage)
return JSONResponse(
{"detail": "token quota exceeded", "retry_after": retry_after},
status_code=429,
)

response = await call_next(request)
usage["model"] = response.headers.get("x-llm-model", "unknown")

first_chunk = None
try:
first_chunk = await response.body_iterator.__anext__()
except StopAsyncIteration:
pass

if first_chunk is not None and self._is_textual(response.media_type or ""):
media = response.media_type or response.headers.get("content-type", "")
if first_chunk is not None and self._is_textual(media):
tokens_chunk = self._num_tokens(first_chunk.decode("utf-8", "ignore"))
total, oldest = await self.store.increment(user, tokens_chunk)
if total > self.max_tokens:
retry_after = max(0, int(self.window - (time.time() - oldest)))
usage["tokens_out"] += tokens_chunk
usage["ts"] = time.time()
await request.app.state.usage.record(**usage)
return JSONResponse(
{"detail": "token quota exceeded", "retry_after": retry_after},
status_code=429,
)
usage["tokens_out"] += tokens_chunk

async def stream_with_quota():
nonlocal total, oldest
if first_chunk is not None:
yield first_chunk
async for chunk in response.body_iterator:
tokens_chunk = 0
if self._is_textual(response.media_type or ""):
media = response.media_type or response.headers.get("content-type", "")
if self._is_textual(media):
tokens_chunk = self._num_tokens(chunk.decode("utf-8", "ignore"))
next_total, oldest = await self.store.increment(user, tokens_chunk)
if next_total > self.max_tokens:
break
total = next_total
if tokens_chunk:
usage["tokens_out"] += tokens_chunk
yield chunk

usage["ts"] = time.time()
await request.app.state.usage.record(**usage)

return StreamingResponse(
stream_with_quota(),
status_code=response.status_code,
Expand Down
35 changes: 35 additions & 0 deletions tests/test_usage_prometheus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import os

import pytest
from fastapi import FastAPI, Request
from httpx import ASGITransport, AsyncClient

from middleware.quota import TokenQuotaMiddleware
from usage.factory import get_usage_backend


@pytest.mark.asyncio
async def test_prometheus_backend_counts_tokens(monkeypatch):
os.environ["USAGE_BACKEND"] = "prometheus"
os.environ["MAX_TOKENS_PER_MIN"] = "1000"
app = FastAPI()
app.add_middleware(TokenQuotaMiddleware)
app.state.usage = get_usage_backend(os.getenv("USAGE_BACKEND", "null"))

@app.post("/echo")
async def echo(request: Request):
data = await request.json()
return {"msg": data.get("msg")}

transport = ASGITransport(app=app)
headers = {"X-Attach-User": "bob"}
async with AsyncClient(transport=transport, base_url="http://test") as client:
await client.post("/echo", json={"msg": "hi"}, headers=headers)
await client.post("/echo", json={"msg": "there"}, headers=headers)

c = app.state.usage.counter
in_val = c.labels(user="bob", direction="in", model="unknown")._value.get()
out_val = c.labels(user="bob", direction="out", model="unknown")._value.get()
assert in_val > 0
assert out_val > 0
assert in_val + out_val == sum(c.values.values())
8 changes: 8 additions & 0 deletions usage/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from __future__ import annotations

"""Public API for Attach Gateway usage accounting."""

from .backends import AbstractUsageBackend
from .factory import get_usage_backend

__all__ = ["AbstractUsageBackend", "get_usage_backend"]
86 changes: 86 additions & 0 deletions usage/backends.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from __future__ import annotations

"""Usage accounting backends for Attach Gateway."""

from typing import Protocol

try:
from prometheus_client import Counter
except Exception: # pragma: no cover - optional dep
Counter = None # type: ignore

class Counter: # type: ignore[misc]
"""Minimal in-memory Counter fallback."""

def __init__(self, name: str, desc: str, labelnames: list[str]):
self.labelnames = labelnames
self.values: dict[tuple[str, ...], float] = {}

def labels(self, **labels):
key = tuple(labels.get(name, "") for name in self.labelnames)
self.values.setdefault(key, 0.0)

class _Wrapper:
def __init__(self, parent: Counter, k: tuple[str, ...]) -> None:
self.parent = parent
self.k = k

def inc(self, amt: float) -> None:
self.parent.values[self.k] += amt

@property
def _value(self):
class V:
def __init__(self, parent: Counter, k: tuple[str, ...]):
self.parent = parent
self.k = k

def get(self) -> float:
return self.parent.values[self.k]

return V(self.parent, self.k)

return _Wrapper(self, key)


class AbstractUsageBackend(Protocol):
"""Interface for usage event sinks."""

async def record(self, **evt) -> None:
"""Persist a single usage event."""
...


class NullUsageBackend:
"""No-op usage backend."""

async def record(self, **evt) -> None: # pragma: no cover - trivial
return


class PrometheusUsageBackend:
"""Expose a Prometheus counter for token usage."""

def __init__(self) -> None:
if Counter is None: # pragma: no cover - missing lib
raise RuntimeError("prometheus_client is required for this backend")
self.counter = Counter(
"attach_usage_tokens_total",
"Total tokens processed by Attach Gateway",
["user", "direction", "model"],
)

async def record(self, **evt) -> None:
user = evt.get("user", "unknown")
model = evt.get("model", "unknown")
tokens_in = int(evt.get("tokens_in", 0) or 0)
tokens_out = int(evt.get("tokens_out", 0) or 0)
self.counter.labels(user=user, direction="in", model=model).inc(tokens_in)
self.counter.labels(user=user, direction="out", model=model).inc(tokens_out)


class OpenMeterBackend:
"""Stub for future OpenMeter integration."""

async def record(self, **evt) -> None: # pragma: no cover - not implemented
raise NotImplementedError
23 changes: 23 additions & 0 deletions usage/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from __future__ import annotations

"""Factory for usage backends."""

from .backends import (
AbstractUsageBackend,
NullUsageBackend,
OpenMeterBackend,
PrometheusUsageBackend,
)


def get_usage_backend(kind: str) -> AbstractUsageBackend:
"""Return an instance of the requested usage backend."""
kind = (kind or "null").lower()
if kind == "prometheus":
try:
return PrometheusUsageBackend()
except Exception:
return NullUsageBackend()
if kind == "openmeter":
return OpenMeterBackend()
return NullUsageBackend()