Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ WEAVIATE_URL=http://localhost:8081
MAX_TOKENS_PER_MIN=60000
QUOTA_ENCODING=cl100k_base

USAGE_METERING=null
# Metering Option (null, prometheus, openmeter)
USAGE_METERING=null

# Development: Auth0 credentials for dev_login script
# AUTH0_DOMAIN=your-domain.auth0.com
Expand Down
68 changes: 52 additions & 16 deletions attach/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,33 @@
"""

import os
from contextlib import asynccontextmanager
from typing import Optional

import weaviate
from fastapi import APIRouter, FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.base import BaseHTTPMiddleware
from pydantic import BaseModel

from a2a.routes import router as a2a_router

# Clean relative imports
from auth import verify_jwt
from auth.oidc import _require_env

# from logs import router as logs_router
from logs import router as logs_router
from mem import get_memory_backend
from middleware.auth import jwt_auth_mw
from middleware.quota import TokenQuotaMiddleware
from middleware.session import session_mw
from proxy.engine import router as proxy_router
from usage.factory import _select_backend, get_usage_backend
from usage.metrics import mount_metrics
from utils.env import int_env

# Guard TokenQuotaMiddleware import (matches main.py pattern)
try:
from middleware.quota import TokenQuotaMiddleware
QUOTA_AVAILABLE = True
except ImportError: # optional extra not installed
QUOTA_AVAILABLE = False

# Import version from parent package
from . import __version__

Expand Down Expand Up @@ -52,7 +57,7 @@ async def get_memory_events(request: Request, limit: int = 10):
return {"data": {"Get": {"MemoryEvent": []}}}

result = (
client.query.get("MemoryEvent", ["timestamp", "role", "content"])
client.query.get("MemoryEvent", ["timestamp", "event", "user", "state"])
.with_additional(["id"])
.with_limit(limit)
.with_sort([{"path": ["timestamp"], "order": "desc"}])
Expand Down Expand Up @@ -100,6 +105,21 @@ class AttachConfig(BaseModel):
auth0_client: Optional[str] = None


@asynccontextmanager
async def lifespan(app: FastAPI):
"""Manage application lifespan - startup and shutdown."""
# Startup
backend_selector = _select_backend()
app.state.usage = get_usage_backend(backend_selector)
mount_metrics(app)

yield

# Shutdown
if hasattr(app.state.usage, 'aclose'):
await app.state.usage.aclose()


def create_app(config: Optional[AttachConfig] = None) -> FastAPI:
"""
Create a FastAPI app with Attach Gateway functionality
Expand Down Expand Up @@ -130,27 +150,43 @@ def create_app(config: Optional[AttachConfig] = None) -> FastAPI:
title="Attach Gateway",
description="Identity & Memory side-car for LLM engines",
version=__version__,
lifespan=lifespan,
)
mount_metrics(app)

# Add middleware
app.middleware("http")(jwt_auth_mw)
app.middleware("http")(session_mw)
@app.get("/auth/config")
async def auth_config():
return {
"domain": config.auth0_domain,
"client_id": config.auth0_client,
"audience": config.oidc_audience,
}

# Add middleware in correct order (CORS outer-most)
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:9000", "http://127.0.0.1:9000"],
allow_methods=["*"],
allow_headers=["*"],
allow_credentials=True,
)

app.add_middleware(BaseHTTPMiddleware, dispatch=jwt_auth_mw)
app.add_middleware(BaseHTTPMiddleware, dispatch=session_mw)

# Only add quota middleware if available and explicitly configured
limit = int_env("MAX_TOKENS_PER_MIN", 60000)
if limit is not None:
if QUOTA_AVAILABLE and limit is not None:
app.add_middleware(TokenQuotaMiddleware)

# Add routes
app.include_router(a2a_router)
app.include_router(a2a_router, prefix="/a2a")
app.include_router(proxy_router)
# app.include_router(logs_router)
app.include_router(logs_router)
app.include_router(mem_router)

# Setup memory backend
memory_backend = get_memory_backend(config.mem_backend, config)
app.state.memory = memory_backend
app.state.config = config
backend_selector = _select_backend()
app.state.usage = get_usage_backend(backend_selector)

return app
71 changes: 23 additions & 48 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,45 +8,36 @@
from contextlib import asynccontextmanager

from a2a.routes import router as a2a_router
from auth.oidc import verify_jwt # Fixed: was auth.jwt, now auth.oidc
from logs import router as logs_router
from mem import write as mem_write # Import memory write function
from middleware.auth import jwt_auth_mw # ← your auth middleware
from middleware.session import session_mw # ← generates session-id header
from middleware.auth import jwt_auth_mw
from middleware.session import session_mw
from proxy.engine import router as proxy_router
from usage.factory import _select_backend, get_usage_backend
from usage.metrics import mount_metrics
from utils.env import int_env

# At the top, make the import conditional
try:
from middleware.quota import TokenQuotaMiddleware

QUOTA_AVAILABLE = True
except ImportError:
QUOTA_AVAILABLE = False

# Memory router
mem_router = APIRouter(prefix="/mem", tags=["memory"])


