Skip to content

Commit

Permalink
Replace custom 'hidden=True' field attribute with builtin 'exclude=Tr…
Browse files Browse the repository at this point in the history
…ue' (#741)

* Replace custom 'hidden=True' field attribute with builtin 'exclude=True'

* Check and warn in case the 'hidden' parameter is used
  • Loading branch information
gsakkis committed Oct 22, 2023
1 parent 0718894 commit a6f1b8d
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 136 deletions.
159 changes: 36 additions & 123 deletions beanie/odm/documents.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import asyncio
import warnings
from typing import (
TYPE_CHECKING,
AbstractSet,
Any,
ClassVar,
Dict,
Iterable,
List,
Mapping,
Optional,
Set,
Type,
TypeVar,
Union,
Expand Down Expand Up @@ -105,19 +103,23 @@
if IS_PYDANTIC_V2:
from pydantic import model_validator

if TYPE_CHECKING:
from pydantic.typing import AbstractSetIntStr, DictStrAny, MappingIntStrAny

DocType = TypeVar("DocType", bound="Document")
DocumentProjectionType = TypeVar("DocumentProjectionType", bound=BaseModel)


def json_schema_extra(schema: Dict[str, Any], model: Type["Document"]) -> None:
props = {}
for k, v in schema.get("properties", {}).items():
if not v.get("hidden", False):
props[k] = v
schema["properties"] = props
# remove excluded fields from the json schema
properties = schema.get("properties")
if not properties:
return
for k, field in get_model_fields(model).items():
k = field.alias or k
if k not in properties:
continue
field_info = field if IS_PYDANTIC_V2 else field.field_info
if field_info.exclude:
del properties[k]


def document_alias_generator(s: str) -> str:
Expand Down Expand Up @@ -152,33 +154,17 @@ class Document(
else:

class Config:
json_encoders = {
ObjectId: lambda v: str(v),
}
json_encoders = {ObjectId: str}
allow_population_by_field_name = True
fields = {"id": "_id"}

@staticmethod
def schema_extra(
schema: Dict[str, Any], model: Type["Document"]
) -> None:
props = {}
for k, v in schema.get("properties", {}).items():
if not v.get("hidden", False):
props[k] = v
schema["properties"] = props
schema_extra = staticmethod(json_schema_extra)

id: Optional[PydanticObjectId] = Field(
default=None, description="MongoDB document ObjectID"
)

# State
if IS_PYDANTIC_V2:
revision_id: Optional[UUID] = Field(
default=None, json_schema_extra={"hidden": True}
)
else:
revision_id: Optional[UUID] = Field(default=None, hidden=True) # type: ignore
revision_id: Optional[UUID] = Field(default=None, exclude=True)
_previous_revision_id: Optional[UUID] = PrivateAttr(default=None)
_saved_state: Optional[Dict[str, Any]] = PrivateAttr(default=None)
_previous_saved_state: Optional[Dict[str, Any]] = PrivateAttr(default=None)
Expand All @@ -195,9 +181,6 @@ def schema_extra(
# Database
_database_major_version: ClassVar[int] = 4

# Other
_hidden_fields: ClassVar[Set[str]] = set()

def _swap_revision(self):
if self.get_settings().use_revision:
self._previous_revision_id = self.revision_id
Expand Down Expand Up @@ -1068,98 +1051,28 @@ async def inspect_collection(
return inspection_result

@classmethod
def get_hidden_fields(cls):
return set(
attribute_name
for attribute_name, model_field in get_model_fields(cls).items()
if get_extra_field_info(model_field, "hidden") is True
def check_hidden_fields(cls):
hidden_fields = [
(name, field)
for name, field in get_model_fields(cls).items()
if get_extra_field_info(field, "hidden") is True
]
if not hidden_fields:
return
warnings.warn(
f"{cls.__name__}: 'hidden=True' is deprecated, please use 'exclude=True'",
DeprecationWarning,
)

if IS_PYDANTIC_V2:

def model_dump(
self,
*,
mode="python",
include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None,
exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None,
by_alias: bool = False,
exclude_hidden: bool = True,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
round_trip: bool = False,
warnings: bool = True,
) -> "DictStrAny":
"""
Overriding of the respective method from Pydantic
Hides fields, marked as "hidden
"""
if exclude_hidden:
if isinstance(exclude, AbstractSet):
exclude = {*self._hidden_fields, *exclude}
elif isinstance(exclude, Mapping):
exclude = dict(
{k: True for k in self._hidden_fields}, **exclude
) # type: ignore
elif exclude is None:
exclude = self._hidden_fields

kwargs = {
"include": include,
"exclude": exclude,
"by_alias": by_alias,
"exclude_unset": exclude_unset,
"exclude_defaults": exclude_defaults,
"exclude_none": exclude_none,
"round_trip": round_trip,
"warnings": warnings,
}

return super().model_dump(**kwargs)

else:

def dict(
self,
*,
include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None,
exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None,
by_alias: bool = False,
skip_defaults: bool = False,
exclude_hidden: bool = True,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
) -> "DictStrAny":
"""
Overriding of the respective method from Pydantic
Hides fields, marked as "hidden
"""
if exclude_hidden:
if isinstance(exclude, AbstractSet):
exclude = {*self._hidden_fields, *exclude}
elif isinstance(exclude, Mapping):
exclude = dict(
{k: True for k in self._hidden_fields}, **exclude
) # type: ignore
elif exclude is None:
exclude = self._hidden_fields

kwargs = {
"include": include,
"exclude": exclude,
"by_alias": by_alias,
"exclude_unset": exclude_unset,
"exclude_defaults": exclude_defaults,
"exclude_none": exclude_none,
}

# TODO: Remove this check when skip_defaults are no longer supported
if skip_defaults:
kwargs["skip_defaults"] = skip_defaults

return super().dict(**kwargs)
if IS_PYDANTIC_V2:
for name, field in hidden_fields:
field.exclude = True
del field.json_schema_extra["hidden"]
cls.model_rebuild(force=True)
else:
for name, field in hidden_fields:
field.field_info.exclude = True
del field.field_info.extra["hidden"]
cls.__exclude_fields__[name] = True

@wrap_with_actions(event_type=EventTypes.VALIDATE_ON_SAVE)
async def validate_self(self, *args, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion beanie/odm/utils/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def check_nested_links(
cls._link_fields[k] = link_info
check_nested_links(link_info, prev_models=[])

cls._hidden_fields = cls.get_hidden_fields()
cls.check_hidden_fields()

@staticmethod
def init_actions(cls):
Expand Down
2 changes: 2 additions & 0 deletions tests/odm/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
DocumentWithCustomIdUUID,
DocumentWithCustomInit,
DocumentWithDecimalField,
DocumentWithDeprecatedHiddenField,
DocumentWithExtras,
DocumentWithHttpUrlField,
DocumentWithIndexMerging1,
Expand Down Expand Up @@ -199,6 +200,7 @@ async def init(db):
DocumentTestModelWithComplexIndex,
DocumentTestModelFailInspection,
DocumentWithBsonEncodersFiledsTypes,
DocumentWithDeprecatedHiddenField,
DocumentWithCustomFiledsTypes,
DocumentWithCustomIdUUID,
DocumentWithCustomIdInt,
Expand Down
2 changes: 2 additions & 0 deletions tests/odm/documents/test_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ async def test_replace(document):
new_doc = document.model_copy(update=update_data)
else:
new_doc = document.copy(update=update_data)
# pydantic v1 doesn't copy excluded fields
new_doc.test_list = document.test_list
# document.test_str = "REPLACED_VALUE"
await new_doc.replace()
new_document = await DocumentTestModel.get(document.id)
Expand Down
20 changes: 9 additions & 11 deletions tests/odm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,7 @@ class DocumentTestModel(Document):
test_int: int
test_doc: SubDocument
test_str: str

if IS_PYDANTIC_V2:
test_list: List[SubDocument] = Field(
json_schema_extra={"hidden": True}
)
else:
test_list: List[SubDocument] = Field(hidden=True)
test_list: List[SubDocument] = Field(exclude=True)

class Settings:
use_cache = True
Expand Down Expand Up @@ -242,6 +236,13 @@ class Settings:
name = "DocumentTestModel"


class DocumentWithDeprecatedHiddenField(Document):
if IS_PYDANTIC_V2:
test_hidden: List[str] = Field(json_schema_extra={"hidden": True})
else:
test_hidden: List[str] = Field(hidden=True)


class DocumentWithCustomIdUUID(Document):
id: UUID = Field(default_factory=uuid4)
name: str
Expand Down Expand Up @@ -534,10 +535,7 @@ class House(Document):
roof: Optional[Link[Roof]] = None
yards: Optional[List[Link[Yard]]] = None
height: Indexed(int) = 2
if IS_PYDANTIC_V2:
name: Indexed(str) = Field(json_schema_extra={"hidden": True})
else:
name: Indexed(str) = Field(hidden=True)
name: Indexed(str) = Field(exclude=True)

if IS_PYDANTIC_V2:
model_config = ConfigDict(
Expand Down
13 changes: 12 additions & 1 deletion tests/odm/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
DocumentTestModel,
DocumentWithBsonEncodersFiledsTypes,
DocumentWithCustomFiledsTypes,
DocumentWithDeprecatedHiddenField,
Sample,
)

Expand Down Expand Up @@ -106,14 +107,24 @@ async def test_custom_filed_types():
)


async def test_hidden(document):
async def test_excluded(document):
document = await DocumentTestModel.find_one()
if IS_PYDANTIC_V2:
assert "test_list" not in document.model_dump()
else:
assert "test_list" not in document.dict()


async def test_hidden():
document = DocumentWithDeprecatedHiddenField(test_hidden=["abc", "def"])
await document.insert()
document = await DocumentWithDeprecatedHiddenField.find_one()
if IS_PYDANTIC_V2:
assert "test_hidden" not in document.model_dump()
else:
assert "test_hidden" not in document.dict()


def test_revision_id_not_in_schema():
"""Check if there is a `revision_id` slipping into the schema."""

Expand Down

0 comments on commit a6f1b8d

Please sign in to comment.