diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index d819c8f..0785560 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -6,41 +6,10 @@ jobs: test: runs-on: ubuntu-latest - services: - postgres: - image: postgres:alpine - ports: - - 5432:5432 - env: - POSTGRES_USER: fastapiusers - POSTGRES_PASSWORD: fastapiuserspassword - # Set health checks to wait until postgres has started - options: >- - --health-cmd pg_isready - --health-interval 10s - --health-timeout 5s - --health-retries 5 - - mariadb: - image: mariadb - ports: - - 3306:3306 - env: - MARIADB_ROOT_PASSWORD: fastapiuserspassword - MARIADB_DATABASE: fastapiusers - MARIADB_USER: fastapiusers - MARIADB_PASSWORD: fastapiuserspassword - strategy: fail-fast: false matrix: python_version: [3.9, '3.10', '3.11', '3.12', '3.13'] - database_url: - [ - "sqlite+aiosqlite:///./test-fastapiusers.db", - "postgresql+asyncpg://fastapiusers:fastapiuserspassword@localhost:5432/fastapiusers", - "mysql+aiomysql://root:fastapiuserspassword@localhost:3306/fastapiusers", - ] steps: - uses: actions/checkout@v4 @@ -57,19 +26,12 @@ jobs: run: | hatch run lint-check - name: Test - env: - DATABASE_URL: ${{ matrix.database_url }} run: | hatch run test-cov-xml - - uses: codecov/codecov-action@v5 - with: - token: ${{ secrets.CODECOV_TOKEN }} - fail_ci_if_error: true - verbose: true - name: Build and install it on system host run: | hatch build - pip install dist/fastapi_users_db_sqlalchemy-*.whl + pip install dist/fastapi_users_db_dynamodb-*.whl python test_build.py release: diff --git a/fastapi_users_db_dynamodb/__init__.py b/fastapi_users_db_dynamodb/__init__.py index e63b588..53bfaca 100644 --- a/fastapi_users_db_dynamodb/__init__.py +++ b/fastapi_users_db_dynamodb/__init__.py @@ -1,300 +1,200 @@ """FastAPI Users database adapter for AWS DynamoDB. This adapter mirrors the SQLAlchemy adapter's public API and return types as closely -as reasonably possible while using DynamoDB via aioboto3. +as reasonably possible while using DynamoDB via `aiopynamodb`. Usage notes: -- You can pass a long-lived aioboto3 resource (created once during app startup) - via the `dynamodb_resource` parameter to avoid creating a resource on every call: - async with aioboto3.Session().resource("dynamodb", region_name=...) as resource: - adapter = DynamoDBUserDatabase( - session, user_table, "users", oauth_account_table, "oauth_accounts", - dynamodb_resource=resource - ) - If you don't provide `dynamodb_resource`, this adapter will create a short-lived - resource per operation (safe, but less optimal). +- This adapter is expected to function correctly, but it is still advisable to exercise + caution in production environments (yet). +- The Database will create non existent tables by default. You can customize the configuration + inside `config.py` using the `get` and `set` methods. +- For now, tables will require ON-DEMAND mode, since traffic is unpredictable in all auth tables! """ -from __future__ import annotations - import uuid -from contextlib import asynccontextmanager -from typing import TYPE_CHECKING, Any, Generic, get_type_hints +from typing import TYPE_CHECKING, Any, Generic, Optional -import aioboto3 -from boto3.dynamodb.conditions import Attr -from botocore.exceptions import ClientError +from aiopynamodb.attributes import BooleanAttribute, NumberAttribute, UnicodeAttribute +from aiopynamodb.exceptions import DeleteError, PutError +from aiopynamodb.indexes import AllProjection, GlobalSecondaryIndex +from aiopynamodb.models import Model from fastapi_users.db.base import BaseUserDatabase from fastapi_users.models import ID, OAP, UP -from pydantic import BaseModel, ConfigDict, Field -from fastapi_users_db_dynamodb._aioboto3_patch import * # noqa: F403 -from fastapi_users_db_dynamodb.generics import UUID_ID +from . import config +from ._generics import UUID_ID +from .attributes import GUID, TransformingUnicodeAttribute +from .config import __version__ # noqa: F401 +from .tables import ensure_tables_exist -__version__ = "1.0.0" -DATABASE_USERTABLE_PRIMARY_KEY: str = "id" +class DynamoDBBaseUserTable(Model, Generic[ID]): + """Base user table schema for DynamoDB.""" + __tablename__: str = config.get("DATABASE_USERTABLE_NAME") -class DynamoDBBaseUserTable(BaseModel, Generic[ID]): - """Base user table schema for DynamoDB.""" + class Meta: + table_name: str = config.get("DATABASE_USERTABLE_NAME") + region: str = config.get("DATABASE_REGION") + billing_mode: str = config.get("DATABASE_BILLING_MODE").value - model_config = ConfigDict(arbitrary_types_allowed=True, from_attributes=True) + class EmailIndex(GlobalSecondaryIndex): + class Meta: + index_name: str = "email-index" + projection = AllProjection() - __tablename__ = "user" + email = TransformingUnicodeAttribute(transform=str.lower, hash_key=True) - if TYPE_CHECKING: + if TYPE_CHECKING: # pragma: no cover id: ID - email: str = Field(..., description="The email of the user") - hashed_password: str = Field(..., description="The hashed password of the user") - is_active: bool = Field( - default=True, description="Whether the user is marked as active in the database" - ) - is_superuser: bool = Field( - default=False, description="Whether the user has admin rights" - ) - is_verified: bool = Field( - default=False, description="Whether the user has verified their email" - ) + email: str + hashed_password: str + is_active: bool + is_superuser: bool + is_verified: bool + else: + email = TransformingUnicodeAttribute(transform=str.lower, null=False) + hashed_password = UnicodeAttribute(null=False) + is_active = BooleanAttribute(default=True, null=False) + is_superuser = BooleanAttribute(default=False, null=False) + is_verified = BooleanAttribute(default=False, null=False) + + # Global Secondary Index + email_index = EmailIndex() class DynamoDBBaseUserTableUUID(DynamoDBBaseUserTable[UUID_ID]): - id: UUID_ID = Field(default_factory=uuid.uuid4, description="The ID for the user") + if TYPE_CHECKING: # pragma: no cover + id: UUID_ID + else: + id: GUID = GUID(hash_key=True, default=uuid.uuid4) -class DynamoDBBaseOAuthAccountTable(Generic[ID]): +class DynamoDBBaseOAuthAccountTable(Model, Generic[ID]): """Base OAuth account table schema for DynamoDB.""" - __tablename__ = "oauth_account" + __tablename__: str = config.get("DATABASE_OAUTHTABLE_NAME") + + class Meta: + table_name: str = config.get("DATABASE_OAUTHTABLE_NAME") + region: str = config.get("DATABASE_REGION") + billing_mode: str = config.get("DATABASE_BILLING_MODE").value + + class AccountIdIndex(GlobalSecondaryIndex): + class Meta: + index_name: str = "account_id-index" + projection = AllProjection() + + account_id = UnicodeAttribute(hash_key=True) - if TYPE_CHECKING: + class OAuthNameIndex(GlobalSecondaryIndex): + class Meta: + index_name: str = "oauth_name-index" + projection = AllProjection() + + oauth_name = UnicodeAttribute(hash_key=True) + + class UserIdIndex(GlobalSecondaryIndex): + class Meta: + index_name = "user_id-index" + projection = AllProjection() + + user_id = GUID(hash_key=True) + + if TYPE_CHECKING: # pragma: no cover id: ID - oauth_name: str = Field(..., description="The name of the OAuth social provider") - access_token: str = Field( - ..., description="The access token linked with the OAuth account" - ) - expires_at: int | None = Field( - default=None, description="The timestamp at which this account expires" - ) - refresh_token: str | None = Field( - default=None, description="The refresh token associated with this OAuth account" - ) - account_id: str = Field(..., description="The ID of this OAuth account") - account_email: str = Field( - ..., description="The email associated with this OAuth account" - ) + oauth_name: str + access_token: str + expires_at: Optional[int] + refresh_token: Optional[str] + account_id: str + account_email: str + else: + oauth_name = UnicodeAttribute(null=False) + access_token = UnicodeAttribute(null=False) + expires_at = NumberAttribute(null=True) + refresh_token = UnicodeAttribute(null=True) + account_id = UnicodeAttribute(null=False) + account_email = TransformingUnicodeAttribute(transform=str.lower, null=False) + + # Global Secondary Index + account_id_index = AccountIdIndex() + oauth_name_index = OAuthNameIndex() + user_id_index = UserIdIndex() class DynamoDBBaseOAuthAccountTableUUID(DynamoDBBaseOAuthAccountTable[UUID_ID]): - id: UUID_ID = Field( - default_factory=uuid.uuid4, description="The ID for the OAuth account" - ) - user_id: UUID_ID = Field( - ..., description="The user ID this OAuth account belongs to" - ) + if TYPE_CHECKING: # pragma: no cover + id: UUID_ID + user_id: UUID_ID + else: + id: GUID = GUID(hash_key=True, default=uuid.uuid4) + user_id: GUID = GUID(null=False) class DynamoDBUserDatabase(Generic[UP, ID], BaseUserDatabase[UP, ID]): """ - Database adapter for AWS DynamoDB using aioboto3. - - :param session: aioboto3.Session instance (not an actual DynamoDB resource). - :param user_table: Python class used to construct returned user objects (callable). - :param user_table_name: DynamoDB table name for users. - :param oauth_account_table: Optional class to construct oauth-account objects. - :param oauth_table_name: Optional DynamoDB table name for oauth accounts. - :param dynamodb_resource: Optional aioboto3 resource object (async context manager result) - created with `async with session.resource("dynamodb") as r:`. If - provided, the adapter will reuse it (recommended). + Database adapter for AWS DynamoDB using aiopynamodb. """ - session: aioboto3.Session user_table: type[UP] oauth_account_table: type[DynamoDBBaseOAuthAccountTable] | None - user_table_name: str - primary_key: str = DATABASE_USERTABLE_PRIMARY_KEY - oauth_account_table_name: str | None - _resource: Any | None - _resource_region: str | None def __init__( self, - session: aioboto3.Session, user_table: type[UP], - user_table_name: str, - primary_key: str = DATABASE_USERTABLE_PRIMARY_KEY, oauth_account_table: type[DynamoDBBaseOAuthAccountTable] | None = None, - oauth_account_table_name: str | None = None, - dynamodb_resource: Any | None = None, - dynamodb_resource_region: str | None = None, ): - self.session = session self.user_table = user_table self.oauth_account_table = oauth_account_table - self.user_table_name = user_table_name - self.primary_key = primary_key - self.oauth_account_table_name = oauth_account_table_name - - self._resource = dynamodb_resource - self._resource_region = dynamodb_resource_region - - @asynccontextmanager - async def _table(self, table_name: str, region: str | None = None): - """Async context manager that yields a Table object. - - If a long-lived resource was provided at init, it's reused (no enter/exit). - Otherwise a short-lived resource is created and cleaned up per call. - """ - if self._resource is not None: - table = await self._resource.Table(table_name) - yield table - else: - if region is None: - raise ValueError( - "Parameter `region` must be specified when `dynamodb_resource` is omitted" - ) - async with self.session.resource( - "dynamodb", region_name=region - ) as dynamodb: - table = await dynamodb.Table(table_name) - yield table - - def _serialize_for_dynamodb(self, data: dict[str, Any]) -> dict[str, Any]: - """Convert UUIDs and other incompatible types for DynamoDB.""" - result = {} - for key, value in data.items(): - if isinstance(value, uuid.UUID): - result[key] = str(value) - elif isinstance(value, list): - result[key] = [str(v) if isinstance(v, uuid.UUID) else v for v in value] - elif isinstance(value, dict): - result[key] = self._serialize_for_dynamodb(value) - else: - result[key] = value - return result - - def _ensure_id_str(self, value: Any) -> str: - """Normalize id to string for DynamoDB keys.""" - return str(value) - - def _extract_id_from_user(self, user_obj: Any) -> str: - """Extract the `id` from a user object/dict/ORM/Pydantic model.""" - - if isinstance(user_obj, dict): - idv = user_obj.get("id") - - elif hasattr(user_obj, "model_dump") and callable( - getattr(user_obj, "model_dump") - ): - try: - idv = user_obj.model_dump().get("id") - except Exception: - idv = getattr(user_obj, "id", None) - - elif hasattr(user_obj, "id"): - idv = getattr(user_obj, "id", None) - - elif hasattr(user_obj, "__dict__"): - idv = vars(user_obj).get("id") - else: - raise ValueError("Cannot extract 'id' from provided user object") - if idv is None: - raise ValueError("User object has no 'id' field") - return self._ensure_id_str(idv) - - def _item_to_user(self, item: dict[str, Any] | None) -> UP | None: - """Convert a DynamoDB item (dict) to an instance of user_table (UP).""" - if item is None: - return None - try: - hints = get_type_hints(self.user_table) - if ( - "id" in hints - and hints["id"] is uuid.UUID - and isinstance(item.get("id"), str) - ): - item = {**item, "id": uuid.UUID(item["id"])} - except Exception: - pass - - return self.user_table(**item) - - def _ensure_email_lower(self, data: dict[str, Any]) -> None: - """Lower-case email in-place if present.""" - if "email" in data and isinstance(data["email"], str): - data["email"] = data["email"].lower() - - async def get( + async def _hydrate_oauth_accounts( self, - id: ID | str, + user: UP, instant_update: bool = False, - ) -> UP | None: - """Get a user by id and hydrate oauth_accounts if available.""" - id_str = self._ensure_id_str(id) - - async with self._table(self.user_table_name, self._resource_region) as table: - resp = await table.get_item( - Key={self.primary_key: id_str}, - ConsistentRead=instant_update, - ) - item = resp.get("Item") - user = self._item_to_user(item) + ) -> UP: + """ + Populate the `oauth_accounts` list of a user by querying the OAuth table. + This mimics SQLAlchemy's lazy relationship loading. + """ + if self.oauth_account_table is None: + return user + await ensure_tables_exist(self.oauth_account_table) - if user is None: - return None + user.oauth_accounts = [] # type: ignore - if self.oauth_account_table and self.oauth_account_table_name: - async with self._table( - self.oauth_account_table_name, self._resource_region - ) as oauth_table: - resp = await oauth_table.scan( - FilterExpression=Attr("user_id").eq(id_str), - ConsistentRead=instant_update, - ) - accounts = resp.get("Items", []) - user.oauth_accounts = [ # type: ignore - self.oauth_account_table(**acc) for acc in accounts - ] + async for oauth_acc in self.oauth_account_table.user_id_index.query( # type: ignore + user.id, + consistent_read=instant_update, + ): + user.oauth_accounts.append(oauth_acc) # type: ignore return user - async def get_by_email( - self, - email: str, - instant_update: bool = False, - ) -> UP | None: - """Get a user by email (case-insensitive: emails are stored lowercased).""" - email_norm = email.lower() - async with self._table(self.user_table_name, self._resource_region) as table: - resp = await table.scan( - FilterExpression=Attr("email").eq(email_norm), - Limit=1, - ConsistentRead=instant_update, - ) - items = resp.get("Items", []) - if not items: - return None - user = self._item_to_user(items[0]) + async def get(self, id: ID, instant_update: bool = False) -> UP | None: + """Get a user by id and hydrate oauth_accounts if available.""" + await ensure_tables_exist(self.user_table) # type: ignore - if user is None: + try: + user = await self.user_table.get(id, consistent_read=instant_update) # type: ignore + user = await self._hydrate_oauth_accounts(user, instant_update) + return user + except self.user_table.DoesNotExist: # type: ignore return None - user_id = self._ensure_id_str(user.id) - if self.oauth_account_table and self.oauth_account_table_name: - async with self._table( - self.oauth_account_table_name, self._resource_region - ) as oauth_table: - resp = await oauth_table.scan( - FilterExpression=Attr("user_id").eq(user_id), - ConsistentRead=instant_update, - ) - accounts = resp.get("Items", []) - user.oauth_accounts = [ # type: ignore - self.oauth_account_table(**acc) for acc in accounts - ] + async def get_by_email(self, email: str, instant_update: bool = False) -> UP | None: + """Get a user by email using the email GSI (case-insensitive).""" + await ensure_tables_exist(self.user_table) # type: ignore - return user + email_lower = email.lower() + async for user in self.user_table.email_index.query( # type: ignore + email_lower, + consistent_read=instant_update, + limit=1, + ): + user = await self._hydrate_oauth_accounts(user, instant_update) + return user + return None async def get_by_oauth_account( self, @@ -302,138 +202,105 @@ async def get_by_oauth_account( account_id: str, instant_update: bool = False, ) -> UP | None: - """Find a user by oauth provider and provider account id.""" - if self.oauth_account_table is None or self.oauth_account_table_name is None: + """Find a user by oauth provider and account_id.""" + if self.oauth_account_table is None: raise NotImplementedError() + await ensure_tables_exist(self.user_table, self.oauth_account_table) # type: ignore - async with self._table( - self.oauth_account_table_name, self._resource_region - ) as oauth_table: - resp = await oauth_table.scan( - FilterExpression=Attr("oauth_name").eq(oauth) - & Attr("account_id").eq(account_id), - Limit=1, - ConsistentRead=instant_update, - ) - items = resp.get("Items", []) - if not items: - return None - - user_id = items[0].get("user_id") - if not user_id: + async for oauth_acc in self.oauth_account_table.account_id_index.query( + account_id, + consistent_read=instant_update, + filter_condition=self.oauth_account_table.oauth_name == oauth, # type: ignore + limit=1, + ): + try: + user = await self.user_table.get( # type: ignore + oauth_acc.user_id, + consistent_read=instant_update, + ) + user = await self._hydrate_oauth_accounts(user, instant_update) + return user + except self.user_table.DoesNotExist: # type: ignore # pragma: no cover return None + return None - return await self.get(user_id) - - async def create(self, create_dict: dict[str, Any]) -> UP: + async def create(self, create_dict: dict[str, Any] | UP) -> UP: """Create a new user and return an instance of UP.""" - item = dict(create_dict) - if "id" not in item or item["id"] is None: - item["id"] = str(uuid.uuid4()) - else: - item["id"] = self._ensure_id_str(item["id"]) - - self._ensure_email_lower(item) + await ensure_tables_exist(self.user_table) # type: ignore - async with self._table(self.user_table_name, self._resource_region) as table: - try: - await table.put_item( - Item=self._serialize_for_dynamodb(item), - ConditionExpression="attribute_not_exists(#id)", - ExpressionAttributeNames={"#id": self.primary_key}, - ) - except ClientError as e: - if ( - e.response.get("Error", {}).get("Code") - == "ConditionalCheckFailedException" - ): - raise ValueError(f"User {item['id']} already exists.") - raise - - refreshed_user = self._item_to_user(item) - if refreshed_user is None: - raise ValueError("Could not cast DB item to User model") - return refreshed_user - - async def update( - self, - user: UP, - update_dict: dict[str, Any], - instant_update: bool = False, - ) -> UP: - """Update a user with update_dict and return the updated UP instance.""" - user_id = self._extract_id_from_user(user) - async with self._table(self.user_table_name, self._resource_region) as table: - resp = await table.get_item( - Key={self.primary_key: user_id}, - ConsistentRead=instant_update, + if isinstance(create_dict, dict): + user = self.user_table(**create_dict) + else: + user = create_dict + try: + await user.save( # type: ignore + condition=self.user_table.id.does_not_exist() + & self.user_table.email.does_not_exist() # type: ignore ) - current = resp.get("Item", None) - - if not current: - raise ValueError("User not found") - - merged = {**current, **update_dict} + except PutError as e: + if e.cause_response_code == "ConditionalCheckFailedException": + raise ValueError( + "User account could not be created because it already exists." + ) from e + raise ValueError( # pragma: no cover + "User account could not be created because the table does not exist." + ) from e + return user - self._ensure_email_lower(merged) + async def update(self, user: UP, update_dict: dict[str, Any]) -> UP: + """Update a user with update_dict and return the updated UP instance.""" + await ensure_tables_exist(self.user_table) # type: ignore - try: - await table.put_item( - Item=self._serialize_for_dynamodb(merged), - ConditionExpression="attribute_exists(#id)", - ExpressionAttributeNames={"#id": self.primary_key}, - ) - except ClientError as e: - if ( - e.response.get("Error", {}).get("Code") - == "ConditionalCheckFailedException" - ): - raise ValueError(f"User {user_id} does not exist.") - raise - - refreshed_user = self._item_to_user(merged) - if refreshed_user is None: - raise ValueError("Could not cast DB item to User model") - return refreshed_user + try: + for k, v in update_dict.items(): + setattr(user, k, v) + await user.save(condition=self.user_table.id.exists()) # type: ignore + return user + except PutError as e: + if e.cause_response_code == "ConditionalCheckFailedException": + raise ValueError( + "User account could not be updated because it does not exist." + ) from e + raise ValueError( # pragma: no cover + "User account could not be updated because the table does not exist." + ) from e async def delete(self, user: UP) -> None: """Delete a user.""" - user_id = self._extract_id_from_user(user) - async with self._table(self.user_table_name, self._resource_region) as table: - try: - await table.delete_item( - Key={self.primary_key: user_id}, - ConditionExpression="attribute_exists(#id)", - ExpressionAttributeNames={"#id": self.primary_key}, - ) - except ClientError as e: - if ( - e.response.get("Error", {}).get("Code") - == "ConditionalCheckFailedException" - ): - raise ValueError(f"User {user_id} does not exist.") - raise + await ensure_tables_exist(self.user_table) # type: ignore + + try: + await user.delete(condition=self.user_table.id.exists()) # type: ignore + except DeleteError as e: + raise ValueError("User account could not be deleted.") from e + except PutError as e: # pragma: no cover + if e.cause_response_code == "ConditionalCheckFailedException": + raise ValueError( + "User account could not be deleted because it does not exist." + ) from e async def add_oauth_account(self, user: UP, create_dict: dict[str, Any]) -> UP: - """Add an OAuth account for `user`. Returns the refreshed user (UP).""" - if self.oauth_account_table is None or self.oauth_account_table_name is None: + """Add an OAuth account and return the refreshed user (UP).""" + if self.oauth_account_table is None: raise NotImplementedError() + await ensure_tables_exist(self.user_table, self.oauth_account_table) # type: ignore - oauth_item = dict(create_dict) - if "id" not in oauth_item or oauth_item["id"] is None: - oauth_item["id"] = str(uuid.uuid4()) - - user_id = self._extract_id_from_user(user) - oauth_item["user_id"] = user_id - - async with self._table( - self.oauth_account_table_name, self._resource_region - ) as oauth_table: - await oauth_table.put_item(Item=self._serialize_for_dynamodb(oauth_item)) - - if hasattr(user, "oauth_accounts"): - oauth_obj = self.oauth_account_table(**oauth_item) - user.oauth_accounts.append(oauth_obj) # type: ignore + try: + create_dict["user_id"] = getattr(create_dict, "user_id", user.id) + oauth_account = self.oauth_account_table(**create_dict) + await oauth_account.save( + condition=self.oauth_account_table.id.does_not_exist() # type: ignore + & self.oauth_account_table.account_id.does_not_exist() # type: ignore + ) + user.oauth_accounts.append(oauth_account) # type: ignore + except PutError as e: # pragma: no cover + if e.cause_response_code == "ConditionalCheckFailedException": + raise ValueError( + "OAuth account could not be added because it already exists." + ) from e + raise ValueError( + "OAuth account could not be added because the table does not exist." + ) from e return user @@ -444,43 +311,30 @@ async def update_oauth_account( update_dict: dict[str, Any], ) -> UP: """Update an OAuth account and return the refreshed user (UP).""" - if self.oauth_account_table is None or self.oauth_account_table_name is None: + if self.oauth_account_table is None: raise NotImplementedError() + await ensure_tables_exist(self.user_table, self.oauth_account_table) # type: ignore - oauth_item = ( - oauth_account.model_dump() # type: ignore - if hasattr(oauth_account, "model_dump") - else vars(oauth_account) - ) - - updated_item = {**oauth_item, **update_dict} - - for field in ("id", "user_id", "oauth_name", "account_id"): - updated_item[field] = getattr(oauth_account, field, oauth_item.get(field)) - - async with self._table( - self.oauth_account_table_name, self._resource_region - ) as oauth_table: - try: - await oauth_table.put_item( - Item=self._serialize_for_dynamodb(updated_item), - ConditionExpression="attribute_exists(#id)", - ExpressionAttributeNames={"#id": self.primary_key}, - ) - except ClientError as e: - if ( - e.response.get("Error", {}).get("Code") - == "ConditionalCheckFailedException" - ): - raise ValueError( - f"OAuth account with ID {updated_item['id']} does not exist." - ) - raise - - if hasattr(user, "oauth_accounts"): - for idx, account in enumerate(user.oauth_accounts): # type: ignore - if str(getattr(account, "id", None)) == str(updated_item["id"]): - user.oauth_accounts[idx] = type(oauth_account)(**updated_item) # type: ignore - break + try: + for k, v in update_dict.items(): + setattr(oauth_account, k, v) + await oauth_account.save( # type: ignore + condition=self.oauth_account_table.id.exists() # type: ignore + & self.oauth_account_table.account_id.exists() # type: ignore + ) + except PutError as e: + if e.cause_response_code == "ConditionalCheckFailedException": + raise ValueError( + "OAuth account could not be updated because it does not exist." + ) from e + raise ValueError( # pragma: no cover + "OAuth account could not be updated because the table does not exist." + ) from e + + for acc in user.oauth_accounts: # type: ignore + if acc.id == oauth_account.id: + for k, v in update_dict.items(): + setattr(acc, k, v) + break return user diff --git a/fastapi_users_db_dynamodb/_aioboto3_patch.py b/fastapi_users_db_dynamodb/_aioboto3_patch.py deleted file mode 100644 index 5592c37..0000000 --- a/fastapi_users_db_dynamodb/_aioboto3_patch.py +++ /dev/null @@ -1,87 +0,0 @@ -import inspect -from binascii import crc32 - -import aiobotocore.endpoint -import aiobotocore.retryhandler -from aiobotocore.endpoint import HttpxStreamingBody, StreamingBody # type: ignore -from aiobotocore.retryhandler import ChecksumError, logger # type: ignore - -try: - import httpx -except ImportError: - httpx = None - - -async def _fixed_check_response(self, attempt_number, response): - http_response = response[0] - expected_crc = http_response.headers.get(self._header_name) - if expected_crc is None: - logger.debug( - "crc32 check skipped, the %s header is not in the http response.", - self._header_name, - ) - else: - if inspect.isawaitable(http_response.content): - data_buf = await http_response.content - else: - data_buf = http_response.content - - actual_crc32 = crc32(data_buf) & 0xFFFFFFFF - if not actual_crc32 == int(expected_crc): - logger.debug( - "retry needed: crc32 check failed, expected != actual: %s != %s", - int(expected_crc), - actual_crc32, - ) - raise ChecksumError( - checksum_type="crc32", - expected_checksum=int(expected_crc), - actual_checksum=actual_crc32, - ) - - -async def convert_to_response_dict(http_response, operation_model): - """Convert an HTTP response object to a request dict. - - This converts the HTTP response object to a dictionary. - - :type http_response: botocore.awsrequest.AWSResponse - :param http_response: The HTTP response from an AWS service request. - - :rtype: dict - :return: A response dictionary which will contain the following keys: - * headers (dict) - * status_code (int) - * body (string or file-like object) - - """ - response_dict = { - "headers": http_response.headers, - "status_code": http_response.status_code, - "context": { - "operation_name": operation_model.name, - }, - } - if response_dict["status_code"] >= 300: - if inspect.isawaitable(http_response.content): - response_dict["body"] = await http_response.content - else: - response_dict["body"] = http_response.content - elif operation_model.has_event_stream_output: - response_dict["body"] = http_response.raw - elif operation_model.has_streaming_output: - if httpx and isinstance(http_response.raw, httpx.Response): - response_dict["body"] = HttpxStreamingBody(http_response.raw) - else: - length = response_dict["headers"].get("content-length") - response_dict["body"] = StreamingBody(http_response.raw, length) - else: - if inspect.isawaitable(http_response.content): - response_dict["body"] = await http_response.content - else: - response_dict["body"] = http_response.content - return response_dict - - -aiobotocore.retryhandler.AioCRC32Checker._check_response = _fixed_check_response # type: ignore -aiobotocore.endpoint.convert_to_response_dict = convert_to_response_dict diff --git a/fastapi_users_db_dynamodb/_generics.py b/fastapi_users_db_dynamodb/_generics.py new file mode 100644 index 0000000..09f63f5 --- /dev/null +++ b/fastapi_users_db_dynamodb/_generics.py @@ -0,0 +1,14 @@ +"""FastAPI Users DynamoDB generics.""" + +import uuid +from datetime import datetime, timezone + +UUID_ID = uuid.UUID + + +def now_utc() -> datetime: + """ + Returns the current time in UTC with timezone awareness. + Equivalent to the old implementation. + """ + return datetime.now(timezone.utc) diff --git a/fastapi_users_db_dynamodb/access_token.py b/fastapi_users_db_dynamodb/access_token.py index fd0d853..df408b7 100644 --- a/fastapi_users_db_dynamodb/access_token.py +++ b/fastapi_users_db_dynamodb/access_token.py @@ -1,122 +1,64 @@ -"""FastAPI Users access token database adapter for AWS DynamoDB. +"""FastAPI Users access token database adapter for AWS DynamoDB.""" -This adapter mirrors the SQLAlchemy adapter's public API and return types as closely -as reasonably possible while using DynamoDB via aioboto3. -""" +from datetime import datetime +from typing import TYPE_CHECKING, Any, Generic -from __future__ import annotations - -import uuid -from contextlib import asynccontextmanager -from datetime import datetime, timezone -from typing import TYPE_CHECKING, Any, Generic, get_type_hints - -import aioboto3 -from botocore.exceptions import ClientError +from aiopynamodb.attributes import UnicodeAttribute, UTCDateTimeAttribute +from aiopynamodb.exceptions import DeleteError, PutError +from aiopynamodb.indexes import AllProjection, GlobalSecondaryIndex +from aiopynamodb.models import Model from fastapi_users.authentication.strategy.db import AP, AccessTokenDatabase from fastapi_users.models import ID -from pydantic import BaseModel, ConfigDict, Field -from fastapi_users_db_dynamodb._aioboto3_patch import * # noqa: F403 -from fastapi_users_db_dynamodb.generics import UUID_ID +from . import config +from ._generics import UUID_ID, now_utc +from .attributes import GUID +from .tables import ensure_tables_exist -DATABASE_TOKENTABLE_PRIMARY_KEY: str = "token" - -class DynamoDBBaseAccessTokenTable(BaseModel, Generic[ID]): +class DynamoDBBaseAccessTokenTable(Model, Generic[ID]): """Base access token table schema for DynamoDB.""" - model_config = ConfigDict(arbitrary_types_allowed=True, from_attributes=True) + __tablename__: str = config.get("DATABASE_TOKENTABLE_NAME") + + class Meta: + table_name: str = config.get("DATABASE_TOKENTABLE_NAME") + region: str = config.get("DATABASE_REGION") + billing_mode: str = config.get("DATABASE_BILLING_MODE").value - __tablename__ = "accesstoken" + class CreatedAtIndex(GlobalSecondaryIndex): + class Meta: + index_name: str = "created_at-index" + projection = AllProjection() - token: str = Field(..., description="The token value of the AccessToken object") - created_at: datetime = Field( - default_factory=lambda: datetime.now(timezone.utc), - description="The date of creation of the AccessToken object", - ) - if TYPE_CHECKING: + created_at = UnicodeAttribute(hash_key=True) + + if TYPE_CHECKING: # pragma: no cover user_id: ID + token: str + created_at: datetime + else: + token = UnicodeAttribute(hash_key=True) + created_at = UTCDateTimeAttribute(default=now_utc, null=False) + + # Global Secondary Index + created_at_index = CreatedAtIndex() class DynamoDBBaseAccessTokenTableUUID(DynamoDBBaseAccessTokenTable[UUID_ID]): - user_id: UUID_ID = Field(..., description="The user ID this token belongs to") + if TYPE_CHECKING: # pragma: no cover + user_id: UUID_ID + else: + user_id: GUID = GUID(null=False) class DynamoDBAccessTokenDatabase(Generic[AP], AccessTokenDatabase[AP]): - """ - Access token database adapter for AWS DynamoDB using aioboto3. + """Access token database adapter for AWS DynamoDB using aiopynamodb.""" - :param session: aioboto3.Session instance (not an actual DynamoDB resource). - :param access_token_table: Python class used to construct returned objects (callable). - :param table_name: DynamoDB table name for access tokens. - :param dynamodb_resource: Optional aioboto3 resource object (async context manager result). - """ - - session: aioboto3.Session access_token_table: type[AP] - table_name: str - primary_key: str = DATABASE_TOKENTABLE_PRIMARY_KEY - _resource: Any | None - _resource_region: str | None - def __init__( - self, - session: aioboto3.Session, - access_token_table: type[AP], - table_name: str, - primary_key: str = DATABASE_TOKENTABLE_PRIMARY_KEY, - dynamodb_resource: Any | None = None, - dynamodb_resource_region: Any | None = None, - ): - self.session = session + def __init__(self, access_token_table: type[AP]): self.access_token_table = access_token_table - self.table_name = table_name - self.primary_key = primary_key - self._resource = dynamodb_resource - self._resource_region = dynamodb_resource_region - - @asynccontextmanager - async def _table(self, table_name: str, region: str | None = None): - """Async context manager that yields a Table object.""" - if self._resource is not None: - table = await self._resource.Table(table_name) - yield table - else: - if region is None: - raise ValueError( - "Parameter `region` must be specified when `dynamodb_resource` is omitted" - ) - async with self.session.resource( - "dynamodb", region_name=region - ) as dynamodb: - table = await dynamodb.Table(table_name) - yield table - - def _item_to_access_token(self, item: dict[str, Any] | None) -> AP | None: - """Convert a DynamoDB item (dict) to an instance of access_token_table (AP).""" - if item is None: - return None - - try: - hints = get_type_hints(self.access_token_table) - if ( - "user_id" in hints - and hints["user_id"] is UUID_ID - and isinstance(item.get("user_id"), str) - ): - item = {**item, "user_id": UUID_ID(item["user_id"])} - - if "created_at" in item and isinstance(item["created_at"], str): - item["created_at"] = datetime.fromisoformat(item["created_at"]) - except Exception: - pass - - return self.access_token_table(**item) - - def _ensure_token(self, token: Any) -> str: - """Normalize token to string for DynamoDB keys.""" - return str(token) async def get_by_token( self, @@ -125,114 +67,69 @@ async def get_by_token( instant_update: bool = False, ) -> AP | None: """Retrieve an access token by token string.""" - async with self._table(self.table_name, self._resource_region) as table: - resp = await table.get_item( - Key={self.primary_key: self._ensure_token(token)}, - ConsistentRead=instant_update, - ) - item = resp.get("Item") + await ensure_tables_exist(self.access_token_table) # type: ignore - if item is None: - return None + try: + token_obj = await self.access_token_table.get( # type: ignore + token, + consistent_read=instant_update, + ) if max_age is not None: - created_at = datetime.fromisoformat(item["created_at"]) - if created_at < max_age: + if token_obj.created_at < max_age: return None + return token_obj + except self.access_token_table.DoesNotExist: # type: ignore + return None - return self._item_to_access_token(item) - - async def create(self, create_dict: dict[str, Any]) -> AP: + async def create(self, create_dict: dict[str, Any] | AP) -> AP: """Create a new access token and return an instance of AP.""" - item = dict(create_dict) - - if "token" not in item or item["token"] is None: - item["token"] = uuid.uuid4().hex[:43] - if "created_at" not in item or not isinstance(item["created_at"], str): - item["created_at"] = datetime.now(timezone.utc).isoformat() - if isinstance(item.get("user_id"), uuid.UUID): - item["user_id"] = str(item["user_id"]) - - async with self._table(self.table_name, self._resource_region) as table: - try: - await table.put_item( - Item=item, - ConditionExpression="attribute_not_exists(#token)", - ExpressionAttributeNames={"#token": self.primary_key}, - ) - except ClientError as e: - if ( - e.response.get("Error", {}).get("Code") - == "ConditionalCheckFailedException" - ): - raise ValueError(f"Token {item['token']} already exists.") - raise - - access_token = self._item_to_access_token(item) - if access_token is None: - raise ValueError("Could not cast DB item to AccessToken model") - - return access_token + await ensure_tables_exist(self.access_token_table) # type: ignore + + if isinstance(create_dict, dict): + token = self.access_token_table(**create_dict) + else: + token = create_dict + try: + await token.save(condition=self.access_token_table.token.does_not_exist()) # type: ignore + except PutError as e: + if e.cause_response_code == "ConditionalCheckFailedException": + raise ValueError( + "Access token could not be created because it already exists." + ) from e + raise ValueError( # pragma: no cover + "Access token could not be created because the table does not exist." + ) from e + return token async def update(self, access_token: AP, update_dict: dict[str, Any]) -> AP: """Update an existing access token.""" + await ensure_tables_exist(self.access_token_table) # type: ignore - token_dict: dict = ( - access_token.model_dump() # type: ignore - if hasattr(access_token, "model_dump") and callable(access_token.model_dump) # type: ignore - else vars(access_token) - if hasattr(access_token, "__dict__") - else dict(access_token) - if isinstance(access_token, dict) - else vars(access_token) - ) - - token_dict.update(update_dict) - - if isinstance(token_dict.get("user_id"), uuid.UUID): - token_dict["user_id"] = str(token_dict["user_id"]) - if isinstance(token_dict.get("created_at"), datetime): - token_dict["created_at"] = token_dict["created_at"].isoformat() - - async with self._table(self.table_name, self._resource_region) as table: - try: - await table.put_item( - Item=token_dict, - ConditionExpression="attribute_exists(#token)", - ExpressionAttributeNames={"#token": self.primary_key}, - ) - except ClientError as e: - if ( - e.response.get("Error", {}).get("Code") - == "ConditionalCheckFailedException" - ): - raise ValueError(f"Token {token_dict['token']} does not exist.") - raise - - updated = self._item_to_access_token(token_dict) - if updated is None: - raise ValueError("Could not cast DB item to AccessToken model") - return updated + try: + for k, v in update_dict.items(): + setattr(access_token, k, v) + await access_token.save(condition=self.access_token_table.token.exists()) # type: ignore + return access_token + except PutError as e: + if e.cause_response_code == "ConditionalCheckFailedException": + raise ValueError( + "Access token could not be updated because it does not exist." + ) from e + raise ValueError( # pragma: no cover + "Access token could not be updated because the table does not exist." + ) from e async def delete(self, access_token: AP) -> None: """Delete an access token.""" - token = getattr(access_token, "token", None) or ( - access_token.get("token") if isinstance(access_token, dict) else None - ) - if token is None: - raise ValueError("access_token has no 'token' field") - - async with self._table(self.table_name, self._resource_region) as table: - try: - await table.delete_item( - Key={self.primary_key: self._ensure_token(token)}, - ConditionExpression="attribute_exists(#token)", - ExpressionAttributeNames={"#token": self.primary_key}, - ) - except ClientError as e: - if ( - e.response.get("Error", {}).get("Code") - == "ConditionalCheckFailedException" - ): - raise ValueError(f"Token {token} does not exist.") - raise + await ensure_tables_exist(self.access_token_table) # type: ignore + + try: + await access_token.delete(condition=self.access_token_table.token.exists()) # type: ignore + except DeleteError as e: + raise ValueError("Access token could not be deleted.") from e + except PutError as e: # pragma: no cover + if e.cause_response_code == "ConditionalCheckFailedException": + raise ValueError( + "Access token could not be deleted because it does not exist." + ) from e diff --git a/fastapi_users_db_dynamodb/attributes.py b/fastapi_users_db_dynamodb/attributes.py new file mode 100644 index 0000000..f42484c --- /dev/null +++ b/fastapi_users_db_dynamodb/attributes.py @@ -0,0 +1,57 @@ +from typing import Callable + +from aiopynamodb.attributes import Attribute, UnicodeAttribute +from aiopynamodb.constants import STRING + +from ._generics import UUID_ID + + +class GUID(Attribute[UUID_ID]): + """ + Custom PynamoDB attribute to store UUIDs as strings. + Ensures value is always a UUID object in Python. + """ + + attr_type = STRING + python_type = UUID_ID + + def serialize(self, value): + if value is None: + return None + if isinstance(value, UUID_ID): + return str(value) + return str(UUID_ID(value)) + + def deserialize(self, value): + if value is None: + return None + if not isinstance(value, UUID_ID): + return UUID_ID(value) + return value + + +class TransformingUnicodeAttribute(UnicodeAttribute): + """ + A UnicodeAttribute that automatically transforms its value. + + Example: lowercasing, uppercasing, capitalizing. + """ + + def __init__(self, transform: Callable[[str], str] | None = None, **kwargs): + """ + :param transform: A callable to transform the string (e.g., str.lower, str.upper) + :param kwargs: Other UnicodeAttribute kwargs + """ + super().__init__(**kwargs) + self.transform = transform + + def serialize(self, value): + if value is not None and self.transform: + value = self.transform(value) + return super().serialize(value) + + def deserialize(self, value): + value = super().deserialize(value) + if value is not None and self.transform: + value = self.transform(value) + return value diff --git a/fastapi_users_db_dynamodb/config.py b/fastapi_users_db_dynamodb/config.py new file mode 100644 index 0000000..6076b15 --- /dev/null +++ b/fastapi_users_db_dynamodb/config.py @@ -0,0 +1,50 @@ +from enum import StrEnum +from typing import Any, Literal, TypedDict + +__version__ = "1.0.0" + + +# Right now, only ON-DEMAND mode is supported! +class BillingMode(StrEnum): + PAY_PER_REQUEST = "PAY_PER_REQUEST" + # PROVISIONED = "PROVISIONED" + + def __str__(self) -> str: + return self.value + + +class __ConfigMap(TypedDict): + DATABASE_REGION: str + # DATABASE_BILLING_MODE: BillingMode + DATABASE_BILLING_MODE: Literal[BillingMode.PAY_PER_REQUEST] + DATABASE_USERTABLE_NAME: str + DATABASE_OAUTHTABLE_NAME: str + DATABASE_TOKENTABLE_NAME: str + + +def __create_config(): + __config_map: __ConfigMap = { + "DATABASE_REGION": "eu-central-1", + "DATABASE_BILLING_MODE": BillingMode.PAY_PER_REQUEST, + "DATABASE_USERTABLE_NAME": "user", + "DATABASE_OAUTHTABLE_NAME": "oauth_account", + "DATABASE_TOKENTABLE_NAME": "accesstoken", + } + + def get(key: str, default: Any = None) -> Any: + return __config_map.get(key, default) + + def set(key: str, value: Any) -> None: + if key not in __config_map: + raise KeyError(f"Unknown config key: {key}") + expected_type = type(__config_map[key]) + if not isinstance(value, expected_type): + raise TypeError( + f"Invalid type for '{key}'. Expected {expected_type.__name__}, got {type(value).__name__}." + ) + __config_map[key] = value + + return get, set + + +get, set = __create_config() diff --git a/fastapi_users_db_dynamodb/generics.py b/fastapi_users_db_dynamodb/generics.py deleted file mode 100644 index 5129349..0000000 --- a/fastapi_users_db_dynamodb/generics.py +++ /dev/null @@ -1,96 +0,0 @@ -"""FastAPI Users DynamoDB generics for UUID and timestamp handling. - -This module replaces SQLAlchemy-specific TypeDecorators with DynamoDB-friendly -helpers while keeping the same public API for compatibility. -""" - -from __future__ import annotations - -import uuid -from datetime import datetime, timezone - -from pydantic import UUID4 - -UUID_ID = uuid.UUID - - -class GUID(UUID4): - """ - Platform-independent GUID type. - - Kept for API compatibility with the old SQLAlchemy-based code. - In DynamoDB, this behaves as a lightweight UUID validator/converter. - """ - - python_type = UUID4 - - def __init__(self, *args, **kwargs): - """DynamoDB does not need type decorators, but we mimic SQLAlchemy API.""" - pass - - @staticmethod - def to_storage(value: UUID_ID | str | None) -> str | None: - """Convert UUID or string to a DynamoDB-storable string.""" - if value is None: - return None - return str(value) if isinstance(value, UUID_ID) else str(UUID_ID(value)) - - @staticmethod - def from_storage(value: str | UUID_ID | None) -> UUID_ID | None: - """Convert a stored string back into a UUID object.""" - if value is None: - return None - return value if isinstance(value, UUID_ID) else UUID_ID(value) - - def __eq__(self, other): - """Override equality to ensure correct comparison with UUID_ID.""" - if isinstance(other, UUID_ID): - return self.int == other.int # Direct comparison of UUIDs as integers - return False # Handle comparison with non-UUID types - - def __repr__(self): - """Override the string representation for better debugging.""" - return f"" - - -def now_utc() -> datetime: - """ - Returns the current time in UTC with timezone awareness. - Equivalent to the old implementation. - """ - return datetime.now(timezone.utc) - - -class TIMESTAMPAware(datetime): - """ - Kept for API compatibility. - - In SQLAlchemy, this handled database-specific timestamp behavior. - In DynamoDB, timestamps are stored as ISO 8601 strings and always - returned as timezone-aware datetimes. - """ - - python_type = datetime - - def __init__(self, *args, **kwargs): - """DynamoDB does not require dialect-level timestamp handling.""" - pass - - @staticmethod - def to_storage(value: datetime | None) -> str | None: - """Convert datetime to an ISO 8601 string for DynamoDB storage.""" - if value is None: - return None - if value.tzinfo is None: - value = value.replace(tzinfo=timezone.utc) - return value.isoformat() - - @staticmethod - def from_storage(value: str | datetime | None) -> datetime | None: - """Convert stored ISO 8601 string to timezone-aware datetime.""" - if value is None: - return None - if isinstance(value, datetime): - return value if value.tzinfo else value.replace(tzinfo=timezone.utc) - dt = datetime.fromisoformat(value) - return dt if dt.tzinfo else dt.replace(tzinfo=timezone.utc) diff --git a/fastapi_users_db_dynamodb/tables.py b/fastapi_users_db_dynamodb/tables.py new file mode 100644 index 0000000..15b22c3 --- /dev/null +++ b/fastapi_users_db_dynamodb/tables.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +from types import CoroutineType +from typing import Any, Protocol + +from aiopynamodb.models import Model + +from . import config + +__tables_cache: set[type[CreatableTable]] = set() + + +class CreatableTable(Protocol): + """ + A protocol class representing PynamoDB tables (`Model`) which can be created. + """ + + @classmethod + async def exists(cls) -> CoroutineType[Any, Any, bool] | bool: ... + + @classmethod + async def delete_table(cls) -> CoroutineType[Any, Any, Any] | Any: ... + + @classmethod + async def create_table( + cls, + *, + wait: bool = ..., + billing_mode: str | None = ..., + ) -> CoroutineType[Any, Any, Any] | Any: ... + + +def __check_creatable_table(cls: type[Any]): + """Check if an object is of type `CreatableTable`""" + if not issubclass(cls, Model): + raise TypeError(f"{cls.__name__} must be a subclass of Model") + if ( + not hasattr(cls, "exists") + or not hasattr(cls, "delete_table") + or not hasattr(cls, "create_table") + ): + raise TypeError( + f"{cls.__name__} must implement exists(), delete_table() and create_table()" + ) + + +async def ensure_tables_exist(*tables: type[CreatableTable]) -> None: + """ + Ensure that all given DynamoDB tables exist. + Will be called automatically from the DB instance. + """ + global __tables_cache + + for table_cls in tables: + __check_creatable_table(table_cls) + if table_cls not in __tables_cache: + if not await table_cls.exists(): + await table_cls.create_table( + billing_mode=config.get("DATABASE_BILLING_MODE").value, + wait=True, + ) + __tables_cache.add(table_cls) + + +async def delete_tables(*tables: type[CreatableTable]) -> None: + """ + Delete all given DynamoDB tables, if existent. + """ + global __tables_cache + + for table_cls in tables: + __check_creatable_table(table_cls) + if await table_cls.exists(): + await table_cls.delete_table() + if table_cls in __tables_cache: + __tables_cache.remove(table_cls) diff --git a/pyproject.toml b/pyproject.toml index f8cf0dd..ad1af8b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,12 +24,12 @@ allow-direct-references = true [tool.hatch.version] source = "regex_commit" commit_extra_args = ["-e"] -path = "fastapi_users_db_dynamodb/__init__.py" +path = "fastapi_users_db_dynamodb/config.py" [tool.hatch.envs.default] installer = "uv" dependencies = [ - "aioboto3", + "aiopynamodb@git+https://github.com/AppSolves/AioPynamoDB@master#egg=aiopynamodb>=1.0.1", "pytest", "pytest-asyncio", "black", @@ -41,7 +41,6 @@ dependencies = [ "asgi_lifespan", "ruff", "moto[all]", - "types-aioboto3", "types-aiobotocore", ] @@ -73,9 +72,10 @@ authors = [ ] description = "FastAPI Users database adapter for AWS DynamoDB" readme = "README.md" +license = "Apache-2.0" +license-files = [ "LICENSE" ] dynamic = ["version"] classifiers = [ - "License :: OSI Approved :: Apache 2.0 License", "Development Status :: 5 - Production/Stable", "Framework :: FastAPI", "Framework :: AsyncIO", @@ -91,7 +91,7 @@ classifiers = [ requires-python = ">=3.9" dependencies = [ "fastapi-users >= 14.0.0", - "aioboto3 >= 15.0.0", + "aiopynamodb@git+https://github.com/AppSolves/AioPynamoDB@master#egg=aiopynamodb>=1.0.1", ] [project.urls] diff --git a/tests/conftest.py b/tests/conftest.py index a3dd11a..1e2f175 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,10 +3,8 @@ import pytest from fastapi_users import schemas - -DATABASE_REGION: str = "eu-central-1" -DATABASE_TOKENTABLE_PRIMARY_KEY: str = "token" -DATABASE_USERTABLE_PRIMARY_KEY: str = "id" +from moto import mock_aws +from pydantic import UUID4 class User(schemas.BaseUser): @@ -25,6 +23,18 @@ class UserOAuth(User, schemas.BaseOAuthAccountMixin): pass +@pytest.fixture(scope="session", autouse=True) +def global_moto_mock(): + """ + Start Moto DynamoDB mock before any test runs, + and stop it after all tests are done. + """ + m = mock_aws() + m.start() + yield + m.stop() + + @pytest.fixture def oauth_account1() -> dict[str, Any]: return { @@ -48,5 +58,5 @@ def oauth_account2() -> dict[str, Any]: @pytest.fixture -def user_id() -> uuid.UUID: +def user_id() -> UUID4: return uuid.uuid4() diff --git a/tests/tables.py b/tests/tables.py deleted file mode 100644 index 1c984cd..0000000 --- a/tests/tables.py +++ /dev/null @@ -1,27 +0,0 @@ -import aioboto3 -import botocore.exceptions - - -async def ensure_table_exists( - session: aioboto3.Session, - table_name: str, - primary_key: str, - region: str, -): - async with session.client("dynamodb", region_name=region) as client: - try: - await client.describe_table(TableName=table_name) - except botocore.exceptions.ClientError: - await client.create_table( - TableName=table_name, - KeySchema=[ - {"AttributeName": primary_key, "KeyType": "HASH"}, - ], - AttributeDefinitions=[ - {"AttributeName": primary_key, "AttributeType": "S"}, - ], - ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, - ) - - waiter = client.get_waiter("table_exists") - await waiter.wait(TableName=table_name) diff --git a/tests/test_access_token.py b/tests/test_access_token.py index 80c527f..2f55fa8 100644 --- a/tests/test_access_token.py +++ b/tests/test_access_token.py @@ -1,79 +1,63 @@ from collections.abc import AsyncGenerator -from datetime import datetime, timedelta, timezone +from datetime import timedelta -import aioboto3 import pytest import pytest_asyncio -from moto import mock_aws -from pydantic import UUID4, BaseModel +from aiopynamodb.models import Model +from pydantic import UUID4 -from fastapi_users_db_dynamodb import DynamoDBBaseUserTableUUID, DynamoDBUserDatabase -from fastapi_users_db_dynamodb._aioboto3_patch import * # noqa: F403 +from fastapi_users_db_dynamodb import ( + DynamoDBBaseUserTableUUID, + DynamoDBUserDatabase, + config, +) +from fastapi_users_db_dynamodb._generics import now_utc from fastapi_users_db_dynamodb.access_token import ( DynamoDBAccessTokenDatabase, DynamoDBBaseAccessTokenTableUUID, ) -from tests.conftest import ( - DATABASE_REGION, - DATABASE_TOKENTABLE_PRIMARY_KEY, - DATABASE_USERTABLE_PRIMARY_KEY, -) -from tests.tables import ensure_table_exists -class Base(BaseModel): +class Base(Model): pass class AccessToken(DynamoDBBaseAccessTokenTableUUID, Base): - pass + __tablename__: str = config.get("DATABASE_TOKENTABLE_NAME") + "_test" + + class Meta: + table_name: str = config.get("DATABASE_TOKENTABLE_NAME") + "_test" + region: str = config.get("DATABASE_REGION") + billing_mode: str = config.get("DATABASE_BILLING_MODE").value class User(DynamoDBBaseUserTableUUID, Base): - pass + __tablename__: str = config.get("DATABASE_USERTABLE_NAME") + "_test" + + class Meta: + table_name: str = config.get("DATABASE_USERTABLE_NAME") + "_test" + region: str = config.get("DATABASE_REGION") + billing_mode: str = config.get("DATABASE_BILLING_MODE").value @pytest_asyncio.fixture async def dynamodb_access_token_db( user_id: UUID4, ) -> AsyncGenerator[DynamoDBAccessTokenDatabase[AccessToken]]: - with mock_aws(): - session = aioboto3.Session() - user_table_name = "users_test" - token_table_name = "access_tokens_test" - await ensure_table_exists( - session, user_table_name, DATABASE_USERTABLE_PRIMARY_KEY, DATABASE_REGION - ) - await ensure_table_exists( - session, token_table_name, DATABASE_TOKENTABLE_PRIMARY_KEY, DATABASE_REGION - ) - - user_db = DynamoDBUserDatabase( - session, - DynamoDBBaseUserTableUUID, - user_table_name, - DATABASE_USERTABLE_PRIMARY_KEY, - dynamodb_resource_region=DATABASE_REGION, - ) - user = await user_db.create( - User( - id=user_id, - email="lancelot@camelot.bt", - hashed_password="guinevere", - ) # type: ignore + user_db = DynamoDBUserDatabase(User) + user = await user_db.create( + User( + id=user_id, + email="lancelot@camelot.bt", + hashed_password="guinevere", ) + ) - token_db = DynamoDBAccessTokenDatabase( - session, - AccessToken, - token_table_name, - DATABASE_TOKENTABLE_PRIMARY_KEY, - dynamodb_resource_region=DATABASE_REGION, - ) + token_db = DynamoDBAccessTokenDatabase(AccessToken) - yield token_db + yield token_db - await user_db.delete(user) + await user_db.delete(user) @pytest.mark.asyncio @@ -89,7 +73,7 @@ async def test_queries( assert access_token.user_id == user_id # Update - new_time = datetime.now(timezone.utc) + new_time = now_utc() updated_access_token = await dynamodb_access_token_db.update( access_token, {"created_at": new_time} ) @@ -102,23 +86,44 @@ async def test_queries( assert token_obj is not None token_obj = await dynamodb_access_token_db.get_by_token( - access_token.token, max_age=datetime.now(timezone.utc) + timedelta(hours=1) + access_token.token, max_age=now_utc() + timedelta(hours=1) ) assert token_obj is None token_obj = await dynamodb_access_token_db.get_by_token( - access_token.token, max_age=datetime.now(timezone.utc) - timedelta(hours=1) + access_token.token, max_age=now_utc() - timedelta(hours=1) ) assert token_obj is not None token_obj = await dynamodb_access_token_db.get_by_token("NOT_EXISTING_TOKEN") assert token_obj is None + # Create existing + with pytest.raises( + ValueError, + match="Access token could not be created because it already exists.", + ): + token = AccessToken() + token.token = "TOKEN" + token.user_id = user_id + await dynamodb_access_token_db.create(token) + # Delete await dynamodb_access_token_db.delete(access_token) + with pytest.raises(ValueError, match="Access token could not be deleted"): + await dynamodb_access_token_db.delete(access_token) + deleted_token = await dynamodb_access_token_db.get_by_token(access_token.token) assert deleted_token is None + # Update non-existent + new_time = now_utc() + with pytest.raises( + ValueError, + match="Access token could not be updated because it does not exist.", + ): + await dynamodb_access_token_db.update(access_token, {"created_at": new_time}) + @pytest.mark.asyncio async def test_insert_existing_token( @@ -127,7 +132,11 @@ async def test_insert_existing_token( ): access_token_create = {"token": "TOKEN", "user_id": user_id} + token = await dynamodb_access_token_db.get_by_token(access_token_create["token"]) + if token: + await dynamodb_access_token_db.delete(token) + await dynamodb_access_token_db.create(access_token_create) - with pytest.raises(Exception): + with pytest.raises(ValueError): await dynamodb_access_token_db.create(access_token_create) diff --git a/tests/test_misc.py b/tests/test_misc.py new file mode 100644 index 0000000..ee0b04a --- /dev/null +++ b/tests/test_misc.py @@ -0,0 +1,72 @@ +import pytest +from aiopynamodb.models import Model + +from fastapi_users_db_dynamodb import config +from fastapi_users_db_dynamodb.attributes import GUID +from fastapi_users_db_dynamodb.tables import delete_tables, ensure_tables_exist + + +class NotAModel: + pass + + +class IncompleteModel(Model): + pass + + +class ValidModel(Model): + class Meta: + table_name: str = "valid_model_test" + region: str = config.get("DATABASE_REGION") + billing_mode: str = config.get("DATABASE_BILLING_MODE").value + + +@pytest.mark.asyncio +async def test_tables_invalid_models(monkeypatch): + with pytest.raises(TypeError, match="must be a subclass of Model"): + await ensure_tables_exist(NotAModel) # type: ignore + + with pytest.raises(AttributeError, match="PynamoDB Models require a"): + await ensure_tables_exist(IncompleteModel) + + with pytest.raises(AttributeError, match="PynamoDB Models require a"): + await delete_tables(IncompleteModel) + + await ensure_tables_exist(ValidModel) + assert await ValidModel.exists() + await delete_tables(ValidModel) + assert not await ValidModel.exists() + + monkeypatch.delattr(Model, "exists", raising=True) + with pytest.raises(TypeError): + await ensure_tables_exist(IncompleteModel) + + +def test_config(monkeypatch): + billing_mode = config.BillingMode.PAY_PER_REQUEST + assert billing_mode.value == str(billing_mode) + + local_get, local_set = config.__create_config() + monkeypatch.setattr(config, "get", local_get) + monkeypatch.setattr(config, "set", local_set) + + with pytest.raises(KeyError, match="Unknown config key"): + config.set("non_existent_key", "some_value") + + with pytest.raises(TypeError, match="Invalid type for"): + config.set("DATABASE_BILLING_MODE", 1001) + + region = "us-east-1" + config.set("DATABASE_REGION", region) + assert config.get("DATABASE_REGION") == region + + +def test_attributes(user_id): + id = GUID() + assert id.serialize(None) is None + + user_id_str = str(user_id) + assert user_id_str == id.serialize(user_id_str) + + assert id.deserialize(None) is None + assert user_id == id.deserialize(user_id) diff --git a/tests/test_users.py b/tests/test_users.py index 3f908bb..5c7562d 100644 --- a/tests/test_users.py +++ b/tests/test_users.py @@ -1,89 +1,73 @@ import random as rd from collections.abc import AsyncGenerator -from typing import Any +from typing import TYPE_CHECKING, Any -import aioboto3 import pytest import pytest_asyncio -from moto import mock_aws -from pydantic import BaseModel, ConfigDict, Field +from aiopynamodb.attributes import UnicodeAttribute +from aiopynamodb.models import Model from fastapi_users_db_dynamodb import ( UUID_ID, DynamoDBBaseOAuthAccountTableUUID, DynamoDBBaseUserTableUUID, DynamoDBUserDatabase, + config, ) -from fastapi_users_db_dynamodb._aioboto3_patch import * # noqa: F403 -from tests.conftest import DATABASE_REGION, DATABASE_USERTABLE_PRIMARY_KEY -from tests.tables import ensure_table_exists -class Base(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True, from_attributes=True) +class Base(Model): + pass class User(DynamoDBBaseUserTableUUID, Base): - first_name: str | None = Field(default=None, description="First name of the user") + __tablename__: str = config.get("DATABASE_USERTABLE_NAME") + "_test" + class Meta: + table_name: str = config.get("DATABASE_USERTABLE_NAME") + "_test" + region: str = config.get("DATABASE_REGION") + billing_mode: str = config.get("DATABASE_BILLING_MODE").value -class OAuthBase(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True, from_attributes=True) + if TYPE_CHECKING: + first_name: str | None = None + else: + first_name = UnicodeAttribute(null=True) -class OAuthAccount(OAuthBase, DynamoDBBaseOAuthAccountTableUUID): +class OAuthBase(Model): + pass + + +class OAuthAccount(DynamoDBBaseOAuthAccountTableUUID, OAuthBase): pass class UserOAuth(DynamoDBBaseUserTableUUID, OAuthBase): - first_name: str | None = Field(default=None, description="First name of the user") - oauth_accounts: list[OAuthAccount] = Field( - default_factory=list, description="Linked OAuth accounts" - ) + __tablename__: str = config.get("DATABASE_OAUTHTABLE_NAME") + "_test" + + class Meta: + table_name: str = config.get("DATABASE_OAUTHTABLE_NAME") + "_test" + region: str = config.get("DATABASE_REGION") + billing_mode: str = config.get("DATABASE_BILLING_MODE").value + + if TYPE_CHECKING: + first_name: str | None = None + else: + first_name = UnicodeAttribute(null=True) + + oauth_accounts: list[OAuthAccount] = [] @pytest_asyncio.fixture async def dynamodb_user_db() -> AsyncGenerator[DynamoDBUserDatabase, None]: - with mock_aws(): - session = aioboto3.Session() - table_name = "users_test" - await ensure_table_exists( - session, table_name, DATABASE_USERTABLE_PRIMARY_KEY, DATABASE_REGION - ) - - db = DynamoDBUserDatabase( - session, - User, - table_name, - DATABASE_USERTABLE_PRIMARY_KEY, - dynamodb_resource_region=DATABASE_REGION, - ) - yield db + db = DynamoDBUserDatabase(User) + yield db @pytest_asyncio.fixture async def dynamodb_user_db_oauth() -> AsyncGenerator[DynamoDBUserDatabase, None]: - with mock_aws(): - session = aioboto3.Session() - user_table_name = "users_test_oauth" - oauth_table_name = "oauth_accounts_test" - await ensure_table_exists( - session, user_table_name, DATABASE_USERTABLE_PRIMARY_KEY, DATABASE_REGION - ) - await ensure_table_exists( - session, oauth_table_name, DATABASE_USERTABLE_PRIMARY_KEY, DATABASE_REGION - ) - - db = DynamoDBUserDatabase( - session, - UserOAuth, - user_table_name, - DATABASE_USERTABLE_PRIMARY_KEY, - OAuthAccount, # type: ignore - oauth_table_name, - dynamodb_resource_region=DATABASE_REGION, - ) - yield db + db = DynamoDBUserDatabase(UserOAuth, OAuthAccount) + yield db @pytest.mark.asyncio @@ -100,6 +84,14 @@ async def test_queries(dynamodb_user_db: DynamoDBUserDatabase[User, UUID_ID]): # Update user updated_user = await dynamodb_user_db.update(user, {"is_superuser": True}) assert updated_user.is_superuser is True + with pytest.raises( + ValueError, + match="User account could not be updated because it does not exist.", + ): + fake_user = User() + fake_user.email = "blabla@gmail.com" + fake_user.hashed_password = "crypticpassword" + await dynamodb_user_db.update(fake_user, {"is_superuser": True}) # Get by id id_user = await dynamodb_user_db.get(user.id) @@ -123,6 +115,8 @@ async def test_queries(dynamodb_user_db: DynamoDBUserDatabase[User, UUID_ID]): # Delete user await dynamodb_user_db.delete(user) + with pytest.raises(ValueError, match="User account could not be deleted"): + await dynamodb_user_db.delete(user) deleted_user = await dynamodb_user_db.get(user.id) assert deleted_user is None @@ -131,7 +125,7 @@ async def test_queries(dynamodb_user_db: DynamoDBUserDatabase[User, UUID_ID]): await dynamodb_user_db.get_by_oauth_account("foo", "bar") with pytest.raises(NotImplementedError): await dynamodb_user_db.add_oauth_account(user, {}) - with pytest.raises(ValueError): + with pytest.raises(NotImplementedError): oauth_account = OAuthAccount() # type: ignore await dynamodb_user_db.update_oauth_account(user, oauth_account, {}) # type: ignore @@ -144,7 +138,13 @@ async def test_insert_existing_email( "email": "lancelot@camelot.bt", "hashed_password": "guinevere", } - await dynamodb_user_db.create(user_create) + user = await dynamodb_user_db.create(user_create) + with pytest.raises( + ValueError, + match="User account could not be created because it already exists.", + ): + user_create["id"] = str(user.id) + await dynamodb_user_db.create(user_create) with pytest.raises(ValueError): # oder eigene Exception existing = await dynamodb_user_db.get_by_email(user_create["email"]) @@ -176,6 +176,7 @@ async def test_queries_oauth( dynamodb_user_db_oauth: DynamoDBUserDatabase[UserOAuth, UUID_ID], oauth_account1: dict[str, Any], oauth_account2: dict[str, Any], + user_id: UUID_ID, ): # Test OAuth accounts user_create = {"email": "lancelot@camelot.bt", "hashed_password": "guinevere"} @@ -205,13 +206,13 @@ def _get_account(_user: UserOAuth): ) assert _get_account(user).access_token == "NEW_TOKEN" # type: ignore - #! IMPORTANT: Since DynamoDB uses eventual consistency, we need a small delay (e.g. `time.sleep(0.01)`) \ + #! NOTE: Since DynamoDB uses eventual consistency, we need a small delay (e.g. `time.sleep(0.01)`) \ #! to ensure the user was fully updated. In production, this should be negligible. \ - #! Alternatively, the `get` and `update` methods of the `DynamoDBDatabase` class allow users \ + #! Alternatively, most methods of the `DynamoDBDatabase` class (e.g. `get`, `update`, ...) allow users \ #! to enable consistent reads via the `instant_update` argument. # Get by id - id_user = await dynamodb_user_db_oauth.get(user.id, instant_update=True) + id_user = await dynamodb_user_db_oauth.get(user.id) assert id_user is not None assert id_user.id == user.id assert _get_account(id_user).access_token == "NEW_TOKEN" # type: ignore @@ -232,3 +233,19 @@ def _get_account(_user: UserOAuth): # Unknown OAuth account unknown_oauth_user = await dynamodb_user_db_oauth.get_by_oauth_account("foo", "bar") assert unknown_oauth_user is None + + with pytest.raises( + ValueError, + match="OAuth account could not be updated because it does not exist.", + ): + user = UserOAuth() + oauth_account = OAuthAccount() + oauth_account.user_id = user_id + oauth_account.oauth_name = "blabla_provider" + oauth_account.account_id = "blabla_id" + oauth_account.account_email = "blabla@gmail.com" + await dynamodb_user_db_oauth.update_oauth_account( + user, + oauth_account, + {"access_token": "NEW_TOKEN"}, + )