Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions src/mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,13 @@

def build_mcp(app: FastAPI) -> FastMCP:
"""
Create (or return cached) FastMCP server
that mirrors the FastAPI app.
Create or return a cached FastMCP server that mirrors the given FastAPI app.

Parameters:
app (FastAPI): FastAPI application to mirror; the created FastMCP instance is cached on `app.state.mcp`.

Returns:
FastMCP: The FastMCP instance corresponding to the provided FastAPI app.
"""

if hasattr(app.state, 'mcp'):
Expand All @@ -18,4 +23,4 @@ def build_mcp(app: FastAPI) -> FastMCP:
settings.experimental.enable_new_openapi_parser = True
mcp = FastMCP.from_fastapi(app, name=app.title)
app.state.mcp = mcp # type: ignore[attr-defined]
return mcp
return mcp
57 changes: 56 additions & 1 deletion src/repositories/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,19 @@ def __init__(self):


def repository_exception_handler(method):
"""
Decorator that standardizes error handling and logging for repository coroutine methods.

Parameters:
method (Callable): The asynchronous repository method to wrap.

Returns:
wrapper (Callable): An async wrapper that:
- re-raises PyMongoError after logging the exception,
- re-raises RepositoryNotInitializedException after logging the exception,
- logs any other exception and raises an HTTPException with status 500 and detail 'Unexpected error ocurred',
- always logs completion of the repository method call with the repository name, method name, and kwargs.
"""
@functools.wraps(method)
async def wrapper(self, *args, **kwargs):
try:
Expand Down Expand Up @@ -81,13 +94,31 @@ class RepositoryInterface:
_global_thread_lock = threading.Lock()

def __new__(cls, *args, **kwargs):
"""
Ensure a single thread-safe instance exists for the subclass.

Creates and returns the singleton instance for this subclass, creating it if absent while holding a global thread lock to prevent concurrent instantiation.

Returns:
The singleton instance of the subclass.
"""
with cls._global_thread_lock:
if cls not in cls._global_instances:
instance = super().__new__(cls)
cls._global_instances[cls] = instance
return cls._global_instances[cls]

def __init__(self, model: ApiBaseModel, *, max_pool_size: int = 3):
"""
Initialize the repository instance for a specific API model and configure its connection pool.

Parameters:
model (ApiBaseModel): The API model used for validation and to determine the repository's collection.
max_pool_size (int, optional): Maximum size of the MongoDB connection pool. Defaults to 3.

Notes:
If the instance is already initialized, this constructor will not reconfigure it. Initialization of the underlying connection is started asynchronously.
"""
if not getattr(self, '_initialized', False):
self.model = model
self._max_pool_size = max_pool_size
Expand All @@ -96,6 +127,11 @@ def __init__(self, model: ApiBaseModel, *, max_pool_size: int = 3):

@retry(stop=stop_after_attempt(5), wait=wait_fixed(0.2))
async def _async_init(self):
"""
Perform idempotent, retry-safe asynchronous initialization of the repository instance.

Ensures a per-instance asyncio.Lock exists and acquires it to run initialization exactly once; on success it marks the instance as initialized and sets the internal _initialized_event so awaiters can proceed. If initialization fails, the original exception from _initialize_connection is propagated after logging.
"""
if getattr(self, '_initialized', False):
return

Expand All @@ -117,6 +153,11 @@ async def _async_init(self):
self._initialized_event.set()

def _initialize(self):
"""
Ensure the repository's asynchronous initializer is executed: run it immediately if no event loop is active, otherwise schedule it on the running loop.

If there is no running asyncio event loop, this method runs self._async_init() to completion on the current thread, blocking until it finishes. If an event loop is running, it schedules self._async_init() as a background task on that loop and returns immediately.
"""
try:
loop = asyncio.get_running_loop()
except RuntimeError:
Expand All @@ -125,13 +166,27 @@ def _initialize(self):
loop.create_task(self._async_init())

async def __aenter__(self):
"""
Waits for repository initialization to complete and returns the repository instance.

Returns:
RepositoryInterface: The initialized repository instance.
"""
await self._initialized_event.wait() # Ensure initialization is complete
return self

async def __aexit__(self, exc_type, exc_value, traceback):
await self._initialized_event.wait()

def _initialize_connection(self):
"""
Initialize the MongoDB async client, store the connection string, and bind the collection for this repository instance.

This method fetches the MongoDB connection string from secrets, creates an AsyncMongoClient configured with pool and timeout settings, and sets self._collection to the repository's collection named by the model. On success it logs the initialized client; on failure it raises a ConnectionError.

Raises:
ConnectionError: If the client or collection cannot be initialized.
"""
try:
self._connection_string = Secrets.get_secret(
"MONGODB_CONNECTION_STRING"
Expand Down Expand Up @@ -229,4 +284,4 @@ async def find_by_query(self, query: dict):
parsed_model = self.model.model_validate(read_data)
parsed_model.set_id(str(read_data["_id"]))
parsed_models.append(parsed_model)
return parsed_models
return parsed_models
7 changes: 6 additions & 1 deletion tests/unit/test_mcp/test_mcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@

@pytest.fixture(autouse=True)
def reset_mcp_state():
"""
Ensure the FastAPI app has no lingering MCP state before and after a test.

This fixture deletes app.state.mcp if it exists, yields control to the test, and then deletes app.state.mcp again to guarantee the MCP state is cleared between tests.
"""
if hasattr(rest_app.state, 'mcp'):
delattr(rest_app.state, 'mcp')
yield
Expand Down Expand Up @@ -76,4 +81,4 @@ async def test_combined_app_serves_rest_and_mcp(monkeypatch):
resp_rest = await client.get('/health')
assert resp_rest.status_code == 200
resp_docs = await client.get('/docs')
assert resp_docs.status_code == 200
assert resp_docs.status_code == 200
Loading