diff --git a/.env.example b/.env.example index c6ef8ca..b821c1c 100644 --- a/.env.example +++ b/.env.example @@ -19,6 +19,9 @@ WEAVIATE_URL=http://localhost:8081 MAX_TOKENS_PER_MIN=60000 QUOTA_ENCODING=cl100k_base +# Metering Option (null, prometheus, openmeter) +USAGE_METERING=null + # Development: Auth0 credentials for dev_login script # AUTH0_DOMAIN=your-domain.auth0.com # AUTH0_CLIENT=your-client-id diff --git a/AGENTS.md b/AGENTS.md index 4398f82..f74fb94 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -13,3 +13,18 @@ This repo is used for the Attach Gateway service. Follow these guidelines for co ## Development Tools - Code should be formatted with `black` and imports sorted with `isort`. + +## πŸ”’ Memory & /mem/events are **read-only** + +> **Do not touch any memory-related code.** + +* **Off-limits files / symbols** + * `mem/**` + * `main.py` β†’ the `/mem/events` route and **all** `MemoryEvent` logic + * Any Weaviate queries, inserts, or schema + +* PRs that change, remove, or β€œrefactor” these areas **will be rejected**. + Only work on the explicitly assigned task (e.g. billing hooks). + +* If your change needs to interact with memory, open an issue first and wait + for maintainer approval. \ No newline at end of file diff --git a/README.md b/README.md index 8ed1f1e..d28ba46 100644 --- a/README.md +++ b/README.md @@ -234,23 +234,75 @@ curl -X POST /v1/logs \ # => HTTP/1.1 202 Accepted ``` +## Usage hooks + +Emit token usage metrics for every request. Choose a backend via +`USAGE_METERING` (alias `USAGE_BACKEND`): + +```bash +export USAGE_METERING=prometheus # or null +``` + +A Prometheus counter `attach_usage_tokens_total{user,direction,model}` is +exposed for Grafana dashboards. +Set `USAGE_METERING=null` (the default) to disable metering entirely. + +> **⚠️ Usage hooks depend on the quota middleware.** +> Make sure `MAX_TOKENS_PER_MIN` is set (any positive number) so the +> `TokenQuotaMiddleware` is enabled; the middleware is what records usage +> events that feed Prometheus. + +```bash +# Enable usage tracking (set any reasonable limit) +export MAX_TOKENS_PER_MIN=60000 +export USAGE_METERING=prometheus +``` + +#### OpenMeter (Stripe / ClickHouse) + +```bash +# No additional dependencies needed - uses direct HTTP API +export MAX_TOKENS_PER_MIN=60000 # Required: enables quota middleware +export USAGE_METERING=openmeter # Required: activates OpenMeter backend +export OPENMETER_API_KEY=your-api-key-here # Required: API authentication +export OPENMETER_URL=https://openmeter.cloud # Optional: defaults to https://openmeter.cloud +``` + +Events are sent directly to OpenMeter's HTTP API and are processed by the LLM tokens meter for billing integration with Stripe. + +> **⚠️ All three variables are required for OpenMeter to work:** +> - `MAX_TOKENS_PER_MIN` enables the quota middleware that records usage events +> - `USAGE_METERING=openmeter` activates the OpenMeter backend +> - `OPENMETER_API_KEY` provides authentication to OpenMeter's API + +The gateway gracefully falls back to `NullUsageBackend` if any required variable is missing. + +### Scraping metrics + +```bash +curl -H "Authorization: Bearer $JWT" http://localhost:8080/metrics +``` + ## Token quotas Attach Gateway can enforce per-user token limits. Install the optional -dependency with `pip install attach-gateway[quota]` and set +dependency with `pip install attach-dev[quota]` and set `MAX_TOKENS_PER_MIN` in your environment to enable the middleware. The counter defaults to the `cl100k_base` encoding; override with `QUOTA_ENCODING` if your model uses a different tokenizer. The default in-memory store works in a single process and is not shared between workersβ€”requests retried across processes may be double-counted. Use Redis for production deployments. +If `tiktoken` is missing, a byte-count fallback is used which counts about +four times more tokens than the `cl100k` tokenizer – install `tiktoken` in +production. ### Enable token quotas ```bash # Optional: Enable token quotas export MAX_TOKENS_PER_MIN=60000 -pip install tiktoken # or pip install attach-gateway[quota] +pip install tiktoken # or pip install attach-dev[quota] ``` To customize the tokenizer: @@ -258,12 +310,6 @@ To customize the tokenizer: export QUOTA_ENCODING=cl100k_base # default ``` -## Roadmap - -* **v0.2** β€” Protected‑resource metadata endpoint (OAuth 2.1), enhanced DID resolvers. -* **v0.3** β€” Token‑exchange (RFC 8693) for on‑behalf‑of delegation. -* **v0.4** β€” Attach Store v1 (Git‑style, policy guards). - --- ## License diff --git a/attach/__init__.py b/attach/__init__.py index 994d780..349db7c 100644 --- a/attach/__init__.py +++ b/attach/__init__.py @@ -4,11 +4,20 @@ Add OIDC SSO, agent-to-agent handoff, and pluggable memory to any Python project. """ -__version__ = "0.2.2" +__version__ = "0.3.7" __author__ = "Hammad Tariq" __email__ = "hammad@attach.dev" -# Clean imports - no sys.path hacks needed since everything will be in the wheel -from .gateway import create_app, AttachConfig +# Remove this line that causes early failure: +# from .gateway import create_app, AttachConfig + +# Optional: Add lazy import for convenience +def create_app(*args, **kwargs): + from .gateway import create_app as _real + return _real(*args, **kwargs) + +def AttachConfig(*args, **kwargs): + from .gateway import AttachConfig as _real + return _real(*args, **kwargs) __all__ = ["create_app", "AttachConfig", "__version__"] \ No newline at end of file diff --git a/attach/__main__.py b/attach/__main__.py index 122d419..99490f2 100644 --- a/attach/__main__.py +++ b/attach/__main__.py @@ -2,7 +2,7 @@ CLI entry point - replaces the need for main.py in wheel """ import uvicorn -from .gateway import create_app +import click def main(): """Run Attach Gateway server""" @@ -13,17 +13,42 @@ def main(): except ImportError: pass # python-dotenv not installed, that's OK for production - import click - @click.command() @click.option("--host", default="0.0.0.0", help="Host to bind to") @click.option("--port", default=8080, help="Port to bind to") @click.option("--reload", is_flag=True, help="Enable auto-reload") def cli(host: str, port: int, reload: bool): - app = create_app() - uvicorn.run(app, host=host, port=port, reload=reload) + try: + # Import here AFTER .env is loaded and CLI is parsed + from .gateway import create_app + app = create_app() + uvicorn.run(app, host=host, port=port, reload=reload) + except RuntimeError as e: + _friendly_exit(e) + except Exception as e: # unexpected crash + click.echo(f"❌ Startup failed: {e}", err=True) + raise click.Abort() cli() +def _friendly_exit(err): + """Convert RuntimeError to clean user message.""" + err_str = str(err) + + if "OPENMETER_API_KEY" in err_str: + msg = (f"❌ {err}\n\n" + "πŸ’‘ Fix:\n" + " export OPENMETER_API_KEY=\"sk_live_...\"\n" + " (or) export USAGE_METERING=null # to disable metering\n\n" + "πŸ“– See README.md for complete setup") + else: + msg = (f"❌ {err}\n\n" + "πŸ’‘ Required environment variables:\n" + " export OIDC_ISSUER=\"https://your-domain.auth0.com/\"\n" + " export OIDC_AUD=\"your-api-identifier\"\n\n" + "πŸ“– See README.md for complete setup instructions") + + raise click.ClickException(msg) + if __name__ == "__main__": main() \ No newline at end of file diff --git a/attach/gateway.py b/attach/gateway.py index 2434389..855ca9f 100644 --- a/attach/gateway.py +++ b/attach/gateway.py @@ -3,24 +3,33 @@ """ import os +from contextlib import asynccontextmanager from typing import Optional import weaviate from fastapi import APIRouter, FastAPI, HTTPException, Request +from fastapi.middleware.cors import CORSMiddleware +from starlette.middleware.base import BaseHTTPMiddleware from pydantic import BaseModel from a2a.routes import router as a2a_router - -# Clean relative imports -from auth import verify_jwt from auth.oidc import _require_env - -# from logs import router as logs_router +import logs +logs_router = logs.router from mem import get_memory_backend from middleware.auth import jwt_auth_mw -from middleware.quota import TokenQuotaMiddleware from middleware.session import session_mw from proxy.engine import router as proxy_router +from usage.factory import _select_backend, get_usage_backend +from usage.metrics import mount_metrics +from utils.env import int_env + +# Guard TokenQuotaMiddleware import (matches main.py pattern) +try: + from middleware.quota import TokenQuotaMiddleware + QUOTA_AVAILABLE = True +except ImportError: # optional extra not installed + QUOTA_AVAILABLE = False # Import version from parent package from . import __version__ @@ -49,7 +58,7 @@ async def get_memory_events(request: Request, limit: int = 10): return {"data": {"Get": {"MemoryEvent": []}}} result = ( - client.query.get("MemoryEvent", ["timestamp", "role", "content"]) + client.query.get("MemoryEvent", ["timestamp", "event", "user", "state"]) .with_additional(["id"]) .with_limit(limit) .with_sort([{"path": ["timestamp"], "order": "desc"}]) @@ -97,6 +106,21 @@ class AttachConfig(BaseModel): auth0_client: Optional[str] = None +@asynccontextmanager +async def lifespan(app: FastAPI): + """Manage application lifespan - startup and shutdown.""" + # Startup + backend_selector = _select_backend() + app.state.usage = get_usage_backend(backend_selector) + mount_metrics(app) + + yield + + # Shutdown + if hasattr(app.state.usage, 'aclose'): + await app.state.usage.aclose() + + def create_app(config: Optional[AttachConfig] = None) -> FastAPI: """ Create a FastAPI app with Attach Gateway functionality @@ -127,17 +151,38 @@ def create_app(config: Optional[AttachConfig] = None) -> FastAPI: title="Attach Gateway", description="Identity & Memory side-car for LLM engines", version=__version__, + lifespan=lifespan, + ) + + @app.get("/auth/config") + async def auth_config(): + return { + "domain": config.auth0_domain, + "client_id": config.auth0_client, + "audience": config.oidc_audience, + } + + # Add middleware in correct order (CORS outer-most) + app.add_middleware( + CORSMiddleware, + allow_origins=["http://localhost:9000", "http://127.0.0.1:9000"], + allow_methods=["*"], + allow_headers=["*"], + allow_credentials=True, ) + + # Only add quota middleware if available and explicitly configured + limit = int_env("MAX_TOKENS_PER_MIN", 60000) + if QUOTA_AVAILABLE and limit is not None: + app.add_middleware(TokenQuotaMiddleware) - # Add middleware - app.middleware("http")(jwt_auth_mw) - app.middleware("http")(session_mw) - app.add_middleware(TokenQuotaMiddleware) + app.add_middleware(BaseHTTPMiddleware, dispatch=jwt_auth_mw) + app.add_middleware(BaseHTTPMiddleware, dispatch=session_mw) # Add routes - app.include_router(a2a_router) + app.include_router(a2a_router, prefix="/a2a") app.include_router(proxy_router) - # app.include_router(logs_router) + app.include_router(logs_router) app.include_router(mem_router) # Setup memory backend diff --git a/logs.py b/logs/__init__.py similarity index 100% rename from logs.py rename to logs/__init__.py diff --git a/main.py b/main.py index f10e823..1a99d96 100644 --- a/main.py +++ b/main.py @@ -5,43 +5,40 @@ from fastapi.middleware import Middleware from fastapi.middleware.cors import CORSMiddleware from starlette.middleware.base import BaseHTTPMiddleware +from contextlib import asynccontextmanager from a2a.routes import router as a2a_router -from auth.oidc import verify_jwt # Fixed: was auth.jwt, now auth.oidc -from logs import router as logs_router -from mem import write as mem_write # Import memory write function -from middleware.auth import jwt_auth_mw # ← your auth middleware -from middleware.session import session_mw # ← generates session-id header +import logs +logs_router = logs.router +from middleware.auth import jwt_auth_mw +from middleware.session import session_mw from proxy.engine import router as proxy_router +from usage.factory import _select_backend, get_usage_backend +from usage.metrics import mount_metrics +from utils.env import int_env -# At the top, make the import conditional try: from middleware.quota import TokenQuotaMiddleware QUOTA_AVAILABLE = True except ImportError: QUOTA_AVAILABLE = False -# Memory router mem_router = APIRouter(prefix="/mem", tags=["memory"]) @mem_router.get("/events") async def get_memory_events(request: Request, limit: int = 10): - """Fetch recent MemoryEvent objects from Weaviate""" + """Fetch recent MemoryEvent objects from Weaviate.""" try: - # Get user info from request state (set by jwt_auth_mw middleware) user_sub = getattr(request.state, "sub", None) if not user_sub: raise HTTPException(status_code=401, detail="User not authenticated") - # Use the exact same client setup as demo_view_memory.py - client = weaviate.Client("http://localhost:6666") + client = weaviate.Client(os.getenv("WEAVIATE_URL", "http://localhost:6666")) - # Test connection first if not client.is_ready(): raise HTTPException(status_code=503, detail="Weaviate is not ready") - # Check schema the same way as demo_view_memory.py try: schema = client.schema.get() classes = {c["class"] for c in schema.get("classes", [])} @@ -51,11 +48,10 @@ async def get_memory_events(request: Request, limit: int = 10): except Exception: return {"data": {"Get": {"MemoryEvent": []}}} - # Query with descending order by timestamp (newest first) result = ( client.query.get( "MemoryEvent", - ["timestamp", "event", "user", "state"] + ["timestamp", "event", "user", "state"], ) .with_additional(["id"]) .with_limit(limit) @@ -63,7 +59,6 @@ async def get_memory_events(request: Request, limit: int = 10): .do() ) - # Check for GraphQL errors like demo_view_memory.py does if "errors" in result: raise HTTPException( status_code=500, detail=f"GraphQL error: {result['errors']}" @@ -74,31 +69,26 @@ async def get_memory_events(request: Request, limit: int = 10): events = result["data"]["Get"]["MemoryEvent"] - # Add the result field from the raw objects for richer display try: raw_objects = client.data_object.get(class_name="MemoryEvent", limit=limit) - # Create a mapping of IDs to full objects id_to_full_object = {} for obj in raw_objects.get("objects", []): obj_id = obj.get("id") if obj_id: id_to_full_object[obj_id] = obj.get("properties", {}) - # Enrich the GraphQL results with data from raw objects for event in events: event_id = event.get("_additional", {}).get("id") if event_id and event_id in id_to_full_object: full_props = id_to_full_object[event_id] - # Add the result field if it exists if "result" in full_props: event["result"] = full_props["result"] - # Add other useful fields for field in ["event", "session_id", "task_id", "user"]: if field in full_props: event[field] = full_props[field] except Exception: - pass # Silently fail if we can't enrich with raw object data + pass return result @@ -108,25 +98,36 @@ async def get_memory_events(request: Request, limit: int = 10): ) -middlewares = [ - # ❢ CORS first (so it executes last and handles responses properly) - Middleware(CORSMiddleware, - allow_origins=["http://localhost:9000", "http://127.0.0.1:9000"], - allow_methods=["*"], - allow_headers=["*"], - allow_credentials=True), - # ❷ Auth middleware - Middleware(BaseHTTPMiddleware, dispatch=jwt_auth_mw), - # ❸ Session middleware - Middleware(BaseHTTPMiddleware, dispatch=session_mw), -] - -# Only add quota middleware if tiktoken is available AND user configured it -if QUOTA_AVAILABLE and os.getenv("MAX_TOKENS_PER_MIN"): - middlewares.append(Middleware(TokenQuotaMiddleware)) - -# Create app without middleware first -app = FastAPI(title="attach-gateway", middleware=middlewares) +@asynccontextmanager +async def lifespan(app: FastAPI): + """Manage application lifespan - startup and shutdown.""" + backend_selector = _select_backend() + app.state.usage = get_usage_backend(backend_selector) + mount_metrics(app) + + yield + + if hasattr(app.state.usage, 'aclose'): + await app.state.usage.aclose() + +app = FastAPI(title="attach-gateway", lifespan=lifespan) + +# Add middleware in correct order (CORS outer-most) +app.add_middleware( + CORSMiddleware, + allow_origins=["http://localhost:9000", "http://127.0.0.1:9000"], + allow_methods=["*"], + allow_headers=["*"], + allow_credentials=True, +) + +# Only add quota middleware if available and explicitly configured +limit = int_env("MAX_TOKENS_PER_MIN", 60000) +if QUOTA_AVAILABLE and limit is not None: + app.add_middleware(TokenQuotaMiddleware) + +app.add_middleware(BaseHTTPMiddleware, dispatch=jwt_auth_mw) +app.add_middleware(BaseHTTPMiddleware, dispatch=session_mw) @app.get("/auth/config") async def auth_config(): @@ -136,8 +137,7 @@ async def auth_config(): "audience": os.getenv("OIDC_AUD"), } -# Add middleware after routes are defined app.include_router(a2a_router, prefix="/a2a") app.include_router(logs_router) app.include_router(mem_router) -app.include_router(proxy_router) # ← ADD THIS BACK +app.include_router(proxy_router) diff --git a/middleware/auth.py b/middleware/auth.py index e9fd448..01d7db3 100644 --- a/middleware/auth.py +++ b/middleware/auth.py @@ -29,8 +29,12 @@ async def jwt_auth_mw(request: Request, call_next): β€’ Verifies it with `auth.oidc.verify_jwt`. β€’ Stores the `sub` claim in `request.state.sub` for downstream middleware. β€’ Rejects the request with 401 on any failure. - β€’ Skips authentication for excluded paths. + β€’ Skips authentication for excluded paths and OPTIONS requests. """ + # Skip authentication for OPTIONS requests (CORS preflight) + if request.method == "OPTIONS": + return await call_next(request) + # Skip authentication for excluded paths if request.url.path in EXCLUDED_PATHS: return await call_next(request) @@ -48,6 +52,4 @@ async def jwt_auth_mw(request: Request, call_next): # attach the user id (sub) for the session-middleware request.state.sub = claims["sub"] - - # continue down the middleware stack / route handler return await call_next(request) diff --git a/middleware/quota.py b/middleware/quota.py index ecd6553..15af934 100644 --- a/middleware/quota.py +++ b/middleware/quota.py @@ -1,34 +1,60 @@ -"""Token quota middleware for Attach Gateway. - -Enforces per-user token limits using a sliding 1-minute window. The -quota is applied to both the request body and the response body. -""" - from __future__ import annotations +"""Token quota middleware enforcing a sliding window budget.""" + +import json +import logging import os import time from collections import deque -from typing import Deque, Dict, Optional, Protocol, Tuple +from typing import ( + AsyncIterator, + Awaitable, + Callable, + Deque, + Dict, + Iterable, + Protocol, + Tuple, +) +from uuid import uuid4 from fastapi import Request from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import JSONResponse, StreamingResponse +# ───────────────────────────────────────────────────────── +# Tokenizer setup +# ───────────────────────────────────────────────────────── +try: + import tiktoken # type: ignore +except Exception: # pragma: no cover + tiktoken = None + +# β‰ˆ cl100k_base: ~4 bytes / token for typical English +_APPROX_BYTES_PER_TOKEN = 4 + +from usage.backends import NullUsageBackend +from utils.env import int_env + +logger = logging.getLogger(__name__) + class AbstractMeterStore(Protocol): - """Interface for token accounting backends.""" + """Token accounting backend.""" async def increment(self, user: str, tokens: int) -> Tuple[int, float]: - """Return the running total and timestamp of the oldest entry.""" + """Increment ``user``'s counter and return ``(total, oldest_ts)``.""" + async def adjust(self, user: str, delta: int) -> Tuple[int, float]: + """Atomically adjust ``user``'s total by ``delta``.""" + + async def peek_total(self, user: str) -> int: + """Return ``user``'s current total without mutating state.""" -class InMemoryMeterStore: - """Simple in-memory sliding window counter. - Not safe for multi-process deployments; each process keeps its own - counters. Use :class:`RedisMeterStore` in production. - """ +class InMemoryMeterStore: + """Simple in-memory meter for tests and development.""" def __init__(self, window: int = 60) -> None: self.window = window @@ -45,11 +71,21 @@ async def increment(self, user: str, tokens: int) -> Tuple[int, float]: oldest = dq[0][0] if dq else now return total, oldest + async def adjust(self, user: str, delta: int) -> Tuple[int, float]: + return await self.increment(user, delta) + + async def peek_total(self, user: str) -> int: + total, _ = await self.increment(user, 0) + dq = self._data.get(user) + if total and dq: + dq.pop() + return total + class RedisMeterStore: - """Redis backed sliding window counter.""" + """Redis backed sliding window meter.""" - def __init__(self, url: str = "redis://localhost:6379", window: int = 60) -> None: + def __init__(self, url: str, window: int = 60) -> None: import redis.asyncio as redis # type: ignore self.window = window @@ -67,98 +103,407 @@ async def increment(self, user: str, tokens: int) -> Tuple[int, float]: entries = results[-1] total = 0 oldest = now - for member, ts in entries: + for m, ts in entries: + try: + _, tok = m.split(":", 1) + total += int(tok) + except Exception: + pass + oldest = min(oldest, ts) + return total, oldest + + async def adjust(self, user: str, delta: int) -> Tuple[int, float]: + now = time.time() + key = f"attach:quota:{user}" + member = f"{now}:{delta}" + async with self.redis.pipeline(transaction=True) as pipe: + await pipe.zadd(key, {member: now}) + await pipe.zremrangebyscore(key, 0, now - self.window) + await pipe.zrange(key, 0, -1, withscores=True) + results = await pipe.execute() + entries = results[-1] + total = 0 + oldest = now + for m, ts in entries: try: - _, tok = member.split(":", 1) + _, tok = m.split(":", 1) total += int(tok) except Exception: pass oldest = min(oldest, ts) return total, oldest + async def peek_total(self, user: str) -> int: + now = time.time() + key = f"attach:quota:{user}" + member = f"{now}:0:{uuid4().hex}" + async with self.redis.pipeline(transaction=True) as pipe: + await pipe.zadd(key, {member: now}) + await pipe.zremrangebyscore(key, 0, now - self.window) + await pipe.zrange(key, 0, -1, withscores=True) + pipe.zrem(key, member) + results = await pipe.execute() + entries = results[2] + total = 0 + for m, _ in entries: + try: + _, tok = m.split(":", 1) + total += int(tok) + except Exception: + pass + return total -class TokenQuotaMiddleware(BaseHTTPMiddleware): - """FastAPI middleware enforcing per-user token quotas.""" - def __init__(self, app, store: Optional[AbstractMeterStore] = None) -> None: - """Create middleware. +def _is_textual(mime: str) -> bool: + mime = (mime or "").lower() + if not mime or mime == "*/*": + return False + return mime.startswith("text/") or "json" in mime - Must be added **after** :func:`session_mw` so the ``X-Attach-User`` - header or client IP is available for quota tracking. - """ - super().__init__(app) - self.store = store or InMemoryMeterStore() - self.window = 60 - self.max_tokens = int(os.getenv("MAX_TOKENS_PER_MIN", "60000")) - enc_name = os.getenv("QUOTA_ENCODING", "cl100k_base") + +# --------------------------------------------------------------------------- +# Token-count helpers +# --------------------------------------------------------------------------- + +def _encoder_for_model(model: str): + """Return a tiktoken encoder, falling back to byte count.""" + if tiktoken is None: # fallback: 1 token β‰ˆ 4 bytes + + class _Approx: + def encode(self, text: str) -> list[int]: + # Never return 0 β†’ always count at least 1 token + return [0] * max(1, len(text) // _APPROX_BYTES_PER_TOKEN) + + return _Approx() + + try: + return tiktoken.encoding_for_model(model) + except Exception: try: - import tiktoken - except Exception as imp_err: # pragma: no cover - import guard - raise RuntimeError( - "tiktoken is required for TokenQuotaMiddleware; install with 'attach-gateway[quota]'" - ) from imp_err - self.encoder = tiktoken.get_encoding(enc_name) + return tiktoken.get_encoding("cl100k_base") + except Exception: + + class _Approx: + def encode(self, text: str) -> list[int]: + return [0] * max(1, len(text) // _APPROX_BYTES_PER_TOKEN) + + return _Approx() - @staticmethod - def _is_textual(mime: str) -> bool: - return mime.startswith("text/") or "json" in mime - def _num_tokens(self, text: str) -> int: - return len(self.encoder.encode(text)) +def _num_tokens(text: str, model: str = "cl100k_base") -> int: + return len(_encoder_for_model(model).encode(text)) + + +def num_tokens_from_messages(messages: Iterable[dict], model: str) -> int: + enc = _encoder_for_model(model) + total = 3 + for msg in messages: + total += 4 + for k, v in msg.items(): + total += len(enc.encode(str(v))) + if k == "name": + total -= 1 + return total + + +async def async_iter(data: Iterable[bytes]) -> AsyncIterator[bytes]: + for chunk in data: + yield chunk + + +_SKIP_PATHS = { + "/metrics", + "/mem/events", + "/auth/config", + "/health", + "/docs", + "/redoc", + "/openapi.json", +} + + +class TokenQuotaMiddleware(BaseHTTPMiddleware): + """Apply per-user LLM token quotas.""" + + def __init__(self, app, store: AbstractMeterStore | None = None) -> None: + super().__init__(app) + self.window = int(os.getenv("WINDOW", "60")) + self.max_tokens: int | None = int_env("MAX_TOKENS_PER_MIN", 60000) + if store is not None: + self.store = store + else: + redis_url = os.getenv("REDIS_URL") + self.store = ( + RedisMeterStore(redis_url, self.window) + if redis_url + else InMemoryMeterStore(self.window) + ) + + class _Streamer: + def __init__( + self, + iterator: AsyncIterator[bytes], + *, + user: str, + store: AbstractMeterStore, + max_tokens: int | None, + is_textual: bool, + ) -> None: + self.iterator = iterator + self.tail = bytearray() + self.on_complete: Callable[[], Awaitable[None]] | None = None + self.user = user + self.store = store + self.limit = max_tokens + self.is_textual = is_textual + self.quota_exceeded = False + + def __aiter__(self) -> AsyncIterator[bytes]: + return self._gen() + + async def _gen(self) -> AsyncIterator[bytes]: + try: + async for chunk in self.iterator: + self.tail.extend(chunk) + if len(self.tail) > 8192: + del self.tail[:-8192] + if self.limit: + chunk_tokens = ( + _num_tokens(chunk.decode("utf-8", "ignore")) + if self.is_textual + else 0 + ) + if chunk_tokens: + total, _ = await self.store.adjust(self.user, chunk_tokens) + if total > self.limit: + await self.store.adjust(self.user, -chunk_tokens) + self.quota_exceeded = True + logger.warning( + "User %s quota breached mid-stream", self.user + ) + return + yield chunk + finally: + if self.on_complete: + await self.on_complete() + + def get_tail(self) -> bytes: + return bytes(self.tail) + + class _BufferedResponse(StreamingResponse): + def __init__( + self, streamer: "TokenQuotaMiddleware._Streamer", **kwargs + ) -> None: + super().__init__(streamer, **kwargs) + self._streamer = streamer + + async def stream_response(self, send): + chunks: list[bytes] = [] + async for chunk in self._streamer: + chunks.append(chunk) + if getattr(self._streamer, "quota_exceeded", False): + retry_after = max( + 0, + int( + self._streamer.window + - (time.time() - getattr(self._streamer, "oldest", 0)) + ), + ) + self.status_code = 429 + self.raw_headers = [ + (b"content-type", b"application/json"), + (b"retry-after", str(retry_after).encode()), + ] + iterator = async_iter( + [ + json.dumps( + { + "detail": "token quota exceeded", + "retry_after": retry_after, + } + ).encode() + ] + ) + else: + iterator = async_iter(chunks) + await send( + { + "type": "http.response.start", + "status": self.status_code, + "headers": self.raw_headers, + } + ) + async for chunk in iterator: + if not isinstance(chunk, (bytes, memoryview)): + chunk = chunk.encode(self.charset) + await send( + {"type": "http.response.body", "body": chunk, "more_body": True} + ) + await send({"type": "http.response.body", "body": b"", "more_body": False}) async def dispatch(self, request: Request, call_next): - user = request.headers.get("x-attach-user") - if not user: - user = request.client.host if request.client else "unknown" - - body = await request.body() - content_type = request.headers.get("content-type", "") - tokens = 0 - if self._is_textual(content_type): - tokens = self._num_tokens(body.decode("utf-8", "ignore")) - total, oldest = await self.store.increment(user, tokens) - if total > self.max_tokens: + if any(request.url.path.startswith(p) for p in _SKIP_PATHS): + return await call_next(request) + + if not hasattr(request.app.state, "usage"): + request.app.state.usage = NullUsageBackend() + + # ── OPTIONAL request-size guard (default 1 MB) ─────────────── + max_bytes = int(os.getenv("MAX_REQUEST_BYTES", "1000000")) + raw = await request.body() + if len(raw) > max_bytes: + return JSONResponse( + { + "detail": "request too large", + "limit_bytes": max_bytes, + }, + status_code=413, + ) + + # Re-use the already-read body from here on + request._body = raw + + # Fix: Get user from request.state.sub (set by auth middleware) + user = getattr(request.state, "sub", None) or ( + request.client.host if request.client else "unknown" + ) + + usage = { + "user": user, + "project": request.headers.get("x-attach-project", "default"), + "tokens_in": 0, + "tokens_out": 0, + "model": "unknown", + "request_id": request.headers.get("x-request-id") or str(uuid4()), + } + req_is_text = _is_textual(request.headers.get("content-type", "")) + + tokens_in = 0 + if req_is_text: + # raw already read above + try: + payload = json.loads(raw.decode()) + except Exception: + payload = None + if isinstance(payload, dict) and "messages" in payload: + model = payload.get("model", "cl100k_base") + usage["model"] = model + tokens_in = num_tokens_from_messages(payload.get("messages", []), model) + else: + tokens_in = _num_tokens(raw.decode("utf-8", "ignore")) + + usage["tokens_in"] = tokens_in + + total, oldest = await self.store.adjust(user, tokens_in) + if self.max_tokens is not None and total > self.max_tokens: + await self.store.adjust(user, -tokens_in) retry_after = max(0, int(self.window - (time.time() - oldest))) + usage["ts"] = time.time() + await request.app.state.usage.record(**usage) return JSONResponse( {"detail": "token quota exceeded", "retry_after": retry_after}, status_code=429, ) - response = await call_next(request) + resp = await call_next(request) - first_chunk = None - try: - first_chunk = await response.body_iterator.__anext__() - except StopAsyncIteration: - pass - - if first_chunk is not None and self._is_textual(response.media_type or ""): - tokens_chunk = self._num_tokens(first_chunk.decode("utf-8", "ignore")) - total, oldest = await self.store.increment(user, tokens_chunk) - if total > self.max_tokens: - retry_after = max(0, int(self.window - (time.time() - oldest))) - return JSONResponse( - {"detail": "token quota exceeded", "retry_after": retry_after}, - status_code=429, - ) + media = resp.media_type or resp.headers.get("content-type", "") + resp_is_text = _is_textual(media) + streamer = self._Streamer( + resp.body_iterator, + user=user, + store=self.store, + max_tokens=self.max_tokens, + is_textual=resp_is_text, + ) + streamer.window = self.window + streamer.oldest = oldest - async def stream_with_quota(): - nonlocal total, oldest - if first_chunk is not None: - yield first_chunk - async for chunk in response.body_iterator: - tokens_chunk = 0 - if self._is_textual(response.media_type or ""): - tokens_chunk = self._num_tokens(chunk.decode("utf-8", "ignore")) - next_total, oldest = await self.store.increment(user, tokens_chunk) - if next_total > self.max_tokens: - break - total = next_total - yield chunk - - return StreamingResponse( - stream_with_quota(), - status_code=response.status_code, - headers=dict(response.headers), - media_type=response.media_type, + headers = dict(resp.headers) + headers.pop("content-length", None) + response = self._BufferedResponse( + streamer, + status_code=resp.status_code, + headers=headers, + media_type=resp.media_type, ) + + async def finalize() -> None: + nonlocal tokens_in, oldest + if getattr(streamer, "quota_exceeded", False): + usage["detail"] = "token quota exceeded mid-stream" + retry_after = max(0, int(self.window - (time.time() - oldest))) + usage["ts"] = time.time() + await request.app.state.usage.record(**usage) + return + tokens_out = 0 + model = usage.get("model", "unknown") + parsed: dict | None = None + if resp_is_text: + # -- Robustly extract the last JSON object ----------------- + text_tail = streamer.get_tail().decode("utf-8", "ignore") + # 1. Split SSE frames if present: keep only the part after the final "data: " + if "data:" in text_tail: + *_, last_frame = text_tail.strip().split("data:") + text_tail = last_frame.strip() + # 2. Strip the trailing '[DONE]' token if it exists + if text_tail.endswith("[DONE]"): + text_tail = text_tail[: text_tail.rfind("[DONE]")].rstrip() + # 3. Find the first '{' from the *left* (because the frame has been cleaned) + brace = text_tail.find("{") + parsed = None + if brace != -1: + try: + parsed = json.loads(text_tail[brace:]) + except Exception: + parsed = None + if isinstance(parsed, dict): + model = parsed.get("model", model) + if "usage" in parsed: + u = parsed.get("usage") or {} + tokens_out = int(u.get("completion_tokens", 0)) + prompt_tokens = int(u.get("prompt_tokens", tokens_in)) + delta_prompt = prompt_tokens - tokens_in # adjust quota window + tokens_in = prompt_tokens # ← canonical value + usage["tokens_in"] = tokens_in + # Replace provisional count with canonical one in the meter + await self.store.adjust(user, delta_prompt) + elif "choices" in parsed: + msgs = [ + (c.get("message") or {}).get("content", "") + for c in parsed.get("choices", []) + ] + tokens_out = num_tokens_from_messages( + [{"content": m} for m in msgs], model + ) + if tokens_out == 0: + tokens_out = _num_tokens(text_tail) + + usage["tokens_out"] = tokens_out + usage["model"] = model + + # All *out* tokens are new – add them once. + await self.store.adjust(user, tokens_out) + + total = await self.store.peek_total(user) + if self.max_tokens is not None and total > self.max_tokens: + retry_after = self.window # simple worst-case + response.status_code = 429 + response.headers["Retry-After"] = str(retry_after) + usage["detail"] = "token quota exceeded post-stream" + response.headers["content-type"] = "application/json" + + response.headers.update( + { + "x-llm-model": model, + "x-tokens-in": str(tokens_in), + "x-tokens-out": str(tokens_out), + } + ) + + usage["ts"] = time.time() + await request.app.state.usage.record(**usage) + logger.info(json.dumps(usage)) + + streamer.on_complete = finalize + return response diff --git a/pyproject.toml b/pyproject.toml index a4ea8f6..35daaa7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,12 +31,16 @@ dependencies = [ "python-jose[cryptography]>=3.3.0", "click>=8.0.0", "python-dotenv>=1.0.0", + "weaviate-client>=3.26.7,<4.0.0", ] [project.optional-dependencies] -memory = ["weaviate-client>=3.26.7,<4.0.0"] -temporal = ["temporalio>=1.5.0"] quota = ["tiktoken>=0.5.0"] +usage = ["prometheus_client>=0.20.0"] +full = [ + "tiktoken>=0.5.0", + "prometheus_client>=0.20.0" +] dev = [ "pytest>=8.0.0", "pytest-asyncio>=0.23.0", @@ -53,7 +57,7 @@ attach-gateway = "attach.__main__:main" # Include your existing modules with cleaner targeting [tool.hatch.build.targets.wheel] -packages = ["attach", "auth", "middleware", "mem", "proxy", "a2a", "attach_pydid"] +packages = ["attach", "auth", "middleware", "mem", "proxy", "a2a", "attach_pydid", "usage", "utils", "logs"] # Dynamic version from attach/__init__.py diff --git a/tests/test_metrics_endpoint.py b/tests/test_metrics_endpoint.py new file mode 100644 index 0000000..45af1a5 --- /dev/null +++ b/tests/test_metrics_endpoint.py @@ -0,0 +1,48 @@ +import os + +import pytest +from fastapi import FastAPI, Request +from httpx import ASGITransport, AsyncClient + +from middleware.quota import TokenQuotaMiddleware +from usage.factory import get_usage_backend +from usage.metrics import mount_metrics + + +def setup_module(): + # βœ… Fixed: Use new variable name + os.environ["USAGE_METERING"] = "prometheus" + + +@pytest.fixture +def app(): + app = FastAPI() + mount_metrics(app) + # βœ… Fixed: Use new variable name + app.state.usage = get_usage_backend(os.getenv("USAGE_METERING", "null")) + return app + + +@pytest.mark.asyncio +async def test_metrics_endpoint(monkeypatch): + os.environ["MAX_TOKENS_PER_MIN"] = "1000" + + app = FastAPI() + mount_metrics(app) + + app.add_middleware(TokenQuotaMiddleware) + app.state.usage = get_usage_backend(os.getenv("USAGE_METERING", "null")) + + @app.post("/echo") + async def echo(request: Request): + data = await request.json() + return {"msg": data.get("msg")} + + transport = ASGITransport(app=app) + headers = {"X-Attach-User": "bob"} + async with AsyncClient(transport=transport, base_url="http://test") as client: + await client.post("/echo", json={"msg": "hi"}, headers=headers) + resp = await client.get("/metrics") + + assert "attach_usage_tokens_total" in resp.text + assert "bob" in resp.text diff --git a/tests/test_prometheus_fallback.py b/tests/test_prometheus_fallback.py new file mode 100644 index 0000000..0536d0c --- /dev/null +++ b/tests/test_prometheus_fallback.py @@ -0,0 +1,19 @@ +import pytest + +from usage.factory import get_usage_backend + +pytest.importorskip("prometheus_client") + + +def test_prometheus_backend_falls_back_to_null_when_unavailable(monkeypatch): + monkeypatch.setenv("USAGE_METERING", "prometheus") + + # This should return NullUsageBackend when prometheus_client unavailable + backend = get_usage_backend("prometheus") + + if hasattr(backend, "counter"): + # prometheus_client available - got PrometheusUsageBackend + assert backend.__class__.__name__ == "PrometheusUsageBackend" + else: + # prometheus_client unavailable - got NullUsageBackend fallback + assert backend.__class__.__name__ == "NullUsageBackend" diff --git a/tests/test_token_quota_middleware.py b/tests/test_token_quota_middleware.py index 00a638f..487b41b 100644 --- a/tests/test_token_quota_middleware.py +++ b/tests/test_token_quota_middleware.py @@ -6,7 +6,9 @@ pytest.importorskip("tiktoken") -from middleware.quota import TokenQuotaMiddleware +from starlette.responses import StreamingResponse + +from middleware.quota import InMemoryMeterStore, TokenQuotaMiddleware @pytest.mark.asyncio @@ -44,3 +46,29 @@ async def echo(request: Request): resp = await client.post("/echo", json={"msg": "hello"}, headers=headers) assert resp.status_code == 429 assert "retry_after" in resp.json() + + +@pytest.mark.asyncio +async def test_midstream_over_limit_rolls_back(monkeypatch): + monkeypatch.setenv("MAX_TOKENS_PER_MIN", "5") + store = InMemoryMeterStore() + app = FastAPI() + app.add_middleware(TokenQuotaMiddleware, store=store) + + @app.get("/stream") + async def stream(): + async def gen(): + yield b"hi" + yield b"aaaaaaa" + + return StreamingResponse(gen(), media_type="text/plain") + + headers = {"X-Attach-User": "carol"} + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.get("/stream", headers=headers) + + assert resp.status_code == 429 + assert resp.json()["detail"] == "token quota exceeded" + total = await store.peek_total("carol") + assert total == 2 diff --git a/tests/test_usage_openmeter.py b/tests/test_usage_openmeter.py new file mode 100644 index 0000000..171dd54 --- /dev/null +++ b/tests/test_usage_openmeter.py @@ -0,0 +1,66 @@ +import os +import sys +import types + +import pytest + + +class DummyEvents: + def __init__(self): + self.called = None + + async def create(self, **event): + self.called = event + + +class DummyClient: + def __init__(self, api_key: str, base_url: str = "https://openmeter.cloud") -> None: + self.api_key = api_key + self.base_url = base_url + self.events = DummyEvents() + + async def aclose(self) -> None: + pass + + +dummy_module = types.SimpleNamespace(Client=DummyClient) + + +@pytest.mark.asyncio +async def test_openmeter_backend_create(monkeypatch): + monkeypatch.setitem(sys.modules, "openmeter", dummy_module) + if "usage.backends" in sys.modules: + del sys.modules["usage.backends"] + if "usage.factory" in sys.modules: + del sys.modules["usage.factory"] + from usage.factory import get_usage_backend + + monkeypatch.setenv("OPENMETER_API_KEY", "k") + monkeypatch.setenv("OPENMETER_URL", "https://example.com") + + backend = get_usage_backend("openmeter") + assert backend.__class__.__name__ == "OpenMeterBackend" + await backend.record(user="bob", tokens_in=1, tokens_out=2, model="m") + await backend.aclose() + + called = backend.client.events.called + assert called["type"] == "tokens" + assert called["subject"] == "bob" + assert called["project"] is None + assert called["data"] == {"tokens_in": 1, "tokens_out": 2, "model": "m"} + assert "time" in called + + +def test_openmeter_backend_missing_key(monkeypatch): + monkeypatch.setitem(sys.modules, "openmeter", dummy_module) + if "usage.backends" in sys.modules: + del sys.modules["usage.backends"] + if "usage.factory" in sys.modules: + del sys.modules["usage.factory"] + from usage.factory import get_usage_backend + + monkeypatch.setenv("USAGE_METERING", "openmeter") + monkeypatch.delenv("OPENMETER_API_KEY", raising=False) + + backend = get_usage_backend("openmeter") + assert backend.__class__.__name__ == "NullUsageBackend" diff --git a/tests/test_usage_prometheus.py b/tests/test_usage_prometheus.py new file mode 100644 index 0000000..d009a0f --- /dev/null +++ b/tests/test_usage_prometheus.py @@ -0,0 +1,51 @@ +import os +import sys + +import pytest +from fastapi import FastAPI, Request +from httpx import ASGITransport, AsyncClient + +# Force reload to pick up prometheus_client if just installed +if "usage.backends" in sys.modules: + del sys.modules["usage.backends"] +if "usage.factory" in sys.modules: + del sys.modules["usage.factory"] + +from middleware.quota import TokenQuotaMiddleware +from usage.factory import get_usage_backend + +# At the top, skip test if prometheus_client not available +pytest.importorskip("prometheus_client") # ← Skip entire test if not installed + + +@pytest.mark.asyncio +async def test_prometheus_backend_counts_tokens(monkeypatch): + # Verify we actually get PrometheusUsageBackend + backend = get_usage_backend("prometheus") + if not hasattr(backend, "counter"): + pytest.skip("prometheus_client not available, got NullUsageBackend") + + monkeypatch.setenv("USAGE_METERING", "prometheus") + monkeypatch.setenv("MAX_TOKENS_PER_MIN", "1000") + app = FastAPI() + app.add_middleware(TokenQuotaMiddleware) + app.state.usage = backend # Use the backend we verified + + @app.post("/echo") + async def echo(request: Request): + data = await request.json() + return {"msg": data.get("msg")} + + transport = ASGITransport(app=app) + headers = {"X-Attach-User": "bob"} + async with AsyncClient(transport=transport, base_url="http://test") as client: + await client.post("/echo", json={"msg": "hi"}, headers=headers) + await client.post("/echo", json={"msg": "there"}, headers=headers) + + c = app.state.usage.counter + in_val = c.labels(user="bob", direction="in", model="unknown")._value.get() + out_val = c.labels(user="bob", direction="out", model="unknown")._value.get() + assert in_val > 0 + assert out_val > 0 + # Removed: assert in_val + out_val == sum(c.values.values()) + # The 'values' attribute only exists in the fallback Counter, not the real Prometheus Counter diff --git a/usage/__init__.py b/usage/__init__.py new file mode 100644 index 0000000..de197e1 --- /dev/null +++ b/usage/__init__.py @@ -0,0 +1,8 @@ +from __future__ import annotations + +"""Public API for Attach Gateway usage accounting.""" + +from .backends import AbstractUsageBackend +from .factory import get_usage_backend + +__all__ = ["AbstractUsageBackend", "get_usage_backend"] diff --git a/usage/backends.py b/usage/backends.py new file mode 100644 index 0000000..f589ce3 --- /dev/null +++ b/usage/backends.py @@ -0,0 +1,175 @@ +from __future__ import annotations + +"""Usage accounting backends for Attach Gateway.""" + +import inspect +import logging +import os +from datetime import datetime, timezone +from typing import Protocol + +logger = logging.getLogger(__name__) + +try: + from prometheus_client import Counter +except Exception: # pragma: no cover - optional dep + Counter = None # type: ignore + + class Counter: # type: ignore[misc] + """Minimal in-memory Counter fallback.""" + + def __init__(self, name: str, desc: str, labelnames: list[str]): + self.labelnames = labelnames + self.values: dict[tuple[str, ...], float] = {} + + def labels(self, **labels): + key = tuple(labels.get(name, "") for name in self.labelnames) + self.values.setdefault(key, 0.0) + + class _Wrapper: + def __init__(self, parent: Counter, k: tuple[str, ...]) -> None: + self.parent = parent + self.k = k + + def inc(self, amt: float) -> None: + self.parent.values[self.k] += amt + + @property + def _value(self): + class V: + def __init__(self, parent: Counter, k: tuple[str, ...]): + self.parent = parent + self.k = k + + def get(self) -> float: + return self.parent.values[self.k] + + return V(self.parent, self.k) + + return _Wrapper(self, key) + + +class AbstractUsageBackend(Protocol): + """Interface for usage event sinks.""" + + async def record(self, **evt) -> None: + """Persist a single usage event.""" + ... + + +class NullUsageBackend: + """No-op usage backend.""" + + async def record(self, **evt) -> None: # pragma: no cover - trivial + return + + +class PrometheusUsageBackend: + """Expose a Prometheus counter for token usage.""" + + def __init__(self) -> None: + if Counter is None: # pragma: no cover - missing lib + raise RuntimeError("prometheus_client is required for this backend") + self.counter = Counter( + "attach_usage_tokens_total", + "Total tokens processed by Attach Gateway", + ["user", "direction", "model"], + ) + + async def record(self, **evt) -> None: + user = evt.get("user", "unknown") + model = evt.get("model", "unknown") + tokens_in = int(evt.get("tokens_in", 0) or 0) + tokens_out = int(evt.get("tokens_out", 0) or 0) + self.counter.labels(user=user, direction="in", model=model).inc(tokens_in) + self.counter.labels(user=user, direction="out", model=model).inc(tokens_out) + + +class OpenMeterBackend: + """Send token usage events to OpenMeter.""" + + def __init__(self) -> None: + api_key = os.getenv("OPENMETER_API_KEY") + if not api_key: + raise ImportError("OPENMETER_API_KEY is required for OpenMeter") + + self.api_key = api_key + self.base_url = os.getenv("OPENMETER_URL", "https://openmeter.cloud") + + # Use httpx instead of buggy OpenMeter SDK + try: + import httpx + self.client = httpx.AsyncClient( + timeout=30.0 + ) + except ImportError as exc: + raise ImportError("httpx is required for OpenMeter") from exc + + async def aclose(self) -> None: + """Close the underlying HTTP client.""" + if hasattr(self.client, 'aclose'): + await self.client.aclose() + + async def record(self, **evt) -> None: + try: + from uuid import uuid4 + except ImportError as exc: + return + + base_time = datetime.now(timezone.utc).isoformat(timespec="milliseconds").replace("+00:00", "Z") + user = evt.get("user") + model = evt.get("model") + + tokens_in = int(evt.get("tokens_in", 0) or 0) + tokens_out = int(evt.get("tokens_out", 0) or 0) + + # Send separate events for input and output tokens + events_to_send = [] + + if tokens_in > 0: + events_to_send.append({ + "specversion": "1.0", + "type": "prompt", # ← Changed from "tokens" to "prompt" + "id": str(uuid4()), + "time": base_time, + "source": "attach-gateway", + "subject": user, + "data": { + "tokens": tokens_in, + "model": model, + "type": "input" # ← This stays the same + } + }) + + if tokens_out > 0: + events_to_send.append({ + "specversion": "1.0", + "type": "prompt", + "id": str(uuid4()), + "time": base_time, + "source": "attach-gateway", + "subject": user, + "data": { + "tokens": tokens_out, # ← Single tokens field + "model": model, + "type": "output" # ← Add type field + } + }) + + # Send each event + for event in events_to_send: + try: + response = await self.client.post( + f"{self.base_url}/api/v1/events", + json=event, + headers={ + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/cloudevents+json" + } + ) + + if response.status_code not in [200, 201, 202, 204]: + logger.warning(f"OpenMeter error: {response.status_code}") + + except Exception as exc: + logger.warning("OpenMeter request failed: %s", exc) diff --git a/usage/factory.py b/usage/factory.py new file mode 100644 index 0000000..61ed842 --- /dev/null +++ b/usage/factory.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +"""Factory for usage backends.""" + +import os +import warnings +import logging + +from .backends import ( + AbstractUsageBackend, + NullUsageBackend, + OpenMeterBackend, + PrometheusUsageBackend, +) + +log = logging.getLogger(__name__) + + +def _select_backend() -> str: + """Return backend name from env vars with deprecation warning.""" + if "USAGE_METERING" in os.environ: + return os.getenv("USAGE_METERING", "null") + if "USAGE_BACKEND" in os.environ: # old name, keep BC + warnings.warn( + "USAGE_BACKEND is deprecated; use USAGE_METERING", + UserWarning, + stacklevel=2, + ) + return os.getenv("USAGE_BACKEND", "null") + + +def get_usage_backend(kind: str) -> AbstractUsageBackend: + """Return an instance of the requested usage backend.""" + kind = (kind or "null").lower() + + if kind == "prometheus": + try: + return PrometheusUsageBackend() + except ImportError as exc: + log.warning( + "Prometheus metering unavailable: %s – " + "falling back to NullUsageBackend. " + "Install with: pip install 'attach-dev[usage]'", + exc + ) + return NullUsageBackend() + + if kind == "openmeter": + # fail-fast on bad config + if not os.getenv("OPENMETER_API_KEY"): + raise RuntimeError( + "USAGE_METERING=openmeter requires OPENMETER_API_KEY. " + "Set the variable or change USAGE_METERING=null to disable." + ) + return OpenMeterBackend() # exceptions inside bubble up + + return NullUsageBackend() diff --git a/usage/metrics.py b/usage/metrics.py new file mode 100644 index 0000000..3e5d027 --- /dev/null +++ b/usage/metrics.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +"""Utilities for exposing Attach Gateway usage metrics.""" + +from fastapi import FastAPI +from fastapi.responses import PlainTextResponse, Response + + +def mount_metrics(app: FastAPI) -> None: + """Attach a Prometheus-compatible ``/metrics`` route to ``app``.""" + + @app.get("/metrics", include_in_schema=False) + async def metrics() -> Response: # noqa: D401 + usage = getattr(app.state, "usage", None) + if usage is None: + return PlainTextResponse("# No usage backend configured\n") + try: + from prometheus_client import CONTENT_TYPE_LATEST, REGISTRY, generate_latest + except ImportError: + counter = getattr(usage, "counter", None) + if counter is None or not hasattr(counter, "values"): + return PlainTextResponse("# No metrics available\n") + lines = [ + "# HELP attach_usage_tokens_total Total tokens processed by Attach Gateway", + "# TYPE attach_usage_tokens_total counter", + ] + for (u, d, m), v in counter.values.items(): + lines.append( + f'attach_usage_tokens_total{{user="{u}",direction="{d}",model="{m}"}} {v}' + ) + return PlainTextResponse("\n".join(lines) + "\n") + else: + if not hasattr(usage, "counter"): + return PlainTextResponse("# No usage counter\n") + return PlainTextResponse( + generate_latest(REGISTRY), media_type=CONTENT_TYPE_LATEST + ) diff --git a/utils/env.py b/utils/env.py new file mode 100644 index 0000000..909534c --- /dev/null +++ b/utils/env.py @@ -0,0 +1,25 @@ +"""Env helpers used across entry-points.""" + +import logging +import os + +log = logging.getLogger(__name__) + + +def int_env(var: str, default: int | None = None) -> int | None: + """Read $VAR as positive int. + β€’ '', 'null', 'none', 'false', 'infinite' -> None + β€’ invalid / non-positive -> default + """ + val = os.getenv(var) + if val is None: + return default + val = val.strip().lower() + if val in {"", "null", "none", "false", "infinite"}: + return None + try: + num = int(val) + return num if num > 0 else default + except ValueError: + log.warning("⚠️ %s=%s is not a valid int; using default=%s", var, val, default) + return default