Skip to content
Merged

Dev #10

Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 84 additions & 50 deletions apipod/engine/backend/fastapi/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import inspect
import threading
import logging
from contextlib import asynccontextmanager
from typing import Union, Callable, get_type_hints, Generator, AsyncGenerator, Iterator, AsyncIterator
from fastapi import APIRouter, FastAPI, Response
from fastapi.responses import JSONResponse
Expand Down Expand Up @@ -34,6 +35,7 @@ def __init__(
prefix: str = "", # "/api",
max_upload_file_size_mb: float = None,
job_queue=None,
lifespan=None,
*args,
**kwargs):
"""
Expand All @@ -46,12 +48,17 @@ def __init__(
prefix: The API route prefix
max_upload_file_size_mb: Maximum file size in MB for uploads
job_queue: Optional custom JobQueue implementation
lifespan: Optional async context manager for custom startup/shutdown logic
args: Additional arguments
kwargs: Additional keyword arguments
"""
# Extract user-provided lifespan (explicit param or kwarg) before parent init
user_lifespan = lifespan or kwargs.pop('lifespan', None)

# Initialize parent classes
api_router_params = inspect.signature(APIRouter.__init__).parameters
api_router_kwargs = {k: kwargs.get(k) for k in api_router_params if k in kwargs}
api_router_kwargs.pop('lifespan', None) # handled via composed lifespan below

APIRouter.__init__(self, **api_router_kwargs)
_BaseBackend.__init__(self, title=title, summary=summary, *args, **kwargs)
Expand All @@ -61,27 +68,35 @@ def __init__(

self.status = SERVER_HEALTH.INITIALIZING

# Registry for functions that workers can execute. Keys are function names.
self._job_func_registry: dict = {}
# Stop event and thread handle for in-process worker (dev mode)
self._worker_stop_event = threading.Event()
self._worker_thread: threading.Thread | None = None
self._logger = logging.getLogger(__name__)

# Build a composed lifespan that merges internal worker hooks with the user-provided lifespan
combined_lifespan = self._build_lifespan(user_lifespan)

# Create or use provided FastAPI app
if app is None:
app = FastAPI(
title=self.title,
summary=self.summary,
contact={"name": "SocAIty", "url": "https://www.socaity.ai"}
contact={"name": "SocAIty", "url": "https://www.socaity.ai"},
lifespan=combined_lifespan,
)
else:
# Existing app: replace its lifespan with our composed version
app.router.lifespan_context = combined_lifespan

self.app: FastAPI = app
self.prefix = prefix
self.add_standard_routes()

# Registry for functions that workers can execute. Keys are function names.
self._job_func_registry: dict = {}
# Stop event and thread handle for in-process worker (dev mode)
self._worker_stop_event = threading.Event()
self._worker_thread: threading.Thread | None = None
self._logger = logging.getLogger(__name__)
self._endpoint_configurator = FastApiEndpointConfigurator(self)

# excpetion handling
# Exception handling
_FastAPIExceptionHandler.__init__(self)
if not getattr(self.app.state, "_socaity_exception_handler_added", False):
self.app.add_exception_handler(Exception, self.global_exception_handler)
Expand All @@ -91,50 +106,69 @@ def __init__(
self._orig_openapi_func = self.app.openapi
self.app.openapi = self.custom_openapi

# Start in-process worker on FastAPI startup (dev convenience).
# Only start if a job_queue with `start_worker` exists.
if not getattr(self.app.state, "_socaity_worker_hooks_added", False):
def _startup():
try:
if self.job_queue and hasattr(self.job_queue, "start_worker"):
# Start worker in a daemon thread so it doesn't block uvicorn
def _run():
try:
self.job_queue.start_worker(
func_registry=self._job_func_registry,
worker_name="api-worker",
stop_event=self._worker_stop_event,
)
except Exception:
self._logger.exception("Worker thread exited with exception")

t = threading.Thread(target=_run, daemon=True)
t.start()
self._worker_thread = t
except Exception:
self._logger.exception("Failed to start in-process worker on startup")

def _shutdown():
try:
# Signal local worker to stop
# ------------------------------------------------------------------
# Lifespan & worker lifecycle
# ------------------------------------------------------------------

def _build_lifespan(self, user_lifespan=None):
"""
Build a composed lifespan context manager that runs:
1. Internal worker startup
2. User-provided lifespan (if any)
3. Internal worker shutdown on exit
"""
router_self = self # capture for closure

@asynccontextmanager
async def _combined_lifespan(app):
router_self._start_background_worker()
try:
if user_lifespan:
async with user_lifespan(app):
yield
else:
yield
finally:
router_self._stop_background_worker()

return _combined_lifespan

def _start_background_worker(self):
"""Start the in-process job queue worker in a daemon thread (dev convenience)."""
try:
if self.job_queue and hasattr(self.job_queue, "start_worker"):
def _run():
try:
self._worker_stop_event.set()
self.job_queue.start_worker(
func_registry=self._job_func_registry,
worker_name="api-worker",
stop_event=self._worker_stop_event,
)
except Exception:
pass

# Call job_queue.shutdown if available
if self.job_queue and hasattr(self.job_queue, "shutdown"):
try:
self.job_queue.shutdown()
except Exception:
self._logger.exception("Error shutting down job queue")
except Exception:
self._logger.exception("Error during worker shutdown handler")

# Register handlers
self.app.add_event_handler("startup", _startup)
self.app.add_event_handler("shutdown", _shutdown)
self.app.state._socaity_worker_hooks_added = True
self._logger.exception("Worker thread exited with exception")

thread = threading.Thread(target=_run, daemon=True)
thread.start()
self._worker_thread = thread
except Exception:
self._logger.exception("Failed to start in-process worker on startup")

def _stop_background_worker(self):
"""Signal the background worker to stop and shut down the job queue."""
try:
self._worker_stop_event.set()
except Exception:
pass

if self.job_queue and hasattr(self.job_queue, "shutdown"):
try:
self.job_queue.shutdown()
except Exception:
self._logger.exception("Error shutting down job queue")

# ------------------------------------------------------------------
# Standard routes
# ------------------------------------------------------------------

def add_standard_routes(self):
"""Add standard API routes for status and health checks."""
Expand Down