@mem_router.get("/events")
async def get_memory_events(request: Request, limit: int = 10):
"""Fetch recent MemoryEvent objects from Weaviate"""
"""Fetch recent MemoryEvent objects from Weaviate."""
try:
# Get user info from request state (set by jwt_auth_mw middleware)
user_sub = getattr(request.state, "sub", None)
if not user_sub:
raise HTTPException(status_code=401, detail="User not authenticated")

# Use the exact same client setup as demo_view_memory.py
client = weaviate.Client("http://localhost:6666")
client = weaviate.Client(os.getenv("WEAVIATE_URL", "http://localhost:6666"))

# Test connection first
if not client.is_ready():
raise HTTPException(status_code=503, detail="Weaviate is not ready")

# Check schema the same way as demo_view_memory.py
try:
schema = client.schema.get()
classes = {c["class"] for c in schema.get("classes", [])}
Expand All @@ -56,7 +47,6 @@ async def get_memory_events(request: Request, limit: int = 10):
except Exception:
return {"data": {"Get": {"MemoryEvent": []}}}

# Query with descending order by timestamp (newest first)
result = (
client.query.get(
"MemoryEvent",
Expand All @@ -68,7 +58,6 @@ async def get_memory_events(request: Request, limit: int = 10):
.do()
)

# Check for GraphQL errors like demo_view_memory.py does
if "errors" in result:
raise HTTPException(
status_code=500, detail=f"GraphQL error: {result['errors']}"
Expand All @@ -79,31 +68,26 @@ async def get_memory_events(request: Request, limit: int = 10):

events = result["data"]["Get"]["MemoryEvent"]

# Add the result field from the raw objects for richer display
try:
raw_objects = client.data_object.get(class_name="MemoryEvent", limit=limit)

# Create a mapping of IDs to full objects
id_to_full_object = {}
for obj in raw_objects.get("objects", []):
obj_id = obj.get("id")
if obj_id:
id_to_full_object[obj_id] = obj.get("properties", {})

# Enrich the GraphQL results with data from raw objects
for event in events:
event_id = event.get("_additional", {}).get("id")
if event_id and event_id in id_to_full_object:
full_props = id_to_full_object[event_id]
# Add the result field if it exists
if "result" in full_props:
event["result"] = full_props["result"]
# Add other useful fields
for field in ["event", "session_id", "task_id", "user"]:
if field in full_props:
event[field] = full_props[field]
except Exception:
pass # Silently fail if we can't enrich with raw object data
pass

return result

Expand All @@ -113,43 +97,36 @@ async def get_memory_events(request: Request, limit: int = 10):
)


middlewares = [
# ❶ CORS first (so it executes last and handles responses properly)
Middleware(
CORSMiddleware,
allow_origins=["http://localhost:9000", "http://127.0.0.1:9000"],
allow_methods=["*"],
allow_headers=["*"],
allow_credentials=True,
),
# ❷ Auth middleware
Middleware(BaseHTTPMiddleware, dispatch=jwt_auth_mw),
# ❸ Session middleware
Middleware(BaseHTTPMiddleware, dispatch=session_mw),
]

# Only add quota middleware if tiktoken is available and a positive limit is set
limit = int_env("MAX_TOKENS_PER_MIN", 60000)
if QUOTA_AVAILABLE and limit is not None:
middlewares.append(Middleware(TokenQuotaMiddleware))

@asynccontextmanager
async def lifespan(app: FastAPI):
"""Manage application lifespan - startup and shutdown."""
# Startup
backend_selector = _select_backend()
app.state.usage = get_usage_backend(backend_selector)
mount_metrics(app)

yield

# Shutdown
if hasattr(app.state.usage, 'aclose'):
await app.state.usage.aclose()

# Create app with lifespan
app = FastAPI(title="attach-gateway", middleware=middlewares, lifespan=lifespan)
app = FastAPI(title="attach-gateway", lifespan=lifespan)

# Add middleware in correct order (CORS outer-most)
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:9000", "http://127.0.0.1:9000"],
allow_methods=["*"],
allow_headers=["*"],
allow_credentials=True,
)

# Only add quota middleware if available and explicitly configured
limit = int_env("MAX_TOKENS_PER_MIN", 60000)
if QUOTA_AVAILABLE and limit is not None:
app.add_middleware(TokenQuotaMiddleware)

app.add_middleware(BaseHTTPMiddleware, dispatch=jwt_auth_mw)
app.add_middleware(BaseHTTPMiddleware, dispatch=session_mw)

@app.get("/auth/config")
async def auth_config():
Expand All @@ -159,9 +136,7 @@ async def auth_config():
"audience": os.getenv("OIDC_AUD"),
}


# Add middleware after routes are defined
app.include_router(a2a_router, prefix="/a2a")
app.include_router(logs_router)
app.include_router(mem_router)
app.include_router(proxy_router) # ← ADD THIS BACK
app.include_router(proxy_router)
4 changes: 3 additions & 1 deletion middleware/quota.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,9 +363,11 @@ async def dispatch(self, request: Request, call_next):
# Re-use the already-read body from here on
request._body = raw

user = request.headers.get("x-attach-user") or (
# Fix: Get user from request.state.sub (set by auth middleware)
user = getattr(request.state, "sub", None) or (
request.client.host if request.client else "unknown"
)

usage = {
"user": user,
"project": request.headers.get("x-attach-project", "default"),
Expand Down