diff --git a/.env b/.env index 1d44286e..9df95374 100644 --- a/.env +++ b/.env @@ -43,3 +43,8 @@ SENTRY_DSN= # Configure these with your own Docker registry images DOCKER_IMAGE_BACKEND=backend DOCKER_IMAGE_FRONTEND=frontend + +# AWS +AWS_ACCESS_KEY_ID= +AWS_SECRET_ACCESS_KEY= +AWS_DEFAULT_REGION= diff --git a/.gitignore b/.gitignore index c0b7dd63..c807732b 100644 --- a/.gitignore +++ b/.gitignore @@ -14,4 +14,7 @@ ENV/ # .DS_Store: macOS Finder metadata file that stores folder view settings and icon positions. -**/.DS_Store \ No newline at end of file +**/.DS_Store + +# Emacs +*~ diff --git a/backend/app/api/routes/documents.py b/backend/app/api/routes/documents.py index 63f0f3a8..185a6144 100644 --- a/backend/app/api/routes/documents.py +++ b/backend/app/api/routes/documents.py @@ -1,33 +1,96 @@ -from fastapi import APIRouter -from sqlmodel import select, and_ +import warnings +from uuid import UUID, uuid4 -from app.api.deps import CurrentUser, SessionDep +from fastapi import APIRouter, File, UploadFile, HTTPException + +from sqlalchemy.exc import NoResultFound, MultipleResultsFound, SQLAlchemyError + +from app.crud import DocumentCrud from app.models import Document, DocumentList +from app.api.deps import CurrentUser, SessionDep +from app.core.cloud import AmazonCloudStorage, CloudStorageError router = APIRouter(prefix="/documents", tags=["documents"]) +def raise_from_unknown(error: Exception): + warnings.warn('Unexpected exception "{}": {}'.format( + type(error).__name__, + error, + )) + raise HTTPException(status_code=500, detail=str(error)) -@router.get( - "/ls", - response_model=DocumentList, -) +@router.get("/ls", response_model=DocumentList) def list_docs( session: SessionDep, current_user: CurrentUser, skip: int = 0, limit: int = 100, ): - statement = (select(Document) - .where( - and_( - Document.owner_id == current_user.id, - Document.deleted_at.is_(None), - ), - ) - .offset(skip) - .limit(limit)) - docs = (session - .exec(statement) - .all()) - - return DocumentList(docs=docs) + crud = DocumentCrud(session, current_user.id) + try: + return crud.read_many(skip, limit) + except (ValueError, SQLAlchemyError) as err: + raise HTTPException(status_code=500, detail=str(err)) + except Exception as err: + raise_from_unknown(err) + +@router.post("/cp") +def upload_doc( + session: SessionDep, + current_user: CurrentUser, + src: UploadFile = File(...), +): + storage = AmazonCloudStorage(current_user) + basename = uuid4() + try: + object_store_url = storage.put(src, str(basename)) + except CloudStorageError as err: + raise HTTPException(status_code=503, detail=str(err)) + except Exception as err: + raise_from_unknown(err) + + crud = DocumentCrud(session, current_user.id) + document = Document( + id=basename, + fname=src.filename, + object_store_url=str(object_store_url) + ) + + try: + return crud.update(document) + except SQLAlchemyError as err: + raise HTTPException(status_code=503, detail=str(err)) + except Exception as err: + raise_from_unknown(err) + +@router.get("/rm/{doc_id}") +def delete_doc( + session: SessionDep, + current_user: CurrentUser, + doc_id: UUID, +): + crud = DocumentCrud(session, current_user.id) + try: + return crud.delete(doc_id) + except NoResultFound as err: + raise HTTPException(status_code=404, detail=str(err)) + except Exception as err: + raise_from_unknown(err) + + # TODO: perform delete on the collection + +@router.get("/stat/{doc_id}", response_model=Document) +def doc_info( + session: SessionDep, + current_user: CurrentUser, + doc_id: UUID, +): + crud = DocumentCrud(session, current_user.id) + try: + return crud.read_one(doc_id) + except NoResultFound as err: + raise HTTPException(status_code=404, detail=str(err)) + except MultipleResultsFound as err: + raise HTTPException(status_code=503, detail=str(err)) + except Exception as err: + raise_from_unknown(err) diff --git a/backend/app/core/cloud/__init__.py b/backend/app/core/cloud/__init__.py new file mode 100644 index 00000000..cde54f38 --- /dev/null +++ b/backend/app/core/cloud/__init__.py @@ -0,0 +1,5 @@ +from .storage import ( + AmazonCloudStorage, + AmazonCloudStorageClient, + CloudStorageError, +) diff --git a/backend/app/core/cloud/storage.py b/backend/app/core/cloud/storage.py new file mode 100644 index 00000000..bb1a2a44 --- /dev/null +++ b/backend/app/core/cloud/storage.py @@ -0,0 +1,96 @@ +import os +import functools as ft +from pathlib import Path +from dataclasses import dataclass, asdict +from urllib.parse import ParseResult, urlunparse + +import boto3 +from fastapi import UploadFile +from botocore.exceptions import ClientError + +from app.api.deps import CurrentUser +from app.core.config import settings + +class CloudStorageError(Exception): + pass + +class AmazonCloudStorageClient: + @ft.cached_property + def client(self): + kwargs = {} + cred_params = ( + ('aws_access_key_id', 'AWS_ACCESS_KEY_ID'), + ('aws_secret_access_key', 'AWS_SECRET_ACCESS_KEY'), + ('region_name', 'AWS_DEFAULT_REGION'), + ) + + for (i, j) in cred_params: + kwargs[i] = os.environ.get(j, getattr(settings, j)) + + return boto3.client('s3', **kwargs) + + def create(self): + try: + # does the bucket exist... + self.client.head_bucket(Bucket=settings.AWS_S3_BUCKET) + except ClientError as err: + response = int(err.response['Error']['Code']) + if response != 404: + raise CloudStorageError(err) from err + # ... if not create it + self.client.create_bucket( + Bucket=settings.AWS_S3_BUCKET, + CreateBucketConfiguration={ + 'LocationConstraint': settings.AWS_DEFAULT_REGION, + }, + ) + +@dataclass(frozen=True) +class SimpleStorageName: + Key: str + Bucket: str = settings.AWS_S3_BUCKET + + def __str__(self): + return urlunparse(self.to_url()) + + def to_url(self): + kwargs = { + 'scheme': 's3', + 'netloc': self.Bucket, + 'path': self.Key, + } + for k in ParseResult._fields: + kwargs.setdefault(k) + + return ParseResult(**kwargs) + +class CloudStorage: + def __init__(self, user: CurrentUser): + self.user = user + + def put(self, source: UploadFile, basename: str): + raise NotImplementedError() + +class AmazonCloudStorage(CloudStorage): + def __init__(self, user: CurrentUser): + super().__init__(user) + self.aws = AmazonCloudStorageClient() + + def put(self, source: UploadFile, basename: str): + key = Path(str(self.user.id), basename) + destination = SimpleStorageName(str(key)) + + kwargs = asdict(destination) + try: + self.aws.client.upload_fileobj( + source.file, + ExtraArgs={ + # 'Metadata': self.user.model_dump(), + 'ContentType': source.content_type, + }, + **kwargs, + ) + except ClientError as err: + raise CloudStorageError(f'AWS Error: "{err}"') from err + + return destination diff --git a/backend/app/core/config.py b/backend/app/core/config.py index d58e03c8..b98e104e 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -95,6 +95,15 @@ def emails_enabled(self) -> bool: FIRST_SUPERUSER: EmailStr FIRST_SUPERUSER_PASSWORD: str + AWS_ACCESS_KEY_ID: str + AWS_SECRET_ACCESS_KEY: str + AWS_DEFAULT_REGION: str + + @computed_field # type: ignore[prop-decorator] + @property + def AWS_S3_BUCKET(self) -> str: + return f'ai-platform-documents-{self.ENVIRONMENT}' + def _check_default_secret(self, var_name: str, value: str | None) -> None: if value == "changethis": message = ( diff --git a/backend/app/core/util.py b/backend/app/core/util.py new file mode 100644 index 00000000..13fbc2aa --- /dev/null +++ b/backend/app/core/util.py @@ -0,0 +1,4 @@ +from datetime import datetime, timezone + +def now(): + return datetime.now(timezone.utc).replace(tzinfo=None) diff --git a/backend/app/crud/__init__.py b/backend/app/crud/__init__.py index c0fbc4e7..85edaa12 100644 --- a/backend/app/crud/__init__.py +++ b/backend/app/crud/__init__.py @@ -4,4 +4,6 @@ create_user, get_user_by_email, update_user, -) \ No newline at end of file +) + +from .document import DocumentCrud diff --git a/backend/app/crud/document.py b/backend/app/crud/document.py new file mode 100644 index 00000000..d4ead6cb --- /dev/null +++ b/backend/app/crud/document.py @@ -0,0 +1,69 @@ +from uuid import UUID +from typing import Optional + +from sqlmodel import Session, select, and_ + +from app.models import Document, DocumentList +from app.core.util import now + +class DocumentCrud: + def __init__(self, session: Session, owner_id: UUID): + self.session = session + self.owner_id = owner_id + + def read_one(self, doc_id: UUID): + statement = ( + select(Document) + .where(and_( + Document.owner_id == self.owner_id, + Document.id == doc_id, + )) + ) + + return self.session.exec(statement).one() + + def read_many( + self, + skip: Optional[int] = None, + limit: Optional[int] = None, + ): + statement = ( + select(Document) + .where(and_( + Document.owner_id == self.owner_id, + Document.deleted_at.is_(None), + )) + ) + if skip is not None: + if skip < 0: + raise ValueError(f'Negative skip: {skip}') + statement = statement.offset(skip) + if limit is not None: + if limit < 0: + raise ValueError(f'Negative limit: {limit}') + statement = statement.limit(limit) + docs = self.session.exec(statement).all() + + return DocumentList(docs=docs) + + def update(self, document: Document): + if not document.owner_id: + document.owner_id = self.owner_id + elif document.owner_id != self.owner_id: + error = 'Invalid document ownership: owner={} attempter={}'.format( + self.owner_id, + document.owner_id, + ) + raise PermissionError(error) + + self.session.add(document) + self.session.commit() + self.session.refresh(document) + + return document + + def delete(self, doc_id: UUID): + document = self.read_one(doc_id) + document.deleted_at = now() + + return self.update(document) diff --git a/backend/app/initial_storage.py b/backend/app/initial_storage.py new file mode 100644 index 00000000..4b48a2a5 --- /dev/null +++ b/backend/app/initial_storage.py @@ -0,0 +1,25 @@ +import logging + +from botocore.exceptions import ClientError + +from app.core.cloud import AmazonCloudStorageClient, CloudStorageError +from app.core.config import settings + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +def init() -> None: + aws = AmazonCloudStorageClient() + try: + aws.create() + except CloudStorageError as err: + logging.error(err) + +def main() -> None: + logger.info("START: setup cloud storage") + init() + logger.info("END: setup cloud storage") + + +if __name__ == "__main__": + main() diff --git a/backend/app/models/document.py b/backend/app/models/document.py index 4f039d31..363eb6fb 100644 --- a/backend/app/models/document.py +++ b/backend/app/models/document.py @@ -1,14 +1,11 @@ from uuid import UUID, uuid4 -from pathlib import Path -from datetime import datetime, timezone +from datetime import datetime from sqlmodel import Field, Relationship, SQLModel +from app.core.util import now from .user import User -def now(): - return datetime.now(timezone.utc).replace(tzinfo=None) - class Document(SQLModel, table=True): id: UUID = Field( default_factory=uuid4, @@ -19,13 +16,12 @@ class Document(SQLModel, table=True): nullable=False, ondelete='CASCADE', ) - fname_external: Path - fname_internal: Path + fname: str object_store_url: str created_at: datetime = Field( default_factory=now, ) - updated_at: datetime | None + # updated_at: datetime | None deleted_at: datetime | None owner: User = Relationship(back_populates='documents') @@ -33,5 +29,11 @@ class Document(SQLModel, table=True): class DocumentList(SQLModel): docs: list[Document] + def __bool__(self): + return bool(self.docs) + def __len__(self): return len(self.docs) + + def __iter__(self): + yield from self.docs diff --git a/backend/app/tests/api/routes/documents/test_delete.py b/backend/app/tests/api/routes/documents/test_delete.py new file mode 100644 index 00000000..9ec82356 --- /dev/null +++ b/backend/app/tests/api/routes/documents/test_delete.py @@ -0,0 +1,59 @@ +import pytest +from sqlmodel import Session, select + +from app.models import Document +from app.tests.utils.document import ( + DocumentMaker, + DocumentStore, + Route, + WebCrawler, + crawler, +) + +@pytest.fixture +def route(): + return Route('rm') + +class TestDocumentRouteDelete: + def test_response_is_success( + self, + db: Session, + route: Route, + crawler: WebCrawler, + ): + store = DocumentStore(db) + response = crawler.get(route.append(store.put())) + + assert response.is_success + + def test_item_is_soft_deleted( + self, + db: Session, + route: Route, + crawler: WebCrawler, + ): + store = DocumentStore(db) + document = store.put() + + crawler.get(route.append(document)) + db.refresh(document) + statement = ( + select(Document) + .where(Document.id == document.id) + ) + result = db.exec(statement).one() + + assert result.deleted_at is not None + + def test_cannot_delete_unknown_document( + self, + db: Session, + route: Route, + crawler: WebCrawler, + ): + DocumentStore.clear(db) + + maker = DocumentMaker(db) + response = crawler.get(route.append(next(maker))) + + assert response.is_error diff --git a/backend/app/tests/api/routes/documents/test_ls.py b/backend/app/tests/api/routes/documents/test_ls.py new file mode 100644 index 00000000..3c5be821 --- /dev/null +++ b/backend/app/tests/api/routes/documents/test_ls.py @@ -0,0 +1,72 @@ +import pytest +from sqlmodel import Session + +from app.tests.utils.document import ( + DocumentComparator, + DocumentStore, + Route, + WebCrawler, + crawler, +) + +class QueryRoute(Route): + def pushq(self, key, value): + qs_args = self.qs_args | { + key: value, + } + return type(self)(self.endpoint, **qs_args) + +@pytest.fixture +def route(): + return QueryRoute('ls') + +class TestDocumentRouteList: + def test_response_is_success(self, route: QueryRoute, crawler: WebCrawler): + response = crawler.get(route) + assert response.is_success + + def test_empty_db_returns_empty_list( + self, + db: Session, + route: QueryRoute, + crawler: WebCrawler, + ): + DocumentStore.clear(db) + docs = (crawler + .get(route) + .json() + .get('docs')) + + assert not docs + + def test_item_reflects_database( + self, + db: Session, + route: QueryRoute, + crawler: WebCrawler, + ): + store = DocumentStore(db) + source = DocumentComparator(store.put()) + target = (crawler + .get(route) + .json() + .get('docs') + .pop()) + + assert source == target + + def test_negative_skip_produces_error( + self, + route: QueryRoute, + crawler: WebCrawler, + ): + response = crawler.get(route.pushq('skip', -1)) + assert response.is_error + + def test_negative_limit_produces_error( + self, + route: QueryRoute, + crawler: WebCrawler, + ): + response = crawler.get(route.pushq('limit', -1)) + assert response.is_error diff --git a/backend/app/tests/api/routes/documents/test_stat.py b/backend/app/tests/api/routes/documents/test_stat.py new file mode 100644 index 00000000..893430bc --- /dev/null +++ b/backend/app/tests/api/routes/documents/test_stat.py @@ -0,0 +1,55 @@ +import pytest +from sqlmodel import Session + +from app.tests.utils.document import ( + DocumentComparator, + DocumentMaker, + DocumentStore, + Route, + WebCrawler, + crawler, +) + +@pytest.fixture +def route(): + return Route('stat') + +class TestDocumentRouteStat: + def test_response_is_success( + self, + db: Session, + route: Route, + crawler: WebCrawler, + ): + store = DocumentStore(db) + response = crawler.get(route.append(store.put())) + + assert response.is_success + + def test_stat_reflects_database( + self, + db: Session, + route: Route, + crawler: WebCrawler, + ): + store = DocumentStore(db) + document = store.put() + source = DocumentComparator(document) + + target = (crawler + .get(route.append(document)) + .json()) + + assert source == target + + def test_cannot_stat_unknown_document( + self, + db: Session, + route: Route, + crawler: Route, + ): + DocumentStore.clear(db) + maker = DocumentMaker(db) + response = crawler.get(route.append(next(maker))) + + assert response.is_error diff --git a/backend/app/tests/api/routes/documents/test_upload.py b/backend/app/tests/api/routes/documents/test_upload.py new file mode 100644 index 00000000..73e03600 --- /dev/null +++ b/backend/app/tests/api/routes/documents/test_upload.py @@ -0,0 +1,93 @@ +import os +import mimetypes +from pathlib import Path +from tempfile import NamedTemporaryFile +from urllib.parse import urlparse + +import pytest +from moto import mock_aws +from sqlmodel import Session, select +from fastapi.testclient import TestClient + +from app.core.cloud import AmazonCloudStorageClient +from app.core.config import settings +from app.models import Document +from app.tests.utils.document import ( + Route, + WebCrawler, +) + +class WebUploader(WebCrawler): + def put(self, route: Route, scratch: Path): + (mtype, _) = mimetypes.guess_type(str(scratch)) + with scratch.open('rb') as fp: + return self.client.post( + str(route), + headers=self.superuser_token_headers, + files={ + 'src': (str(scratch), fp, mtype), + }, + ) + +@pytest.fixture +def scratch(): + with NamedTemporaryFile(mode='w', suffix='.txt') as fp: + print('Hello World', file=fp, flush=True) + yield Path(fp.name) + +@pytest.fixture +def route(): + return Route('cp') + +@pytest.fixture +def uploader(client: TestClient, superuser_token_headers: dict[str, str]): + return WebUploader(client, superuser_token_headers) + +@pytest.fixture(scope='class') +def aws_credentials(): + os.environ['AWS_ACCESS_KEY_ID'] = 'testing' + os.environ['AWS_SECRET_ACCESS_KEY'] = 'testing' + os.environ['AWS_SECURITY_TOKEN'] = 'testing' + os.environ['AWS_SESSION_TOKEN'] = 'testing' + os.environ['AWS_DEFAULT_REGION'] = settings.AWS_DEFAULT_REGION + +@mock_aws +@pytest.mark.usefixtures('aws_credentials') +class TestDocumentRouteUpload: + def test_adds_to_database( + self, + db: Session, + route: Route, + scratch: Path, + uploader: WebUploader, + ): + aws = AmazonCloudStorageClient() + aws.create() + + response = uploader.put(route, scratch) + doc_id = (response + .json() + .get('id')) + statement = ( + select(Document) + .where(Document.id == doc_id) + ) + result = db.exec(statement).one() + + assert result.fname == str(scratch) + + def test_adds_to_S3( + self, + route: Route, + scratch: Path, + uploader: WebUploader, + ): + aws = AmazonCloudStorageClient() + aws.create() + + response = uploader.put(route, scratch) + url = urlparse(response.json().get('object_store_url')) + key = Path(url.path) + key = key.relative_to(key.root) + + assert aws.client.head_object(Bucket=url.netloc, Key=str(key)) diff --git a/backend/app/tests/crud/documents/test_delete.py b/backend/app/tests/crud/documents/test_delete.py new file mode 100644 index 00000000..4d78c6f8 --- /dev/null +++ b/backend/app/tests/crud/documents/test_delete.py @@ -0,0 +1,41 @@ +import pytest +from sqlmodel import Session, select +from sqlalchemy.exc import NoResultFound + +from app.crud import DocumentCrud +from app.models import Document + +from app.tests.utils.document import DocumentStore + +@pytest.fixture +def document(db: Session): + store = DocumentStore(db) + document = store.put() + + crud = DocumentCrud(db, document.owner_id) + crud.delete(document.id) + + statement = ( + select(Document) + .where(Document.id == document.id) + ) + return db.exec(statement).one() + +class TestDatabaseDelete: + def test_delete_is_soft(self, document: Document): + assert document is not None + + def test_delete_marks_deleted(self, document: Document): + assert document.deleted_at is not None + + def test_delete_follows_insert(self, document: Document): + assert document.created_at <= document.deleted_at + + def test_cannot_delete_others_documents(self, db: Session): + store = DocumentStore(db) + document = store.put() + other_owner_id = store.documents.index.peek() + + crud = DocumentCrud(db, other_owner_id) + with pytest.raises(NoResultFound): + crud.delete(document.id) diff --git a/backend/app/tests/crud/documents/test_read_many.py b/backend/app/tests/crud/documents/test_read_many.py new file mode 100644 index 00000000..bc1592b3 --- /dev/null +++ b/backend/app/tests/crud/documents/test_read_many.py @@ -0,0 +1,93 @@ +import pytest +from sqlmodel import Session + +from app.crud import DocumentCrud + +from app.tests.utils.document import DocumentStore, DocumentIndexGenerator + +@pytest.fixture +def store(db: Session): + ds = DocumentStore(db) + for _ in ds.fill(TestDatabaseReadMany._ndocs): + pass + + return ds + +class TestDatabaseReadMany: + _ndocs = 10 + + def test_number_read_is_expected( + self, + db: Session, + store: DocumentStore, + ): + crud = DocumentCrud(db, store.owner) + docs = crud.read_many() + assert len(docs) == self._ndocs + + def test_deleted_docs_are_excluded( + self, + db: Session, + store: DocumentStore, + ): + crud = DocumentCrud(db, store.owner) + assert all(x.deleted_at is None for x in crud.read_many()) + + def test_skip_is_respected( + self, + db: Session, + store: DocumentStore, + ): + crud = DocumentCrud(db, store.owner) + skip = self._ndocs // 2 + doc_ids = set(x.id for x in crud.read_many(skip=skip)) + index = DocumentIndexGenerator(skip) + + for (_, doc) in zip(range(skip, self._ndocs), index): + assert doc in doc_ids + + def test_zero_skip_includes_all( + self, + db: Session, + store: DocumentStore, + ): + crud = DocumentCrud(db, store.owner) + docs = crud.read_many(skip=0) + assert len(docs) == self._ndocs + + def test_negative_skip_raises_exception( + self, + db: Session, + store: DocumentStore, + ): + crud = DocumentCrud(db, store.owner) + with pytest.raises(ValueError): + crud.read_many(skip=-1) + + def test_limit_is_respected( + self, + db: Session, + store: DocumentStore, + ): + crud = DocumentCrud(db, store.owner) + limit = self._ndocs // 2 + docs = crud.read_many(limit=limit) + + assert len(docs) == limit + + def test_zero_limit_includes_nothing( + self, + db: Session, + store: DocumentStore, + ): + crud = DocumentCrud(db, store.owner) + assert not crud.read_many(limit=0) + + def test_negative_limit_raises_exception( + self, + db: Session, + store: DocumentStore, + ): + crud = DocumentCrud(db, store.owner) + with pytest.raises(ValueError): + crud.read_many(limit=-1) diff --git a/backend/app/tests/crud/documents/test_read_one.py b/backend/app/tests/crud/documents/test_read_one.py new file mode 100644 index 00000000..0d25055e --- /dev/null +++ b/backend/app/tests/crud/documents/test_read_one.py @@ -0,0 +1,39 @@ +import pytest +from sqlmodel import Session +from sqlalchemy.exc import NoResultFound + +from app.crud import DocumentCrud + +from app.tests.utils.document import DocumentStore + +@pytest.fixture +def store(db: Session): + return DocumentStore(db) + +class TestDatabaseReadOne: + def test_can_select_valid_id(self, db: Session, store: DocumentStore): + document = store.put() + + crud = DocumentCrud(db, store.owner) + result = crud.read_one(document.id) + + assert result.id == document.id + + def test_cannot_select_invalid_id(self, db: Session, store: DocumentStore): + document = next(store.documents) + + crud = DocumentCrud(db, store.owner) + with pytest.raises(NoResultFound): + crud.read_one(document.id) + + def test_cannot_read_others_documents( + self, + db: Session, + store: DocumentStore, + ): + document = store.put() + other = DocumentStore(db) + + crud = DocumentCrud(db, other.owner) + with pytest.raises(NoResultFound): + crud.read_one(document.id) diff --git a/backend/app/tests/crud/documents/test_update.py b/backend/app/tests/crud/documents/test_update.py new file mode 100644 index 00000000..6a14217c --- /dev/null +++ b/backend/app/tests/crud/documents/test_update.py @@ -0,0 +1,65 @@ +import pytest +from sqlmodel import Session + +from app.crud import DocumentCrud + +from app.tests.utils.document import DocumentMaker, DocumentStore + +@pytest.fixture +def documents(db: Session): + store = DocumentStore(db) + return store.documents + +class TestDatabaseUpdate: + def test_update_adds_one(self, db: Session, documents: DocumentMaker): + crud = DocumentCrud(db, documents.owner_id) + + before = crud.read_many() + crud.update(next(documents)) + after = crud.read_many() + + assert len(before) + 1 == len(after) + + def test_sequential_update_is_ordered( + self, + db: Session, + documents: DocumentMaker, + ): + crud = DocumentCrud(db, documents.owner_id) + (a, b) = (crud.update(y) for (_, y) in zip(range(2), documents)) + + assert a.created_at <= b.created_at + + def test_insert_does_not_delete( + self, + db: Session, + documents: DocumentMaker, + ): + crud = DocumentCrud(db, documents.owner_id) + document = crud.update(next(documents)) + + assert document.deleted_at is None + + def test_update_sets_default_owner( + self, + db: Session, + documents: DocumentMaker, + ): + crud = DocumentCrud(db, documents.owner_id) + document = next(documents) + document.owner_id = None + result = crud.update(document) + + assert result.owner_id == document.owner_id + + def test_update_respects_owner( + self, + db: Session, + documents: DocumentMaker, + ): + document = next(documents) + document.owner_id = documents.index.peek() + + crud = DocumentCrud(db, documents.owner_id) + with pytest.raises(PermissionError): + crud.update(document) diff --git a/backend/app/tests/utils/document.py b/backend/app/tests/utils/document.py new file mode 100644 index 00000000..3ee4111c --- /dev/null +++ b/backend/app/tests/utils/document.py @@ -0,0 +1,155 @@ +import itertools as it +import functools as ft +from uuid import UUID +from pathlib import Path +from datetime import datetime +from dataclasses import dataclass +from urllib.parse import ParseResult, urlunparse + +import pytest +from sqlmodel import Session, delete +from fastapi.testclient import TestClient + +from app.core.config import settings +from app.crud.user import get_user_by_email +from app.models import Document + +@ft.cache +def get_user_id_by_email(db: Session): + user = get_user_by_email(session=db, email=settings.FIRST_SUPERUSER) + return user.id + +class DocumentIndexGenerator: + def __init__(self, start=0): + self.start = start + + def __iter__(self): + return self + + def __next__(self): + uu_id = self.peek() + self.start += 1 + return uu_id + + def peek(self): + return UUID(int=self.start) + +class DocumentMaker: + def __init__(self, db: Session): + self.owner_id = get_user_id_by_email(db) + self.index = DocumentIndexGenerator() + + def __iter__(self): + return self + + def __next__(self): + doc_id = next(self.index) + args = str(doc_id).split('-') + fname = Path('/', *args).with_suffix('.xyz') + + return Document( + id=doc_id, + owner_id=self.owner_id, + fname=fname.name, + object_store_url=fname.as_uri(), + ) + +class DocumentStore: + @staticmethod + def clear(db: Session): + db.exec(delete(Document)) + db.commit() + + @property + def owner(self): + return self.documents.owner_id + + def __init__(self, db: Session): + self.db = db + self.documents = DocumentMaker(self.db) + self.clear(self.db) + + def put(self): + doc = next(self.documents) + + self.db.add(doc) + self.db.commit() + self.db.refresh(doc) + + return doc + + def extend(self, n: int): + for _ in range(n): + yield self.put() + + def fill(self, n: int): + return list(self.extend(n)) + +class Route: + _empty = ParseResult(*it.repeat('', len(ParseResult._fields))) + _root = Path(settings.API_V1_STR, 'documents') + + def __init__(self, endpoint, **qs_args): + self.endpoint = endpoint + self.qs_args = qs_args + + def __str__(self): + return urlunparse(self.to_url()) + + def to_url(self): + path = self._root.joinpath(self.endpoint) + kwargs = { + 'path': str(path), + } + if self.qs_args: + query = '&'.join(it.starmap('{}={}'.format, self.qs_args.items())) + kwargs['query'] = query + + return self._empty._replace(**kwargs) + + def append(self, doc: Document): + endpoint = Path(self.endpoint, str(doc.id)) + return type(self)(endpoint, **self.qs_args) + +@dataclass +class WebCrawler: + client: TestClient + superuser_token_headers: dict[str, str] + + def get(self, route: Route): + return self.client.get( + str(route), + headers=self.superuser_token_headers, + ) + +class DocumentComparator: + @ft.singledispatchmethod + @staticmethod + def to_string(value): + return value + + @to_string.register + @staticmethod + def _(value: UUID): + return str(value) + + @to_string.register + @staticmethod + def _(value: datetime): + return value.isoformat() + + def __init__(self, document: Document): + self.document = document + + def __eq__(self, other: dict): + this = dict(self.to_dict()) + return this == other + + def to_dict(self): + document = dict(self.document) + for (k, v) in document.items(): + yield (k, self.to_string(v)) + +@pytest.fixture +def crawler(client: TestClient, superuser_token_headers: dict[str, str]): + return WebCrawler(client, superuser_token_headers) diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 1c77b83d..72f0143d 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -21,6 +21,8 @@ dependencies = [ "pydantic-settings<3.0.0,>=2.2.1", "sentry-sdk[fastapi]<2.0.0,>=1.40.6", "pyjwt<3.0.0,>=2.8.0", + "boto3>=1.37.20", + "moto[s3]>=5.1.1", ] [tool.uv] diff --git a/backend/scripts/prestart.sh b/backend/scripts/prestart.sh index 1b395d51..ce08c17e 100644 --- a/backend/scripts/prestart.sh +++ b/backend/scripts/prestart.sh @@ -9,5 +9,12 @@ python app/backend_pre_start.py # Run migrations alembic upgrade head -# Create initial data in DB -python app/initial_data.py +# Initialize services +services=( + app/initial_data.py + app/initial_storage.py +) + +for i in ${services[@]}; do + python $i +done