Skip to content

Commit

Permalink
Fix/settings class in migrations (#409)
Browse files Browse the repository at this point in the history
* fix: use Settings inner class for the migrations log setup

* fix: mypy

* fix: mark root doc with _inheritance_inited

* version: 1.15.2
  • Loading branch information
roman-right committed Nov 9, 2022
1 parent 3ace946 commit b90703c
Show file tree
Hide file tree
Showing 20 changed files with 97 additions and 61 deletions.
2 changes: 1 addition & 1 deletion beanie/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from beanie.odm.views import View
from beanie.odm.union_doc import UnionDoc

__version__ = "1.15.1"
__version__ = "1.15.2"
__all__ = [
# ODM
"Document",
Expand Down
2 changes: 1 addition & 1 deletion beanie/migrations/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class MigrationLog(Document):
name: str
is_current: bool

class Collection:
class Settings:
name = "migrations_log"


Expand Down
4 changes: 2 additions & 2 deletions beanie/odm/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ async def insert_one(
cls: Type[DocType],
document: DocType,
session: Optional[ClientSession] = None,
bulk_writer: "BulkWriter" = None,
bulk_writer: Optional["BulkWriter"] = None,
link_rule: WriteRules = WriteRules.DO_NOTHING,
) -> Optional[DocType]:
"""
Expand Down Expand Up @@ -855,7 +855,7 @@ def dict(
include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None,
exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None,
by_alias: bool = False,
skip_defaults: bool = None,
skip_defaults: bool = False,
exclude_hidden: bool = True,
exclude_unset: bool = False,
exclude_defaults: bool = False,
Expand Down
2 changes: 2 additions & 0 deletions beanie/odm/interfaces/aggregate.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from abc import abstractmethod
from typing import TypeVar, Type, Optional, Union, Dict, Any, overload

from pydantic import BaseModel
Expand All @@ -12,6 +13,7 @@

class AggregateInterface:
@classmethod
@abstractmethod
def find_all(cls) -> FindMany:
pass

Expand Down
42 changes: 24 additions & 18 deletions beanie/odm/interfaces/find.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from abc import abstractmethod
from typing import (
Optional,
List,
Expand All @@ -10,6 +11,7 @@
ClassVar,
TypeVar,
Dict,
TYPE_CHECKING,
)
from collections.abc import Iterable
from pydantic import (
Expand All @@ -22,7 +24,9 @@
from beanie.odm.queries.find import FindOne, FindMany
from beanie.odm.settings.base import ItemSettings

DocType = TypeVar("DocType", bound="FindInterface")
if TYPE_CHECKING:
from beanie.odm.documents import DocType

DocumentProjectionType = TypeVar("DocumentProjectionType", bound=BaseModel)


Expand All @@ -37,16 +41,18 @@ class FindInterface:
_children: ClassVar[Dict[str, Type]]

@classmethod
@abstractmethod
def get_model_type(cls) -> ModelType:
pass

@classmethod
@abstractmethod
def get_settings(cls) -> ItemSettings:
pass

@overload
@classmethod
def find_one(
def find_one( # type: ignore
cls: Type["DocType"],
*args: Union[Mapping[str, Any], bool],
projection_model: None = None,
Expand All @@ -60,7 +66,7 @@ def find_one(

@overload
@classmethod
def find_one(
def find_one( # type: ignore
cls: Type["DocType"],
*args: Union[Mapping[str, Any], bool],
projection_model: Type["DocumentProjectionType"],
Expand All @@ -73,7 +79,7 @@ def find_one(
...

@classmethod
def find_one(
def find_one( # type: ignore
cls: Type["DocType"],
*args: Union[Mapping[str, Any], bool],
projection_model: Optional[Type["DocumentProjectionType"]] = None,
Expand Down Expand Up @@ -107,7 +113,7 @@ def find_one(

@overload
@classmethod
def find_many(
def find_many( # type: ignore
cls: Type["DocType"],
*args: Union[Mapping[str, Any], bool],
projection_model: None = None,
Expand All @@ -124,10 +130,10 @@ def find_many(

@overload
@classmethod
def find_many(
def find_many( # type: ignore
cls: Type["DocType"],
*args: Union[Mapping[str, Any], bool],
projection_model: Type["DocumentProjectionType"] = None,
projection_model: Optional[Type["DocumentProjectionType"]] = None,
skip: Optional[int] = None,
limit: Optional[int] = None,
sort: Union[None, str, List[Tuple[str, SortDirection]]] = None,
Expand All @@ -140,7 +146,7 @@ def find_many(
...

@classmethod
def find_many(
def find_many( # type: ignore
cls: Type["DocType"],
*args: Union[Mapping[str, Any], bool],
projection_model: Optional[Type["DocumentProjectionType"]] = None,
Expand Down Expand Up @@ -182,7 +188,7 @@ def find_many(

@overload
@classmethod
def find(
def find( # type: ignore
cls: Type["DocType"],
*args: Union[Mapping[str, Any], bool],
projection_model: None = None,
Expand All @@ -199,7 +205,7 @@ def find(

@overload
@classmethod
def find(
def find( # type: ignore
cls: Type["DocType"],
*args: Union[Mapping[str, Any], bool],
projection_model: Type["DocumentProjectionType"],
Expand All @@ -215,7 +221,7 @@ def find(
...

@classmethod
def find(
def find( # type: ignore
cls: Type["DocType"],
*args: Union[Mapping[str, Any], bool],
projection_model: Optional[Type["DocumentProjectionType"]] = None,
Expand Down Expand Up @@ -246,7 +252,7 @@ def find(

@overload
@classmethod
def find_all(
def find_all( # type: ignore
cls: Type["DocType"],
skip: Optional[int] = None,
limit: Optional[int] = None,
Expand All @@ -261,7 +267,7 @@ def find_all(

@overload
@classmethod
def find_all(
def find_all( # type: ignore
cls: Type["DocType"],
skip: Optional[int] = None,
limit: Optional[int] = None,
Expand All @@ -275,7 +281,7 @@ def find_all(
...

@classmethod
def find_all(
def find_all( # type: ignore
cls: Type["DocType"],
skip: Optional[int] = None,
limit: Optional[int] = None,
Expand Down Expand Up @@ -311,7 +317,7 @@ def find_all(

@overload
@classmethod
def all(
def all( # type: ignore
cls: Type["DocType"],
projection_model: None = None,
skip: Optional[int] = None,
Expand All @@ -326,7 +332,7 @@ def all(

@overload
@classmethod
def all(
def all( # type: ignore
cls: Type["DocType"],
projection_model: Type["DocumentProjectionType"],
skip: Optional[int] = None,
Expand All @@ -340,7 +346,7 @@ def all(
...

@classmethod
def all(
def all( # type: ignore
cls: Type["DocType"],
projection_model: Optional[Type["DocumentProjectionType"]] = None,
skip: Optional[int] = None,
Expand Down Expand Up @@ -373,7 +379,7 @@ async def count(cls) -> int:
:return: int
"""
return await cls.find_all().count()
return await cls.find_all().count() # type: ignore

@classmethod
def _add_class_id_filter(cls, args: Tuple, with_children: bool = False):
Expand Down
3 changes: 3 additions & 0 deletions beanie/odm/interfaces/getters.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from abc import abstractmethod

from motor.motor_asyncio import AsyncIOMotorCollection

from beanie.odm.settings.base import ItemSettings


class OtherGettersInterface:
@classmethod
@abstractmethod
def get_settings(cls) -> ItemSettings:
pass

Expand Down
2 changes: 1 addition & 1 deletion beanie/odm/interfaces/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ def set_session(self, session: Optional[ClientSession] = None):
:return:
"""
if session is not None:
self.session = session
self.session: Optional[ClientSession] = session
return self
2 changes: 2 additions & 0 deletions beanie/odm/queries/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,11 @@ async def __anext__(self) -> CursorResultType:
return next_item
return parse_obj(projection, next_item) # type: ignore

@abstractmethod
def _get_cache(self) -> List[Dict[str, Any]]:
...

@abstractmethod
def _set_cache(self, data):
...

Expand Down
3 changes: 2 additions & 1 deletion beanie/odm/queries/delete.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Type, TYPE_CHECKING, Any, Mapping, Optional, Dict, Generator

from pymongo.client_session import ClientSession
from pymongo.results import DeleteResult

from beanie.odm.bulk import BulkWriter, Operation
Expand All @@ -26,7 +27,7 @@ def __init__(
):
self.document_model = document_model
self.find_query = find_query
self.session = None
self.session: Optional[ClientSession] = None
self.bulk_writer = bulk_writer
self.pymongo_kwargs: Dict[str, Any] = pymongo_kwargs

Expand Down
4 changes: 2 additions & 2 deletions beanie/odm/queries/find.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def find_many(
def find_many(
self: "FindMany[FindQueryResultType]",
*args: Union[Mapping[str, Any], bool],
projection_model: Type[FindQueryProjectionType] = None,
projection_model: Optional[Type[FindQueryProjectionType]] = None,
skip: Optional[int] = None,
limit: Optional[int] = None,
sort: Union[None, str, List[Tuple[str, SortDirection]]] = None,
Expand Down Expand Up @@ -366,7 +366,7 @@ def find(
def find(
self: "FindMany[FindQueryResultType]",
*args: Union[Mapping[str, Any], bool],
projection_model: Type[FindQueryProjectionType] = None,
projection_model: Optional[Type[FindQueryProjectionType]] = None,
skip: Optional[int] = None,
limit: Optional[int] = None,
sort: Union[None, str, List[Tuple[str, SortDirection]]] = None,
Expand Down
4 changes: 3 additions & 1 deletion beanie/odm/queries/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ async def _update(self) -> UpdateResult:

def __await__(
self,
) -> Generator[Any, None, Union[UpdateResult, InsertOneResult]]:
) -> Generator[
Any, None, Union[UpdateResult, InsertOneResult, Optional["DocType"]]
]:
"""
Run the query
:return:
Expand Down
3 changes: 2 additions & 1 deletion beanie/odm/utils/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
List,
Mapping,
Union,
Optional,
)
from typing import Any, Callable, Dict, Type
from uuid import UUID
Expand Down Expand Up @@ -61,7 +62,7 @@ def __init__(
exclude: Union[
AbstractSet[Union[str, int]], Mapping[Union[str, int], Any], None
] = None,
custom_encoders: Dict[Type, Callable] = None,
custom_encoders: Optional[Dict[Type, Callable]] = None,
by_alias: bool = True,
to_db: bool = False,
):
Expand Down
14 changes: 9 additions & 5 deletions beanie/odm/utils/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ class Initializer:
def __init__(
self,
database: AsyncIOMotorDatabase = None,
connection_string: str = None,
document_models: List[
Union[Type["DocType"], Type["View"], str]
connection_string: Optional[str] = None,
document_models: Optional[
List[Union[Type["DocType"], Type["View"], str]]
] = None,
allow_index_dropping: bool = False,
recreate_views: bool = False,
Expand Down Expand Up @@ -314,6 +314,8 @@ async def init_document(self, cls: Type[Document]) -> Optional[Output]:
class_name=cls.__name__,
collection_name=cls.get_collection_name(),
)
if cls.get_settings().is_root:
cls._inheritance_inited = True # TODO refactor. Looks ugly
elif output is not None:
output.class_name = f"{output.class_name}.{cls.__name__}"
cls._class_id = output.class_name
Expand Down Expand Up @@ -432,8 +434,10 @@ async def init_class(

async def init_beanie(
database: AsyncIOMotorDatabase = None,
connection_string: str = None,
document_models: List[Union[Type["DocType"], Type["View"], str]] = None,
connection_string: Optional[str] = None,
document_models: Optional[
List[Union[Type["DocType"], Type["View"], str]]
] = None,
allow_index_dropping: bool = False,
recreate_views: bool = False,
):
Expand Down

0 comments on commit b90703c

Please sign in to comment.