Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(upload): async improved #2544

Merged
merged 7 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions backend/celery_config.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
# celery_config.py
import os

import dotenv
from celery import Celery

dotenv.load_dotenv()

CELERY_BROKER_URL = os.getenv("CELERY_BROKER_URL", "")
CELERY_BROKER_QUEUE_NAME = os.getenv("CELERY_BROKER_QUEUE_NAME", "quivr")

Expand Down
113 changes: 55 additions & 58 deletions backend/celery_worker.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import asyncio
import io
import os
from datetime import datetime, timezone
from tempfile import NamedTemporaryFile
from uuid import UUID

from celery.schedules import crontab
from celery_config import celery
from fastapi import UploadFile
from logger import get_logger
from middlewares.auth.auth_bearer import AuthBearer
from models.files import File
Expand All @@ -18,7 +17,7 @@
from modules.notification.entity.notification import NotificationsStatusEnum
from modules.notification.service.notification_service import NotificationService
from modules.onboarding.service.onboarding_service import OnboardingService
from packages.files.crawl.crawler import CrawlWebsite
from packages.files.crawl.crawler import CrawlWebsite, slugify
from packages.files.parsers.github import process_github
from packages.files.processors import filter_file
from packages.utils.telemetry import maybe_send_telemetry
Expand All @@ -42,39 +41,36 @@ def process_file_and_notify(
):
try:
supabase_client = get_supabase_client()
tmp_file_name = "tmp-file-" + file_name
tmp_file_name = tmp_file_name.replace("/", "_")
tmp_name = file_name.replace("/", "_")
base_file_name = os.path.basename(file_name)
_, file_extension = os.path.splitext(base_file_name)

with open(tmp_file_name, "wb+") as f:
with NamedTemporaryFile(
suffix="_" + tmp_name, # pyright: ignore reportPrivateUsage=none
) as tmp_file:
res = supabase_client.storage.from_("quivr").download(file_name)
f.write(res)
f.seek(0)
file_content = f.read()

