From 675885c762578c42f2a071c34544ac6855dfa828 Mon Sep 17 00:00:00 2001 From: AmineDiro Date: Tue, 4 Jun 2024 15:29:27 +0200 Subject: [PATCH] feat(upload): async improved (#2544) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Description Hey, Here's a breakdown of what I've done: - Reducing the number of opened fd and memory footprint: Previously, for each uploaded file, we were opening a temporary NamedTemporaryFile to write existing content read from Supabase. However, due to the dependency on `langchain` loader classes, we couldn't use memory buffers for the loaders. Now, with the changes made, we only open a single temporary file for each `process_file_and_notify`, cutting down on excessive file opening, read syscalls, and memory buffer usage. This could cause stability issues when ingesting and processing large volumes of documents. Unfortunately, there is still reopening of temporary files in some code paths but this can be improved further in later work. - Removing `UploadFile` class from File: The `UploadFile` ( a FastAPI abstraction over a SpooledTemporaryFile for multipart upload) was redundant in our `File` setup since we already downloaded the file from remote storage and read it into memory + wrote the file into a temp file. By removing this abstraction, we streamline our code and eliminate unnecessary complexity. - `async` function Adjustments: I've removed the async labeling from functions where it wasn't truly asynchronous. For instance, calling `filter_file` for processing files isn't genuinely async, ass async file reading isn't actually asynchronous—it [uses a threadpool for reading the file](https://github.com/encode/starlette/blob/9f16bf5c25e126200701f6e04330864f4a91a898/starlette/datastructures.py#L458) . Given that we're already leveraging `celery` for parallelism (one worker per core), we need to ensure that reading and processing occur in the same thread, or at least minimize thread spawning. Additionally, since the rest of the code isn't inherently asynchronous, our bottleneck lies in CPU operations rather than asynchronous processing. These changes aim to improve performance and streamline our codebase. Let me know if you have any questions or suggestions for further improvements! ## Checklist before requesting a review - [x] My code follows the style guidelines of this project - [x] I have performed a self-review of my code - [x] I have ideally added tests that prove my fix is effective or that my feature works --------- Signed-off-by: aminediro Co-authored-by: aminediro Co-authored-by: Stan Girard --- backend/celery_config.py | 3 + backend/celery_worker.py | 113 +++++++++--------- backend/main.py | 12 +- backend/models/files.py | 74 +++--------- backend/modules/brain/rags/quivr_rag.py | 2 +- .../upload/controller/upload_routes.py | 4 +- backend/packages/files/crawl/crawler.py | 15 +-- backend/packages/files/parsers/audio.py | 32 +---- backend/packages/files/parsers/code_python.py | 4 +- backend/packages/files/parsers/common.py | 2 +- backend/packages/files/parsers/github.py | 2 +- backend/packages/files/parsers/pdf.py | 6 +- backend/packages/files/parsers/txt.py | 4 +- backend/packages/files/processors.py | 9 +- 14 files changed, 104 insertions(+), 178 deletions(-) diff --git a/backend/celery_config.py b/backend/celery_config.py index bb7eaa2b2d8..85d093168f5 100644 --- a/backend/celery_config.py +++ b/backend/celery_config.py @@ -1,8 +1,11 @@ # celery_config.py import os +import dotenv from celery import Celery +dotenv.load_dotenv() + CELERY_BROKER_URL = os.getenv("CELERY_BROKER_URL", "") CELERY_BROKER_QUEUE_NAME = os.getenv("CELERY_BROKER_QUEUE_NAME", "quivr") diff --git a/backend/celery_worker.py b/backend/celery_worker.py index 34a6b98738e..8baddcecb51 100644 --- a/backend/celery_worker.py +++ b/backend/celery_worker.py @@ -1,11 +1,10 @@ -import asyncio -import io import os from datetime import datetime, timezone +from tempfile import NamedTemporaryFile +from uuid import UUID from celery.schedules import crontab from celery_config import celery -from fastapi import UploadFile from logger import get_logger from middlewares.auth.auth_bearer import AuthBearer from models.files import File @@ -18,7 +17,7 @@ from modules.notification.entity.notification import NotificationsStatusEnum from modules.notification.service.notification_service import NotificationService from modules.onboarding.service.onboarding_service import OnboardingService -from packages.files.crawl.crawler import CrawlWebsite +from packages.files.crawl.crawler import CrawlWebsite, slugify from packages.files.parsers.github import process_github from packages.files.processors import filter_file from packages.utils.telemetry import maybe_send_telemetry @@ -42,39 +41,36 @@ def process_file_and_notify( ): try: supabase_client = get_supabase_client() - tmp_file_name = "tmp-file-" + file_name - tmp_file_name = tmp_file_name.replace("/", "_") + tmp_name = file_name.replace("/", "_") + base_file_name = os.path.basename(file_name) + _, file_extension = os.path.splitext(base_file_name) - with open(tmp_file_name, "wb+") as f: + with NamedTemporaryFile( + suffix="_" + tmp_name, # pyright: ignore reportPrivateUsage=none + ) as tmp_file: res = supabase_client.storage.from_("quivr").download(file_name) - f.write(res) - f.seek(0) - file_content = f.read() - - upload_file = UploadFile( - file=f, filename=file_name.split("/")[-1], size=len(file_content) + tmp_file.write(res) + tmp_file.flush() + file_instance = File( + file_name=base_file_name, + tmp_file_path=tmp_file.name, + bytes_content=res, + file_size=len(res), + file_extension=file_extension, ) - - file_instance = File(file=upload_file) - loop = asyncio.get_event_loop() brain_vector_service = BrainVectorService(brain_id) if delete_file: # TODO fix bug brain_vector_service.delete_file_from_brain( file_original_name, only_vectors=True ) - message = loop.run_until_complete( - filter_file( - file=file_instance, - brain_id=brain_id, - original_file_name=file_original_name, - ) - ) - f.close() - os.remove(tmp_file_name) + message = filter_file( + file=file_instance, + brain_id=brain_id, + original_file_name=file_original_name, + ) if notification_id: - notification_service.update_notification_by_id( notification_id, NotificationUpdatableProperties( @@ -85,10 +81,12 @@ def process_file_and_notify( brain_service.update_brain_last_update_time(brain_id) return True + except TimeoutError: logger.error("TimeoutError") except Exception as e: + logger.exception(e) notification_service.update_notification_by_id( notification_id, NotificationUpdatableProperties( @@ -96,52 +94,51 @@ def process_file_and_notify( description=f"An error occurred while processing the file: {e}", ), ) - return False @celery.task(name="process_crawl_and_notify") def process_crawl_and_notify( - crawl_website_url, - brain_id, + crawl_website_url: str, + brain_id: UUID, notification_id=None, ): + crawl_website = CrawlWebsite(url=crawl_website_url) if not crawl_website.checkGithub(): - file_path, file_name = crawl_website.process() - - with open(file_path, "rb") as f: - file_content = f.read() - - # Create a file-like object in memory using BytesIO - file_object = io.BytesIO(file_content) - upload_file = UploadFile( - file=file_object, filename=file_name, size=len(file_content) - ) - file_instance = File(file=upload_file) - - loop = asyncio.get_event_loop() - message = loop.run_until_complete( - filter_file( + # Build file data + extracted_content = crawl_website.process() + extracted_content_bytes = extracted_content.encode("utf-8") + file_name = slugify(crawl_website.url) + ".txt" + + with NamedTemporaryFile( + suffix="_" + file_name, # pyright: ignore reportPrivateUsage=none + ) as tmp_file: + tmp_file.write(extracted_content_bytes) + tmp_file.flush() + file_instance = File( + file_name=file_name, + tmp_file_path=tmp_file.name, + bytes_content=extracted_content_bytes, + file_size=len(extracted_content), + file_extension=".txt", + ) + message = filter_file( file=file_instance, brain_id=brain_id, original_file_name=crawl_website_url, ) - ) - notification_service.update_notification_by_id( - notification_id, - NotificationUpdatableProperties( - status=NotificationsStatusEnum.SUCCESS, - description=f"Your URL has been properly crawled!", - ), - ) - else: - loop = asyncio.get_event_loop() - message = loop.run_until_complete( - process_github( - repo=crawl_website.url, - brain_id=brain_id, + notification_service.update_notification_by_id( + notification_id, + NotificationUpdatableProperties( + status=NotificationsStatusEnum.SUCCESS, + description="Your URL has been properly crawled!", + ), ) + else: + message = process_github( + repo=crawl_website.url, + brain_id=brain_id, ) if notification_id: diff --git a/backend/main.py b/backend/main.py index c6fad6795c5..ddb9d6f1eb9 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,15 +1,9 @@ -import os - -if __name__ == "__main__": - # import needed here when running main.py to debug backend - # you will need to run pip install python-dotenv - from dotenv import load_dotenv # type: ignore - - load_dotenv() import logging +import os import litellm import sentry_sdk +from dotenv import load_dotenv # type: ignore from fastapi import FastAPI, HTTPException, Request from fastapi.responses import HTMLResponse, JSONResponse from logger import get_logger @@ -35,6 +29,8 @@ from sentry_sdk.integrations.fastapi import FastApiIntegration from sentry_sdk.integrations.starlette import StarletteIntegration +load_dotenv() + # Set the logging level for all loggers to WARNING logging.basicConfig(level=logging.INFO) logging.getLogger("httpx").setLevel(logging.WARNING) diff --git a/backend/models/files.py b/backend/models/files.py index f3f38552066..0ca20d2a3fb 100644 --- a/backend/models/files.py +++ b/backend/models/files.py @@ -1,66 +1,38 @@ -import os -import tempfile -from typing import Any, Optional -from uuid import UUID +from pathlib import Path +from typing import List, Optional -from fastapi import UploadFile from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_core.documents import Document from logger import get_logger from models.databases.supabase.supabase import SupabaseDB from models.settings import get_supabase_db from modules.brain.service.brain_vector_service import BrainVectorService -from packages.files.file import compute_sha1_from_file +from packages.files.file import compute_sha1_from_content from pydantic import BaseModel logger = get_logger(__name__) class File(BaseModel): - id: Optional[UUID] = None - file: Optional[UploadFile] = None - file_name: Optional[str] = "" - file_size: Optional[int] = None - file_sha1: Optional[str] = "" - vectors_ids: Optional[list] = [] - file_extension: Optional[str] = "" - content: Optional[Any] = None + file_name: str + tmp_file_path: Path + bytes_content: bytes + file_size: int + file_extension: str chunk_size: int = 400 chunk_overlap: int = 100 - documents: Optional[Document] = None + documents: List[Document] = [] + file_sha1: Optional[str] = None + vectors_ids: Optional[list] = [] + + def __init__(self, **data): + super().__init__(**data) + data["file_sha1"] = compute_sha1_from_content(data["bytes_content"]) @property def supabase_db(self) -> SupabaseDB: return get_supabase_db() - def __init__(self, **kwargs): - super().__init__(**kwargs) - - if self.file: - self.file_name = self.file.filename - self.file_size = self.file.size # pyright: ignore reportPrivateUsage=none - self.file_extension = os.path.splitext( - self.file.filename # pyright: ignore reportPrivateUsage=none - )[-1].lower() - - async def compute_file_sha1(self): - """ - Compute the sha1 of the file using a temporary file - """ - with tempfile.NamedTemporaryFile( - delete=False, - suffix=self.file.filename, # pyright: ignore reportPrivateUsage=none - ) as tmp_file: - await self.file.seek(0) # pyright: ignore reportPrivateUsage=none - self.content = ( - await self.file.read() # pyright: ignore reportPrivateUsage=none - ) - tmp_file.write(self.content) - tmp_file.flush() - self.file_sha1 = compute_sha1_from_file(tmp_file.name) - - os.remove(tmp_file.name) - def compute_documents(self, loader_class): """ Compute the documents from the file @@ -69,18 +41,8 @@ def compute_documents(self, loader_class): loader_class (class): The class of the loader to use to load the file """ logger.info(f"Computing documents from file {self.file_name}") - - documents = [] - with tempfile.NamedTemporaryFile( - delete=False, - suffix=self.file.filename, # pyright: ignore reportPrivateUsage=none - ) as tmp_file: - tmp_file.write(self.content) # pyright: ignore reportPrivateUsage=none - tmp_file.flush() - loader = loader_class(tmp_file.name) - documents = loader.load() - - os.remove(tmp_file.name) + loader = loader_class(self.tmp_file_path) + documents = loader.load() text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap @@ -129,7 +91,7 @@ def file_is_empty(self): """ Check if file is empty by checking if the file pointer is at the beginning of the file """ - return self.file.size < 1 # pyright: ignore reportPrivateUsage=none + return self.file_size < 1 # pyright: ignore reportPrivateUsage=none def link_file_to_brain(self, brain_id): self.set_file_vectors_ids() diff --git a/backend/modules/brain/rags/quivr_rag.py b/backend/modules/brain/rags/quivr_rag.py index cf99dc7b6ee..f7626416ead 100644 --- a/backend/modules/brain/rags/quivr_rag.py +++ b/backend/modules/brain/rags/quivr_rag.py @@ -314,7 +314,7 @@ def get_chain(self): self.brain_id ) # pyright: ignore reportPrivateUsage=none - list_files_array = [file.file_name for file in list_files_array] + list_files_array = [file.file_name or file.url for file in list_files_array] # Max first 10 files if len(list_files_array) > 20: list_files_array = list_files_array[:20] diff --git a/backend/modules/upload/controller/upload_routes.py b/backend/modules/upload/controller/upload_routes.py index f07cd23a784..4524cfc29b7 100644 --- a/backend/modules/upload/controller/upload_routes.py +++ b/backend/modules/upload/controller/upload_routes.py @@ -74,7 +74,7 @@ async def upload_file( filename_with_brain_id = str(brain_id) + "/" + str(uploadFile.filename) try: - file_in_storage = upload_file_storage(file_content, filename_with_brain_id) + upload_file_storage(file_content, filename_with_brain_id) except Exception as e: print(e) @@ -104,7 +104,7 @@ async def upload_file( )[-1].lower(), ) - added_knowledge = knowledge_service.add_knowledge(knowledge_to_add) + knowledge_service.add_knowledge(knowledge_to_add) process_file_and_notify.delay( file_name=filename_with_brain_id, diff --git a/backend/packages/files/crawl/crawler.py b/backend/packages/files/crawl/crawler.py index 0b913329441..60dd1697f73 100644 --- a/backend/packages/files/crawl/crawler.py +++ b/backend/packages/files/crawl/crawler.py @@ -1,6 +1,5 @@ import os import re -import tempfile import unicodedata from langchain_community.document_loaders import PlaywrightURLLoader @@ -17,27 +16,21 @@ class CrawlWebsite(BaseModel): max_pages: int = 100 max_time: int = 60 - def process(self): + def process(self) -> str: # Extract and combine content recursively loader = PlaywrightURLLoader( urls=[self.url], remove_selectors=["header", "footer"] ) - data = loader.load() + data = loader.load() # Now turn the data into a string logger.info(f"Extracted content from {len(data)} pages") - logger.info(data) + logger.debug(f"Extracted data : {data}") extracted_content = "" for page in data: extracted_content += page.page_content - # Create a file - file_name = slugify(self.url) + ".txt" - temp_file_path = os.path.join(tempfile.gettempdir(), file_name) - with open(temp_file_path, "w") as temp_file: - temp_file.write(extracted_content) # type: ignore - - return temp_file_path, file_name + return extracted_content def checkGithub(self): return "github.com" in self.url diff --git a/backend/packages/files/parsers/audio.py b/backend/packages/files/parsers/audio.py index ce5ee146c9e..6324fae5947 100644 --- a/backend/packages/files/parsers/audio.py +++ b/backend/packages/files/parsers/audio.py @@ -1,5 +1,3 @@ -import os -import tempfile import time import openai @@ -9,33 +7,13 @@ from packages.files.file import compute_sha1_from_content -async def process_audio( - file: File, user, original_file_name, integration=None, integration_link=None -): - temp_filename = None - file_sha = "" +def process_audio(file: File, **kwargs): dateshort = time.strftime("%Y%m%d-%H%M%S") file_meta_name = f"audiotranscript_{dateshort}.txt" documents_vector_store = get_documents_vector_store() - try: - upload_file = file.file - with tempfile.NamedTemporaryFile( - delete=False, - suffix=upload_file.filename, # pyright: ignore reportPrivateUsage=none - ) as tmp_file: - await upload_file.seek(0) # pyright: ignore reportPrivateUsage=none - content = ( - await upload_file.read() # pyright: ignore reportPrivateUsage=none - ) - tmp_file.write(content) - tmp_file.flush() - tmp_file.close() - - temp_filename = tmp_file.name - - with open(tmp_file.name, "rb") as audio_file: - transcript = openai.Audio.transcribe("whisper-1", audio_file) + with open(file.tmp_file_path, "rb") as audio_file: + transcript = openai.Audio.transcribe("whisper-1", audio_file) file_sha = compute_sha1_from_content( transcript.text.encode("utf-8") # pyright: ignore reportPrivateUsage=none @@ -70,7 +48,3 @@ async def process_audio( ] documents_vector_store.add_documents(docs_with_metadata) - - finally: - if temp_filename and os.path.exists(temp_filename): - os.remove(temp_filename) diff --git a/backend/packages/files/parsers/code_python.py b/backend/packages/files/parsers/code_python.py index 2c52416bfd9..b3d9af076a1 100644 --- a/backend/packages/files/parsers/code_python.py +++ b/backend/packages/files/parsers/code_python.py @@ -4,10 +4,10 @@ from .common import process_file -async def process_python( +def process_python( file: File, brain_id, original_file_name, integration=None, integration_link=None ): - return await process_file( + return process_file( file=file, loader_class=PythonLoader, brain_id=brain_id, diff --git a/backend/packages/files/parsers/common.py b/backend/packages/files/parsers/common.py index fcd57863e5f..6b56aeb30a6 100644 --- a/backend/packages/files/parsers/common.py +++ b/backend/packages/files/parsers/common.py @@ -21,7 +21,7 @@ logger = get_logger(__name__) -async def process_file( +def process_file( file: File, loader_class, brain_id, diff --git a/backend/packages/files/parsers/github.py b/backend/packages/files/parsers/github.py index d8831d1be08..aa47ea7ce02 100644 --- a/backend/packages/files/parsers/github.py +++ b/backend/packages/files/parsers/github.py @@ -9,7 +9,7 @@ from packages.files.file import compute_sha1_from_content -async def process_github( +def process_github( repo, brain_id, ): diff --git a/backend/packages/files/parsers/pdf.py b/backend/packages/files/parsers/pdf.py index 605bb9fa36f..e43ec96a9e1 100644 --- a/backend/packages/files/parsers/pdf.py +++ b/backend/packages/files/parsers/pdf.py @@ -5,7 +5,11 @@ def process_pdf( - file: File, brain_id, original_file_name, integration=None, integration_link=None + file: File, + brain_id, + original_file_name, + integration=None, + integration_link=None, ): return process_file( file=file, diff --git a/backend/packages/files/parsers/txt.py b/backend/packages/files/parsers/txt.py index fdcf9e4677d..33666241964 100644 --- a/backend/packages/files/parsers/txt.py +++ b/backend/packages/files/parsers/txt.py @@ -4,10 +4,10 @@ from .common import process_file -async def process_txt( +def process_txt( file: File, brain_id, original_file_name, integration=None, integration_link=None ): - return await process_file( + return process_file( file=file, loader_class=TextLoader, brain_id=brain_id, diff --git a/backend/packages/files/processors.py b/backend/packages/files/processors.py index d0a3a550b5b..04e70a976e1 100644 --- a/backend/packages/files/processors.py +++ b/backend/packages/files/processors.py @@ -1,4 +1,3 @@ -from fastapi import HTTPException from modules.brain.service.brain_service import BrainService from .parsers.audio import process_audio @@ -52,16 +51,14 @@ def create_response(message, type): # TODO: Move filter_file to a file service to avoid circular imports from models/files.py for File class -async def filter_file( +def filter_file( file, brain_id, original_file_name=None, ): - await file.compute_file_sha1() - file_exists = file.file_already_exists() file_exists_in_brain = file.file_already_exists_in_brain(brain_id) - using_file_name = original_file_name or file.file.filename if file.file else "" + using_file_name = file.file_name brain = brain_service.get_brain_by_id(brain_id) if brain is None: @@ -86,7 +83,7 @@ async def filter_file( if file.file_extension in file_processors: try: - result = await file_processors[file.file_extension]( + result = file_processors[file.file_extension]( file=file, brain_id=brain_id, original_file_name=original_file_name,