diff --git a/backend/app/api/routes/collections.py b/backend/app/api/routes/collections.py index 2a5b49fb..8487e82f 100644 --- a/backend/app/api/routes/collections.py +++ b/backend/app/api/routes/collections.py @@ -13,7 +13,9 @@ from fastapi import Path as FastPath from pydantic import BaseModel, Field, HttpUrl from sqlalchemy.exc import SQLAlchemyError +from sqlmodel import Session +from app.core.db import engine from app.api.deps import CurrentUser, SessionDep, CurrentUserOrgProject from app.core.cloud import AmazonCloudStorage from app.api.routes.responses import handle_openai_error @@ -24,7 +26,7 @@ DocumentCollectionCrud, ) from app.crud.rag import OpenAIVectorStoreCrud, OpenAIAssistantCrud -from app.models import Collection, Document +from app.models import Collection, Document, UserProjectOrg from app.models.collection import CollectionStatus from app.utils import APIResponse, load_description, get_openai_client @@ -209,10 +211,12 @@ def _backout(crud: OpenAIAssistantCrud, assistant_id: str): exc_info=True, ) - +# TODO: Avoid passing ORM / object models (e.g., UserProjectOrg, ResponsePayload) +# into background job functions. When moving to Celery or another task queue, +# this can cause issues with serialization/deserialization. Instead, pass only +# primitive types (IDs, strings, etc.) and rehydrate objects inside the task. def do_create_collection( - session: SessionDep, - current_user: CurrentUserOrgProject, + current_user: UserProjectOrg, request: CreationRequest, payload: ResponsePayload, client: OpenAI, @@ -226,37 +230,41 @@ def do_create_collection( ) storage = AmazonCloudStorage(current_user) - document_crud = DocumentCrud(session, current_user.id) assistant_crud = OpenAIAssistantCrud(client) vector_store_crud = OpenAIVectorStoreCrud(client) - collection_crud = CollectionCrud(session, current_user.id) try: - vector_store = vector_store_crud.create() - - docs = list(request(document_crud)) - flat_docs = [doc for sublist in docs for doc in sublist] + docs = [] + flat_docs = [] + with Session(engine) as session: + document_crud = DocumentCrud(session, current_user.id) + docs = list(request(document_crud)) + flat_docs = [doc for sublist in docs for doc in sublist] file_exts = {doc.fname.split(".")[-1] for doc in flat_docs if "." in doc.fname} file_sizes_kb = [ storage.get_file_size_kb(doc.object_store_url) for doc in flat_docs ] + vector_store = vector_store_crud.create() list(vector_store_crud.update(vector_store.id, storage, docs)) assistant_options = dict(request.extract_super_type(AssistantOptions)) assistant = assistant_crud.create(vector_store.id, **assistant_options) - collection = collection_crud.read_one(UUID(payload.key)) - collection.llm_service_id = assistant.id - collection.llm_service_name = request.model - collection.status = CollectionStatus.successful - collection.updated_at = now() + # Update database with results + with Session(engine) as session: + collection_crud = CollectionCrud(session, current_user.id) + collection = collection_crud.read_one(UUID(payload.key)) + collection.llm_service_id = assistant.id + collection.llm_service_name = request.model + collection.status = CollectionStatus.successful + collection.updated_at = now() - if flat_docs: - DocumentCollectionCrud(session).create(collection, flat_docs) + if flat_docs: + DocumentCollectionCrud(session).create(collection, flat_docs) - collection_crud._update(collection) + collection_crud._update(collection) elapsed = time.time() - start_time logger.info( @@ -272,14 +280,16 @@ def do_create_collection( ) if "assistant" in locals(): _backout(assistant_crud, assistant.id) - try: - collection = collection_crud.read_one(UUID(payload.key)) - collection.status = CollectionStatus.failed - collection.updated_at = now() - message = extract_error_message(err) - collection.error_message = message - collection_crud._update(collection) + try: + with Session(engine) as session: + collection_crud = CollectionCrud(session, current_user.id) + collection = collection_crud.read_one(UUID(payload.key)) + collection.status = CollectionStatus.failed + collection.updated_at = now() + message = extract_error_message(err) + collection.error_message = message + collection_crud._update(collection) except Exception as suberr: logger.warning( f"[do_create_collection] Failed to update collection status | {{'collection_id': '{payload.key}', 'reason': '{str(suberr)}'}}" @@ -317,7 +327,7 @@ def create_collection( collection_crud.create(collection) background_tasks.add_task( - do_create_collection, session, current_user, request, payload, client + do_create_collection, current_user, request, payload, client ) logger.info( @@ -327,9 +337,12 @@ def create_collection( return APIResponse.success_response(data=None, metadata=asdict(payload)) +# TODO: Avoid passing ORM / object models (e.g., UserProjectOrg, ResponsePayload) +# into background job functions. When moving to Celery or another task queue, +# this can cause issues with serialization/deserialization. Instead, pass only +# primitive types (IDs, strings, etc.) and rehydrate objects inside the task. def do_delete_collection( - session: SessionDep, - current_user: CurrentUserOrgProject, + current_user: UserProjectOrg, request: DeletionRequest, payload: ResponsePayload, client: OpenAI, @@ -339,15 +352,21 @@ def do_delete_collection( else: callback = WebHookCallback(request.callback_url, payload) - collection_crud = CollectionCrud(session, current_user.id) try: - collection = collection_crud.read_one(request.collection_id) - assistant = OpenAIAssistantCrud(client) - data = collection_crud.delete(collection, assistant) - logger.info( - f"[do_delete_collection] Collection deleted successfully | {{'collection_id': '{collection.id}'}}" - ) - callback.success(data.model_dump(mode="json")) + with Session(engine) as session: + collection_crud = CollectionCrud(session, current_user.id) + collection = collection_crud.read_one(request.collection_id) + assistant = OpenAIAssistantCrud(client) + # TODO: Decouple OpenAI collection deletion from DB session handling. + # Currently, the call to OpenAI is tightly coupled with the session, + # which may keep the session open until deletion completes. + + data = collection_crud.delete(collection, assistant) + logger.info( + f"[do_delete_collection] Collection deleted successfully | {{'collection_id': '{collection.id}'}}" + ) + callback.success(data.model_dump(mode="json")) + except (ValueError, PermissionError, SQLAlchemyError) as err: logger.error( f"[do_delete_collection] Failed to delete collection | {{'collection_id': '{request.collection_id}', 'error': '{str(err)}'}}", @@ -381,7 +400,7 @@ def delete_collection( payload = ResponsePayload("processing", route) background_tasks.add_task( - do_delete_collection, session, current_user, request, payload, client + do_delete_collection, current_user, request, payload, client ) logger.info( diff --git a/backend/app/tests/api/routes/collections/test_create_collections.py b/backend/app/tests/api/routes/collections/test_create_collections.py index 2bede71a..e9c38afa 100644 --- a/backend/app/tests/api/routes/collections/test_create_collections.py +++ b/backend/app/tests/api/routes/collections/test_create_collections.py @@ -4,7 +4,7 @@ from sqlmodel import Session from fastapi.testclient import TestClient -from unittest.mock import patch +from unittest.mock import patch, MagicMock from app.core.config import settings from app.tests.utils.document import DocumentStore @@ -43,13 +43,30 @@ class TestCollectionRouteCreate: _n_documents = 5 @patch("app.api.routes.collections.get_openai_client") + @patch("fastapi.BackgroundTasks.add_task") + @patch("app.api.routes.collections.Session") def test_create_collection_success( self, + mock_session, + mock_add_task, mock_get_openai_client, client: TestClient, db: Session, user_api_key_header, ): + # Configure mocks + mock_openai_client = get_mock_openai_client_with_vector_store() + mock_get_openai_client.return_value = mock_openai_client + + # # Make background task run synchronously + def run_task_immediately(task_func, *args, **kwargs): + task_func(*args, **kwargs) + + mock_add_task.side_effect = run_task_immediately + + # Mock Session to return the test database session + mock_session.return_value.__enter__.return_value = db + store = DocumentStore(db) documents = store.fill(self._n_documents) doc_ids = [str(doc.id) for doc in documents] @@ -64,9 +81,6 @@ def test_create_collection_success( headers = user_api_key_header - mock_openai_client = get_mock_openai_client_with_vector_store() - mock_get_openai_client.return_value = mock_openai_client - response = client.post( f"{settings.API_V1_STR}/collections/create", json=body, headers=headers ) @@ -88,7 +102,6 @@ def test_create_collection_success( headers=headers, ) assert info_response.status_code == 200 - info_data = info_response.json()["data"] assert collection.status == CollectionStatus.successful.value assert collection.owner_id == user.user_id