Skip to content

Commit

Permalink
implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
roman-right committed Sep 9, 2021
1 parent 9dce295 commit 8615709
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 8 deletions.
8 changes: 8 additions & 0 deletions beanie/exceptions.py
Expand Up @@ -24,3 +24,11 @@ class MigrationException(Exception):

class ReplaceError(Exception):
pass


class StateManagementIsTurnedOff(Exception):
pass


class StateNotSaved(Exception):
pass
Empty file.
57 changes: 56 additions & 1 deletion beanie/odm/documents.py
Expand Up @@ -13,7 +13,7 @@

from bson import ObjectId
from motor.motor_asyncio import AsyncIOMotorDatabase, AsyncIOMotorCollection
from pydantic import ValidationError, parse_obj_as
from pydantic import ValidationError, parse_obj_as, PrivateAttr
from pydantic.main import BaseModel
from pymongo.client_session import ClientSession
from pymongo.results import (
Expand Down Expand Up @@ -44,6 +44,7 @@
from beanie.odm.queries.update import UpdateMany
from beanie.odm.utils.collection import collection_factory
from beanie.odm.utils.dump import get_dict
from beanie.odm.utils.state import saved_state_needed

DocType = TypeVar("DocType", bound="Document")
DocumentProjectionType = TypeVar("DocumentProjectionType", bound=BaseModel)
Expand All @@ -66,6 +67,8 @@ class Document(BaseModel, UpdateMethods):

id: Optional[PydanticObjectId] = None

_saved_state: Optional[Dict[str, Any]] = PrivateAttr(default=None)

def __init__(self, *args, **kwargs):
super(Document, self).__init__(*args, **kwargs)
self.get_motor_collection()
Expand All @@ -84,6 +87,8 @@ async def _sync(self) -> None:
)
for key, value in dict(new_instance).items():
setattr(self, key, value)
if self.use_state_management():
self._save_state()

@wrap_with_actions(EventTypes.INSERT)
async def insert(
Expand All @@ -100,6 +105,8 @@ async def insert(
if not isinstance(new_id, self.__fields__["id"].type_):
new_id = self.__fields__["id"].type_(new_id)
self.id = new_id
if self.use_state_management():
self._save_state()
return self

async def create(
Expand Down Expand Up @@ -432,6 +439,8 @@ async def replace(
await self.find_one({"_id": self.id}).replace_one(
self, session=session
)
if self.use_state_management():
self._save_state()
return self

async def save(
Expand Down Expand Up @@ -648,6 +657,52 @@ async def inspect_collection(
)
return inspection_result

# State management

@classmethod
def use_state_management(cls) -> bool:
collection_meta = cls._get_collection_meta()
return collection_meta.use_state_management

def _save_state(self):
if self.use_state_management():
self._saved_state = self.dict()

@classmethod
def _parse_obj_saving_state(cls: Type[DocType], obj: Any) -> DocType:
result: DocType = cls.parse_obj(obj)
result._save_state()
return result

@saved_state_needed
def rollback(self) -> None:
for key, value in self._saved_state.items(): # type: ignore
setattr(self, key, value)

@property # type: ignore
@saved_state_needed
def is_changed(self) -> bool:
if self._saved_state == self.dict():
return False
return True

@saved_state_needed
def get_changes(self) -> Dict[str, Any]:
changes = {}
if self.is_changed:
current_state = self.dict()
for k, v in self._saved_state.items(): # type: ignore
if v != current_state[k]:
changes[k] = v
return changes

@saved_state_needed
async def save_changes(self) -> None:
if not self.is_changed:
return None
changes = self.get_changes()
await self.set(changes)

class Config:
json_encoders = {
ObjectId: lambda v: str(v),
Expand Down
11 changes: 5 additions & 6 deletions beanie/odm/queries/cursor.py
Expand Up @@ -11,6 +11,7 @@
)

from pydantic.main import BaseModel
from beanie.odm.utils.parsing import parse_obj

CursorResultType = TypeVar("CursorResultType")

Expand Down Expand Up @@ -40,11 +41,9 @@ async def __anext__(self) -> CursorResultType:
self.cursor = self.motor_cursor
next_item = await self.cursor.__anext__()
projection = self.get_projection_model()
return (
projection.parse_obj(next_item)
if projection is not None
else next_item
) # type: ignore
if projection is None:
return next_item
return parse_obj(projection, next_item) # type: ignore

async def to_list(
self, length: Optional[int] = None
Expand All @@ -64,6 +63,6 @@ async def to_list(
if projection is not None:
return cast(
List[CursorResultType],
[projection.parse_obj(i) for i in motor_list],
[parse_obj(projection, i) for i in motor_list],
)
return cast(List[CursorResultType], motor_list)
3 changes: 2 additions & 1 deletion beanie/odm/queries/find.py
Expand Up @@ -40,6 +40,7 @@
UpdateMany,
UpdateOne,
)
from beanie.odm.utils.parsing import parse_obj
from beanie.odm.utils.projection import get_projection

if TYPE_CHECKING:
Expand Down Expand Up @@ -620,5 +621,5 @@ def __await__(
if document is None:
return None
return cast(
FindQueryResultType, self.projection_model.parse_obj(document)
FindQueryResultType, parse_obj(self.projection_model, document)
)
2 changes: 2 additions & 0 deletions beanie/odm/utils/collection.py
Expand Up @@ -20,6 +20,7 @@ def validate(cls, v):

class CollectionInputParameters(BaseModel):
name: str = ""
use_state_management: bool = False
indexes: List[IndexModelField] = []

class Config:
Expand Down Expand Up @@ -95,5 +96,6 @@ class CollectionMeta:
name: str = collection_parameters.name
motor_collection: AsyncIOMotorCollection = collection
indexes: List = found_indexes
use_state_management: bool = collection_parameters.use_state_management

return CollectionMeta
9 changes: 9 additions & 0 deletions beanie/odm/utils/parsing.py
@@ -0,0 +1,9 @@
from typing import Any, Type

from pydantic import BaseModel


def parse_obj(model: Type[BaseModel], data: Any) -> BaseModel:
if hasattr(model, "_parse_obj_saving_state"):
return model._parse_obj_saving_state(data)
return model.parse_obj(data)
33 changes: 33 additions & 0 deletions beanie/odm/utils/state.py
@@ -0,0 +1,33 @@
import inspect
from functools import wraps
from typing import Callable, TYPE_CHECKING

from beanie.exceptions import StateManagementIsTurnedOff, StateNotSaved

if TYPE_CHECKING:
from beanie.odm.documents import DocType


def check_if_state_saved(self: "DocType"):
if not self.use_state_management():
raise StateManagementIsTurnedOff(
"State management is turned off for this document"
)
if self._saved_state is None:
raise StateNotSaved("No state was saved")


def saved_state_needed(f: Callable):
@wraps(f)
def sync_wrapper(self: "DocType", *args, **kwargs):
check_if_state_saved(self)
return f(self, *args, **kwargs)

@wraps(f)
async def async_wrapper(self: "DocType", *args, **kwargs):
check_if_state_saved(self)
return await f(self, *args, **kwargs)

if inspect.iscoroutinefunction(f):
return async_wrapper
return sync_wrapper

0 comments on commit 8615709

Please sign in to comment.