Skip to content

Commit

Permalink
Add better handling for ingesting duplicates
Browse files Browse the repository at this point in the history
  • Loading branch information
NolanTrem committed Jun 13, 2024
1 parent efff384 commit 2658f60
Showing 1 changed file with 124 additions and 28 deletions.
152 changes: 124 additions & 28 deletions r2r/main/r2r_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,27 @@ async def aingest_documents(
)

document_infos = []
skipped_documents = []
processed_documents = []
existing_document_ids = [
str(doc_info.document_id)
for doc_info in self.providers.vector_db.get_documents_info()
]

for iteration, document in enumerate(documents):
if (
version is not None
and str(document.id) in existing_document_ids
):
logger.error(f"Document with ID {document.id} already exists.")
if len(documents) == 1:
raise HTTPException(
status_code=409,
detail=f"Document with ID {document.id} already exists.",
)
skipped_documents.append(document.title or str(document.id))
continue

document_metadata = (
metadatas[iteration] if metadatas else document.metadata
)
Expand Down Expand Up @@ -319,14 +339,47 @@ async def aingest_documents(
)
)

processed_documents.append(document.title or str(document.id))

if skipped_documents and len(skipped_documents) == len(documents):
logger.error("All provided documents already exist.")
raise HTTPException(
status_code=409,
detail="All provided documents already exist. Use the update endpoint to update these documents.",
)

if skipped_documents:
logger.warning(
f"Skipped ingestion for the following documents since they already exist: {', '.join(skipped_documents)}. Use the update endpoint to update these documents."
)

await self.ingestion_pipeline.run(
input=to_async_generator(documents),
versions=versions,
input=to_async_generator(
[
doc
for doc in documents
if str(doc.id) not in existing_document_ids
]
),
versions=[
info.version
for info in document_infos
if info.created_at == info.updated_at
],
run_manager=self.run_manager,
)

self.providers.vector_db.upsert_documents_info(document_infos)
return {"results": "Entries upserted successfully."}
return {
"processed_documents": [
f"Document '{title}' processed successfully."
for title in processed_documents
],
"skipped_documents": [
f"Document '{title}' skipped since it already exists."
for title in skipped_documents
],
}

class IngestDocumentsRequest(BaseModel):
documents: list[Document]
Expand All @@ -337,6 +390,10 @@ async def ingest_documents_app(self, request: IngestDocumentsRequest):
) as run_id:
try:
return await self.aingest_documents(request.documents)

except HTTPException as he:
raise he

except Exception as e:
await self.logging_connection.log(
log_id=run_id,
Expand Down Expand Up @@ -445,10 +502,7 @@ async def update_documents_app(self, request: UpdateDocumentsRequest):
logger.error(
f"update_documents_app(documents={request.documents}) - \n\n{str(e)})"
)
logger.error(
f"update_documents_app(documents={request.documents}) - \n\n{str(e)})"
)
raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(status_code=500, detail=str(e)) from e

@syncable
async def aingest_files(
Expand Down Expand Up @@ -482,6 +536,13 @@ async def aingest_files(
try:
documents = []
document_infos = []
skipped_documents = []
processed_documents = []
existing_document_ids = [
str(doc_info.document_id)
for doc_info in self.providers.vector_db.get_documents_info()
]

for iteration, file in enumerate(files):
logger.info(f"Processing file: {file.filename}")
if (
Expand Down Expand Up @@ -510,14 +571,27 @@ async def aingest_files(
detail=f"'{file_extension}' is not a valid DocumentType.",
)

file_content = await file.read()
logger.info(f"File read successfully: {file.filename}")

document_id = (
generate_id_from_label(file.filename)
if document_ids is None
else document_ids[iteration]
)
if (
version is not None
and str(document_id) in existing_document_ids
):
logger.error(f"File with ID {document_id} already exists.")
if len(files) == 1:
raise HTTPException(
status_code=409,
detail=f"File with ID {document_id} already exists.",
)
skipped_documents.append(file.filename)
continue

file_content = await file.read()
logger.info(f"File read successfully: {file.filename}")

document_metadata = metadatas[iteration] if metadatas else {}
document_title = (
document_metadata.get("title", None) or file.filename
Expand Down Expand Up @@ -556,6 +630,20 @@ async def aingest_files(
)
)

processed_documents.append(file.filename)

if skipped_documents and len(skipped_documents) == len(files):
logger.error("All uploaded documents already exist.")
raise HTTPException(
status_code=409,
detail="All uploaded documents already exist. Use the update endpoint to update these documents.",
)

if skipped_documents:
logger.warning(
f"Skipped ingestion for the following documents since they already exist: {', '.join(skipped_documents)}. Use the update endpoint to update these documents."
)

# Run the pipeline asynchronously
await self.ingestion_pipeline.run(
input=to_async_generator(documents),
Expand All @@ -567,10 +655,14 @@ async def aingest_files(
self.providers.vector_db.upsert_documents_info(document_infos)

return {
"results": [
f"File '{file.filename}' processed successfully."
for file in files
]
"processed_documents": [
f"File '{filename}' processed successfully."
for filename in processed_documents
],
"skipped_documents": [
f"File '{filename}' skipped since it already exists."
for filename in skipped_documents
],
}
except Exception as e:
raise e
Expand Down Expand Up @@ -632,6 +724,10 @@ async def ingest_files_app(
document_ids=ids_list,
user_ids=user_ids_list,
)

except HTTPException as he:
raise he

except Exception as e:
logger.error(f"ingest_files() - \n\n{str(e)})")
await self.logging_connection.log(
Expand Down Expand Up @@ -1223,27 +1319,27 @@ async def aanalytics(
analysis_type = analysis_config[0]
if analysis_type == "bar_chart":
extract_key = analysis_config[1]
results[
filter_key
] = AnalysisTypes.generate_bar_chart_data(
filtered_logs[filter_key], extract_key
results[filter_key] = (
AnalysisTypes.generate_bar_chart_data(
filtered_logs[filter_key], extract_key
)
)
elif analysis_type == "basic_statistics":
extract_key = analysis_config[1]
results[
filter_key
] = AnalysisTypes.calculate_basic_statistics(
filtered_logs[filter_key], extract_key
results[filter_key] = (
AnalysisTypes.calculate_basic_statistics(
filtered_logs[filter_key], extract_key
)
)
elif analysis_type == "percentile":
extract_key = analysis_config[1]
percentile = int(analysis_config[2])
results[
filter_key
] = AnalysisTypes.calculate_percentile(
filtered_logs[filter_key],
extract_key,
percentile,
results[filter_key] = (
AnalysisTypes.calculate_percentile(
filtered_logs[filter_key],
extract_key,
percentile,
)
)
else:
logger.warning(
Expand Down

0 comments on commit 2658f60

Please sign in to comment.