Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 81 additions & 21 deletions api/endpoints/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,56 +2,116 @@
Upload endpoint for document ingestion.
"""

from fastapi import APIRouter, UploadFile, File, HTTPException
import logging
from fastapi import APIRouter, UploadFile, File, HTTPException, Request, status
import os
import pathlib
from concurrent.futures import ThreadPoolExecutor
from typing import Optional

from services.file_service import save_upload_file
from services.parsing_service import extract_text_from_pdf
from models.schemas import UploadResponse
from services.exceptions import DocumentSaveError, DocumentParseError, DocumentChunkError, DocumentEmbeddingError

from services.logging_config import get_logger, set_log_context, clear_log_context

logger = logging.getLogger(__name__)
logger = get_logger(__name__)
router = APIRouter()

# Configuration: allowed types and size (bytes)
ALLOWED_CONTENT_TYPES = {"application/pdf", "text/plain", "text/markdown"}
MAX_UPLOAD_BYTES = int(os.getenv("MAX_UPLOAD_BYTES", str(25 * 1024 * 1024))) # default 25 MB


def _secure_filename(name: str) -> str:
# Simple sanitization: take only base name and strip suspicious characters
base = pathlib.Path(name).name
# remove path separators and control chars
return "".join(c for c in base if c.isprintable())


@router.post("/upload", response_model=UploadResponse)
def upload_document(file: UploadFile = File(...)): # noqa: B008
async def upload_document(request: Request, file: UploadFile = File(...)): # noqa: B008
"""Upload and do a light parse to provide a preview.
This endpoint is async but offloads blocking file IO to a threadpool.
"""
# Per-request logging context
request_id = request.headers.get("X-Request-Id") or None
if request_id:
set_log_context(request_id=request_id)
try:
saved_path = save_upload_file(file.file, file.filename)
# Basic content-type and size checks
content_type = (file.content_type or "").lower()
if content_type not in ALLOWED_CONTENT_TYPES:
logger.warning("Rejected upload due to content-type", extra={"content_type": content_type})
raise HTTPException(status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE, detail="Unsupported file type")

# If client provided Content-Length header, check early
content_length = request.headers.get("content-length")
if content_length:
try:
if int(content_length) > MAX_UPLOAD_BYTES:
logger.warning("Rejected upload due to size header too large", extra={"size": content_length})
raise HTTPException(status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, detail="File too large")
except ValueError:
# ignore invalid header and continue with streaming checks
pass

Comment on lines +50 to +60
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Block oversized uploads during streaming; avoid saving then rejecting

Header-based checks aren’t enough. As written, a client can omit/misstate Content-Length and force writing arbitrarily large files to disk. Enforce the size limit during the write, and delete the file if a post-save check fails. Also, avoid per-request ThreadPoolExecutor and use asyncio.to_thread.

Apply this diff:

@@
-        # Offload blocking save to threadpool
-        loop = __import__("asyncio").get_running_loop()
-        with ThreadPoolExecutor(max_workers=1) as ex:
-            saved_path = await loop.run_in_executor(ex, save_upload_file, file.file, filename)
+        # Offload blocking save, enforcing size during save
+        saved_path = await asyncio.to_thread(save_upload_file, file.file, filename, MAX_UPLOAD_BYTES)
@@
-        # Quick size check after save
+        # Quick size check after save (defense-in-depth)
         try:
             size = os.path.getsize(saved_path)
             if size > MAX_UPLOAD_BYTES:
-                logger.warning("Saved file exceeds max size", extra={"size": size, "path": saved_path})
-                raise HTTPException(status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, detail="File too large")
-        except OSError:
-            logger.exception("Failed to stat saved file", extra={"path": saved_path})
+                logger.warning("Saved file exceeds max size", extra={"filename": os.path.basename(saved_path), "size": size})
+                with contextlib.suppress(Exception):
+                    os.remove(saved_path)
+                raise HTTPException(status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, detail="File too large")
+        except OSError as e:
+            logger.exception("Failed to stat saved file", extra={"filename": os.path.basename(saved_path)})
+            with contextlib.suppress(Exception):
+                os.remove(saved_path)
+            raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to validate saved file") from e

And update the saver to enforce the limit while writing (outside this file):

diff --git a/services/file_service.py b/services/file_service.py
@@
-def save_upload_file(upload_file: IO, filename: str) -> str:
+def save_upload_file(upload_file: IO, filename: str, max_bytes: int | None = None) -> str:
@@
-        with tempfile.NamedTemporaryFile(dir=UPLOAD_DIR, delete=False) as tmp:
-            for chunk in iter(lambda: upload_file.read(8192), b""):
-                tmp.write(chunk)
-            temp_path = tmp.name
+        with tempfile.NamedTemporaryFile(dir=UPLOAD_DIR, delete=False) as tmp:
+            total = 0
+            for chunk in iter(lambda: upload_file.read(8192), b""):
+                tmp.write(chunk)
+                total += len(chunk)
+                if max_bytes is not None and total > max_bytes:
+                    temp_path = tmp.name
+                    raise DocumentSaveError("File exceeds allowed size.")
+            temp_path = tmp.name
@@
-    except OSError as e:
+    except OSError as e:
         logger.exception("OSError during file save: %s", filename)
         with contextlib.suppress(Exception):
             if 'temp_path' in locals() and os.path.exists(temp_path):
                 os.remove(temp_path)
         raise DocumentSaveError("Failed to save file securely.") from e
+    except DocumentSaveError as e:
+        logger.warning("Upload aborted due to size limit")
+        with contextlib.suppress(Exception):
+            if 'temp_path' in locals() and os.path.exists(temp_path):
+                os.remove(temp_path)
+        raise

Also applies to: 64-68, 69-77

🤖 Prompt for AI Agents
In api/endpoints/upload.py around lines 50-60 (and similarly 64-68, 69-77), the
current logic only checks Content-Length header and may allow oversized uploads
to be written to disk if the header is missing or incorrect; update the upload
flow to enforce MAX_UPLOAD_BYTES during the streaming/save operation (move
size-check logic into the saver so it raises once the cumulative bytes exceed
the limit), ensure that if any post-save size check fails the partially written
file is deleted, and replace any per-request ThreadPoolExecutor usage with
asyncio.to_thread for disk I/O offloading; adjust calling code here to call the
updated saver that raises on overflow and handle that exception to return HTTP
413 and perform file cleanup.

# Sanitize filename
filename = _secure_filename(file.filename or "upload")

# Offload blocking save to threadpool
loop = __import__("asyncio").get_running_loop()
with ThreadPoolExecutor(max_workers=1) as ex:
saved_path = await loop.run_in_executor(ex, save_upload_file, file.file, filename)

# Quick size check after save
try:
size = os.path.getsize(saved_path)
if size > MAX_UPLOAD_BYTES:
logger.warning("Saved file exceeds max size", extra={"size": size, "path": saved_path})
raise HTTPException(status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, detail="File too large")
except OSError:
logger.exception("Failed to stat saved file", extra={"path": saved_path})

parsing_status = "success"
text_preview = None
if file.content_type == "application/pdf":
text_preview: Optional[str] = None

if content_type == "application/pdf":
try:
text = extract_text_from_pdf(saved_path)
text_preview = text[:500] if text else None
except DocumentParseError:
logger.error("Document parse error")
logger.error("Document parse error", extra={"path": saved_path})
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Use logger.exception in except blocks and chain exceptions

Align with Ruff TRY400 and B904; also avoid logging full paths elsewhere.

Apply this diff:

-                logger.error("Document parse error", extra={"path": saved_path})
+                logger.exception("Document parse error", extra={"filename": os.path.basename(saved_path)})
@@
-                logger.error("Unicode decode error while reading file for preview", extra={"path": saved_path})
+                logger.exception("Unicode decode error while reading file for preview", extra={"filename": os.path.basename(saved_path)})
@@
-                logger.error("OS error while reading file for preview", extra={"path": saved_path})
+                logger.exception("OS error while reading file for preview", extra={"filename": os.path.basename(saved_path)})
@@
-        logger.error("Document save error", extra={"error": str(dse)})
+        logger.exception("Document save error", extra={"error": str(dse)})
         raise HTTPException(status_code=400, detail="Failed to save uploaded document") from dse
@@
-        logger.error("Document processing error", extra={"error": str(de)})
-        raise HTTPException(status_code=422, detail="Error processing document")
+        logger.exception("Document processing error", extra={"error": str(de)})
+        raise HTTPException(status_code=422, detail="Error processing document") from de

Also applies to: 98-98, 101-101, 110-110, 113-114

🧰 Tools
🪛 Ruff (0.13.1)

86-86: Use logging.exception instead of logging.error

Replace with exception

(TRY400)

🤖 Prompt for AI Agents
In api/endpoints/upload.py around lines 86, 98, 101, 110 and 113-114, replace
logger.error(...) calls inside except blocks with logger.exception(...) so the
stack trace is included, avoid logging full filesystem paths by logging only the
filename or a redacted path (e.g., os.path.basename(saved_path) or
"<redacted>"), and when re-raising exceptions ensure you use exception chaining
(raise NewError(...) from e) rather than raising without the original exception;
do this for each listed location.

parsing_status = "failed"
except Exception:
logger.exception("Error parsing PDF file for preview: %s", saved_path)
logger.exception("Error parsing PDF file for preview", extra={"path": saved_path})
parsing_status = "failed"
elif file.content_type in {"text/plain", "text/markdown"}:
elif content_type in {"text/plain", "text/markdown"}:
try:
with open(saved_path, "r", encoding="utf-8", errors="replace") as f:
text = f.read(500)
text_preview = text if text else None
except UnicodeDecodeError:
parsing_status = "failed"
logger.error("Unicode decode error while reading file for preview: %s", saved_path)
logger.error("Unicode decode error while reading file for preview", extra={"path": saved_path})
except OSError:
parsing_status = "failed"
logger.error("OS error while reading file for preview: %s", saved_path)
logger.error("OS error while reading file for preview", extra={"path": saved_path})

return UploadResponse(
filename=file.filename,
filename=filename,
message="File uploaded and parsed.",
parsing_status=parsing_status,
text_preview=text_preview
text_preview=text_preview,
)
except DocumentSaveError as dse:
logger.error("Document save error: %s", dse)
raise HTTPException(status_code=400, detail=str(dse)) from dse
logger.error("Document save error", extra={"error": str(dse)})
raise HTTPException(status_code=400, detail="Failed to save uploaded document") from dse
except (DocumentParseError, DocumentChunkError, DocumentEmbeddingError) as de:
logger.error("Document processing error: %s", de)
raise HTTPException(status_code=422, detail=str(de))
except Exception:
logger.exception("Unhandled error in upload_document")
raise HTTPException(status_code=500, detail="Internal server error")
logger.error("Document processing error", extra={"error": str(de)})
raise HTTPException(status_code=422, detail="Error processing document")
finally:
# Clear per-request logging context
clear_log_context()
Comment on lines +115 to +117
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Close the uploaded file in finally to avoid resource leaks

Starlette’s UploadFile exposes an async close(). Ensure cleanup.

Apply this diff:

     finally:
-        # Clear per-request logging context
-        clear_log_context()
+        with contextlib.suppress(Exception):
+            await file.close()
+        # Clear per-request logging context
+        clear_log_context()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
finally:
# Clear per-request logging context
clear_log_context()
finally:
with contextlib.suppress(Exception):
await file.close()
# Clear per-request logging context
clear_log_context()
🤖 Prompt for AI Agents
In api/endpoints/upload.py around lines 115 to 117, the finally block only
clears per-request logging context and does not close the Starlette UploadFile,
risking a resource leak; modify the finally block to also await the uploaded
file's async close() (e.g., await uploaded_file.close()) so the uploaded file is
properly cleaned up before exiting the request handler, ensuring the call is
awaited and wrapped in a safe try/except if needed to avoid masking original
exceptions.