diff --git a/backend/app/api/docs/collections/create.md b/backend/app/api/docs/collections/create.md index 3917d7c1..c1f36aac 100644 --- a/backend/app/api/docs/collections/create.md +++ b/backend/app/api/docs/collections/create.md @@ -10,7 +10,9 @@ pipeline: * Attach the Vector Store to an OpenAI [Assistant](https://platform.openai.com/docs/api-reference/assistants). Use parameters in the request body relevant to an Assistant to flesh out - its configuration. + its configuration. Note that an assistant will only be created when you pass both + "model" and "instruction" in the request body otherwise only a vector store will be + created from the documents given. If any one of the OpenAI interactions fail, all OpenAI resources are cleaned up. If a Vector Store is unable to be created, for example, @@ -23,5 +25,5 @@ The immediate response from the endpoint is `collection_job` object which is going to contain the collection "job ID", status and action type ("CREATE"). Once the collection has been created, information about the collection will be returned to the user via the callback URL. If a callback URL is not provided, -clients can poll the `collection job info` endpoint with the `id` in the +clients can check the `collection job info` endpoint with the `id` in the `collection_job` object returned as it is the `job id`, to retrieve the same information. diff --git a/backend/app/api/docs/collections/delete.md b/backend/app/api/docs/collections/delete.md index 63a1e3cf..78471b02 100644 --- a/backend/app/api/docs/collections/delete.md +++ b/backend/app/api/docs/collections/delete.md @@ -10,4 +10,6 @@ documents can still be accessed via the documents endpoints. The response from t endpoint will be a `collection_job` object which will contain the collection `job ID`, status and action type ("DELETE"). when you take the id returned and use the collection job info endpoint, if the job is successful, you will get the status as successful and nothing will -be returned as the collection as it has been deleted and marked as deleted. +be returned for the collection as it has been deleted. Additionally, if a `callback_url` was +provided in the request body, you will receive a message indicating whether the deletion was +successful. diff --git a/backend/app/api/docs/collections/job_info.md b/backend/app/api/docs/collections/job_info.md index e785967b..8af2df5e 100644 --- a/backend/app/api/docs/collections/job_info.md +++ b/backend/app/api/docs/collections/job_info.md @@ -1,5 +1,4 @@ -Retrieve information about a collection job by the collection job ID. This endpoint can be considered the polling endpoint for collection creation job. This endpoint provides detailed status and metadata for a specific collection job -in the AI platform. It is especially useful for: +Retrieve information about a collection job by the collection job ID. This endpoint provides detailed status and metadata for a specific collection job in the AI platform. It is especially useful for: * Fetching the collection job object containing the ID which will be collection job id, collection ID, status of the job as well as error message. diff --git a/backend/app/api/routes/collections.py b/backend/app/api/routes/collections.py index c6210beb..83b89ec5 100644 --- a/backend/app/api/routes/collections.py +++ b/backend/app/api/routes/collections.py @@ -3,10 +3,9 @@ from uuid import UUID from typing import List -from fastapi import APIRouter, Query +from fastapi import APIRouter, Query, HTTPException from fastapi import Path as FastPath - from app.api.deps import SessionDep, CurrentUserOrgProject from app.crud import ( CollectionCrud, @@ -20,13 +19,11 @@ CollectionJobCreate, ) from app.models.collection import ( - ResponsePayload, CreationRequest, DeletionRequest, CollectionPublic, ) from app.utils import APIResponse, load_description -from app.services.collections.helpers import extract_error_message from app.services.collections import ( create_collection as create_service, delete_collection as delete_service, @@ -55,22 +52,31 @@ def create_collection( ) ) - this = inspect.currentframe() - route = router.url_path_for(this.f_code.co_name) - payload = ResponsePayload( - status="processing", route=route, key=str(collection_job.id) + # True iff both model and instructions were provided in the request body + with_assistant = bool( + getattr(request, "model", None) and getattr(request, "instructions", None) ) create_service.start_job( db=session, request=request, - payload=payload, collection_job_id=collection_job.id, project_id=current_user.project_id, organization_id=current_user.organization_id, + with_assistant=with_assistant, ) - return APIResponse.success_response(collection_job) + metadata = None + if not with_assistant: + metadata = { + "note": ( + "This job will create a vector store only (no Assistant). " + "Assistant creation happens when both 'model' and 'instructions' are included." + ), + "with_assistant": False, + } + + return APIResponse.success_response(collection_job, metadata=metadata) @router.post( @@ -82,30 +88,19 @@ def delete_collection( current_user: CurrentUserOrgProject, request: DeletionRequest, ): - collection_crud = CollectionCrud(session, current_user.project_id) - collection = collection_crud.read_one(request.collection_id) - collection_job_crud = CollectionJobCrud(session, current_user.project_id) collection_job = collection_job_crud.create( CollectionJobCreate( action_type=CollectionActionType.DELETE, project_id=current_user.project_id, status=CollectionJobStatus.PENDING, - collection_id=collection.id, + collection_id=request.collection_id, ) ) - this = inspect.currentframe() - route = router.url_path_for(this.f_code.co_name) - payload = ResponsePayload( - status="processing", route=route, key=str(collection_job.id) - ) - delete_service.start_job( db=session, request=request, - payload=payload, - collection=collection, collection_job_id=collection_job.id, project_id=current_user.project_id, organization_id=current_user.organization_id, diff --git a/backend/app/crud/collection/collection.py b/backend/app/crud/collection/collection.py index d218ef2a..971b6fb3 100644 --- a/backend/app/crud/collection/collection.py +++ b/backend/app/crud/collection/collection.py @@ -20,19 +20,6 @@ def __init__(self, session: Session, project_id: int): self.project_id = project_id def _update(self, collection: Collection): - if not collection.project_id: - collection.project_id = self.project_id - elif collection.project_id != self.project_id: - err = ( - f"Invalid collection ownership: owner_project={self.project_id} " - f"attempter={collection.project_id}" - ) - logger.error( - "[CollectionCrud._update] Permission error | " - f"{{'collection_id': '{collection.id}', 'error': '{err}'}}" - ) - raise PermissionError(err) - self.session.add(collection) self.session.commit() self.session.refresh(collection) @@ -53,29 +40,28 @@ def _exists(self, collection: Collection) -> bool: return present def create( - self, - collection: Collection, - documents: Optional[list[Document]] = None, - ): + self, collection: Collection, documents: Optional[list[Document]] = None + ) -> Collection: + existing = None try: existing = self.read_one(collection.id) except HTTPException as e: - if e.status_code == 404: - self.session.add(collection) - self.session.commit() - self.session.refresh(collection) - else: + if e.status_code != 404: raise - else: + + if existing is not None: logger.warning( "[CollectionCrud.create] Collection already present | " f"{{'collection_id': '{collection.id}'}}" ) return existing + self.session.add(collection) + self.session.commit() + self.session.refresh(collection) + if documents: - dc_crud = DocumentCollectionCrud(self.session) - dc_crud.create(collection, documents) + DocumentCollectionCrud(self.session).create(collection, documents) return collection @@ -116,6 +102,12 @@ def read_all(self): collections = self.session.exec(statement).all() return collections + def delete_by_id(self, collection_id: UUID) -> Collection: + coll = self.read_one(collection_id) + coll.deleted_at = now() + + return self._update(coll) + @ft.singledispatchmethod def delete(self, model, remote): # remote should be an OpenAICrud try: @@ -145,7 +137,10 @@ def _(self, model: Document, remote): DocumentCollection, DocumentCollection.collection_id == Collection.id, ) - .where(DocumentCollection.document_id == model.id) + .where( + DocumentCollection.document_id == model.id, + Collection.deleted_at.is_(None), + ) .distinct() ) diff --git a/backend/app/models/collection.py b/backend/app/models/collection.py index 9e5f866f..194ad983 100644 --- a/backend/app/models/collection.py +++ b/backend/app/models/collection.py @@ -3,7 +3,7 @@ from typing import Any, Optional from sqlmodel import Field, Relationship, SQLModel -from pydantic import HttpUrl +from pydantic import HttpUrl, model_validator from app.core.util import now from .organization import Organization @@ -36,21 +36,6 @@ class Collection(SQLModel, table=True): project: Project = Relationship(back_populates="collections") -class ResponsePayload(SQLModel): - """Response metadata for background jobs—gives status, route, a UUID key, - and creation time.""" - - status: str - route: str - key: str = Field(default_factory=lambda: str(uuid4())) - time: datetime = Field(default_factory=now) - - @classmethod - def now(cls): - """Returns current UTC time without timezone info""" - return now() - - # pydantic models - class DocumentOptions(SQLModel): documents: list[UUID] = Field( @@ -73,27 +58,57 @@ class AssistantOptions(SQLModel): # Fields to be passed along to OpenAI. They must be a subset of # parameters accepted by the OpenAI.clien.beta.assistants.create # API. - model: str = Field( + model: Optional[str] = Field( + default=None, description=( + "**[To Be Deprecated]** " "OpenAI model to attach to this assistant. The model " "must be compatable with the assistants API; see the " "OpenAI [model documentation](https://platform.openai.com/docs/models/compare) for more." ), ) - instructions: str = Field( + + instructions: Optional[str] = Field( + default=None, description=( - "Assistant instruction. Sometimes referred to as the " '"system" prompt.' + "**[To Be Deprecated]** " + "Assistant instruction. Sometimes referred to as the " + '"system" prompt.' ), ) temperature: float = Field( default=1e-6, description=( + "**[To Be Deprecated]** " "Model temperature. The default is slightly " "greater-than zero because it is [unknown how OpenAI " "handles zero](https://community.openai.com/t/clarifications-on-setting-temperature-0/886447/5)." ), ) + @model_validator(mode="before") + def _assistant_fields_all_or_none(cls, values: dict[str, Any]) -> dict[str, Any]: + def norm(x: Any) -> Any: + if x is None: + return None + if isinstance(x, str): + s = x.strip() + return s if s else None + return x # let Pydantic handle non-strings + + model = norm(values.get("model")) + instructions = norm(values.get("instructions")) + + if (model is None) ^ (instructions is None): + raise ValueError( + "To create an Assistant, provide BOTH 'model' and 'instructions'. " + "If you only want a vector store, remove both fields." + ) + + values["model"] = model + values["instructions"] = instructions + return values + class CallbackRequest(SQLModel): callback_url: Optional[HttpUrl] = Field( @@ -108,7 +123,7 @@ class CreationRequest( CallbackRequest, ): def extract_super_type(self, cls: "CreationRequest"): - for field_name in cls.__fields__.keys(): + for field_name in cls.model_fields.keys(): field_value = getattr(self, field_name) yield (field_name, field_value) diff --git a/backend/app/services/collections/create_collection.py b/backend/app/services/collections/create_collection.py index d424c533..3f412774 100644 --- a/backend/app/services/collections/create_collection.py +++ b/backend/app/services/collections/create_collection.py @@ -17,12 +17,10 @@ from app.crud.rag import OpenAIVectorStoreCrud, OpenAIAssistantCrud from app.models import ( CollectionJobStatus, - CollectionJob, Collection, CollectionJobUpdate, ) from app.models.collection import ( - ResponsePayload, CreationRequest, AssistantOptions, ) @@ -41,9 +39,9 @@ def start_job( db: Session, request: CreationRequest, - payload: ResponsePayload, project_id: int, collection_job_id: UUID, + with_assistant: bool, organization_id: int, ) -> str: trace_id = correlation_id.get() or "N/A" @@ -57,9 +55,9 @@ def start_job( function_path="app.services.collections.create_collection.execute_job", project_id=project_id, job_id=str(collection_job_id), - payload=payload.model_dump(), trace_id=trace_id, - request=request.model_dump(), + with_assistant=with_assistant, + request=request.model_dump(mode="json"), organization_id=organization_id, ) @@ -75,9 +73,9 @@ def execute_job( request: dict, project_id: int, organization_id: int, - payload: dict, task_id: str, job_id: str, + with_assistant: bool, task_instance, ) -> None: """ @@ -85,107 +83,145 @@ def execute_job( """ start_time = time.time() - try: - with Session(engine) as session: - creation_request = CreationRequest(**request) - payload = ResponsePayload(**payload) + assistant = None + assistant_crud = None + vector_store = None + vector_store_crud = None + collection_job = None + callback = None - job_id = UUID(job_id) + try: + creation_request = CreationRequest(**request) + job_uuid = UUID(job_id) + with Session(engine) as session: collection_job_crud = CollectionJobCrud(session, project_id) - collection_job = collection_job_crud.read_one(job_id) + collection_job = collection_job_crud.read_one(job_uuid) collection_job = collection_job_crud.update( - job_id, + job_uuid, CollectionJobUpdate( - task_id=task_id, status=CollectionJobStatus.PROCESSING + task_id=task_id, + status=CollectionJobStatus.PROCESSING, ), ) client = get_openai_client(session, organization_id, project_id) + storage = get_cloud_storage(session=session, project_id=project_id) - callback = ( - SilentCallback(payload) - if creation_request.callback_url is None - else WebHookCallback(creation_request.callback_url, payload) + document_crud = DocumentCrud(session, project_id) + docs_batches = batch_documents( + document_crud, + creation_request.documents, + creation_request.batch_size, ) + flat_docs = [doc for batch in docs_batches for doc in batch] - storage = get_cloud_storage(session=session, project_id=project_id) - document_crud = DocumentCrud(session, project_id) - assistant_crud = OpenAIAssistantCrud(client) - vector_store_crud = OpenAIVectorStoreCrud(client) + callback = ( + SilentCallback(collection_job) + if not creation_request.callback_url + else WebHookCallback(creation_request.callback_url, collection_job) + ) - try: - vector_store = vector_store_crud.create() + vector_store_crud = OpenAIVectorStoreCrud(client) + vector_store = vector_store_crud.create() + list(vector_store_crud.update(vector_store.id, storage, docs_batches)) - docs_batches = batch_documents( - document_crud, - creation_request.documents, - creation_request.batch_size, - ) - flat_docs = [doc for batch in docs_batches for doc in batch] + if with_assistant: + assistant_crud = OpenAIAssistantCrud(client) + assistant_options = dict( + creation_request.extract_super_type(AssistantOptions) + ) + assistant_options = { + k: v for k, v in assistant_options.items() if v is not None + } + + assistant = assistant_crud.create(vector_store.id, **assistant_options) + llm_service_id = assistant.id + llm_service_name = assistant_options.get("model") or "assistant" + + logger.info( + "[execute_job] Assistant created | assistant_id=%s, vector_store_id=%s", + assistant.id, + vector_store.id, + ) + else: + llm_service_id = vector_store.id + llm_service_name = "openai vector store" + logger.info( + "[execute_job] Skipping assistant creation | with_assistant=False" + ) - file_exts = { - doc.fname.split(".")[-1] for doc in flat_docs if "." in doc.fname - } - file_sizes_kb = [ - storage.get_file_size_kb(doc.object_store_url) for doc in flat_docs - ] + file_exts = {doc.fname.split(".")[-1] for doc in flat_docs if "." in doc.fname} + file_sizes_kb = [ + storage.get_file_size_kb(doc.object_store_url) for doc in flat_docs + ] - list(vector_store_crud.update(vector_store.id, storage, docs_batches)) + with Session(engine) as session: + collection_crud = CollectionCrud(session, project_id) + + collection_id = uuid4() + collection = Collection( + id=collection_id, + project_id=project_id, + organization_id=organization_id, + llm_service_id=llm_service_id, + llm_service_name=llm_service_name, + ) + collection_crud.create(collection) + collection = collection_crud.read_one(collection.id) - assistant_options = dict( - creation_request.extract_super_type(AssistantOptions) - ) - assistant = assistant_crud.create(vector_store.id, **assistant_options) - - collection_id = uuid4() - collection_crud = CollectionCrud(session, project_id) - collection = Collection( - id=collection_id, - project_id=project_id, - organization_id=organization_id, - llm_service_id=assistant.id, - llm_service_name=creation_request.model, - ) + if flat_docs: + DocumentCollectionCrud(session).create(collection, flat_docs) - collection_crud.create(collection) - collection_data = collection_crud.read_one(collection.id) + collection_job_crud = CollectionJobCrud(session, project_id) + collection_job_crud.update( + collection_job.id, + CollectionJobUpdate( + status=CollectionJobStatus.SUCCESSFUL, + collection_id=collection.id, + ), + ) - if flat_docs: - DocumentCollectionCrud(session).create(collection_data, flat_docs) + success_payload = collection.model_dump(mode="json", exclude_none=True) - collection_job_crud.update( - collection_job.id, - CollectionJobUpdate( - status=CollectionJobStatus.SUCCESSFUL, - collection_id=collection.id, - ), - ) + elapsed = time.time() - start_time + logger.info( + "[create_collection.execute_job] Collection created: %s | Time: %.2fs | Files: %d | Sizes: %s KB | Types: %s", + collection_id, + elapsed, + len(flat_docs), + file_sizes_kb, + list(file_exts), + ) - elapsed = time.time() - start_time - logger.info( - "[create_collection.execute_job] Collection created: %s | Time: %.2fs | Files: %d | Sizes: %s KB | Types: %s", - collection_id, - elapsed, - len(flat_docs), - file_sizes_kb, - list(file_exts), - ) + if callback: + callback.success(success_payload) - callback.success(collection.model_dump(mode="json")) + except Exception as err: + logger.error( + "[create_collection.execute_job] Collection Creation Failed | {'collection_job_id': '%s', 'error': '%s'}", + job_id, + str(err), + exc_info=True, + ) - except Exception as err: - logger.error( - "[create_collection.execute_job] Collection Creation Failed | " - "{'collection_job_id': '%s', 'error': '%s'}", - job_id, - str(err), - exc_info=True, + try: + if assistant is not None and assistant_crud is not None: + _backout(assistant_crud, assistant.id) + elif vector_store is not None and vector_store_crud is not None: + _backout(vector_store_crud, vector_store.id) + else: + logger.warning( + "[create_collection._backout] Skipping: no resource/crud available" ) - - if "assistant" in locals(): - _backout(assistant_crud, assistant.id) - + except Exception: + logger.warning("[create_collection.execute_job] Backout failed") + + try: + with Session(engine) as session: + collection_job_crud = CollectionJobCrud(session, project_id) + if collection_job is None: + collection_job = collection_job_crud.read_one(UUID(job_id)) collection_job_crud.update( collection_job.id, CollectionJobUpdate( @@ -193,20 +229,10 @@ def execute_job( error_message=str(err), ), ) + except Exception: + logger.warning( + "[create_collection.execute_job] Failed to mark job as FAILED" + ) - callback.fail(str(err)) - - except Exception as outer_err: - logger.error( - "[create_collection.execute_job] Unexpected Error during collection creation: %s", - str(outer_err), - exc_info=True, - ) - - collection_job_crud.update( - collection_job.id, - CollectionJobUpdate( - status=CollectionJobStatus.FAILED, - error_message=str(outer_err), - ), - ) + if callback: + callback.fail(str(err)) diff --git a/backend/app/services/collections/delete_collection.py b/backend/app/services/collections/delete_collection.py index 088647c3..9335374e 100644 --- a/backend/app/services/collections/delete_collection.py +++ b/backend/app/services/collections/delete_collection.py @@ -3,17 +3,15 @@ from sqlmodel import Session from asgi_correlation_id import correlation_id -from sqlalchemy.exc import SQLAlchemyError from app.core.db import engine from app.crud import CollectionCrud, CollectionJobCrud -from app.crud.rag import OpenAIAssistantCrud +from app.crud.rag import OpenAIAssistantCrud, OpenAIVectorStoreCrud from app.models import CollectionJobStatus, CollectionJobUpdate -from app.models.collection import Collection, DeletionRequest +from app.models.collection import DeletionRequest from app.services.collections.helpers import ( SilentCallback, WebHookCallback, - ResponsePayload, ) from app.celery.utils import start_low_priority_job from app.utils import get_openai_client @@ -25,10 +23,8 @@ def start_job( db: Session, request: DeletionRequest, - collection: Collection, project_id: int, collection_job_id: UUID, - payload: ResponsePayload, organization_id: int, ) -> str: trace_id = correlation_id.get() or "N/A" @@ -42,23 +38,21 @@ def start_job( function_path="app.services.collections.delete_collection.execute_job", project_id=project_id, job_id=str(collection_job_id), - collection_id=str(collection.id), + collection_id=str(request.collection_id), trace_id=trace_id, - request=request.model_dump(), - payload=payload.model_dump(), + request=request.model_dump(mode="json"), organization_id=organization_id, ) logger.info( "[delete_collection.start_job] Job scheduled to delete collection | " - f"Job_id={collection_job_id}, project_id={project_id}, task_id={task_id}, collection_id={collection.id}" + f"Job_id={collection_job_id}, project_id={project_id}, task_id={task_id}, collection_id={request.collection_id}" ) return collection_job_id def execute_job( request: dict, - payload: dict, project_id: int, organization_id: int, task_id: str, @@ -67,89 +61,134 @@ def execute_job( task_instance, ) -> None: deletion_request = DeletionRequest(**request) - payload = ResponsePayload(**payload) - callback = ( - SilentCallback(payload) - if deletion_request.callback_url is None - else WebHookCallback(deletion_request.callback_url, payload) - ) + collection_id = UUID(collection_id) + job_id = UUID(job_id) - if not isinstance(collection_id, UUID): - collection_id = UUID(str(collection_id)) - if not isinstance(job_id, UUID): - job_id = UUID(str(job_id)) + collection_job = None + client = None - try: - with Session(engine) as session: - client = get_openai_client(session, organization_id, project_id) + with Session(engine) as session: + client = get_openai_client(session, organization_id, project_id) - collection_job_crud = CollectionJobCrud(session, project_id) - collection_job = collection_job_crud.read_one(job_id) - collection_job = collection_job_crud.update( - job_id, - CollectionJobUpdate( - task_id=task_id, status=CollectionJobStatus.PROCESSING - ), + collection_job_crud = CollectionJobCrud(session, project_id) + collection_job = collection_job_crud.read_one(job_id) + collection_job = collection_job_crud.update( + job_id, + CollectionJobUpdate(task_id=task_id, status=CollectionJobStatus.PROCESSING), + ) + + collection = CollectionCrud(session, project_id).read_one(collection_id) + + service = (collection.llm_service_name or "").strip().lower() + is_vector = service == "openai vector store" + llm_service_id = ( + ( + getattr(collection, "vector_store_id", None) + or getattr(collection, "llm_service_id", None) ) + if is_vector + else ( + getattr(collection, "assistant_id", None) + or getattr(collection, "llm_service_id", None) + ) + ) - assistant_crud = OpenAIAssistantCrud(client) - collection_crud = CollectionCrud(session, project_id) + callback = ( + SilentCallback(collection_job) + if not deletion_request.callback_url + else WebHookCallback(deletion_request.callback_url, collection_job) + ) - collection = collection_crud.read_one(collection_id) + try: + if not llm_service_id: + raise ValueError( + f"Missing llm service id for service '{collection.llm_service_name}' on collection {collection_id}" + ) - try: - result = collection_crud.delete(collection, assistant_crud) + if is_vector: + OpenAIVectorStoreCrud(client).delete(llm_service_id) + else: + OpenAIAssistantCrud(client).delete(llm_service_id) + except Exception as err: + try: + with Session(engine) as session: + collection_job_crud = CollectionJobCrud(session, project_id) collection_job_crud.update( collection_job.id, CollectionJobUpdate( - status=CollectionJobStatus.SUCCESSFUL, - error_message=None, + status=CollectionJobStatus.FAILED, error_message=str(err) ), ) + collection_job = collection_job_crud.read_one(collection_job.id) + except Exception: + logger.warning( + "[delete_collection.execute_job] Failed to mark job as FAILED" + ) - logger.info( - "[delete_collection.execute_job] Collection deleted successfully | {'collection_id': '%s', 'job_id': '%s'}", - str(collection.id), - str(job_id), - ) - callback.success(result.model_dump(mode="json")) + logger.error( + "[delete_collection.execute_job] Failed to delete collection remotely | " + "{'collection_id': '%s', 'error': '%s', 'job_id': '%s'}", + str(collection_id), + str(err), + str(job_id), + exc_info=True, + ) + + if callback: + callback.collection_job = collection_job + callback.fail(str(err)) + return - except (ValueError, PermissionError, SQLAlchemyError) as err: + try: + with Session(engine) as session: + CollectionCrud(session, project_id).delete_by_id(collection_id) + collection_job_crud = CollectionJobCrud(session, project_id) + collection_job_crud.update( + collection_job.id, + CollectionJobUpdate( + status=CollectionJobStatus.SUCCESSFUL, error_message=None + ), + ) + collection_job = collection_job_crud.read_one(collection_job.id) + + logger.info( + "[delete_collection.execute_job] Collection deleted successfully | {'collection_id': '%s', 'job_id': '%s'}", + str(collection_id), + str(job_id), + ) + + if callback: + callback.collection_job = collection_job + callback.success({"collection_id": str(collection_id), "deleted": True}) + + except Exception as err: + try: + with Session(engine) as session: + collection_job_crud = CollectionJobCrud(session, project_id) collection_job_crud.update( collection_job.id, CollectionJobUpdate( - status=CollectionJobStatus.FAILED, - error_message=str(err), + status=CollectionJobStatus.FAILED, error_message=str(err) ), ) - - logger.error( - "[delete_collection.execute_job] Failed to delete collection | {'collection_id': '%s', 'error': '%s', 'job_id': '%s'}", - str(collection.id), - str(err), - str(job_id), - exc_info=True, - ) - callback.fail(str(err)) - - except Exception as err: - collection_job_crud.update( - collection_job.id, - CollectionJobUpdate( - status=CollectionJobStatus.FAILED, - error_message=str(err), - ), - ) + collection_job = collection_job_crud.read_one(collection_job.id) + except Exception: + logger.warning( + "[delete_collection.execute_job] Failed to mark job as FAILED" + ) logger.error( - "[delete_collection.execute_job] Unexpected error during deletion | " + "[delete_collection.execute_job] Unexpected error during local deletion | " "{'collection_id': '%s', 'error': '%s', 'error_type': '%s', 'job_id': '%s'}", - str(collection.id), + str(collection_id), str(err), type(err).__name__, str(job_id), exc_info=True, ) - callback.fail(str(err)) + + if callback: + callback.collection_job = collection_job + callback.fail(str(err)) diff --git a/backend/app/services/collections/helpers.py b/backend/app/services/collections/helpers.py index 158994c6..0c5070a1 100644 --- a/backend/app/services/collections/helpers.py +++ b/backend/app/services/collections/helpers.py @@ -4,15 +4,12 @@ import re from uuid import UUID from typing import List -from dataclasses import asdict, replace from pydantic import HttpUrl from openai import OpenAIError from app.core.util import post_callback from app.crud.document import DocumentCrud -from app.models.collection import ResponsePayload -from app.crud.rag import OpenAIAssistantCrud from app.utils import APIResponse @@ -70,10 +67,9 @@ def batch_documents( return docs_batches -# functions related to callback handling - class CallbackHandler: - def __init__(self, payload: ResponsePayload): - self.payload = payload + def __init__(self, collection_job): + self.collection_job = collection_job def fail(self, body): raise NotImplementedError() @@ -84,46 +80,50 @@ def success(self, body): class SilentCallback(CallbackHandler): def fail(self, body): - logger.info(f"[SilentCallback.fail] Silent callback failure") + logger.info("[SilentCallback.fail] Silent callback failure") return def success(self, body): - logger.info(f"[SilentCallback.success] Silent callback success") + logger.info("[SilentCallback.success] Silent callback success") return class WebHookCallback(CallbackHandler): - def __init__(self, url: HttpUrl, payload: ResponsePayload): - super().__init__(payload) + def __init__(self, url: HttpUrl, collection_job): + super().__init__(collection_job) self.url = url logger.info( f"[WebHookCallback.init] Initialized webhook callback | {{'url': '{url}'}}" ) - def __call__(self, response: APIResponse, status: str): - time = ResponsePayload.now() - payload = replace(self.payload, status=status, time=time) - response.metadata = asdict(payload) + def __call__(self, response: APIResponse): logger.info( - f"[WebHookCallback.call] Posting callback | {{'url': '{self.url}', 'status': '{status}'}}" + f"[WebHookCallback.call] Posting callback | {{'url': '{self.url}'}}" ) post_callback(self.url, response) def fail(self, body): - logger.warning(f"[WebHookCallback.fail] Callback failed | {{'body': '{body}'}}") - self(APIResponse.failure_response(body), "incomplete") + logger.warning( + f"[WebHookCallback.fail] Callback failed | {{'error': '{body}'}}" + ) + response = APIResponse.failure_response( + error=str(body), + metadata={"collection_job_id": str(getattr(self.collection_job, "id", ""))}, + ) + self(response) def success(self, body): - logger.info(f"[WebHookCallback.success] Callback succeeded") - self(APIResponse.success_response(body), "complete") + logger.info("[WebHookCallback.success] Callback succeeded") + response = APIResponse.success_response(body) + self(response) -def _backout(crud: OpenAIAssistantCrud, assistant_id: str): +def _backout(crud, llm_service_id: str): """Best-effort cleanup: attempt to delete the assistant by ID""" try: - crud.delete(assistant_id) + crud.delete(llm_service_id) except OpenAIError as err: logger.error( - f"[backout] Failed to delete assistant | {{'assistant_id': '{assistant_id}', 'error': '{str(err)}'}}", + f"[backout] Failed to delete resource | {{'llm_service_id': '{llm_service_id}', 'error': '{str(err)}'}}", exc_info=True, ) diff --git a/backend/app/tests/api/routes/collections/test_collection_info.py b/backend/app/tests/api/routes/collections/test_collection_info.py index 2317ef24..4ae3341b 100644 --- a/backend/app/tests/api/routes/collections/test_collection_info.py +++ b/backend/app/tests/api/routes/collections/test_collection_info.py @@ -1,85 +1,30 @@ -from uuid import uuid4, UUID -from typing import Optional - from fastapi.testclient import TestClient from sqlmodel import Session from app.core.config import settings from app.core.util import now +from app.tests.utils.utils import get_project +from app.tests.utils.collection import get_collection, get_collection_job from app.models import ( - Collection, - CollectionJobCreate, CollectionActionType, CollectionJobStatus, - CollectionJobUpdate, ) -from app.crud import CollectionJobCrud, CollectionCrud - - -def create_collection( - db: Session, - user, - with_llm: bool = False, -): - """Create a Collection row (optionally prefilled with LLM service fields).""" - llm_service_id = None - llm_service_name = None - if with_llm: - llm_service_id = f"asst_{uuid4()}" - llm_service_name = "gpt-4o" - - collection = Collection( - id=uuid4(), - organization_id=user.organization_id, - project_id=user.project_id, - llm_service_id=llm_service_id, - llm_service_name=llm_service_name, - ) - - return CollectionCrud(db, user.project_id).create(collection) - - -def create_collection_job( - db: Session, - user, - collection_id: Optional[UUID] = None, - action_type: CollectionActionType = CollectionActionType.CREATE, - status: CollectionJobStatus = CollectionJobStatus.PENDING, -): - """Create a CollectionJob row (uses create schema for clarity).""" - job_in = CollectionJobCreate( - collection_id=collection_id, - project_id=user.project_id, - action_type=action_type, - status=status, - ) - collection_job = CollectionJobCrud(db, user.project_id).create(job_in) - - if collection_job.status == CollectionJobStatus.FAILED: - job_in = CollectionJobUpdate( - error_message="Something went wrong during the collection job process." - ) - collection_job = CollectionJobCrud(db, user.project_id).update( - collection_job.id, job_in - ) - - return collection_job def test_collection_info_processing( - db: Session, client: "TestClient", user_api_key_header, user_api_key + db: Session, client: "TestClient", user_api_key_header ): headers = user_api_key_header + project = get_project(db, "Dalgo") - collection_job = create_collection_job(db, user_api_key) + collection_job = get_collection_job(db, project) - response = client.get( + resp = client.get( f"{settings.API_V1_STR}/collections/info/jobs/{collection_job.id}", headers=headers, ) - - assert response.status_code == 200 - data = response.json()["data"] + assert resp.status_code == 200 + data = resp.json()["data"] assert data["status"] == CollectionJobStatus.PENDING assert data["inserted_at"] is not None @@ -88,22 +33,23 @@ def test_collection_info_processing( def test_collection_info_successful( - db: Session, client: "TestClient", user_api_key_header, user_api_key + db: Session, client: "TestClient", user_api_key_header ): headers = user_api_key_header + project = get_project(db, "Dalgo") - collection = create_collection(db, user_api_key, with_llm=True) - collection_job = create_collection_job( - db, user_api_key, collection.id, status=CollectionJobStatus.SUCCESSFUL + collection = get_collection(db, project) + + collection_job = get_collection_job( + db, project, collection_id=collection.id, status=CollectionJobStatus.SUCCESSFUL ) - response = client.get( + resp = client.get( f"{settings.API_V1_STR}/collections/info/jobs/{collection_job.id}", headers=headers, ) - - assert response.status_code == 200 - data = response.json()["data"] + assert resp.status_code == 200 + data = resp.json()["data"] assert data["id"] == str(collection_job.id) assert data["status"] == CollectionJobStatus.SUCCESSFUL @@ -117,22 +63,23 @@ def test_collection_info_successful( assert col["llm_service_name"] == "gpt-4o" -def test_collection_info_failed( - db: Session, client: "TestClient", user_api_key_header, user_api_key -): +def test_collection_info_failed(db: Session, client: "TestClient", user_api_key_header): headers = user_api_key_header + project = get_project(db, "Dalgo") - collection_job = create_collection_job( - db, user_api_key, status=CollectionJobStatus.FAILED + collection_job = get_collection_job( + db, + project, + status=CollectionJobStatus.FAILED, + error_message="something went wrong", ) - response = client.get( + resp = client.get( f"{settings.API_V1_STR}/collections/info/jobs/{collection_job.id}", headers=headers, ) - - assert response.status_code == 200 - data = response.json()["data"] + assert resp.status_code == 200 + data = resp.json()["data"] assert data["status"] == CollectionJobStatus.FAILED assert data["error_message"] is not None diff --git a/backend/app/tests/api/routes/collections/test_create_collections.py b/backend/app/tests/api/routes/collections/test_create_collections.py index 2b5d786b..6aa723cd 100644 --- a/backend/app/tests/api/routes/collections/test_create_collections.py +++ b/backend/app/tests/api/routes/collections/test_create_collections.py @@ -1,49 +1,136 @@ -from uuid import UUID +from uuid import UUID, uuid4 from unittest.mock import patch from fastapi.testclient import TestClient -from unittest.mock import patch -from app.models.collection import Collection, CreationRequest +from app.core.config import settings +from app.models import CollectionJobStatus, CollectionActionType +from app.models.collection import CreationRequest + + +def _extract_metadata(body: dict) -> dict | None: + return body.get("metadata") or body.get("meta") -def test_collection_creation_success( - client: TestClient, user_api_key_header: dict[str, str], user_api_key +@patch("app.api.routes.collections.create_service.start_job") +def test_collection_creation_with_assistant_calls_start_job_and_returns_job( + mock_start_job, + client: TestClient, + user_api_key_header: dict[str, str], + user_api_key, ): - with patch("app.api.routes.collections.create_service.start_job") as mock_job_start: - creation_data = CreationRequest( - model="gpt-4o", - instructions="string", - temperature=0.000001, - documents=[UUID("f3e86a17-1e6f-41ec-b020-5b08eebef928")], - batch_size=1, - callback_url=None, - ) - - resp = client.post( - "/api/v1/collections/create", - json=creation_data.model_dump(mode="json"), - headers=user_api_key_header, - ) - - assert resp.status_code == 200 - body = resp.json() - - data = body["data"] - assert isinstance(data, dict) - assert data["action_type"] == "CREATE" - assert data["status"] == "PENDING" - assert data["project_id"] == user_api_key.project_id - assert data["collection_id"] is None - assert data["task_id"] is None - assert "trace_id" in data - assert data["inserted_at"] - assert data["updated_at"] - - job_key = data["id"] - - mock_job_start.assert_called_once() - kwargs = mock_job_start.call_args.kwargs - assert "db" in kwargs - assert kwargs["request"] == creation_data - assert kwargs["collection_job_id"] == UUID(job_key) + creation_data = CreationRequest( + model="gpt-4o", + instructions="string", + temperature=0.000001, + documents=[UUID("f3e86a17-1e6f-41ec-b020-5b08eebef928")], + batch_size=1, + callback_url=None, + ) + + resp = client.post( + f"{settings.API_V1_STR}/collections/create", + json=creation_data.model_dump(mode="json"), + headers=user_api_key_header, + ) + + assert resp.status_code == 200 + body = resp.json() + + data = body["data"] + assert data["action_type"] == CollectionActionType.CREATE + assert data["status"] == CollectionJobStatus.PENDING + assert data["project_id"] == user_api_key.project_id + assert data["collection_id"] is None + assert data["task_id"] is None + assert "trace_id" in data + assert data["inserted_at"] + assert data["updated_at"] + + assert _extract_metadata(body) in (None, {}) + + mock_start_job.assert_called_once() + kwargs = mock_start_job.call_args.kwargs + assert "db" in kwargs + assert kwargs["project_id"] == user_api_key.project_id + assert kwargs["organization_id"] == user_api_key.organization_id + assert kwargs["with_assistant"] is True + + returned_job_id = UUID(data["id"]) + assert kwargs["collection_job_id"] == returned_job_id + + assert kwargs["request"].model_dump(mode="json") == creation_data.model_dump( + mode="json" + ) + + +@patch("app.api.routes.collections.create_service.start_job") +def test_collection_creation_vector_only_adds_metadata_and_sets_with_assistant_false( + mock_start_job, + client: TestClient, + user_api_key_header: dict[str, str], + user_api_key, +): + creation_data = CreationRequest( + temperature=0.000001, + documents=[str(uuid4())], + batch_size=1, + callback_url=None, + ) + + resp = client.post( + f"{settings.API_V1_STR}/collections/create", + json=creation_data.model_dump(mode="json"), + headers=user_api_key_header, + ) + + assert resp.status_code == 200 + body = resp.json() + + data = body["data"] + assert data["action_type"] == CollectionActionType.CREATE + assert data["status"] == CollectionJobStatus.PENDING + assert data["project_id"] == user_api_key.project_id + + meta = _extract_metadata(body) + assert isinstance(meta, dict) + assert meta.get("with_assistant") is False + assert "vector store only" in meta.get("note", "").lower() + + mock_start_job.assert_called_once() + kwargs = mock_start_job.call_args.kwargs + assert kwargs["project_id"] == user_api_key.project_id + assert kwargs["organization_id"] == user_api_key.organization_id + assert kwargs["with_assistant"] is False + assert kwargs["collection_job_id"] == UUID(data["id"]) + assert kwargs["request"].model_dump(mode="json") == creation_data.model_dump( + mode="json" + ) + + +def test_collection_creation_vector_only_request_validation_error( + client: TestClient, user_api_key_header: dict[str, str] +): + payload = { + "model": "gpt-4o", + "temperature": 0.000001, + "documents": [str(uuid4())], + "batch_size": 1, + "callback_url": None, + } + + resp = client.post( + f"{settings.API_V1_STR}/collections/create", + json=payload, + headers=user_api_key_header, + ) + + assert resp.status_code == 422 + body = resp.json() + assert body["success"] is False + assert body["data"] is None + assert body["metadata"] is None + assert ( + "To create an Assistant, provide BOTH 'model' and 'instructions'" + in body["error"] + ) diff --git a/backend/app/tests/crud/collections/collection/test_crud_collection_create.py b/backend/app/tests/crud/collections/collection/test_crud_collection_create.py index 925f595e..fc52cd08 100644 --- a/backend/app/tests/crud/collections/collection/test_crud_collection_create.py +++ b/backend/app/tests/crud/collections/collection/test_crud_collection_create.py @@ -1,11 +1,12 @@ +from uuid import uuid4 + import openai_responses from sqlmodel import Session, select from app.crud import CollectionCrud -from app.models import DocumentCollection +from app.models import DocumentCollection, Collection from app.tests.utils.document import DocumentStore from app.tests.utils.utils import get_project -from app.tests.utils.collection import get_collection class TestCollectionCreate: @@ -14,7 +15,14 @@ class TestCollectionCreate: @openai_responses.mock() def test_create_associates_documents(self, db: Session): project = get_project(db) - collection = get_collection(db, project_id=project.id) + collection = Collection( + id=uuid4(), + project_id=project.id, + organization_id=project.organization_id, + llm_service_id="asst_dummy", + llm_service_name="gpt-4o", + ) + store = DocumentStore(db, project_id=collection.project_id) documents = store.fill(self._n_documents) diff --git a/backend/app/tests/crud/collections/collection/test_crud_collection_delete.py b/backend/app/tests/crud/collections/collection/test_crud_collection_delete.py index e151a1c6..a2668b19 100644 --- a/backend/app/tests/crud/collections/collection/test_crud_collection_delete.py +++ b/backend/app/tests/crud/collections/collection/test_crud_collection_delete.py @@ -1,15 +1,37 @@ import pytest + import openai_responses from openai import OpenAI from sqlmodel import Session, select from app.core.config import settings from app.crud import CollectionCrud -from app.models import APIKey +from app.models import APIKey, Collection from app.crud.rag import OpenAIAssistantCrud from app.tests.utils.utils import get_project from app.tests.utils.document import DocumentStore -from app.tests.utils.collection import get_collection, uuid_increment + + +def get_collection_for_delete( + db: Session, client=None, project_id: int = None +) -> Collection: + project = get_project(db) + if client is None: + client = OpenAI(api_key="test_api_key") + + vector_store = client.vector_stores.create() + assistant = client.beta.assistants.create( + model="gpt-4o", + tools=[{"type": "file_search"}], + tool_resources={"file_search": {"vector_store_ids": [vector_store.id]}}, + ) + + return Collection( + organization_id=project.organization_id, + project_id=project_id, + llm_service_id=assistant.id, + llm_service_name="gpt-4o", + ) class TestCollectionDelete: @@ -21,7 +43,7 @@ def test_delete_marks_deleted(self, db: Session): client = OpenAI(api_key="sk-test-key") assistant = OpenAIAssistantCrud(client) - collection = get_collection(db, client, project_id=project.id) + collection = get_collection_for_delete(db, client, project_id=project.id) crud = CollectionCrud(db, collection.project_id) collection_ = crud.delete(collection, assistant) @@ -34,26 +56,13 @@ def test_delete_follows_insert(self, db: Session): assistant = OpenAIAssistantCrud(client) project = get_project(db) - collection = get_collection(db, project_id=project.id) + collection = get_collection_for_delete(db, project_id=project.id) crud = CollectionCrud(db, collection.project_id) collection_ = crud.delete(collection, assistant) assert collection_.inserted_at <= collection_.deleted_at - @openai_responses.mock() - def test_cannot_delete_others_collections(self, db: Session): - client = OpenAI(api_key="sk-test-key") - - assistant = OpenAIAssistantCrud(client) - project = get_project(db) - collection = get_collection(db, project_id=project.id) - c_id = uuid_increment(collection.id) - - crud = CollectionCrud(db, c_id) - with pytest.raises(PermissionError): - crud.delete(collection, assistant) - @openai_responses.mock() def test_delete_document_deletes_collections(self, db: Session): project = get_project(db) @@ -68,7 +77,7 @@ def test_delete_document_deletes_collections(self, db: Session): client = OpenAI(api_key="sk-test-key") resources = [] for _ in range(self._n_collections): - coll = get_collection(db, client, project_id=project.id) + coll = get_collection_for_delete(db, client, project_id=project.id) crud = CollectionCrud(db, project_id=project.id) collection = crud.create(coll, documents) resources.append((crud, collection)) diff --git a/backend/app/tests/crud/collections/collection/test_crud_collection_read_all.py b/backend/app/tests/crud/collections/collection/test_crud_collection_read_all.py index d1f329a2..a9da3523 100644 --- a/backend/app/tests/crud/collections/collection/test_crud_collection_read_all.py +++ b/backend/app/tests/crud/collections/collection/test_crud_collection_read_all.py @@ -1,4 +1,5 @@ import pytest + from openai_responses import OpenAIMock from openai import OpenAI from sqlmodel import Session @@ -17,7 +18,7 @@ def create_collections(db: Session, n: int): with openai_mock.router: client = OpenAI(api_key="sk-test-key") for _ in range(n): - collection = get_collection(db, client, project_id=project.id) + collection = get_collection(db, project=project) store = DocumentStore(db, project_id=collection.project_id) documents = store.fill(1) if crud is None: diff --git a/backend/app/tests/crud/collections/collection/test_crud_collection_read_one.py b/backend/app/tests/crud/collections/collection/test_crud_collection_read_one.py index acf7d39a..ceb46c1a 100644 --- a/backend/app/tests/crud/collections/collection/test_crud_collection_read_one.py +++ b/backend/app/tests/crud/collections/collection/test_crud_collection_read_one.py @@ -6,10 +6,9 @@ from sqlmodel import Session from app.crud import CollectionCrud -from app.core.config import settings from app.tests.utils.document import DocumentStore from app.tests.utils.utils import get_project -from app.tests.utils.collection import get_collection, uuid_increment +from app.tests.utils.collection import get_collection def mk_collection(db: Session): @@ -17,7 +16,7 @@ def mk_collection(db: Session): project = get_project(db) with openai_mock.router: client = OpenAI(api_key="sk-test-key") - collection = get_collection(db, client, project_id=project.id) + collection = get_collection(db, project=project) store = DocumentStore(db, project_id=collection.project_id) documents = store.fill(1) crud = CollectionCrud(db, collection.project_id) diff --git a/backend/app/tests/services/collections/test_create_collection.py b/backend/app/tests/services/collections/test_create_collection.py index 430e7b4b..3107a887 100644 --- a/backend/app/tests/services/collections/test_create_collection.py +++ b/backend/app/tests/services/collections/test_create_collection.py @@ -12,10 +12,11 @@ from app.core.config import settings from app.crud import CollectionCrud, CollectionJobCrud, DocumentCollectionCrud from app.models import CollectionJobStatus, CollectionJob, CollectionActionType -from app.models.collection import CreationRequest, ResponsePayload +from app.models.collection import CreationRequest from app.services.collections.create_collection import start_job, execute_job from app.tests.utils.openai import get_mock_openai_client_with_vector_store from app.tests.utils.utils import get_project +from app.tests.utils.collection import get_collection_job from app.tests.utils.document import DocumentStore @@ -61,11 +62,16 @@ def test_start_job_creates_collection_job_and_schedules_task(db: Session): batch_size=1, callback_url=None, ) - route = "/collections/create" - payload = ResponsePayload(status="processing", route=route) job_id = uuid4() - _ = create_collection_job_for_create(db, project, job_id) + _ = get_collection_job( + db, + project, + job_id=job_id, + action_type=CollectionActionType.CREATE, + status=CollectionJobStatus.PENDING, + collection_id=None, + ) with patch( "app.services.collections.create_collection.start_low_priority_job" @@ -76,8 +82,8 @@ def test_start_job_creates_collection_job_and_schedules_task(db: Session): db=db, request=request, project_id=project.id, - payload=payload, collection_job_id=job_id, + with_assistant=True, organization_id=project.organization_id, ) @@ -102,10 +108,7 @@ def test_start_job_creates_collection_job_and_schedules_task(db: Session): assert kwargs["project_id"] == project.id assert kwargs["organization_id"] == project.organization_id assert kwargs["job_id"] == str(job_id) - assert kwargs["request"] == request.model_dump() - - passed_payload = kwargs.get("payload", kwargs.get("payload_data")) - assert passed_payload == payload.model_dump() + assert kwargs["request"] == request.model_dump(mode="json") @pytest.mark.usefixtures("aws_credentials") @@ -141,20 +144,18 @@ def test_execute_job_success_flow_updates_job_and_creates_collection( batch_size=1, callback_url=None, ) - sample_payload = ResponsePayload(status="pending", route="/test/route") mock_client = get_mock_openai_client_with_vector_store() mock_get_openai_client.return_value = mock_client job_id = uuid4() - job_crud = CollectionJobCrud(db, project.id) - job_crud.create( - CollectionJob( - id=job_id, - project_id=project.id, - status=CollectionJobStatus.PENDING, - action_type=CollectionActionType.CREATE.value, - ) + _ = get_collection_job( + db, + project, + job_id=job_id, + action_type=CollectionActionType.CREATE, + status=CollectionJobStatus.PENDING, + collection_id=None, ) task_id = uuid4() @@ -165,10 +166,10 @@ def test_execute_job_success_flow_updates_job_and_creates_collection( execute_job( request=sample_request.model_dump(), - payload=sample_payload.model_dump(), project_id=project.id, organization_id=project.organization_id, task_id=str(task_id), + with_assistant=True, job_id=str(job_id), task_instance=None, ) @@ -188,3 +189,67 @@ def test_execute_job_success_flow_updates_job_and_creates_collection( docs = DocumentCollectionCrud(db).read(created_collection, skip=0, limit=10) assert len(docs) == 1 assert docs[0].fname == document.fname + + +@pytest.mark.usefixtures("aws_credentials") +@mock_aws +@patch("app.services.collections.create_collection.get_openai_client") +def test_execute_job_assistant_create_failure_marks_failed_and_deletes_vector( + mock_get_openai_client, db +): + project = get_project(db) + + job = get_collection_job( + db, + project, + job_id=uuid4(), + action_type=CollectionActionType.CREATE, + status=CollectionJobStatus.PENDING, + collection_id=None, + ) + + req = CreationRequest( + model="gpt-4o", + instructions="string", + temperature=0.0, + documents=[], + batch_size=1, + callback_url=None, + ) + + _ = mock_get_openai_client.return_value + + with patch( + "app.services.collections.create_collection.Session" + ) as SessionCtor, patch( + "app.services.collections.create_collection.OpenAIVectorStoreCrud" + ) as MockVS, patch( + "app.services.collections.create_collection.OpenAIAssistantCrud" + ) as MockAsst: + SessionCtor.return_value.__enter__.return_value = db + SessionCtor.return_value.__exit__.return_value = False + + MockVS.return_value.create.return_value = type( + "Vector store", (), {"id": "vs_123"} + )() + MockVS.return_value.update.return_value = [] + + MockAsst.return_value.create.side_effect = RuntimeError("assistant boom") + + task_id = str(uuid4()) + execute_job( + request=req.model_dump(), + project_id=project.id, + organization_id=project.organization_id, + task_id=task_id, + with_assistant=True, + job_id=str(job.id), + task_instance=None, + ) + + failed = CollectionJobCrud(db, project.id).read_one(job.id) + assert failed.task_id == task_id + assert failed.status == CollectionJobStatus.FAILED + assert "assistant boom" in (failed.error_message or "") + + MockVS.return_value.delete.assert_called_once_with("vs_123") diff --git a/backend/app/tests/services/collections/test_delete_collection.py b/backend/app/tests/services/collections/test_delete_collection.py index f6f55c6a..096ae5ed 100644 --- a/backend/app/tests/services/collections/test_delete_collection.py +++ b/backend/app/tests/services/collections/test_delete_collection.py @@ -1,63 +1,28 @@ from unittest.mock import patch, MagicMock -from uuid import uuid4, UUID +from uuid import uuid4 -from sqlmodel import Session from sqlalchemy.exc import SQLAlchemyError from app.models.collection import ( DeletionRequest, - Collection, - ResponsePayload, ) from app.tests.utils.utils import get_project -from app.crud import CollectionCrud, CollectionJobCrud -from app.models import CollectionJobStatus, CollectionJob, CollectionActionType +from app.crud import CollectionJobCrud +from app.models import CollectionJobStatus, CollectionActionType +from app.tests.utils.collection import get_collection, get_collection_job from app.services.collections.delete_collection import start_job, execute_job -def create_collection(db: Session, project): - collection = Collection( - id=uuid4(), - project_id=project.id, - organization_id=project.organization_id, - llm_service_id="asst-nasjnl", - llm_service_name="gpt-4o", - ) - return CollectionCrud(db, project.id).create(collection) - - -def create_collection_job( - db: Session, - project, - collection, - job_id: UUID | None = None, -): - if job_id is None: - job_id = uuid4() - job_crud = CollectionJobCrud(db, project.id) - return job_crud.create( - CollectionJob( - id=job_id, - action_type=CollectionActionType.DELETE, - project_id=project.id, - collection_id=collection.id, - status=CollectionJobStatus.PENDING, - ) - ) - - -def test_start_job_creates_collection_job_and_schedules_task(db: Session): +def test_start_job_creates_collection_job_and_schedules_task(db): """ - - start_job should update an existing CollectionJob (status=processing, action=delete) + - start_job should update an existing CollectionJob (status=PENDING, action=DELETE) - schedule the task with the provided job_id and collection_id - - return the same job_id (string) + - return the same job_id (UUID) """ project = get_project(db) - created_collection = create_collection(db, project) + created_collection = get_collection(db, project) req = DeletionRequest(collection_id=created_collection.id) - route = "/collections/delete" - payload = ResponsePayload(status="processing", route=route) with patch( "app.services.collections.delete_collection.start_low_priority_job" @@ -65,20 +30,20 @@ def test_start_job_creates_collection_job_and_schedules_task(db: Session): mock_schedule.return_value = "fake-task-id" collection_job_id = uuid4() - precreated = create_collection_job( - db=db, - project=project, - collection=created_collection, + _ = get_collection_job( + db, + project, job_id=collection_job_id, + action_type=CollectionActionType.DELETE, + status=CollectionJobStatus.PENDING, + collection_id=created_collection.id, ) returned = start_job( db=db, request=req, - collection=created_collection, project_id=project.id, collection_job_id=collection_job_id, - payload=payload, organization_id=project.organization_id, ) @@ -103,25 +68,30 @@ def test_start_job_creates_collection_job_and_schedules_task(db: Session): assert kwargs["organization_id"] == project.organization_id assert kwargs["job_id"] == str(job.id) assert kwargs["collection_id"] == str(created_collection.id) - assert kwargs["request"] == req.model_dump() - assert kwargs["payload"] == payload.model_dump() + assert kwargs["request"] == req.model_dump(mode="json") assert "trace_id" in kwargs @patch("app.services.collections.delete_collection.get_openai_client") def test_execute_job_delete_success_updates_job_and_calls_delete( - mock_get_openai_client, db: Session + mock_get_openai_client, db ): """ - execute_job should set task_id on the CollectionJob - - call CollectionCrud.delete(collection, assistant_crud) + - call remote delete via OpenAIAssistantCrud.delete(...) + - delete local record via CollectionCrud.delete_by_id(...) - mark job successful and clear error_message """ project = get_project(db) - collection = create_collection(db, project) - - job = create_collection_job(db, project, collection) + collection = get_collection(db, project, assistant_id="asst_123") + job = get_collection_job( + db, + project, + action_type=CollectionActionType.DELETE, + status=CollectionJobStatus.PENDING, + collection_id=collection.id, + ) mock_get_openai_client.return_value = MagicMock() @@ -138,25 +108,18 @@ def test_execute_job_delete_success_updates_job_and_calls_delete( collection_crud_instance = MockCollectionCrud.return_value collection_crud_instance.read_one.return_value = collection - deletion_result = MagicMock() - deletion_result.model_dump.return_value = { - "id": str(collection.id), - "deleted": True, - } - collection_crud_instance.delete.return_value = deletion_result + MockAssistantCrud.return_value.delete.return_value = None task_id = uuid4() req = DeletionRequest(collection_id=collection.id) - payload = ResponsePayload(status="processing", route="/test/delete") execute_job( - request=req.model_dump(), - payload=payload.model_dump(), + request=req.model_dump(mode="json"), project_id=project.id, organization_id=project.organization_id, task_id=str(task_id), job_id=str(job.id), - collection_id=collection.id, + collection_id=str(collection.id), task_instance=None, ) @@ -167,26 +130,30 @@ def test_execute_job_delete_success_updates_job_and_calls_delete( MockCollectionCrud.assert_called_with(db, project.id) collection_crud_instance.read_one.assert_called_once_with(collection.id) - collection_crud_instance.delete.assert_called_once() - args, kwargs = collection_crud_instance.delete.call_args - assert isinstance(args[0], Collection) + MockAssistantCrud.assert_called_once() + MockAssistantCrud.return_value.delete.assert_called_once_with("asst_123") + + collection_crud_instance.delete_by_id.assert_called_once_with(collection.id) mock_get_openai_client.assert_called_once() @patch("app.services.collections.delete_collection.get_openai_client") -def test_execute_job_delete_failure_marks_job_failed( - mock_get_openai_client, db: Session -): +def test_execute_job_delete_failure_marks_job_failed(mock_get_openai_client, db): """ - When CollectionCrud.delete raises (e.g., SQLAlchemyError), - the job should be marked failed and error_message set. + When the remote delete (OpenAIAssistantCrud.delete) raises, + the job should be marked FAILED and error_message set. """ project = get_project(db) - collection = create_collection(db, project) - - job = create_collection_job(db, project, collection) + collection = get_collection(db, project, assistant_id="asst_123") + job = get_collection_job( + db, + project, + action_type=CollectionActionType.DELETE, + status=CollectionJobStatus.PENDING, + collection_id=collection.id, + ) mock_get_openai_client.return_value = MagicMock() @@ -202,15 +169,16 @@ def test_execute_job_delete_failure_marks_job_failed( collection_crud_instance = MockCollectionCrud.return_value collection_crud_instance.read_one.return_value = collection - collection_crud_instance.delete.side_effect = SQLAlchemyError("boom") + + MockAssistantCrud.return_value.delete.side_effect = SQLAlchemyError( + "something went wrong" + ) task_id = uuid4() req = DeletionRequest(collection_id=collection.id) - payload = ResponsePayload(status="processing", route="/test/delete") execute_job( - request=req.model_dump(), - payload=payload.model_dump(), + request=req.model_dump(mode="json"), project_id=project.id, organization_id=project.organization_id, task_id=str(task_id), @@ -222,7 +190,16 @@ def test_execute_job_delete_failure_marks_job_failed( failed_job = CollectionJobCrud(db, project.id).read_one(job.id) assert failed_job.task_id == str(task_id) assert failed_job.status == CollectionJobStatus.FAILED - assert failed_job.error_message and "boom" in failed_job.error_message + assert ( + failed_job.error_message + and "something went wrong" in failed_job.error_message + ) - MockAssistantCrud.assert_called_once() MockCollectionCrud.assert_called_with(db, project.id) + collection_crud_instance.read_one.assert_called_once_with(collection.id) + + MockAssistantCrud.assert_called_once() + MockAssistantCrud.return_value.delete.assert_called_once_with("asst_123") + + collection_crud_instance.delete_by_id.assert_not_called() + mock_get_openai_client.assert_called_once() diff --git a/backend/app/tests/services/collections/test_helpers.py b/backend/app/tests/services/collections/test_helpers.py new file mode 100644 index 00000000..3c6149c4 --- /dev/null +++ b/backend/app/tests/services/collections/test_helpers.py @@ -0,0 +1,161 @@ +from __future__ import annotations + +import json +from types import SimpleNamespace +from uuid import uuid4 + +from app.services.collections import helpers + + +def test_extract_error_message_parses_json_and_strips_prefix(): + payload = {"error": {"message": "Inner JSON message"}} + err = Exception(f"Error code: 400 - {json.dumps(payload)}") + msg = helpers.extract_error_message(err) + assert msg == "Inner JSON message" + + +def test_extract_error_message_parses_python_dict_repr(): + payload = {"error": {"message": "Dict-repr message"}} + err = Exception(str(payload)) + msg = helpers.extract_error_message(err) + assert msg == "Dict-repr message" + + +def test_extract_error_message_falls_back_to_clean_text_and_truncates(): + long_text = "x" * 1500 + err = Exception(long_text) + msg = helpers.extract_error_message(err) + assert len(msg) == 1000 + assert msg == long_text[:1000] + + +def test_extract_error_message_handles_non_matching_bodies(): + err = Exception("some random error without structure") + msg = helpers.extract_error_message(err) + assert msg == "some random error without structure" + + +# batch documents + + +class FakeDocumentCrud: + def __init__(self): + self.calls = [] + + def read_each(self, ids): + self.calls.append(list(ids)) + return [ + SimpleNamespace( + id=i, fname=f"{i}.txt", object_store_url=f"s3://bucket/{i}.txt" + ) + for i in ids + ] + + +def test_batch_documents_even_chunks(): + crud = FakeDocumentCrud() + ids = [uuid4() for _ in range(6)] + batches = helpers.batch_documents(crud, ids, batch_size=3) + + # read_each called with chunks [0:3], [3:6] + assert crud.calls == [ids[0:3], ids[3:6]] + # output mirrors what read_each returned + assert len(batches) == 2 + assert [d.id for d in batches[0]] == ids[0:3] + assert [d.id for d in batches[1]] == ids[3:6] + + +def test_batch_documents_ragged_last_chunk(): + crud = FakeDocumentCrud() + ids = [uuid4() for _ in range(5)] + batches = helpers.batch_documents(crud, ids, batch_size=2) + + assert crud.calls == [ids[0:2], ids[2:4], ids[4:5]] + assert [d.id for d in batches[0]] == ids[0:2] + assert [d.id for d in batches[1]] == ids[2:4] + assert [d.id for d in batches[2]] == ids[4:5] + + +def test_batch_documents_empty_input(): + crud = FakeDocumentCrud() + batches = helpers.batch_documents(crud, [], batch_size=3) + assert batches == [] + assert crud.calls == [] + + +def test_silent_callback_is_noop(): + job = SimpleNamespace(id=uuid4()) + cb = helpers.SilentCallback(job) + cb.success({"ok": True}) + cb.fail("oops") + + +def test_webhook_callback_success_posts(monkeypatch): + job = SimpleNamespace(id=uuid4()) + url = "https://example.com/hook" + sent = {} + + def fake_post_callback(u, response): + sent["url"] = u + sent["response"] = response + + monkeypatch.setattr(helpers, "post_callback", fake_post_callback) + + cb = helpers.WebHookCallback(url=url, collection_job=job) + payload = {"collection_id": "abc123", "deleted": True} + cb.success(payload) + + assert sent["url"] == url + assert isinstance(sent["response"], helpers.APIResponse) + assert getattr(sent["response"], "data", None) == payload + assert getattr(sent["response"], "success", True) is True + + +def test_webhook_callback_fail_posts_with_job_id(monkeypatch): + job_id = uuid4() + job = SimpleNamespace(id=job_id) + url = "https://example.com/hook" + sent = {} + + def fake_post_callback(u, response): + sent["url"] = u + sent["response"] = response + + monkeypatch.setattr(helpers, "post_callback", fake_post_callback) + + cb = helpers.WebHookCallback(url=url, collection_job=job) + cb.fail("boom") + + assert sent["url"] == url + assert isinstance(sent["response"], helpers.APIResponse) + assert getattr(sent["response"], "success", False) is False + meta = getattr(sent["response"], "metadata", {}) or {} + assert meta.get("collection_job_id") == str(job_id) + assert "boom" in (getattr(sent["response"], "error", "") or "") + + +# _backout + + +def test_backout_calls_delete_and_swallows_openai_error(monkeypatch): + class Crud: + def __init__(self): + self.calls = 0 + + def delete(self, resource_id: str): + self.calls += 1 + + crud = Crud() + helpers._backout(crud, "rsrc_1") + assert crud.calls == 1 + + class DummyOpenAIError(Exception): + pass + + monkeypatch.setattr(helpers, "OpenAIError", DummyOpenAIError) + + class FailingCrud: + def delete(self, resource_id: str): + raise DummyOpenAIError("nope") + + helpers._backout(FailingCrud(), "rsrc_2") diff --git a/backend/app/tests/utils/collection.py b/backend/app/tests/utils/collection.py index 36e1ecf0..429bfc8b 100644 --- a/backend/app/tests/utils/collection.py +++ b/backend/app/tests/utils/collection.py @@ -1,14 +1,15 @@ -from uuid import UUID -from uuid import uuid4 +from uuid import UUID, uuid4 +from typing import Optional -from openai import OpenAI from sqlmodel import Session -from app.core.config import settings -from app.models import Collection, Organization, Project -from app.tests.utils.utils import get_user_id_by_email, get_project -from app.tests.utils.test_data import create_test_project -from app.tests.utils.test_data import create_test_api_key +from app.models import ( + Collection, + CollectionActionType, + CollectionJob, + CollectionJobStatus, +) +from app.crud import CollectionCrud, CollectionJobCrud class constants: @@ -17,25 +18,78 @@ class constants: def uuid_increment(value: UUID): - inc = int(value) + 1 # hopefully doesn't overflow! + inc = int(value) + 1 return UUID(int=inc) -def get_collection(db: Session, client=None, project_id: int = None) -> Collection: - project = get_project(db) - if client is None: - client = OpenAI(api_key="test_api_key") +def get_collection( + db: Session, + project, + *, + assistant_id: Optional[str] = None, + model: str = "gpt-4o", + collection_id: Optional[UUID] = None, +) -> Collection: + """ + Create a Collection configured for the Assistant path. + execute_job will treat this as `is_vector = False` and use assistant id. + """ + if assistant_id is None: + assistant_id = f"asst_{uuid4().hex}" - vector_store = client.vector_stores.create() - assistant = client.beta.assistants.create( - model=constants.openai_model, - tools=[{"type": "file_search"}], - tool_resources={"file_search": {"vector_store_ids": [vector_store.id]}}, + collection = Collection( + id=collection_id or uuid4(), + project_id=project.id, + organization_id=project.organization_id, + llm_service_name=model, + llm_service_id=assistant_id, ) + return CollectionCrud(db, project.id).create(collection) + + +def get_vector_store_collection( + db: Session, + project, + *, + vector_store_id: Optional[str] = None, + collection_id: Optional[UUID] = None, +) -> Collection: + """ + Create a Collection configured for the Vector Store path. + execute_job will treat this as `is_vector = True` and use vector store id. + """ + if vector_store_id is None: + vector_store_id = f"vs_{uuid4().hex}" - return Collection( + collection = Collection( + id=collection_id or uuid4(), + project_id=project.id, organization_id=project.organization_id, - project_id=project_id, - llm_service_id=assistant.id, - llm_service_name=constants.llm_service_name, + llm_service_name="openai vector store", + llm_service_id=vector_store_id, + ) + return CollectionCrud(db, project.id).create(collection) + + +def get_collection_job( + db: Session, + project, + *, + action_type: CollectionActionType = CollectionActionType.CREATE, + status: CollectionJobStatus = CollectionJobStatus.PENDING, + collection_id: Optional[UUID] = None, + error_message: Optional[str] = None, + job_id: Optional[UUID] = None, +) -> CollectionJob: + """ + Generic seed for a CollectionJob row. + """ + job = CollectionJob( + id=job_id or uuid4(), + project_id=project.id, + action_type=action_type.value if hasattr(action_type, "value") else action_type, + status=status.value if hasattr(status, "value") else status, + error_message=error_message, + collection_id=collection_id, ) + return CollectionJobCrud(db, project.id).create(job)