Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 56 additions & 37 deletions backend/app/api/routes/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
from fastapi import Path as FastPath
from pydantic import BaseModel, Field, HttpUrl
from sqlalchemy.exc import SQLAlchemyError
from sqlmodel import Session

from app.core.db import engine
from app.api.deps import CurrentUser, SessionDep, CurrentUserOrgProject
from app.core.cloud import AmazonCloudStorage
from app.api.routes.responses import handle_openai_error
Expand All @@ -24,7 +26,7 @@
DocumentCollectionCrud,
)
from app.crud.rag import OpenAIVectorStoreCrud, OpenAIAssistantCrud
from app.models import Collection, Document
from app.models import Collection, Document, UserProjectOrg
from app.models.collection import CollectionStatus
from app.utils import APIResponse, load_description, get_openai_client

Expand Down Expand Up @@ -209,10 +211,12 @@ def _backout(crud: OpenAIAssistantCrud, assistant_id: str):
exc_info=True,
)


# TODO: Avoid passing ORM / object models (e.g., UserProjectOrg, ResponsePayload)
# into background job functions. When moving to Celery or another task queue,
# this can cause issues with serialization/deserialization. Instead, pass only
# primitive types (IDs, strings, etc.) and rehydrate objects inside the task.
def do_create_collection(
session: SessionDep,
current_user: CurrentUserOrgProject,
current_user: UserProjectOrg,
request: CreationRequest,
payload: ResponsePayload,
client: OpenAI,
Expand All @@ -226,37 +230,41 @@ def do_create_collection(
)

storage = AmazonCloudStorage(current_user)
document_crud = DocumentCrud(session, current_user.id)
assistant_crud = OpenAIAssistantCrud(client)
vector_store_crud = OpenAIVectorStoreCrud(client)
collection_crud = CollectionCrud(session, current_user.id)

try:
vector_store = vector_store_crud.create()

docs = list(request(document_crud))
flat_docs = [doc for sublist in docs for doc in sublist]
docs = []
flat_docs = []
with Session(engine) as session:
document_crud = DocumentCrud(session, current_user.id)
docs = list(request(document_crud))
flat_docs = [doc for sublist in docs for doc in sublist]

file_exts = {doc.fname.split(".")[-1] for doc in flat_docs if "." in doc.fname}
file_sizes_kb = [
storage.get_file_size_kb(doc.object_store_url) for doc in flat_docs
]

vector_store = vector_store_crud.create()
list(vector_store_crud.update(vector_store.id, storage, docs))

assistant_options = dict(request.extract_super_type(AssistantOptions))
assistant = assistant_crud.create(vector_store.id, **assistant_options)

collection = collection_crud.read_one(UUID(payload.key))
collection.llm_service_id = assistant.id
collection.llm_service_name = request.model
collection.status = CollectionStatus.successful
collection.updated_at = now()
# Update database with results
with Session(engine) as session:
collection_crud = CollectionCrud(session, current_user.id)
collection = collection_crud.read_one(UUID(payload.key))
collection.llm_service_id = assistant.id
collection.llm_service_name = request.model
collection.status = CollectionStatus.successful
collection.updated_at = now()

if flat_docs:
DocumentCollectionCrud(session).create(collection, flat_docs)
if flat_docs:
DocumentCollectionCrud(session).create(collection, flat_docs)

collection_crud._update(collection)
collection_crud._update(collection)

elapsed = time.time() - start_time
logger.info(
Expand All @@ -272,14 +280,16 @@ def do_create_collection(
)
if "assistant" in locals():
_backout(assistant_crud, assistant.id)
try:
collection = collection_crud.read_one(UUID(payload.key))
collection.status = CollectionStatus.failed
collection.updated_at = now()
message = extract_error_message(err)
collection.error_message = message

collection_crud._update(collection)
try:
with Session(engine) as session:
collection_crud = CollectionCrud(session, current_user.id)
collection = collection_crud.read_one(UUID(payload.key))
collection.status = CollectionStatus.failed
collection.updated_at = now()
message = extract_error_message(err)
collection.error_message = message
collection_crud._update(collection)
except Exception as suberr:
logger.warning(
f"[do_create_collection] Failed to update collection status | {{'collection_id': '{payload.key}', 'reason': '{str(suberr)}'}}"
Expand Down Expand Up @@ -317,7 +327,7 @@ def create_collection(
collection_crud.create(collection)

background_tasks.add_task(
do_create_collection, session, current_user, request, payload, client
do_create_collection, current_user, request, payload, client
)

logger.info(
Expand All @@ -327,9 +337,12 @@ def create_collection(
return APIResponse.success_response(data=None, metadata=asdict(payload))


# TODO: Avoid passing ORM / object models (e.g., UserProjectOrg, ResponsePayload)
# into background job functions. When moving to Celery or another task queue,
# this can cause issues with serialization/deserialization. Instead, pass only
# primitive types (IDs, strings, etc.) and rehydrate objects inside the task.
def do_delete_collection(
session: SessionDep,
current_user: CurrentUserOrgProject,
current_user: UserProjectOrg,
request: DeletionRequest,
payload: ResponsePayload,
client: OpenAI,
Expand All @@ -339,15 +352,21 @@ def do_delete_collection(
else:
callback = WebHookCallback(request.callback_url, payload)

collection_crud = CollectionCrud(session, current_user.id)
try:
collection = collection_crud.read_one(request.collection_id)
assistant = OpenAIAssistantCrud(client)
data = collection_crud.delete(collection, assistant)
logger.info(
f"[do_delete_collection] Collection deleted successfully | {{'collection_id': '{collection.id}'}}"
)
callback.success(data.model_dump(mode="json"))
with Session(engine) as session:
collection_crud = CollectionCrud(session, current_user.id)
collection = collection_crud.read_one(request.collection_id)
assistant = OpenAIAssistantCrud(client)
# TODO: Decouple OpenAI collection deletion from DB session handling.
# Currently, the call to OpenAI is tightly coupled with the session,
# which may keep the session open until deletion completes.

data = collection_crud.delete(collection, assistant)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As call to openAI is tightly coupled with collection delete, session may be open untill collection deleted from openAI.
@kartpop @AkhileshNegi should we de-couple this from session?
Not doing this right now because this may require changes in document endpoint also. or better question can be do we require background job to delete a collection?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point

if we are not doing this right now, at least put a TODO comment here in code in this PR itself so that we can revisit this later

logger.info(
f"[do_delete_collection] Collection deleted successfully | {{'collection_id': '{collection.id}'}}"
)
callback.success(data.model_dump(mode="json"))

except (ValueError, PermissionError, SQLAlchemyError) as err:
logger.error(
f"[do_delete_collection] Failed to delete collection | {{'collection_id': '{request.collection_id}', 'error': '{str(err)}'}}",
Expand Down Expand Up @@ -381,7 +400,7 @@ def delete_collection(
payload = ResponsePayload("processing", route)

background_tasks.add_task(
do_delete_collection, session, current_user, request, payload, client
do_delete_collection, current_user, request, payload, client
)

logger.info(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from sqlmodel import Session
from fastapi.testclient import TestClient
from unittest.mock import patch
from unittest.mock import patch, MagicMock

from app.core.config import settings
from app.tests.utils.document import DocumentStore
Expand Down Expand Up @@ -43,13 +43,30 @@ class TestCollectionRouteCreate:
_n_documents = 5

@patch("app.api.routes.collections.get_openai_client")
@patch("fastapi.BackgroundTasks.add_task")
@patch("app.api.routes.collections.Session")
def test_create_collection_success(
self,
mock_session,
mock_add_task,
mock_get_openai_client,
client: TestClient,
db: Session,
user_api_key_header,
):
# Configure mocks
mock_openai_client = get_mock_openai_client_with_vector_store()
mock_get_openai_client.return_value = mock_openai_client

# # Make background task run synchronously
def run_task_immediately(task_func, *args, **kwargs):
task_func(*args, **kwargs)

mock_add_task.side_effect = run_task_immediately

# Mock Session to return the test database session
mock_session.return_value.__enter__.return_value = db

store = DocumentStore(db)
documents = store.fill(self._n_documents)
doc_ids = [str(doc.id) for doc in documents]
Expand All @@ -64,9 +81,6 @@ def test_create_collection_success(

headers = user_api_key_header

mock_openai_client = get_mock_openai_client_with_vector_store()
mock_get_openai_client.return_value = mock_openai_client

response = client.post(
f"{settings.API_V1_STR}/collections/create", json=body, headers=headers
)
Expand All @@ -88,7 +102,6 @@ def test_create_collection_success(
headers=headers,
)
assert info_response.status_code == 200
info_data = info_response.json()["data"]

assert collection.status == CollectionStatus.successful.value
assert collection.owner_id == user.user_id
Expand Down
Loading