upload_file = UploadFile(
file=f, filename=file_name.split("/")[-1], size=len(file_content)
tmp_file.write(res)
tmp_file.flush()
file_instance = File(
file_name=base_file_name,
tmp_file_path=tmp_file.name,
bytes_content=res,
file_size=len(res),
file_extension=file_extension,
)

file_instance = File(file=upload_file)
loop = asyncio.get_event_loop()
brain_vector_service = BrainVectorService(brain_id)
if delete_file: # TODO fix bug
brain_vector_service.delete_file_from_brain(
file_original_name, only_vectors=True
)
message = loop.run_until_complete(
filter_file(
file=file_instance,
brain_id=brain_id,
original_file_name=file_original_name,
)
)

f.close()
os.remove(tmp_file_name)
message = filter_file(
file=file_instance,
brain_id=brain_id,
original_file_name=file_original_name,
)

if notification_id:

notification_service.update_notification_by_id(
notification_id,
NotificationUpdatableProperties(
Expand All @@ -85,63 +81,64 @@ def process_file_and_notify(
brain_service.update_brain_last_update_time(brain_id)

return True

except TimeoutError:
logger.error("TimeoutError")

except Exception as e:
logger.exception(e)
notification_service.update_notification_by_id(
notification_id,
NotificationUpdatableProperties(
status=NotificationsStatusEnum.ERROR,
description=f"An error occurred while processing the file: {e}",
),
)
return False


@celery.task(name="process_crawl_and_notify")
def process_crawl_and_notify(
crawl_website_url,
brain_id,
crawl_website_url: str,
brain_id: UUID,
notification_id=None,
):

crawl_website = CrawlWebsite(url=crawl_website_url)

if not crawl_website.checkGithub():
file_path, file_name = crawl_website.process()

with open(file_path, "rb") as f:
file_content = f.read()

# Create a file-like object in memory using BytesIO
file_object = io.BytesIO(file_content)
upload_file = UploadFile(
file=file_object, filename=file_name, size=len(file_content)
)
file_instance = File(file=upload_file)

loop = asyncio.get_event_loop()
message = loop.run_until_complete(
filter_file(
# Build file data
extracted_content = crawl_website.process()
extracted_content_bytes = extracted_content.encode("utf-8")
file_name = slugify(crawl_website.url) + ".txt"

with NamedTemporaryFile(
suffix="_" + file_name, # pyright: ignore reportPrivateUsage=none
) as tmp_file:
tmp_file.write(extracted_content_bytes)
tmp_file.flush()
file_instance = File(
file_name=file_name,
tmp_file_path=tmp_file.name,
bytes_content=extracted_content_bytes,
file_size=len(extracted_content),
file_extension=".txt",
)
message = filter_file(
file=file_instance,
brain_id=brain_id,
original_file_name=crawl_website_url,
)
)
notification_service.update_notification_by_id(
notification_id,
NotificationUpdatableProperties(
status=NotificationsStatusEnum.SUCCESS,
description=f"Your URL has been properly crawled!",
),
)
else:
loop = asyncio.get_event_loop()
message = loop.run_until_complete(
process_github(
repo=crawl_website.url,
brain_id=brain_id,
notification_service.update_notification_by_id(
notification_id,
NotificationUpdatableProperties(
status=NotificationsStatusEnum.SUCCESS,
description="Your URL has been properly crawled!",
),
)
else:
message = process_github(
repo=crawl_website.url,
brain_id=brain_id,
)

if notification_id:
Expand Down
12 changes: 4 additions & 8 deletions backend/main.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
import os

if __name__ == "__main__":
# import needed here when running main.py to debug backend
# you will need to run pip install python-dotenv
from dotenv import load_dotenv # type: ignore

load_dotenv()
import logging
import os

import litellm
import sentry_sdk
from dotenv import load_dotenv # type: ignore
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import HTMLResponse, JSONResponse
from logger import get_logger
Expand All @@ -35,6 +29,8 @@
from sentry_sdk.integrations.fastapi import FastApiIntegration
from sentry_sdk.integrations.starlette import StarletteIntegration

load_dotenv()

# Set the logging level for all loggers to WARNING
logging.basicConfig(level=logging.INFO)
logging.getLogger("httpx").setLevel(logging.WARNING)
Expand Down
74 changes: 18 additions & 56 deletions backend/models/files.py
Original file line number Diff line number Diff line change
@@ -1,66 +1,38 @@
import os
import tempfile
from typing import Any, Optional
from uuid import UUID
from pathlib import Path
from typing import List, Optional

from fastapi import UploadFile
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_core.documents import Document
from logger import get_logger
from models.databases.supabase.supabase import SupabaseDB
from models.settings import get_supabase_db
from modules.brain.service.brain_vector_service import BrainVectorService
from packages.files.file import compute_sha1_from_file
from packages.files.file import compute_sha1_from_content
from pydantic import BaseModel

logger = get_logger(__name__)


class File(BaseModel):
id: Optional[UUID] = None
file: Optional[UploadFile] = None
file_name: Optional[str] = ""
file_size: Optional[int] = None
file_sha1: Optional[str] = ""
vectors_ids: Optional[list] = []
file_extension: Optional[str] = ""
content: Optional[Any] = None
file_name: str
tmp_file_path: Path
bytes_content: bytes
file_size: int
file_extension: str
chunk_size: int = 400
chunk_overlap: int = 100
documents: Optional[Document] = None
documents: List[Document] = []
file_sha1: Optional[str] = None
vectors_ids: Optional[list] = []

def __init__(self, **data):
super().__init__(**data)
data["file_sha1"] = compute_sha1_from_content(data["bytes_content"])

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Loving it <3

@property
def supabase_db(self) -> SupabaseDB:
return get_supabase_db()

def __init__(self, **kwargs):
super().__init__(**kwargs)

if self.file:
self.file_name = self.file.filename
self.file_size = self.file.size # pyright: ignore reportPrivateUsage=none
self.file_extension = os.path.splitext(
self.file.filename # pyright: ignore reportPrivateUsage=none
)[-1].lower()

async def compute_file_sha1(self):
"""
Compute the sha1 of the file using a temporary file
"""
with tempfile.NamedTemporaryFile(
delete=False,
suffix=self.file.filename, # pyright: ignore reportPrivateUsage=none
) as tmp_file:
await self.file.seek(0) # pyright: ignore reportPrivateUsage=none
self.content = (
await self.file.read() # pyright: ignore reportPrivateUsage=none
)
tmp_file.write(self.content)
tmp_file.flush()
self.file_sha1 = compute_sha1_from_file(tmp_file.name)

os.remove(tmp_file.name)

def compute_documents(self, loader_class):
"""
Compute the documents from the file
Expand All @@ -69,18 +41,8 @@ def compute_documents(self, loader_class):
loader_class (class): The class of the loader to use to load the file
"""
logger.info(f"Computing documents from file {self.file_name}")

documents = []
with tempfile.NamedTemporaryFile(
delete=False,
suffix=self.file.filename, # pyright: ignore reportPrivateUsage=none
) as tmp_file:
tmp_file.write(self.content) # pyright: ignore reportPrivateUsage=none
tmp_file.flush()
loader = loader_class(tmp_file.name)
documents = loader.load()

os.remove(tmp_file.name)
loader = loader_class(self.tmp_file_path)
documents = loader.load()

text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
Expand Down Expand Up @@ -129,7 +91,7 @@ def file_is_empty(self):
"""
Check if file is empty by checking if the file pointer is at the beginning of the file
"""
return self.file.size < 1 # pyright: ignore reportPrivateUsage=none
return self.file_size < 1 # pyright: ignore reportPrivateUsage=none

def link_file_to_brain(self, brain_id):
self.set_file_vectors_ids()
Expand Down
2 changes: 1 addition & 1 deletion backend/modules/brain/rags/quivr_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def get_chain(self):
self.brain_id
) # pyright: ignore reportPrivateUsage=none

list_files_array = [file.file_name for file in list_files_array]
list_files_array = [file.file_name or file.url for file in list_files_array]
# Max first 10 files
if len(list_files_array) > 20:
list_files_array = list_files_array[:20]
Expand Down
4 changes: 2 additions & 2 deletions backend/modules/upload/controller/upload_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ async def upload_file(
filename_with_brain_id = str(brain_id) + "/" + str(uploadFile.filename)

try:
file_in_storage = upload_file_storage(file_content, filename_with_brain_id)
upload_file_storage(file_content, filename_with_brain_id)

except Exception as e:
print(e)
Expand Down Expand Up @@ -104,7 +104,7 @@ async def upload_file(
)[-1].lower(),
)

added_knowledge = knowledge_service.add_knowledge(knowledge_to_add)
knowledge_service.add_knowledge(knowledge_to_add)

process_file_and_notify.delay(
file_name=filename_with_brain_id,
Expand Down
15 changes: 4 additions & 11 deletions backend/packages/files/crawl/crawler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
import re
import tempfile
import unicodedata

from langchain_community.document_loaders import PlaywrightURLLoader
Expand All @@ -17,27 +16,21 @@ class CrawlWebsite(BaseModel):
max_pages: int = 100
max_time: int = 60

def process(self):
def process(self) -> str:
# Extract and combine content recursively
loader = PlaywrightURLLoader(
urls=[self.url], remove_selectors=["header", "footer"]
)
data = loader.load()

data = loader.load()
# Now turn the data into a string
logger.info(f"Extracted content from {len(data)} pages")
logger.info(data)
logger.debug(f"Extracted data : {data}")
extracted_content = ""
for page in data:
extracted_content += page.page_content

# Create a file
file_name = slugify(self.url) + ".txt"
temp_file_path = os.path.join(tempfile.gettempdir(), file_name)
with open(temp_file_path, "w") as temp_file:
temp_file.write(extracted_content) # type: ignore

return temp_file_path, file_name
return extracted_content

def checkGithub(self):
return "github.com" in self.url
Expand Down
Loading
Loading