-
Notifications
You must be signed in to change notification settings - Fork 5
Refactor Collection background tasks to Remove session dependency #353
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,7 +13,9 @@ | |
| from fastapi import Path as FastPath | ||
| from pydantic import BaseModel, Field, HttpUrl | ||
| from sqlalchemy.exc import SQLAlchemyError | ||
| from sqlmodel import Session | ||
|
|
||
| from app.core.db import engine | ||
| from app.api.deps import CurrentUser, SessionDep, CurrentUserOrgProject | ||
| from app.core.cloud import AmazonCloudStorage | ||
| from app.api.routes.responses import handle_openai_error | ||
|
|
@@ -24,7 +26,7 @@ | |
| DocumentCollectionCrud, | ||
| ) | ||
| from app.crud.rag import OpenAIVectorStoreCrud, OpenAIAssistantCrud | ||
| from app.models import Collection, Document | ||
| from app.models import Collection, Document, UserProjectOrg | ||
| from app.models.collection import CollectionStatus | ||
| from app.utils import APIResponse, load_description, get_openai_client | ||
|
|
||
|
|
@@ -209,10 +211,12 @@ def _backout(crud: OpenAIAssistantCrud, assistant_id: str): | |
| exc_info=True, | ||
| ) | ||
|
|
||
|
|
||
| # TODO: Avoid passing ORM / object models (e.g., UserProjectOrg, ResponsePayload) | ||
| # into background job functions. When moving to Celery or another task queue, | ||
| # this can cause issues with serialization/deserialization. Instead, pass only | ||
| # primitive types (IDs, strings, etc.) and rehydrate objects inside the task. | ||
| def do_create_collection( | ||
| session: SessionDep, | ||
| current_user: CurrentUserOrgProject, | ||
| current_user: UserProjectOrg, | ||
| request: CreationRequest, | ||
| payload: ResponsePayload, | ||
| client: OpenAI, | ||
|
|
@@ -226,37 +230,41 @@ def do_create_collection( | |
| ) | ||
|
|
||
| storage = AmazonCloudStorage(current_user) | ||
| document_crud = DocumentCrud(session, current_user.id) | ||
| assistant_crud = OpenAIAssistantCrud(client) | ||
| vector_store_crud = OpenAIVectorStoreCrud(client) | ||
| collection_crud = CollectionCrud(session, current_user.id) | ||
|
|
||
| try: | ||
| vector_store = vector_store_crud.create() | ||
|
|
||
| docs = list(request(document_crud)) | ||
| flat_docs = [doc for sublist in docs for doc in sublist] | ||
| docs = [] | ||
| flat_docs = [] | ||
| with Session(engine) as session: | ||
| document_crud = DocumentCrud(session, current_user.id) | ||
| docs = list(request(document_crud)) | ||
| flat_docs = [doc for sublist in docs for doc in sublist] | ||
|
|
||
| file_exts = {doc.fname.split(".")[-1] for doc in flat_docs if "." in doc.fname} | ||
| file_sizes_kb = [ | ||
| storage.get_file_size_kb(doc.object_store_url) for doc in flat_docs | ||
| ] | ||
|
|
||
| vector_store = vector_store_crud.create() | ||
| list(vector_store_crud.update(vector_store.id, storage, docs)) | ||
|
|
||
| assistant_options = dict(request.extract_super_type(AssistantOptions)) | ||
| assistant = assistant_crud.create(vector_store.id, **assistant_options) | ||
|
|
||
| collection = collection_crud.read_one(UUID(payload.key)) | ||
| collection.llm_service_id = assistant.id | ||
| collection.llm_service_name = request.model | ||
| collection.status = CollectionStatus.successful | ||
| collection.updated_at = now() | ||
| # Update database with results | ||
| with Session(engine) as session: | ||
| collection_crud = CollectionCrud(session, current_user.id) | ||
| collection = collection_crud.read_one(UUID(payload.key)) | ||
| collection.llm_service_id = assistant.id | ||
| collection.llm_service_name = request.model | ||
| collection.status = CollectionStatus.successful | ||
| collection.updated_at = now() | ||
|
|
||
| if flat_docs: | ||
| DocumentCollectionCrud(session).create(collection, flat_docs) | ||
| if flat_docs: | ||
| DocumentCollectionCrud(session).create(collection, flat_docs) | ||
|
|
||
| collection_crud._update(collection) | ||
| collection_crud._update(collection) | ||
|
|
||
| elapsed = time.time() - start_time | ||
| logger.info( | ||
|
|
@@ -272,14 +280,16 @@ def do_create_collection( | |
| ) | ||
| if "assistant" in locals(): | ||
| _backout(assistant_crud, assistant.id) | ||
| try: | ||
| collection = collection_crud.read_one(UUID(payload.key)) | ||
| collection.status = CollectionStatus.failed | ||
| collection.updated_at = now() | ||
| message = extract_error_message(err) | ||
| collection.error_message = message | ||
|
|
||
| collection_crud._update(collection) | ||
| try: | ||
| with Session(engine) as session: | ||
| collection_crud = CollectionCrud(session, current_user.id) | ||
| collection = collection_crud.read_one(UUID(payload.key)) | ||
| collection.status = CollectionStatus.failed | ||
| collection.updated_at = now() | ||
| message = extract_error_message(err) | ||
| collection.error_message = message | ||
| collection_crud._update(collection) | ||
| except Exception as suberr: | ||
avirajsingh7 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| logger.warning( | ||
| f"[do_create_collection] Failed to update collection status | {{'collection_id': '{payload.key}', 'reason': '{str(suberr)}'}}" | ||
|
|
@@ -317,7 +327,7 @@ def create_collection( | |
| collection_crud.create(collection) | ||
|
|
||
| background_tasks.add_task( | ||
| do_create_collection, session, current_user, request, payload, client | ||
| do_create_collection, current_user, request, payload, client | ||
| ) | ||
|
|
||
| logger.info( | ||
|
|
@@ -327,9 +337,12 @@ def create_collection( | |
| return APIResponse.success_response(data=None, metadata=asdict(payload)) | ||
|
|
||
|
|
||
| # TODO: Avoid passing ORM / object models (e.g., UserProjectOrg, ResponsePayload) | ||
| # into background job functions. When moving to Celery or another task queue, | ||
| # this can cause issues with serialization/deserialization. Instead, pass only | ||
| # primitive types (IDs, strings, etc.) and rehydrate objects inside the task. | ||
| def do_delete_collection( | ||
| session: SessionDep, | ||
| current_user: CurrentUserOrgProject, | ||
| current_user: UserProjectOrg, | ||
| request: DeletionRequest, | ||
| payload: ResponsePayload, | ||
| client: OpenAI, | ||
|
|
@@ -339,15 +352,21 @@ def do_delete_collection( | |
| else: | ||
| callback = WebHookCallback(request.callback_url, payload) | ||
|
|
||
| collection_crud = CollectionCrud(session, current_user.id) | ||
| try: | ||
| collection = collection_crud.read_one(request.collection_id) | ||
| assistant = OpenAIAssistantCrud(client) | ||
| data = collection_crud.delete(collection, assistant) | ||
| logger.info( | ||
| f"[do_delete_collection] Collection deleted successfully | {{'collection_id': '{collection.id}'}}" | ||
| ) | ||
| callback.success(data.model_dump(mode="json")) | ||
| with Session(engine) as session: | ||
| collection_crud = CollectionCrud(session, current_user.id) | ||
| collection = collection_crud.read_one(request.collection_id) | ||
| assistant = OpenAIAssistantCrud(client) | ||
| # TODO: Decouple OpenAI collection deletion from DB session handling. | ||
| # Currently, the call to OpenAI is tightly coupled with the session, | ||
| # which may keep the session open until deletion completes. | ||
|
|
||
| data = collection_crud.delete(collection, assistant) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As call to openAI is tightly coupled with collection delete, session may be open untill collection deleted from openAI. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good point if we are not doing this right now, at least put a TODO comment here in code in this PR itself so that we can revisit this later |
||
| logger.info( | ||
| f"[do_delete_collection] Collection deleted successfully | {{'collection_id': '{collection.id}'}}" | ||
| ) | ||
| callback.success(data.model_dump(mode="json")) | ||
|
|
||
avirajsingh7 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| except (ValueError, PermissionError, SQLAlchemyError) as err: | ||
| logger.error( | ||
| f"[do_delete_collection] Failed to delete collection | {{'collection_id': '{request.collection_id}', 'error': '{str(err)}'}}", | ||
|
|
@@ -381,7 +400,7 @@ def delete_collection( | |
| payload = ResponsePayload("processing", route) | ||
|
|
||
| background_tasks.add_task( | ||
| do_delete_collection, session, current_user, request, payload, client | ||
| do_delete_collection, current_user, request, payload, client | ||
| ) | ||
|
|
||
| logger.info( | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.