diff --git a/backend/app/alembic/versions/40307ab77e9f_add_storage_path_to_project_and_project_to_document_table.py b/backend/app/alembic/versions/40307ab77e9f_add_storage_path_to_project_and_project_to_document_table.py new file mode 100644 index 00000000..14bee314 --- /dev/null +++ b/backend/app/alembic/versions/40307ab77e9f_add_storage_path_to_project_and_project_to_document_table.py @@ -0,0 +1,108 @@ +"""add storage_path to project and project_id to document table + +Revision ID: 40307ab77e9f +Revises: 8725df286943 +Create Date: 2025-08-28 10:54:30.712627 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "40307ab77e9f" +down_revision = "8725df286943" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + op.add_column("project", sa.Column("storage_path", sa.Uuid(), nullable=True)) + + conn = op.get_bind() + conn.execute(sa.text("UPDATE project SET storage_path = gen_random_uuid()")) + + op.alter_column("project", "storage_path", nullable=False) + op.create_unique_constraint("uq_project_storage_path", "project", ["storage_path"]) + + op.add_column("document", sa.Column("project_id", sa.Integer(), nullable=True)) + op.add_column("document", sa.Column("is_deleted", sa.Boolean(), nullable=True)) + + conn.execute( + sa.text( + """ + UPDATE document + SET is_deleted = CASE + WHEN deleted_at IS NULL THEN false + ELSE true + END + """ + ) + ) + conn.execute( + sa.text( + """ + UPDATE document + SET project_id = ( + SELECT project_id FROM apikey + WHERE apikey.user_id = document.owner_id + LIMIT 1 + ) + """ + ) + ) + + op.alter_column("document", "is_deleted", nullable=False) + op.alter_column("document", "project_id", nullable=False) + + op.drop_constraint("document_owner_id_fkey", "document", type_="foreignkey") + op.create_foreign_key( + None, "document", "project", ["project_id"], ["id"], ondelete="CASCADE" + ) + op.drop_column("document", "owner_id") + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint("uq_project_storage_path", "project", type_="unique") + op.drop_column("project", "storage_path") + + op.add_column( + "document", + sa.Column("owner_id", sa.Integer(), autoincrement=False, nullable=True), + ) + + conn = op.get_bind() + # Backfill owner_id from project_id using api_key mapping + conn.execute( + sa.text( + """ + UPDATE document d + SET owner_id = ( + SELECT user_id + FROM apikey a + WHERE a.project_id = d.project_id + LIMIT 1 + ) + """ + ) + ) + + op.alter_column("document", "owner_id", nullable=False) + + op.drop_constraint("document_project_id_fkey", "document", type_="foreignkey") + op.create_foreign_key( + "document_owner_id_fkey", + "document", + "user", + ["owner_id"], + ["id"], + ondelete="CASCADE", + ) + op.drop_column("document", "is_deleted") + op.drop_column("document", "project_id") + # ### end Alembic commands ### diff --git a/backend/app/api/routes/collections.py b/backend/app/api/routes/collections.py index 2a5b49fb..bea3ed3a 100644 --- a/backend/app/api/routes/collections.py +++ b/backend/app/api/routes/collections.py @@ -15,7 +15,7 @@ from sqlalchemy.exc import SQLAlchemyError from app.api.deps import CurrentUser, SessionDep, CurrentUserOrgProject -from app.core.cloud import AmazonCloudStorage +from app.core.cloud import get_cloud_storage from app.api.routes.responses import handle_openai_error from app.core.util import now, post_callback from app.crud import ( @@ -24,7 +24,7 @@ DocumentCollectionCrud, ) from app.crud.rag import OpenAIVectorStoreCrud, OpenAIAssistantCrud -from app.models import Collection, Document +from app.models import Collection, Document, DocumentPublic from app.models.collection import CollectionStatus from app.utils import APIResponse, load_description, get_openai_client @@ -225,8 +225,8 @@ def do_create_collection( else WebHookCallback(request.callback_url, payload) ) - storage = AmazonCloudStorage(current_user) - document_crud = DocumentCrud(session, current_user.id) + storage = get_cloud_storage(session=session, project_id=current_user.project_id) + document_crud = DocumentCrud(session, current_user.project_id) assistant_crud = OpenAIAssistantCrud(client) vector_store_crud = OpenAIVectorStoreCrud(client) collection_crud = CollectionCrud(session, current_user.id) @@ -423,7 +423,7 @@ def list_collections( @router.post( "/docs/{collection_id}", description=load_description("collections/docs.md"), - response_model=APIResponse[List[Document]], + response_model=APIResponse[List[DocumentPublic]], ) def collection_documents( session: SessionDep, diff --git a/backend/app/api/routes/documents.py b/backend/app/api/routes/documents.py index 3924cabf..ee9a0843 100644 --- a/backend/app/api/routes/documents.py +++ b/backend/app/api/routes/documents.py @@ -3,14 +3,14 @@ from typing import List from pathlib import Path -from fastapi import APIRouter, File, UploadFile, Query +from fastapi import APIRouter, File, UploadFile, Query, HTTPException from fastapi import Path as FastPath from app.crud import DocumentCrud, CollectionCrud -from app.models import Document +from app.models import Document, DocumentPublic, Message from app.utils import APIResponse, load_description, get_openai_client from app.api.deps import CurrentUser, SessionDep, CurrentUserOrgProject -from app.core.cloud import AmazonCloudStorage +from app.core.cloud import get_cloud_storage from app.crud.rag import OpenAIAssistantCrud logger = logging.getLogger(__name__) @@ -20,15 +20,15 @@ @router.get( "/list", description=load_description("documents/list.md"), - response_model=APIResponse[List[Document]], + response_model=APIResponse[List[DocumentPublic]], ) def list_docs( session: SessionDep, - current_user: CurrentUser, + current_user: CurrentUserOrgProject, skip: int = Query(0, ge=0), limit: int = Query(100, gt=0, le=100), ): - crud = DocumentCrud(session, current_user.id) + crud = DocumentCrud(session, current_user.project_id) data = crud.read_many(skip, limit) return APIResponse.success_response(data) @@ -36,19 +36,19 @@ def list_docs( @router.post( "/upload", description=load_description("documents/upload.md"), - response_model=APIResponse[Document], + response_model=APIResponse[DocumentPublic], ) def upload_doc( session: SessionDep, - current_user: CurrentUser, + current_user: CurrentUserOrgProject, src: UploadFile = File(...), ): - storage = AmazonCloudStorage(current_user) + storage = get_cloud_storage(session=session, project_id=current_user.project_id) document_id = uuid4() object_store_url = storage.put(src, Path(str(document_id))) - crud = DocumentCrud(session, current_user.id) + crud = DocumentCrud(session, current_user.project_id) document = Document( id=document_id, fname=src.filename, @@ -58,10 +58,10 @@ def upload_doc( return APIResponse.success_response(data) -@router.get( +@router.delete( "/remove/{doc_id}", description=load_description("documents/delete.md"), - response_model=APIResponse[Document], + response_model=APIResponse[Message], ) def remove_doc( session: SessionDep, @@ -73,18 +73,21 @@ def remove_doc( ) a_crud = OpenAIAssistantCrud(client) - d_crud = DocumentCrud(session, current_user.id) + d_crud = DocumentCrud(session, current_user.project_id) c_crud = CollectionCrud(session, current_user.id) document = d_crud.delete(doc_id) data = c_crud.delete(document, a_crud) - return APIResponse.success_response(data) + + return APIResponse.success_response( + Message(message="Document Deleted Successfully") + ) @router.delete( "/remove/{doc_id}/permanent", description=load_description("documents/permanent_delete.md"), - response_model=APIResponse[Document], + response_model=APIResponse[Message], ) def permanent_delete_doc( session: SessionDep, @@ -94,11 +97,10 @@ def permanent_delete_doc( client = get_openai_client( session, current_user.organization_id, current_user.project_id ) - a_crud = OpenAIAssistantCrud(client) - d_crud = DocumentCrud(session, current_user.id) + d_crud = DocumentCrud(session, current_user.project_id) c_crud = CollectionCrud(session, current_user.id) - storage = AmazonCloudStorage(current_user) + storage = get_cloud_storage(session=session, project_id=current_user.project_id) document = d_crud.read_one(doc_id) @@ -107,19 +109,30 @@ def permanent_delete_doc( storage.delete(document.object_store_url) d_crud.delete(doc_id) - return APIResponse.success_response(document) + return APIResponse.success_response( + Message(message="Document permanently deleted successfully") + ) @router.get( "/info/{doc_id}", description=load_description("documents/info.md"), - response_model=APIResponse[Document], + response_model=APIResponse[DocumentPublic], ) def doc_info( session: SessionDep, - current_user: CurrentUser, + current_user: CurrentUserOrgProject, doc_id: UUID = FastPath(description="Document to retrieve"), + include_url: bool = Query( + False, description="Include a signed URL to access the document" + ), ): - crud = DocumentCrud(session, current_user.id) - data = crud.read_one(doc_id) - return APIResponse.success_response(data) + crud = DocumentCrud(session, current_user.project_id) + document = crud.read_one(doc_id) + + doc_schema = DocumentPublic.model_validate(document, from_attributes=True) + if include_url: + storage = get_cloud_storage(session=session, project_id=current_user.project_id) + doc_schema.signed_url = storage.get_signed_url(document.object_store_url) + + return APIResponse.success_response(doc_schema) diff --git a/backend/app/core/cloud/__init__.py b/backend/app/core/cloud/__init__.py index 545a0235..c29a35ad 100644 --- a/backend/app/core/cloud/__init__.py +++ b/backend/app/core/cloud/__init__.py @@ -3,4 +3,5 @@ AmazonCloudStorageClient, CloudStorage, CloudStorageError, + get_cloud_storage, ) diff --git a/backend/app/core/cloud/storage.py b/backend/app/core/cloud/storage.py index 0c225d91..a3248a01 100644 --- a/backend/app/core/cloud/storage.py +++ b/backend/app/core/cloud/storage.py @@ -1,16 +1,20 @@ import os +from sqlmodel import Session +from uuid import UUID import logging import functools as ft from pathlib import Path from dataclasses import dataclass, asdict from urllib.parse import ParseResult, urlparse, urlunparse +from abc import ABC, abstractmethod import boto3 from fastapi import UploadFile from botocore.exceptions import ClientError from botocore.response import StreamingBody -from app.api.deps import CurrentUser +from app.crud import get_project_by_id +from app.models import UserProjectOrg from app.core.config import settings from app.utils import mask_string @@ -107,25 +111,47 @@ def from_url(cls, url: str): return cls(Bucket=url.netloc, Key=str(path)) -class CloudStorage: - def __init__(self, user: CurrentUser): - self.user = user +class CloudStorage(ABC): + def __init__(self, project_id: int, storage_path: UUID): + self.project_id = project_id + self.storage_path = str(storage_path) - def put(self, source: UploadFile, basename: str): - raise NotImplementedError() + @abstractmethod + def put(self, source: UploadFile, filepath: Path) -> SimpleStorageName: + """Upload a file to storage""" + pass + @abstractmethod def stream(self, url: str) -> StreamingBody: - raise NotImplementedError() + """Stream a file from storage""" + pass + + @abstractmethod + def get_file_size_kb(self, url: str) -> float: + """Return the file size in KB""" + pass + + @abstractmethod + def get_signed_url(self, url: str, expires_in: int = 3600) -> str: + """Generate a signed URL with an optional expiry""" + pass + + @abstractmethod + def delete(self, url: str) -> None: + """Delete a file from storage""" + pass class AmazonCloudStorage(CloudStorage): - def __init__(self, user: CurrentUser): - super().__init__(user) + def __init__(self, project_id: int, storage_path: UUID): + super().__init__(project_id, storage_path) self.aws = AmazonCloudStorageClient() - def put(self, source: UploadFile, basename: Path) -> SimpleStorageName: - key = Path(str(self.user.id), basename) - destination = SimpleStorageName(str(key)) + def put(self, source: UploadFile, file_path: Path) -> SimpleStorageName: + if file_path.is_absolute(): + raise ValueError("file_path must be relative to the project's storage root") + key = Path(self.storage_path) / file_path + destination = SimpleStorageName(key.as_posix()) kwargs = asdict(destination) try: @@ -138,12 +164,12 @@ def put(self, source: UploadFile, basename: Path) -> SimpleStorageName: ) logger.info( f"[AmazonCloudStorage.put] File uploaded successfully | " - f"{{'user_id': '{self.user.id}', 'bucket': '{mask_string(destination.Bucket)}', 'key': '{mask_string(destination.Key)}'}}" + f"{{'project_id': '{self.project_id}', 'bucket': '{mask_string(destination.Bucket)}', 'key': '{mask_string(destination.Key)}'}}" ) except ClientError as err: logger.error( f"[AmazonCloudStorage.put] AWS upload error | " - f"{{'user_id': '{self.user.id}', 'bucket': '{mask_string(destination.Bucket)}', 'key': '{mask_string(destination.Key)}', 'error': '{str(err)}'}}", + f"{{'project_id': '{self.project_id}', 'bucket': '{mask_string(destination.Bucket)}', 'key': '{mask_string(destination.Key)}', 'error': '{str(err)}'}}", exc_info=True, ) raise CloudStorageError(f'AWS Error: "{err}"') from err @@ -157,13 +183,13 @@ def stream(self, url: str) -> StreamingBody: body = self.aws.client.get_object(**kwargs).get("Body") logger.info( f"[AmazonCloudStorage.stream] File streamed successfully | " - f"{{'user_id': '{self.user.id}', 'bucket': '{mask_string(name.Bucket)}', 'key': '{mask_string(name.Key)}'}}" + f"{{'project_id': '{self.project_id}', 'bucket': '{mask_string(name.Bucket)}', 'key': '{mask_string(name.Key)}'}}" ) return body except ClientError as err: logger.error( f"[AmazonCloudStorage.stream] AWS stream error | " - f"{{'user_id': '{self.user.id}', 'bucket': '{mask_string(name.Bucket)}', 'key': '{mask_string(name.Key)}', 'error': '{str(err)}'}}", + f"{{'project_id': '{self.project_id}', 'bucket': '{mask_string(name.Bucket)}', 'key': '{mask_string(name.Key)}', 'error': '{str(err)}'}}", exc_info=True, ) raise CloudStorageError(f'AWS Error: "{err}" ({url})') from err @@ -177,13 +203,40 @@ def get_file_size_kb(self, url: str) -> float: size_kb = round(size_bytes / 1024, 2) logger.info( f"[AmazonCloudStorage.get_file_size_kb] File size retrieved successfully | " - f"{{'user_id': '{self.user.id}', 'bucket': '{mask_string(name.Bucket)}', 'key': '{mask_string(name.Key)}', 'size_kb': {size_kb}}}" + f"{{'project_id': '{self.project_id}', 'bucket': '{mask_string(name.Bucket)}', 'key': '{mask_string(name.Key)}', 'size_kb': {size_kb}}}" ) return size_kb except ClientError as err: logger.error( f"[AmazonCloudStorage.get_file_size_kb] AWS head object error | " - f"{{'user_id': '{self.user.id}', 'bucket': '{mask_string(name.Bucket)}', 'key': '{mask_string(name.Key)}', 'error': '{str(err)}'}}", + f"{{'project_id': '{self.project_id}', 'bucket': '{mask_string(name.Bucket)}', 'key': '{mask_string(name.Key)}', 'error': '{str(err)}'}}", + exc_info=True, + ) + raise CloudStorageError(f'AWS Error: "{err}" ({url})') from err + + def get_signed_url(self, url: str, expires_in: int = 3600) -> str: + """ + Generate a signed S3 URL for the given file. + :param url: S3 url (e.g., s3://bucket/key) + :param expires_in: Expiry time in seconds (default: 1 hour) + :return: Signed URL as string + """ + name = SimpleStorageName.from_url(url) + try: + signed_url = self.aws.client.generate_presigned_url( + "get_object", + Params={"Bucket": name.Bucket, "Key": name.Key}, + ExpiresIn=expires_in, + ) + logger.info( + f"[AmazonCloudStorage.get_signed_url] Signed URL generated | " + f"{{'project_id': '{self.project_id}', 'bucket': '{mask_string(name.Bucket)}', 'key': '{mask_string(name.Key)}'}}" + ) + return signed_url + except ClientError as err: + logger.error( + f"[AmazonCloudStorage.get_signed_url] AWS presign error | " + f"{{'project_id': '{self.project_id}', 'bucket': '{mask_string(name.Bucket)}', 'key': '{mask_string(name.Key)}', 'error': '{str(err)}'}}", exc_info=True, ) raise CloudStorageError(f'AWS Error: "{err}" ({url})') from err @@ -195,12 +248,25 @@ def delete(self, url: str) -> None: self.aws.client.delete_object(**kwargs) logger.info( f"[AmazonCloudStorage.delete] File deleted successfully | " - f"{{'user_id': '{self.user.id}', 'bucket': '{mask_string(name.Bucket)}', 'key': '{mask_string(name.Key)}'}}" + f"{{'project_id': '{self.project_id}', 'bucket': '{mask_string(name.Bucket)}', 'key': '{mask_string(name.Key)}'}}" ) except ClientError as err: logger.error( f"[AmazonCloudStorage.delete] AWS delete error | " - f"{{'user_id': '{self.user.id}', 'bucket': '{mask_string(name.Bucket)}', 'key': '{mask_string(name.Key)}', 'error': '{str(err)}'}}", + f"{{'project_id': '{self.project_id}', 'bucket': '{mask_string(name.Bucket)}', 'key': '{mask_string(name.Key)}', 'error': '{str(err)}'}}", exc_info=True, ) raise CloudStorageError(f'AWS Error: "{err}" ({url})') from err + + +def get_cloud_storage(session: Session, project_id: int) -> CloudStorage: + """ + Method to create and configure a cloud storage instance. + """ + project = get_project_by_id(session=session, project_id=project_id) + if not project: + raise ValueError(f"Invalid project_id: {project_id}") + + storage_path = project.storage_path + + return AmazonCloudStorage(project_id=project_id, storage_path=storage_path) diff --git a/backend/app/crud/document.py b/backend/app/crud/document.py index 4d504c59..8703d56a 100644 --- a/backend/app/crud/document.py +++ b/backend/app/crud/document.py @@ -12,22 +12,23 @@ class DocumentCrud: - def __init__(self, session: Session, owner_id: int): + def __init__(self, session: Session, project_id: int): self.session = session - self.owner_id = owner_id + self.project_id = project_id - def read_one(self, doc_id: UUID): + def read_one(self, doc_id: UUID) -> Document: statement = select(Document).where( and_( - Document.owner_id == self.owner_id, Document.id == doc_id, + Document.project_id == self.project_id, + Document.is_deleted.is_(False), ) ) result = self.session.exec(statement).one_or_none() if result is None: logger.warning( - f"[DocumentCrud.read_one] Document not found | {{'doc_id': '{doc_id}', 'owner_id': {self.owner_id}}}" + f"[DocumentCrud.read_one] Document not found | {{'doc_id': '{doc_id}', 'project_id': {self.project_id}}}" ) raise HTTPException(status_code=404, detail="Document not found") @@ -35,14 +36,11 @@ def read_one(self, doc_id: UUID): def read_many( self, - skip: Optional[int] = None, - limit: Optional[int] = None, - ): + skip: int | None = None, + limit: int | None = None, + ) -> list[Document]: statement = select(Document).where( - and_( - Document.owner_id == self.owner_id, - Document.deleted_at.is_(None), - ) + and_(Document.project_id == self.project_id, Document.is_deleted.is_(False)) ) if skip is not None: @@ -51,7 +49,7 @@ def read_many( raise ValueError(f"Negative skip: {skip}") except ValueError as err: logger.error( - f"[DocumentCrud.read_many] Invalid skip value | {{'owner_id': {self.owner_id}, 'skip': {skip}, 'error': '{str(err)}'}}", + f"[DocumentCrud.read_many] Invalid skip value | {{'project_id': {self.project_id}, 'skip': {skip}, 'error': '{str(err)}'}}", exc_info=True, ) raise @@ -63,7 +61,7 @@ def read_many( raise ValueError(f"Negative limit: {limit}") except ValueError as err: logger.error( - f"[DocumentCrud.read_many] Invalid limit value | {{'owner_id': {self.owner_id}, 'limit': {limit}, 'error': '{str(err)}'}}", + f"[DocumentCrud.read_many] Invalid limit value | {{'project_id': {self.project_id}, 'limit': {limit}, 'error': '{str(err)}'}}", exc_info=True, ) raise @@ -72,11 +70,12 @@ def read_many( documents = self.session.exec(statement).all() return documents - def read_each(self, doc_ids: List[UUID]): + def read_each(self, doc_ids: list[UUID]): statement = select(Document).where( and_( - Document.owner_id == self.owner_id, + Document.project_id == self.project_id, Document.id.in_(doc_ids), + Document.is_deleted.is_(False), ) ) results = self.session.exec(statement).all() @@ -89,7 +88,7 @@ def read_each(self, doc_ids: List[UUID]): ) except ValueError as err: logger.error( - f"[DocumentCrud.read_each] Mismatch in retrieved documents | {{'owner_id': {self.owner_id}, 'requested_count': {requested_count}, 'retrieved_count': {retrieved_count}}}", + f"[DocumentCrud.read_each] Mismatch in retrieved documents | {{'project_id': {self.project_id}, 'requested_count': {requested_count}, 'retrieved_count': {retrieved_count}}}", exc_info=True, ) raise @@ -97,12 +96,12 @@ def read_each(self, doc_ids: List[UUID]): return results def update(self, document: Document): - if not document.owner_id: - document.owner_id = self.owner_id - elif document.owner_id != self.owner_id: - error = "Invalid document ownership: owner={} attempter={}".format( - self.owner_id, - document.owner_id, + if not document.project_id: + document.project_id = self.project_id + elif document.project_id != self.project_id: + error = "Invalid document ownership: project={} attempter={}".format( + self.project_id, + document.project_id, ) try: raise PermissionError(error) @@ -118,18 +117,19 @@ def update(self, document: Document): self.session.commit() self.session.refresh(document) logger.info( - f"[DocumentCrud.update] Document updated successfully | {{'doc_id': '{document.id}', 'owner_id': {self.owner_id}}}" + f"[DocumentCrud.update] Document updated successfully | {{'doc_id': '{document.id}', 'project_id': {self.project_id}}}" ) return document def delete(self, doc_id: UUID): document = self.read_one(doc_id) + document.is_deleted = True document.deleted_at = now() document.updated_at = now() updated_document = self.update(document) logger.info( - f"[DocumentCrud.delete] Document deleted successfully | {{'doc_id': '{doc_id}', 'owner_id': {self.owner_id}}}" + f"[DocumentCrud.delete] Document deleted successfully | {{'doc_id': '{doc_id}', 'project_id': {self.project_id}}}" ) return updated_document diff --git a/backend/app/crud/project.py b/backend/app/crud/project.py index a876ae03..570a3b26 100644 --- a/backend/app/crud/project.py +++ b/backend/app/crud/project.py @@ -35,8 +35,7 @@ def create_project(*, session: Session, project_create: ProjectCreate) -> Projec def get_project_by_id(*, session: Session, project_id: int) -> Optional[Project]: - statement = select(Project).where(Project.id == project_id) - return session.exec(statement).first() + return session.get(Project, project_id) def get_project_by_name( diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index 2f5c8bc6..e4a85a4e 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -2,7 +2,7 @@ from .auth import Token, TokenPayload from .collection import Collection -from .document import Document +from .document import Document, DocumentPublic from .document_collection import DocumentCollection from .message import Message diff --git a/backend/app/models/document.py b/backend/app/models/document.py index 71da74c1..d900e8ad 100644 --- a/backend/app/models/document.py +++ b/backend/app/models/document.py @@ -1,30 +1,47 @@ from uuid import UUID, uuid4 from datetime import datetime -from sqlmodel import Field, Relationship, SQLModel +from sqlmodel import Field, SQLModel from app.core.util import now -from .user import User -class Document(SQLModel, table=True): +class DocumentBase(SQLModel): + project_id: int = Field( + description="The ID of the project to which the document belongs", + foreign_key="project.id", + nullable=False, + ondelete="CASCADE", + ) + fname: str = Field(description="The original filename of the document") + + +class Document(DocumentBase, table=True): id: UUID = Field( default_factory=uuid4, primary_key=True, + description="The unique identifier of the document", ) - owner_id: int = Field( - foreign_key="user.id", - nullable=False, - ondelete="CASCADE", - ) - fname: str object_store_url: str inserted_at: datetime = Field( - default_factory=now, + default_factory=now, description="The timestamp when the document was inserted" ) updated_at: datetime = Field( default_factory=now, + description="The timestamp when the document was last updated", ) + is_deleted: bool = Field(default=False) deleted_at: datetime | None - owner: User = Relationship(back_populates="documents") + +class DocumentPublic(DocumentBase): + id: UUID = Field(description="The unique identifier of the document") + signed_url: str | None = Field( + default=None, description="A signed URL for accessing the document" + ) + inserted_at: datetime = Field( + description="The timestamp when the document was inserted" + ) + updated_at: datetime = Field( + description="The timestamp when the document was last updated" + ) diff --git a/backend/app/models/project.py b/backend/app/models/project.py index 43c0012e..57c3068d 100644 --- a/backend/app/models/project.py +++ b/backend/app/models/project.py @@ -1,3 +1,4 @@ +from uuid import UUID, uuid4 from datetime import datetime from typing import Optional, List from sqlmodel import Field, Relationship, SQLModel, UniqueConstraint @@ -34,6 +35,7 @@ class Project(ProjectBase, table=True): organization_id: int = Field( foreign_key="organization.id", index=True, nullable=False, ondelete="CASCADE" ) + storage_path: UUID = Field(default_factory=uuid4, nullable=False, unique=True) inserted_at: datetime = Field(default_factory=now, nullable=False) updated_at: datetime = Field(default_factory=now, nullable=False) diff --git a/backend/app/models/user.py b/backend/app/models/user.py index fa4e3967..fa526ab5 100644 --- a/backend/app/models/user.py +++ b/backend/app/models/user.py @@ -48,9 +48,6 @@ class UpdatePassword(SQLModel): class User(UserBase, table=True): id: int = Field(default=None, primary_key=True) hashed_password: str - documents: list["Document"] = Relationship( - back_populates="owner", cascade_delete=True - ) collections: list["Collection"] = Relationship( back_populates="owner", cascade_delete=True ) diff --git a/backend/app/tests/api/routes/collections/test_create_collections.py b/backend/app/tests/api/routes/collections/test_create_collections.py index 2bede71a..22764df4 100644 --- a/backend/app/tests/api/routes/collections/test_create_collections.py +++ b/backend/app/tests/api/routes/collections/test_create_collections.py @@ -6,6 +6,7 @@ from fastapi.testclient import TestClient from unittest.mock import patch +from app.models import APIKeyPublic from app.core.config import settings from app.tests.utils.document import DocumentStore from app.tests.utils.utils import get_user_from_api_key @@ -35,7 +36,7 @@ class FakeS3Client: def head_object(self, Bucket, Key): return {"ContentLength": 1024} - monkeypatch.setattr("app.api.routes.collections.AmazonCloudStorage", FakeStorage) + monkeypatch.setattr("app.api.routes.collections.get_cloud_storage", FakeStorage) monkeypatch.setattr("boto3.client", lambda service: FakeS3Client()) @@ -48,9 +49,9 @@ def test_create_collection_success( mock_get_openai_client, client: TestClient, db: Session, - user_api_key_header, + user_api_key: APIKeyPublic, ): - store = DocumentStore(db) + store = DocumentStore(db, project_id=user_api_key.project_id) documents = store.fill(self._n_documents) doc_ids = [str(doc.id) for doc in documents] @@ -62,7 +63,7 @@ def test_create_collection_success( "temperature": 0.1, } - headers = user_api_key_header + headers = {"X-API-KEY": user_api_key.key} mock_openai_client = get_mock_openai_client_with_vector_store() mock_get_openai_client.return_value = mock_openai_client diff --git a/backend/app/tests/api/routes/documents/test_route_document_info.py b/backend/app/tests/api/routes/documents/test_route_document_info.py index 41b3ebc5..5ed92143 100644 --- a/backend/app/tests/api/routes/documents/test_route_document_info.py +++ b/backend/app/tests/api/routes/documents/test_route_document_info.py @@ -24,7 +24,7 @@ def test_response_is_success( route: Route, crawler: WebCrawler, ): - store = DocumentStore(db) + store = DocumentStore(db=db, project_id=crawler.user_api_key.project_id) response = crawler.get(route.append(store.put())) assert response.is_success @@ -35,7 +35,7 @@ def test_info_reflects_database( route: Route, crawler: WebCrawler, ): - store = DocumentStore(db) + store = DocumentStore(db=db, project_id=crawler.user_api_key.project_id) document = store.put() source = DocumentComparator(document) @@ -44,10 +44,10 @@ def test_info_reflects_database( assert source == target.data def test_cannot_info_unknown_document( - self, db: Session, route: Route, crawler: Route + self, db: Session, route: Route, crawler: WebCrawler ): DocumentStore.clear(db) - maker = DocumentMaker(db) + maker = DocumentMaker(project_id=crawler.user_api_key.project_id, session=db) response = crawler.get(route.append(next(maker))) assert response.is_error diff --git a/backend/app/tests/api/routes/documents/test_route_document_list.py b/backend/app/tests/api/routes/documents/test_route_document_list.py index b488b9ae..8b2ec7a7 100644 --- a/backend/app/tests/api/routes/documents/test_route_document_list.py +++ b/backend/app/tests/api/routes/documents/test_route_document_list.py @@ -48,7 +48,7 @@ def test_item_reflects_database( route: QueryRoute, crawler: WebCrawler, ): - store = DocumentStore(db) + store = DocumentStore(db=db, project_id=crawler.user_api_key.project_id) source = DocumentComparator(store.put()) response = httpx_to_standard(crawler.get(route)) @@ -78,7 +78,7 @@ def test_skip_greater_than_limit_is_difference( route: QueryRoute, crawler: WebCrawler, ): - store = DocumentStore(db) + store = DocumentStore(db=db, project_id=crawler.user_api_key.project_id) limit = len(store.fill(self._ndocs)) skip = limit // 2 diff --git a/backend/app/tests/api/routes/documents/test_route_document_permanent_remove.py b/backend/app/tests/api/routes/documents/test_route_document_permanent_remove.py index 10d11f3e..179d247d 100644 --- a/backend/app/tests/api/routes/documents/test_route_document_permanent_remove.py +++ b/backend/app/tests/api/routes/documents/test_route_document_permanent_remove.py @@ -59,8 +59,7 @@ def test_permanent_delete_document_from_s3( aws = AmazonCloudStorageClient() aws.create() - # Setup document in DB and S3 - store = DocumentStore(db) + store = DocumentStore(db=db, project_id=crawler.user_api_key.project_id) document = store.put() s3_key = Path(urlparse(document.object_store_url).path).relative_to("/") aws.client.put_object( @@ -94,7 +93,7 @@ def test_cannot_delete_nonexistent_document( ): DocumentStore.clear(db) - maker = DocumentMaker(db) + maker = DocumentMaker(project_id=crawler.user_api_key.project_id, session=db) response = crawler.delete(route.append(next(maker), suffix="permanent")) assert response.is_error diff --git a/backend/app/tests/api/routes/documents/test_route_document_remove.py b/backend/app/tests/api/routes/documents/test_route_document_remove.py index 292b2b10..7c1e3476 100644 --- a/backend/app/tests/api/routes/documents/test_route_document_remove.py +++ b/backend/app/tests/api/routes/documents/test_route_document_remove.py @@ -35,8 +35,8 @@ def test_response_is_success( client = OpenAI(api_key="sk-test-key") mock_get_openai_client.return_value = client - store = DocumentStore(db) - response = crawler.get(route.append(store.put())) + store = DocumentStore(db=db, project_id=crawler.user_api_key.project_id) + response = crawler.delete(route.append(store.put())) assert response.is_success @@ -54,15 +54,15 @@ def test_item_is_soft_removed( client = OpenAI(api_key="sk-test-key") mock_get_openai_client.return_value = client - store = DocumentStore(db) + store = DocumentStore(db=db, project_id=crawler.user_api_key.project_id) document = store.put() - crawler.get(route.append(document)) + crawler.delete(route.append(document)) db.refresh(document) statement = select(Document).where(Document.id == document.id) result = db.exec(statement).one() - assert result.deleted_at is not None + assert result.is_deleted is True @openai_responses.mock() @patch("app.api.routes.documents.get_openai_client") @@ -79,7 +79,10 @@ def test_cannot_remove_unknown_document( mock_get_openai_client.return_value = client DocumentStore.clear(db) - maker = DocumentMaker(db) - response = crawler.get(route.append(next(maker))) + + maker = DocumentMaker( + project_id=crawler.user_api_key.project_id, session=db + ) + response = crawler.delete(route.append(next(maker))) assert response.is_error diff --git a/backend/app/tests/api/routes/documents/test_route_document_upload.py b/backend/app/tests/api/routes/documents/test_route_document_upload.py index 4c6abaaa..25232861 100644 --- a/backend/app/tests/api/routes/documents/test_route_document_upload.py +++ b/backend/app/tests/api/routes/documents/test_route_document_upload.py @@ -9,6 +9,7 @@ from sqlmodel import Session, select from fastapi.testclient import TestClient +from app.models import APIKeyPublic from app.core.cloud import AmazonCloudStorageClient from app.core.config import settings from app.models import Document @@ -25,7 +26,7 @@ def put(self, route: Route, scratch: Path): with scratch.open("rb") as fp: return self.client.post( str(route), - headers=self.user_api_key_header, + headers={"X-API-KEY": self.user_api_key.key}, files={ "src": (str(scratch), fp, mtype), }, @@ -45,8 +46,8 @@ def route(): @pytest.fixture -def uploader(client: TestClient, user_api_key_header: dict[str, str]): - return WebUploader(client, user_api_key_header) +def uploader(client: TestClient, user_api_key: APIKeyPublic): + return WebUploader(client, user_api_key) @pytest.fixture(scope="class") @@ -80,6 +81,7 @@ def test_adds_to_database( def test_adds_to_S3( self, + db: Session, route: Route, scratch: Path, uploader: WebUploader, @@ -88,7 +90,13 @@ def test_adds_to_S3( aws.create() response = httpx_to_standard(uploader.put(route, scratch)) - url = urlparse(response.data["object_store_url"]) + doc_id = response.data["id"] + + # Get the document from database to access object_store_url + statement = select(Document).where(Document.id == doc_id) + result = db.exec(statement).one() + + url = urlparse(result.object_store_url) key = Path(url.path) key = key.relative_to(key.root) diff --git a/backend/app/tests/crud/collections/test_crud_collection_create.py b/backend/app/tests/crud/collections/test_crud_collection_create.py index d6feb064..53293d28 100644 --- a/backend/app/tests/crud/collections/test_crud_collection_create.py +++ b/backend/app/tests/crud/collections/test_crud_collection_create.py @@ -12,10 +12,10 @@ class TestCollectionCreate: @openai_responses.mock() def test_create_associates_documents(self, db: Session): - store = DocumentStore(db) - documents = store.fill(self._n_documents) - collection = get_collection(db) + store = DocumentStore(db, project_id=collection.project_id) + + documents = store.fill(self._n_documents) crud = CollectionCrud(db, collection.owner_id) collection = crud.create(collection, documents) diff --git a/backend/app/tests/crud/collections/test_crud_collection_delete.py b/backend/app/tests/crud/collections/test_crud_collection_delete.py index 8029c5e8..0a01588b 100644 --- a/backend/app/tests/crud/collections/test_crud_collection_delete.py +++ b/backend/app/tests/crud/collections/test_crud_collection_delete.py @@ -1,11 +1,13 @@ import pytest import openai_responses from openai import OpenAI -from sqlmodel import Session +from sqlmodel import Session, select from app.core.config import settings from app.crud import CollectionCrud +from app.models import APIKey from app.crud.rag import OpenAIAssistantCrud +from app.tests.utils.utils import get_project from app.tests.utils.document import DocumentStore from app.tests.utils.collection import get_collection, uuid_increment @@ -51,14 +53,21 @@ def test_cannot_delete_others_collections(self, db: Session): @openai_responses.mock() def test_delete_document_deletes_collections(self, db: Session): - store = DocumentStore(db) + project = get_project(db) + store = DocumentStore(db, project_id=project.id) documents = store.fill(1) + stmt = select(APIKey).where( + APIKey.project_id == project.id, APIKey.is_deleted == False + ) + api_key = db.exec(stmt).first() + owner_id = api_key.user_id + client = OpenAI(api_key="sk-test-key") resources = [] for _ in range(self._n_collections): - coll = get_collection(db, client) - crud = CollectionCrud(db, coll.owner_id) + coll = get_collection(db, client, owner_id=owner_id) + crud = CollectionCrud(db, owner_id=owner_id) collection = crud.create(coll, documents) resources.append((crud, collection)) diff --git a/backend/app/tests/crud/collections/test_crud_collection_read_all.py b/backend/app/tests/crud/collections/test_crud_collection_read_all.py index b6f63b88..f8cc82fb 100644 --- a/backend/app/tests/crud/collections/test_crud_collection_read_all.py +++ b/backend/app/tests/crud/collections/test_crud_collection_read_all.py @@ -11,14 +11,14 @@ def create_collections(db: Session, n: int): crud = None - store = DocumentStore(db) - documents = store.fill(1) openai_mock = OpenAIMock() with openai_mock.router: client = OpenAI(api_key="sk-test-key") for _ in range(n): collection = get_collection(db, client) + store = DocumentStore(db, project_id=collection.project_id) + documents = store.fill(1) if crud is None: crud = CollectionCrud(db, collection.owner_id) crud.create(collection, documents) diff --git a/backend/app/tests/crud/collections/test_crud_collection_read_one.py b/backend/app/tests/crud/collections/test_crud_collection_read_one.py index ed0d31ad..388a68ad 100644 --- a/backend/app/tests/crud/collections/test_crud_collection_read_one.py +++ b/backend/app/tests/crud/collections/test_crud_collection_read_one.py @@ -4,20 +4,19 @@ from sqlmodel import Session from sqlalchemy.exc import NoResultFound -from app.core.config import settings from app.crud import CollectionCrud +from app.core.config import settings from app.tests.utils.document import DocumentStore from app.tests.utils.collection import get_collection, uuid_increment def mk_collection(db: Session): - store = DocumentStore(db) - documents = store.fill(1) - openai_mock = OpenAIMock() with openai_mock.router: client = OpenAI(api_key="sk-test-key") collection = get_collection(db, client) + store = DocumentStore(db, project_id=collection.project_id) + documents = store.fill(1) crud = CollectionCrud(db, collection.owner_id) return crud.create(collection, documents) diff --git a/backend/app/tests/crud/documents/test_crud_document_delete.py b/backend/app/tests/crud/documents/test_crud_document_delete.py index 09329410..6ce8797b 100644 --- a/backend/app/tests/crud/documents/test_crud_document_delete.py +++ b/backend/app/tests/crud/documents/test_crud_document_delete.py @@ -6,15 +6,18 @@ from app.models import Document from app.tests.utils.document import DocumentStore +from app.tests.utils.utils import get_project +from app.tests.utils.test_data import create_test_project from app.core.exception_handlers import HTTPException @pytest.fixture def document(db: Session): - store = DocumentStore(db) + project = get_project(db) + store = DocumentStore(db, project.id) document = store.put() - crud = DocumentCrud(db, document.owner_id) + crud = DocumentCrud(db, document.project_id) crud.delete(document.id) statement = select(Document).where(Document.id == document.id) @@ -26,17 +29,18 @@ def test_delete_is_soft(self, document: Document): assert document is not None def test_delete_marks_deleted(self, document: Document): - assert document.deleted_at is not None + assert document.is_deleted is True def test_delete_follows_insert(self, document: Document): assert document.inserted_at <= document.deleted_at def test_cannot_delete_others_documents(self, db: Session): - store = DocumentStore(db) + project = get_project(db) + store = DocumentStore(db, project.id) document = store.put() - other_owner_id = store.documents.owner_id + 1 + other_project = create_test_project(db) - crud = DocumentCrud(db, other_owner_id) + crud = DocumentCrud(db, other_project.id) with pytest.raises(HTTPException) as exc_info: crud.delete(document.id) diff --git a/backend/app/tests/crud/documents/test_crud_document_read_many.py b/backend/app/tests/crud/documents/test_crud_document_read_many.py index 1a6ac6a7..4d23abb9 100644 --- a/backend/app/tests/crud/documents/test_crud_document_read_many.py +++ b/backend/app/tests/crud/documents/test_crud_document_read_many.py @@ -4,11 +4,13 @@ from app.crud import DocumentCrud from app.tests.utils.document import DocumentStore +from app.tests.utils.utils import get_project @pytest.fixture def store(db: Session): - ds = DocumentStore(db) + project = get_project(db) + ds = DocumentStore(db, project.id) ds.fill(TestDatabaseReadMany._ndocs) return ds @@ -22,7 +24,7 @@ def test_number_read_is_expected( db: Session, store: DocumentStore, ): - crud = DocumentCrud(db, store.owner) + crud = DocumentCrud(db, store.project.id) docs = crud.read_many() assert len(docs) == self._ndocs @@ -31,15 +33,15 @@ def test_deleted_docs_are_excluded( db: Session, store: DocumentStore, ): - crud = DocumentCrud(db, store.owner) - assert all(x.deleted_at is None for x in crud.read_many()) + crud = DocumentCrud(db, store.project.id) + assert all(x.is_deleted is False for x in crud.read_many()) def test_skip_is_respected( self, db: Session, store: DocumentStore, ): - crud = DocumentCrud(db, store.owner) + crud = DocumentCrud(db, store.project.id) skip = self._ndocs // 2 docs = crud.read_many(skip=skip) @@ -50,7 +52,7 @@ def test_zero_skip_includes_all( db: Session, store: DocumentStore, ): - crud = DocumentCrud(db, store.owner) + crud = DocumentCrud(db, store.project.id) docs = crud.read_many(skip=0) assert len(docs) == self._ndocs @@ -59,7 +61,7 @@ def test_big_skip_is_empty( db: Session, store: DocumentStore, ): - crud = DocumentCrud(db, store.owner) + crud = DocumentCrud(db, store.project.id) skip = self._ndocs + 1 assert not crud.read_many(skip=skip) @@ -68,7 +70,7 @@ def test_negative_skip_raises_exception( db: Session, store: DocumentStore, ): - crud = DocumentCrud(db, store.owner) + crud = DocumentCrud(db, store.project.id) with pytest.raises(ValueError): crud.read_many(skip=-1) @@ -77,7 +79,7 @@ def test_limit_is_respected( db: Session, store: DocumentStore, ): - crud = DocumentCrud(db, store.owner) + crud = DocumentCrud(db, store.project.id) limit = self._ndocs // 2 docs = crud.read_many(limit=limit) @@ -88,7 +90,7 @@ def test_zero_limit_includes_nothing( db: Session, store: DocumentStore, ): - crud = DocumentCrud(db, store.owner) + crud = DocumentCrud(db, store.project.id) assert not crud.read_many(limit=0) def test_negative_limit_raises_exception( @@ -96,7 +98,7 @@ def test_negative_limit_raises_exception( db: Session, store: DocumentStore, ): - crud = DocumentCrud(db, store.owner) + crud = DocumentCrud(db, store.project.id) with pytest.raises(ValueError): crud.read_many(limit=-1) @@ -105,7 +107,7 @@ def test_skip_greater_than_limit_is_difference( db: Session, store: DocumentStore, ): - crud = DocumentCrud(db, store.owner) + crud = DocumentCrud(db, store.project.id) limit = self._ndocs skip = limit // 2 docs = crud.read_many(skip=skip, limit=limit) diff --git a/backend/app/tests/crud/documents/test_crud_document_read_one.py b/backend/app/tests/crud/documents/test_crud_document_read_one.py index a3de8f49..3a9e896a 100644 --- a/backend/app/tests/crud/documents/test_crud_document_read_one.py +++ b/backend/app/tests/crud/documents/test_crud_document_read_one.py @@ -5,19 +5,22 @@ from app.crud import DocumentCrud from app.tests.utils.document import DocumentStore +from app.tests.utils.utils import get_project +from app.tests.utils.test_data import create_test_project from app.core.exception_handlers import HTTPException @pytest.fixture def store(db: Session): - return DocumentStore(db) + project = get_project(db) + return DocumentStore(db, project.id) class TestDatabaseReadOne: def test_can_select_valid_id(self, db: Session, store: DocumentStore): document = store.put() - crud = DocumentCrud(db, store.owner) + crud = DocumentCrud(db, store.project.id) result = crud.read_one(document.id) assert result.id == document.id @@ -25,7 +28,7 @@ def test_can_select_valid_id(self, db: Session, store: DocumentStore): def test_cannot_select_invalid_id(self, db: Session, store: DocumentStore): document = next(store.documents) - crud = DocumentCrud(db, store.owner) + crud = DocumentCrud(db, store.project.id) with pytest.raises(HTTPException) as exc_info: crud.read_one(document.id) @@ -35,9 +38,9 @@ def test_cannot_select_invalid_id(self, db: Session, store: DocumentStore): def test_cannot_read_others_documents(self, db: Session, store: DocumentStore): document = store.put() - other = DocumentStore(db) + other_project = create_test_project(db) - crud = DocumentCrud(db, other.owner) + crud = DocumentCrud(db, other_project.id) with pytest.raises(HTTPException) as exc_info: crud.read_one(document.id) diff --git a/backend/app/tests/crud/documents/test_crud_document_update.py b/backend/app/tests/crud/documents/test_crud_document_update.py index a14805bc..f38f16ad 100644 --- a/backend/app/tests/crud/documents/test_crud_document_update.py +++ b/backend/app/tests/crud/documents/test_crud_document_update.py @@ -4,17 +4,20 @@ from app.crud import DocumentCrud from app.tests.utils.document import DocumentMaker, DocumentStore +from app.tests.utils.utils import get_project +from app.tests.utils.test_data import create_test_project @pytest.fixture def documents(db: Session): - store = DocumentStore(db) + project = get_project(db) + store = DocumentStore(db, project.id) return store.documents class TestDatabaseUpdate: def test_update_adds_one(self, db: Session, documents: DocumentMaker): - crud = DocumentCrud(db, documents.owner_id) + crud = DocumentCrud(db, documents.project_id) before = crud.read_many() crud.update(next(documents)) @@ -27,7 +30,7 @@ def test_sequential_update_is_ordered( db: Session, documents: DocumentMaker, ): - crud = DocumentCrud(db, documents.owner_id) + crud = DocumentCrud(db, documents.project_id) (a, b) = (crud.update(y) for (_, y) in zip(range(2), documents)) assert a.inserted_at <= b.inserted_at @@ -37,22 +40,22 @@ def test_insert_does_not_delete( db: Session, documents: DocumentMaker, ): - crud = DocumentCrud(db, documents.owner_id) + crud = DocumentCrud(db, documents.project_id) document = crud.update(next(documents)) - assert document.deleted_at is None + assert document.is_deleted is False def test_update_sets_default_owner( self, db: Session, documents: DocumentMaker, ): - crud = DocumentCrud(db, documents.owner_id) + crud = DocumentCrud(db, documents.project_id) document = next(documents) - document.owner_id = None + document.project_id = None result = crud.update(document) - assert result.owner_id == document.owner_id + assert result.project_id == documents.project_id def test_update_respects_owner( self, @@ -60,8 +63,9 @@ def test_update_respects_owner( documents: DocumentMaker, ): document = next(documents) - document.owner_id = documents.index.peek() + other_project = create_test_project(db) + document.project_id = other_project.id - crud = DocumentCrud(db, documents.owner_id) + crud = DocumentCrud(db, documents.project_id) with pytest.raises(PermissionError): crud.update(document) diff --git a/backend/app/tests/utils/collection.py b/backend/app/tests/utils/collection.py index 411af994..b2d3ae94 100644 --- a/backend/app/tests/utils/collection.py +++ b/backend/app/tests/utils/collection.py @@ -21,8 +21,9 @@ def uuid_increment(value: UUID): return UUID(int=inc) -def get_collection(db: Session, client=None): - owner_id = get_user_id_by_email(db) +def get_collection(db: Session, client=None, owner_id: int = None) -> Collection: + if owner_id is None: + owner_id = get_user_id_by_email(db) # Step 1: Create real organization and project entries project = create_test_project(db) diff --git a/backend/app/tests/utils/document.py b/backend/app/tests/utils/document.py index e6234a1c..6a08733b 100644 --- a/backend/app/tests/utils/document.py +++ b/backend/app/tests/utils/document.py @@ -12,15 +12,11 @@ from fastapi.testclient import TestClient from app.core.config import settings -from app.models import Document +from app.crud.project import get_project_by_id +from app.models import APIKeyPublic, Document, DocumentPublic, Project from app.utils import APIResponse -from .utils import SequentialUuidGenerator, get_user_id_by_email - - -@ft.cache -def _get_user_id_by_email(db: Session): - return get_user_id_by_email(db) +from .utils import SequentialUuidGenerator def httpx_to_standard(response: Response): @@ -28,48 +24,54 @@ def httpx_to_standard(response: Response): class DocumentMaker: - def __init__(self, db: Session): - self.owner_id = _get_user_id_by_email(db) + def __init__(self, project_id: int, session: Session): + self.project_id = project_id + self.session = session + self.project: Project = None self.index = SequentialUuidGenerator() def __iter__(self): return self def __next__(self): + if self.project is None: + self.project = get_project_by_id( + session=self.session, project_id=self.project_id + ) + doc_id = next(self.index) - key = f"{self.owner_id}/{doc_id}.txt" + key = f"{self.project.storage_path}/{doc_id}.txt" object_store_url = f"s3://{settings.AWS_S3_BUCKET}/{key}" return Document( id=doc_id, - owner_id=self.owner_id, + project_id=self.project.id, fname=f"{doc_id}.xyz", object_store_url=object_store_url, + is_deleted=False, ) class DocumentStore: + def __init__(self, db: Session, project_id: int): + self.db = db + self.documents = DocumentMaker(project_id=project_id, session=db) + self.clear(self.db) + @staticmethod def clear(db: Session): db.exec(delete(Document)) db.commit() @property - def owner(self): - return self.documents.owner_id - - def __init__(self, db: Session): - self.db = db - self.documents = DocumentMaker(self.db) - self.clear(self.db) + def project(self): + return self.documents.project def put(self): doc = next(self.documents) - self.db.add(doc) self.db.commit() self.db.refresh(doc) - return doc def extend(self, n: int): @@ -113,22 +115,24 @@ def append(self, doc: Document, suffix: str = None): @dataclass class WebCrawler: client: TestClient - user_api_key_header: dict[str, str] + user_api_key: APIKeyPublic def get(self, route: Route): return self.client.get( str(route), - headers=self.user_api_key_header, + headers={"X-API-KEY": self.user_api_key.key}, ) def delete(self, route: Route): return self.client.delete( str(route), - headers=self.user_api_key_header, + headers={"X-API-KEY": self.user_api_key.key}, ) class DocumentComparator: + """Compare a Document model against the DocumentPublic API response.""" + @ft.singledispatchmethod @staticmethod def to_string(value): @@ -148,15 +152,21 @@ def __init__(self, document: Document): self.document = document def __eq__(self, other: dict): - this = dict(self.to_dict()) + this = dict(self.to_public_dict()) return this == other - def to_dict(self): - document = dict(self.document) - for k, v in document.items(): - yield (k, self.to_string(v)) + def to_public_dict(self) -> dict: + """Convert Document to dict matching DocumentPublic schema.""" + field_names = DocumentPublic.model_fields.keys() + + result = {} + for field in field_names: + value = getattr(self.document, field, None) + result[field] = self.to_string(value) + + return result @pytest.fixture -def crawler(client: TestClient, user_api_key_header: dict[str, str]): - return WebCrawler(client, user_api_key_header) +def crawler(client: TestClient, user_api_key: APIKeyPublic): + return WebCrawler(client, user_api_key=user_api_key)