From 34427e2dc5c961a834dc1e75fa22906dfa7354f9 Mon Sep 17 00:00:00 2001 From: Hammad Tariq Date: Wed, 23 Jul 2025 15:49:59 -0700 Subject: [PATCH 1/2] gateway now mirrors main --- attach/gateway.py | 70 +++++++++++++++++++++++++++++++++++------------ 1 file changed, 53 insertions(+), 17 deletions(-) diff --git a/attach/gateway.py b/attach/gateway.py index ce27650..d0c1dff 100644 --- a/attach/gateway.py +++ b/attach/gateway.py @@ -3,28 +3,33 @@ """ import os +from contextlib import asynccontextmanager from typing import Optional import weaviate from fastapi import APIRouter, FastAPI, HTTPException, Request +from fastapi.middleware.cors import CORSMiddleware +from starlette.middleware.base import BaseHTTPMiddleware from pydantic import BaseModel from a2a.routes import router as a2a_router - -# Clean relative imports -from auth import verify_jwt from auth.oidc import _require_env - -# from logs import router as logs_router +from logs import router as logs_router from mem import get_memory_backend from middleware.auth import jwt_auth_mw -from middleware.quota import TokenQuotaMiddleware from middleware.session import session_mw from proxy.engine import router as proxy_router from usage.factory import _select_backend, get_usage_backend from usage.metrics import mount_metrics from utils.env import int_env +# Guard TokenQuotaMiddleware import (matches main.py pattern) +try: + from middleware.quota import TokenQuotaMiddleware + QUOTA_AVAILABLE = True +except ImportError: # optional extra not installed + QUOTA_AVAILABLE = False + # Import version from parent package from . import __version__ @@ -52,7 +57,7 @@ async def get_memory_events(request: Request, limit: int = 10): return {"data": {"Get": {"MemoryEvent": []}}} result = ( - client.query.get("MemoryEvent", ["timestamp", "role", "content"]) + client.query.get("MemoryEvent", ["timestamp", "event", "user", "state"]) .with_additional(["id"]) .with_limit(limit) .with_sort([{"path": ["timestamp"], "order": "desc"}]) @@ -100,6 +105,21 @@ class AttachConfig(BaseModel): auth0_client: Optional[str] = None +@asynccontextmanager +async def lifespan(app: FastAPI): + """Manage application lifespan - startup and shutdown.""" + # Startup + backend_selector = _select_backend() + app.state.usage = get_usage_backend(backend_selector) + mount_metrics(app) + + yield + + # Shutdown + if hasattr(app.state.usage, 'aclose'): + await app.state.usage.aclose() + + def create_app(config: Optional[AttachConfig] = None) -> FastAPI: """ Create a FastAPI app with Attach Gateway functionality @@ -130,27 +150,43 @@ def create_app(config: Optional[AttachConfig] = None) -> FastAPI: title="Attach Gateway", description="Identity & Memory side-car for LLM engines", version=__version__, + lifespan=lifespan, ) - mount_metrics(app) - # Add middleware - app.middleware("http")(jwt_auth_mw) - app.middleware("http")(session_mw) - limit = int_env("MAX_TOKENS_PER_MIN", 60000) - if limit is not None: + @app.get("/auth/config") + async def auth_config(): + return { + "domain": config.auth0_domain, + "client_id": config.auth0_client, + "audience": config.oidc_audience, + } + + # Add middleware in correct order (CORS outer-most) + app.add_middleware( + CORSMiddleware, + allow_origins=["http://localhost:9000", "http://127.0.0.1:9000"], + allow_methods=["*"], + allow_headers=["*"], + allow_credentials=True, + ) + + app.add_middleware(BaseHTTPMiddleware, dispatch=jwt_auth_mw) + app.add_middleware(BaseHTTPMiddleware, dispatch=session_mw) + + # Only add quota middleware if available and explicitly configured + limit = int_env("MAX_TOKENS_PER_MIN", None) + if QUOTA_AVAILABLE and limit is not None: app.add_middleware(TokenQuotaMiddleware) # Add routes - app.include_router(a2a_router) + app.include_router(a2a_router, prefix="/a2a") app.include_router(proxy_router) - # app.include_router(logs_router) + app.include_router(logs_router) app.include_router(mem_router) # Setup memory backend memory_backend = get_memory_backend(config.mem_backend, config) app.state.memory = memory_backend app.state.config = config - backend_selector = _select_backend() - app.state.usage = get_usage_backend(backend_selector) return app From 79e4e471405b096ea38057de365b2fc3197e70ba Mon Sep 17 00:00:00 2001 From: Hammad Tariq Date: Thu, 24 Jul 2025 17:15:14 -0700 Subject: [PATCH 2/2] header bugs squashed --- .env.example | 3 +- attach/gateway.py | 2 +- main.py | 71 +++++++++++++++------------------------------ middleware/quota.py | 4 ++- 4 files changed, 29 insertions(+), 51 deletions(-) diff --git a/.env.example b/.env.example index 6019812..b821c1c 100644 --- a/.env.example +++ b/.env.example @@ -19,7 +19,8 @@ WEAVIATE_URL=http://localhost:8081 MAX_TOKENS_PER_MIN=60000 QUOTA_ENCODING=cl100k_base -USAGE_METERING=null +# Metering Option (null, prometheus, openmeter) +USAGE_METERING=null # Development: Auth0 credentials for dev_login script # AUTH0_DOMAIN=your-domain.auth0.com diff --git a/attach/gateway.py b/attach/gateway.py index d0c1dff..ba286d9 100644 --- a/attach/gateway.py +++ b/attach/gateway.py @@ -174,7 +174,7 @@ async def auth_config(): app.add_middleware(BaseHTTPMiddleware, dispatch=session_mw) # Only add quota middleware if available and explicitly configured - limit = int_env("MAX_TOKENS_PER_MIN", None) + limit = int_env("MAX_TOKENS_PER_MIN", 60000) if QUOTA_AVAILABLE and limit is not None: app.add_middleware(TokenQuotaMiddleware) diff --git a/main.py b/main.py index 76e1efc..5ac3826 100644 --- a/main.py +++ b/main.py @@ -8,45 +8,36 @@ from contextlib import asynccontextmanager from a2a.routes import router as a2a_router -from auth.oidc import verify_jwt # Fixed: was auth.jwt, now auth.oidc from logs import router as logs_router -from mem import write as mem_write # Import memory write function -from middleware.auth import jwt_auth_mw # ← your auth middleware -from middleware.session import session_mw # ← generates session-id header +from middleware.auth import jwt_auth_mw +from middleware.session import session_mw from proxy.engine import router as proxy_router from usage.factory import _select_backend, get_usage_backend from usage.metrics import mount_metrics from utils.env import int_env -# At the top, make the import conditional try: from middleware.quota import TokenQuotaMiddleware - QUOTA_AVAILABLE = True except ImportError: QUOTA_AVAILABLE = False -# Memory router mem_router = APIRouter(prefix="/mem", tags=["memory"]) @mem_router.get("/events") async def get_memory_events(request: Request, limit: int = 10): - """Fetch recent MemoryEvent objects from Weaviate""" + """Fetch recent MemoryEvent objects from Weaviate.""" try: - # Get user info from request state (set by jwt_auth_mw middleware) user_sub = getattr(request.state, "sub", None) if not user_sub: raise HTTPException(status_code=401, detail="User not authenticated") - # Use the exact same client setup as demo_view_memory.py - client = weaviate.Client("http://localhost:6666") + client = weaviate.Client(os.getenv("WEAVIATE_URL", "http://localhost:6666")) - # Test connection first if not client.is_ready(): raise HTTPException(status_code=503, detail="Weaviate is not ready") - # Check schema the same way as demo_view_memory.py try: schema = client.schema.get() classes = {c["class"] for c in schema.get("classes", [])} @@ -56,7 +47,6 @@ async def get_memory_events(request: Request, limit: int = 10): except Exception: return {"data": {"Get": {"MemoryEvent": []}}} - # Query with descending order by timestamp (newest first) result = ( client.query.get( "MemoryEvent", @@ -68,7 +58,6 @@ async def get_memory_events(request: Request, limit: int = 10): .do() ) - # Check for GraphQL errors like demo_view_memory.py does if "errors" in result: raise HTTPException( status_code=500, detail=f"GraphQL error: {result['errors']}" @@ -79,31 +68,26 @@ async def get_memory_events(request: Request, limit: int = 10): events = result["data"]["Get"]["MemoryEvent"] - # Add the result field from the raw objects for richer display try: raw_objects = client.data_object.get(class_name="MemoryEvent", limit=limit) - # Create a mapping of IDs to full objects id_to_full_object = {} for obj in raw_objects.get("objects", []): obj_id = obj.get("id") if obj_id: id_to_full_object[obj_id] = obj.get("properties", {}) - # Enrich the GraphQL results with data from raw objects for event in events: event_id = event.get("_additional", {}).get("id") if event_id and event_id in id_to_full_object: full_props = id_to_full_object[event_id] - # Add the result field if it exists if "result" in full_props: event["result"] = full_props["result"] - # Add other useful fields for field in ["event", "session_id", "task_id", "user"]: if field in full_props: event[field] = full_props[field] except Exception: - pass # Silently fail if we can't enrich with raw object data + pass return result @@ -113,43 +97,36 @@ async def get_memory_events(request: Request, limit: int = 10): ) -middlewares = [ - # ❶ CORS first (so it executes last and handles responses properly) - Middleware( - CORSMiddleware, - allow_origins=["http://localhost:9000", "http://127.0.0.1:9000"], - allow_methods=["*"], - allow_headers=["*"], - allow_credentials=True, - ), - # ❷ Auth middleware - Middleware(BaseHTTPMiddleware, dispatch=jwt_auth_mw), - # ❸ Session middleware - Middleware(BaseHTTPMiddleware, dispatch=session_mw), -] - -# Only add quota middleware if tiktoken is available and a positive limit is set -limit = int_env("MAX_TOKENS_PER_MIN", 60000) -if QUOTA_AVAILABLE and limit is not None: - middlewares.append(Middleware(TokenQuotaMiddleware)) - @asynccontextmanager async def lifespan(app: FastAPI): """Manage application lifespan - startup and shutdown.""" - # Startup backend_selector = _select_backend() app.state.usage = get_usage_backend(backend_selector) mount_metrics(app) yield - # Shutdown if hasattr(app.state.usage, 'aclose'): await app.state.usage.aclose() -# Create app with lifespan -app = FastAPI(title="attach-gateway", middleware=middlewares, lifespan=lifespan) +app = FastAPI(title="attach-gateway", lifespan=lifespan) +# Add middleware in correct order (CORS outer-most) +app.add_middleware( + CORSMiddleware, + allow_origins=["http://localhost:9000", "http://127.0.0.1:9000"], + allow_methods=["*"], + allow_headers=["*"], + allow_credentials=True, +) + +# Only add quota middleware if available and explicitly configured +limit = int_env("MAX_TOKENS_PER_MIN", 60000) +if QUOTA_AVAILABLE and limit is not None: + app.add_middleware(TokenQuotaMiddleware) + +app.add_middleware(BaseHTTPMiddleware, dispatch=jwt_auth_mw) +app.add_middleware(BaseHTTPMiddleware, dispatch=session_mw) @app.get("/auth/config") async def auth_config(): @@ -159,9 +136,7 @@ async def auth_config(): "audience": os.getenv("OIDC_AUD"), } - -# Add middleware after routes are defined app.include_router(a2a_router, prefix="/a2a") app.include_router(logs_router) app.include_router(mem_router) -app.include_router(proxy_router) # ← ADD THIS BACK +app.include_router(proxy_router) diff --git a/middleware/quota.py b/middleware/quota.py index 2e5fc53..15af934 100644 --- a/middleware/quota.py +++ b/middleware/quota.py @@ -363,9 +363,11 @@ async def dispatch(self, request: Request, call_next): # Re-use the already-read body from here on request._body = raw - user = request.headers.get("x-attach-user") or ( + # Fix: Get user from request.state.sub (set by auth middleware) + user = getattr(request.state, "sub", None) or ( request.client.host if request.client else "unknown" ) + usage = { "user": user, "project": request.headers.get("x-attach-project", "default"),