diff --git a/backend/celery_config.py b/backend/celery_config.py index bb7eaa2b2d8..85d093168f5 100644 --- a/backend/celery_config.py +++ b/backend/celery_config.py @@ -1,8 +1,11 @@ # celery_config.py import os +import dotenv from celery import Celery +dotenv.load_dotenv() + CELERY_BROKER_URL = os.getenv("CELERY_BROKER_URL", "") CELERY_BROKER_QUEUE_NAME = os.getenv("CELERY_BROKER_QUEUE_NAME", "quivr") diff --git a/backend/celery_worker.py b/backend/celery_worker.py index 34a6b98738e..8baddcecb51 100644 --- a/backend/celery_worker.py +++ b/backend/celery_worker.py @@ -1,11 +1,10 @@ -import asyncio -import io import os from datetime import datetime, timezone +from tempfile import NamedTemporaryFile +from uuid import UUID from celery.schedules import crontab from celery_config import celery -from fastapi import UploadFile from logger import get_logger from middlewares.auth.auth_bearer import AuthBearer from models.files import File @@ -18,7 +17,7 @@ from modules.notification.entity.notification import NotificationsStatusEnum from modules.notification.service.notification_service import NotificationService from modules.onboarding.service.onboarding_service import OnboardingService -from packages.files.crawl.crawler import CrawlWebsite +from packages.files.crawl.crawler import CrawlWebsite, slugify from packages.files.parsers.github import process_github from packages.files.processors import filter_file from packages.utils.telemetry import maybe_send_telemetry @@ -42,39 +41,36 @@ def process_file_and_notify( ): try: supabase_client = get_supabase_client() - tmp_file_name = "tmp-file-" + file_name - tmp_file_name = tmp_file_name.replace("/", "_") + tmp_name = file_name.replace("/", "_") + base_file_name = os.path.basename(file_name) + _, file_extension = os.path.splitext(base_file_name) - with open(tmp_file_name, "wb+") as f: + with NamedTemporaryFile( + suffix="_" + tmp_name, # pyright: ignore reportPrivateUsage=none + ) as tmp_file: res = supabase_client.storage.from_("quivr").download(file_name) - f.write(res) - f.seek(0) - file_content = f.read() - - upload_file = UploadFile( - file=f, filename=file_name.split("/")[-1], size=len(file_content) + tmp_file.write(res) + tmp_file.flush() + file_instance = File( + file_name=base_file_name, + tmp_file_path=tmp_file.name, + bytes_content=res, + file_size=len(res), + file_extension=file_extension, ) - - file_instance = File(file=upload_file) - loop = asyncio.get_event_loop() brain_vector_service = BrainVectorService(brain_id) if delete_file: # TODO fix bug brain_vector_service.delete_file_from_brain( file_original_name, only_vectors=True ) - message = loop.run_until_complete( - filter_file( - file=file_instance, - brain_id=brain_id, - original_file_name=file_original_name, - ) - ) - f.close() - os.remove(tmp_file_name) + message = filter_file( + file=file_instance, + brain_id=brain_id, + original_file_name=file_original_name, + ) if notification_id: - notification_service.update_notification_by_id( notification_id, NotificationUpdatableProperties( @@ -85,10 +81,12 @@ def process_file_and_notify( brain_service.update_brain_last_update_time(brain_id) return True + except TimeoutError: logger.error("TimeoutError") except Exception as e: + logger.exception(e) notification_service.update_notification_by_id( notification_id, NotificationUpdatableProperties( @@ -96,52 +94,51 @@ def process_file_and_notify( description=f"An error occurred while processing the file: {e}", ), ) - return False @celery.task(name="process_crawl_and_notify") def process_crawl_and_notify( - crawl_website_url, - brain_id, + crawl_website_url: str, + brain_id: UUID, notification_id=None, ): + crawl_website = CrawlWebsite(url=crawl_website_url) if not crawl_website.checkGithub(): - file_path, file_name = crawl_website.process() - - with open(file_path, "rb") as f: - file_content = f.read() - - # Create a file-like object in memory using BytesIO - file_object = io.BytesIO(file_content) - upload_file = UploadFile( - file=file_object, filename=file_name, size=len(file_content) - ) - file_instance = File(file=upload_file) - - loop = asyncio.get_event_loop() - message = loop.run_until_complete( - filter_file( + # Build file data + extracted_content = crawl_website.process() + extracted_content_bytes = extracted_content.encode("utf-8") + file_name = slugify(crawl_website.url) + ".txt" + + with NamedTemporaryFile( + suffix="_" + file_name, # pyright: ignore reportPrivateUsage=none + ) as tmp_file: + tmp_file.write(extracted_content_bytes) + tmp_file.flush() + file_instance = File( + file_name=file_name, + tmp_file_path=tmp_file.name, + bytes_content=extracted_content_bytes, + file_size=len(extracted_content), + file_extension=".txt", + ) + message = filter_file( file=file_instance, brain_id=brain_id, original_file_name=crawl_website_url, ) - ) - notification_service.update_notification_by_id( - notification_id, - NotificationUpdatableProperties( - status=NotificationsStatusEnum.SUCCESS, - description=f"Your URL has been properly crawled!", - ), - ) - else: - loop = asyncio.get_event_loop() - message = loop.run_until_complete( - process_github( - repo=crawl_website.url, - brain_id=brain_id, + notification_service.update_notification_by_id( + notification_id, + NotificationUpdatableProperties( + status=NotificationsStatusEnum.SUCCESS, + description="Your URL has been properly crawled!", + ), ) + else: + message = process_github( + repo=crawl_website.url, + brain_id=brain_id, ) if notification_id: diff --git a/backend/main.py b/backend/main.py index c6fad6795c5..ddb9d6f1eb9 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,15 +1,9 @@ -import os - -if __name__ == "__main__": - # import needed here when running main.py to debug backend - # you will need to run pip install python-dotenv - from dotenv import load_dotenv # type: ignore - - load_dotenv() import logging +import os import litellm import sentry_sdk +from dotenv import load_dotenv # type: ignore from fastapi import FastAPI, HTTPException, Request from fastapi.responses import HTMLResponse, JSONResponse from logger import get_logger @@ -35,6 +29,8 @@ from sentry_sdk.integrations.fastapi import FastApiIntegration from sentry_sdk.integrations.starlette import StarletteIntegration +load_dotenv() + # Set the logging level for all loggers to WARNING logging.basicConfig(level=logging.INFO) logging.getLogger("httpx").setLevel(logging.WARNING) diff --git a/backend/models/files.py b/backend/models/files.py index f3f38552066..0ca20d2a3fb 100644 --- a/backend/models/files.py +++ b/backend/models/files.py @@ -1,66 +1,38 @@ -import os -import tempfile -from typing import Any, Optional -from uuid import UUID +from pathlib import Path +from typing import List, Optional -from fastapi import UploadFile from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_core.documents import Document from logger import get_logger from models.databases.supabase.supabase import SupabaseDB from models.settings import get_supabase_db from modules.brain.service.brain_vector_service import BrainVectorService -from packages.files.file import compute_sha1_from_file +from packages.files.file import compute_sha1_from_content from pydantic import BaseModel logger = get_logger(__name__) class File(BaseModel): - id: Optional[UUID] = None - file: Optional[UploadFile] = None - file_name: Optional[str] = "" - file_size: Optional[int] = None - file_sha1: Optional[str] = "" - vectors_ids: Optional[list] = [] - file_extension: Optional[str] = "" - content: Optional[Any] = None + file_name: str + tmp_file_path: Path + bytes_content: bytes + file_size: int + file_extension: str chunk_size: int = 400 chunk_overlap: int = 100 - documents: Optional[Document] = None + documents: List[Document] = [] + file_sha1: Optional[str] = None + vectors_ids: Optional[list] = [] + + def __init__(self, **data): + super().__init__(**data) + data["file_sha1"] = compute_sha1_from_content(data["bytes_content"]) @property def supabase_db(self) -> SupabaseDB: return get_supabase_db() - def __init__(self, **kwargs): - super().__init__(**kwargs) - - if self.file: - self.file_name = self.file.filename - self.file_size = self.file.size # pyright: ignore reportPrivateUsage=none - self.file_extension = os.path.splitext( - self.file.filename # pyright: ignore reportPrivateUsage=none - )[-1].lower() - - async def compute_file_sha1(self): - """ - Compute the sha1 of the file using a temporary file - """ - with tempfile.NamedTemporaryFile( - delete=False, - suffix=self.file.filename, # pyright: ignore reportPrivateUsage=none - ) as tmp_file: - await self.file.seek(0) # pyright: ignore reportPrivateUsage=none - self.content = ( - await self.file.read() # pyright: ignore reportPrivateUsage=none - ) - tmp_file.write(self.content) - tmp_file.flush() - self.file_sha1 = compute_sha1_from_file(tmp_file.name) - - os.remove(tmp_file.name) - def compute_documents(self, loader_class): """ Compute the documents from the file @@ -69,18 +41,8 @@ def compute_documents(self, loader_class): loader_class (class): The class of the loader to use to load the file """ logger.info(f"Computing documents from file {self.file_name}") - - documents = [] - with tempfile.NamedTemporaryFile( - delete=False, - suffix=self.file.filename, # pyright: ignore reportPrivateUsage=none - ) as tmp_file: - tmp_file.write(self.content) # pyright: ignore reportPrivateUsage=none - tmp_file.flush() - loader = loader_class(tmp_file.name) - documents = loader.load() - - os.remove(tmp_file.name) + loader = loader_class(self.tmp_file_path) + documents = loader.load() text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap @@ -129,7 +91,7 @@ def file_is_empty(self): """ Check if file is empty by checking if the file pointer is at the beginning of the file """ - return self.file.size < 1 # pyright: ignore reportPrivateUsage=none + return self.file_size < 1 # pyright: ignore reportPrivateUsage=none def link_file_to_brain(self, brain_id): self.set_file_vectors_ids() diff --git a/backend/modules/brain/rags/quivr_rag.py b/backend/modules/brain/rags/quivr_rag.py index cf99dc7b6ee..f7626416ead 100644 --- a/backend/modules/brain/rags/quivr_rag.py +++ b/backend/modules/brain/rags/quivr_rag.py @@ -314,7 +314,7 @@ def get_chain(self): self.brain_id ) # pyright: ignore reportPrivateUsage=none - list_files_array = [file.file_name for file in list_files_array] + list_files_array = [file.file_name or file.url for file in list_files_array] # Max first 10 files if len(list_files_array) > 20: list_files_array = list_files_array[:20] diff --git a/backend/modules/upload/controller/upload_routes.py b/backend/modules/upload/controller/upload_routes.py index f07cd23a784..4524cfc29b7 100644 --- a/backend/modules/upload/controller/upload_routes.py +++ b/backend/modules/upload/controller/upload_routes.py @@ -74,7 +74,7 @@ async def upload_file( filename_with_brain_id = str(brain_id) + "/" + str(uploadFile.filename) try: - file_in_storage = upload_file_storage(file_content, filename_with_brain_id) + upload_file_storage(file_content, filename_with_brain_id) except Exception as e: print(e) @@ -104,7 +104,7 @@ async def upload_file( )[-1].lower(), ) - added_knowledge = knowledge_service.add_knowledge(knowledge_to_add) + knowledge_service.add_knowledge(knowledge_to_add) process_file_and_notify.delay( file_name=filename_with_brain_id, diff --git a/backend/packages/files/crawl/crawler.py b/backend/packages/files/crawl/crawler.py index 0b913329441..60dd1697f73 100644 --- a/backend/packages/files/crawl/crawler.py +++ b/backend/packages/files/crawl/crawler.py @@ -1,6 +1,5 @@ import os import re -import tempfile import unicodedata from langchain_community.document_loaders import PlaywrightURLLoader @@ -17,27 +16,21 @@ class CrawlWebsite(BaseModel): max_pages: int = 100 max_time: int = 60 - def process(self): + def process(self) -> str: # Extract and combine content recursively loader = PlaywrightURLLoader( urls=[self.url], remove_selectors=["header", "footer"] ) - data = loader.load() + data = loader.load() # Now turn the data into a string logger.info(f"Extracted content from {len(data)} pages") - logger.info(data) + logger.debug(f"Extracted data : {data}") extracted_content = "" for page in data: extracted_content += page.page_content - # Create a file - file_name = slugify(self.url) + ".txt" - temp_file_path = os.path.join(tempfile.gettempdir(), file_name) - with open(temp_file_path, "w") as temp_file: - temp_file.write(extracted_content) # type: ignore - - return temp_file_path, file_name + return extracted_content def checkGithub(self): return "github.com" in self.url diff --git a/backend/packages/files/parsers/audio.py b/backend/packages/files/parsers/audio.py index ce5ee146c9e..6324fae5947 100644 --- a/backend/packages/files/parsers/audio.py +++ b/backend/packages/files/parsers/audio.py @@ -1,5 +1,3 @@ -import os -import tempfile import time import openai @@ -9,33 +7,13 @@ from packages.files.file import compute_sha1_from_content -async def process_audio( - file: File, user, original_file_name, integration=None, integration_link=None -): - temp_filename = None - file_sha = "" +def process_audio(file: File, **kwargs): dateshort = time.strftime("%Y%m%d-%H%M%S") file_meta_name = f"audiotranscript_{dateshort}.txt" documents_vector_store = get_documents_vector_store() - try: - upload_file = file.file - with tempfile.NamedTemporaryFile( - delete=False, - suffix=upload_file.filename, # pyright: ignore reportPrivateUsage=none - ) as tmp_file: - await upload_file.seek(0) # pyright: ignore reportPrivateUsage=none - content = ( - await upload_file.read() # pyright: ignore reportPrivateUsage=none - ) - tmp_file.write(content) - tmp_file.flush() - tmp_file.close() - - temp_filename = tmp_file.name - - with open(tmp_file.name, "rb") as audio_file: - transcript = openai.Audio.transcribe("whisper-1", audio_file) + with open(file.tmp_file_path, "rb") as audio_file: + transcript = openai.Audio.transcribe("whisper-1", audio_file) file_sha = compute_sha1_from_content( transcript.text.encode("utf-8") # pyright: ignore reportPrivateUsage=none @@ -70,7 +48,3 @@ async def process_audio( ] documents_vector_store.add_documents(docs_with_metadata) - - finally: - if temp_filename and os.path.exists(temp_filename): - os.remove(temp_filename) diff --git a/backend/packages/files/parsers/code_python.py b/backend/packages/files/parsers/code_python.py index 2c52416bfd9..b3d9af076a1 100644 --- a/backend/packages/files/parsers/code_python.py +++ b/backend/packages/files/parsers/code_python.py @@ -4,10 +4,10 @@ from .common import process_file -async def process_python( +def process_python( file: File, brain_id, original_file_name, integration=None, integration_link=None ): - return await process_file( + return process_file( file=file, loader_class=PythonLoader, brain_id=brain_id, diff --git a/backend/packages/files/parsers/common.py b/backend/packages/files/parsers/common.py index fcd57863e5f..6b56aeb30a6 100644 --- a/backend/packages/files/parsers/common.py +++ b/backend/packages/files/parsers/common.py @@ -21,7 +21,7 @@ logger = get_logger(__name__) -async def process_file( +def process_file( file: File, loader_class, brain_id, diff --git a/backend/packages/files/parsers/github.py b/backend/packages/files/parsers/github.py index d8831d1be08..aa47ea7ce02 100644 --- a/backend/packages/files/parsers/github.py +++ b/backend/packages/files/parsers/github.py @@ -9,7 +9,7 @@ from packages.files.file import compute_sha1_from_content -async def process_github( +def process_github( repo, brain_id, ): diff --git a/backend/packages/files/parsers/pdf.py b/backend/packages/files/parsers/pdf.py index 605bb9fa36f..e43ec96a9e1 100644 --- a/backend/packages/files/parsers/pdf.py +++ b/backend/packages/files/parsers/pdf.py @@ -5,7 +5,11 @@ def process_pdf( - file: File, brain_id, original_file_name, integration=None, integration_link=None + file: File, + brain_id, + original_file_name, + integration=None, + integration_link=None, ): return process_file( file=file, diff --git a/backend/packages/files/parsers/txt.py b/backend/packages/files/parsers/txt.py index fdcf9e4677d..33666241964 100644 --- a/backend/packages/files/parsers/txt.py +++ b/backend/packages/files/parsers/txt.py @@ -4,10 +4,10 @@ from .common import process_file -async def process_txt( +def process_txt( file: File, brain_id, original_file_name, integration=None, integration_link=None ): - return await process_file( + return process_file( file=file, loader_class=TextLoader, brain_id=brain_id, diff --git a/backend/packages/files/processors.py b/backend/packages/files/processors.py index d0a3a550b5b..04e70a976e1 100644 --- a/backend/packages/files/processors.py +++ b/backend/packages/files/processors.py @@ -1,4 +1,3 @@ -from fastapi import HTTPException from modules.brain.service.brain_service import BrainService from .parsers.audio import process_audio @@ -52,16 +51,14 @@ def create_response(message, type): # TODO: Move filter_file to a file service to avoid circular imports from models/files.py for File class -async def filter_file( +def filter_file( file, brain_id, original_file_name=None, ): - await file.compute_file_sha1() - file_exists = file.file_already_exists() file_exists_in_brain = file.file_already_exists_in_brain(brain_id) - using_file_name = original_file_name or file.file.filename if file.file else "" + using_file_name = file.file_name brain = brain_service.get_brain_by_id(brain_id) if brain is None: @@ -86,7 +83,7 @@ async def filter_file( if file.file_extension in file_processors: try: - result = await file_processors[file.file_extension]( + result = file_processors[file.file_extension]( file=file, brain_id=brain_id, original_file_name=original_file_name,