Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace custom 'hidden=True' field attribute with builtin 'exclude=True' #741

Merged
merged 2 commits into from
Oct 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 36 additions & 123 deletions beanie/odm/documents.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import asyncio
import warnings
from typing import (
TYPE_CHECKING,
AbstractSet,
Any,
ClassVar,
Dict,
Iterable,
List,
Mapping,
Optional,
Set,
Type,
TypeVar,
Union,
Expand Down Expand Up @@ -105,19 +103,23 @@
if IS_PYDANTIC_V2:
from pydantic import model_validator

if TYPE_CHECKING:
from pydantic.typing import AbstractSetIntStr, DictStrAny, MappingIntStrAny

DocType = TypeVar("DocType", bound="Document")
DocumentProjectionType = TypeVar("DocumentProjectionType", bound=BaseModel)


def json_schema_extra(schema: Dict[str, Any], model: Type["Document"]) -> None:
props = {}
for k, v in schema.get("properties", {}).items():
if not v.get("hidden", False):
props[k] = v
schema["properties"] = props
# remove excluded fields from the json schema
properties = schema.get("properties")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't remember how this call is implemented in Pydantic. Could you please check if the schema was not given as a mutable object here? If so, modifying it can lead to unexpected errors - it is better to clone it then.
I'll double-check this point from my side too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not exactly sure what you mean, schema is a mutable object (dict) as shown in the annotation. The json_schema_extra callable returns None so it's only called for its side-effects, mutating schema. It would have no effect if the schema was cloned.

The calling code in Pydantic v2 is here: https://github.com/pydantic/pydantic/blob/main/pydantic/json_schema.py#L1364-L1368

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, you are right. This is expected behavior.

if not properties:
return
for k, field in get_model_fields(model).items():
k = field.alias or k
if k not in properties:
continue
field_info = field if IS_PYDANTIC_V2 else field.field_info
if field_info.exclude:
del properties[k]


def document_alias_generator(s: str) -> str:
Expand Down Expand Up @@ -152,33 +154,17 @@ class Document(
else:

class Config:
json_encoders = {
ObjectId: lambda v: str(v),
}
json_encoders = {ObjectId: str}
allow_population_by_field_name = True
fields = {"id": "_id"}

@staticmethod
def schema_extra(
schema: Dict[str, Any], model: Type["Document"]
) -> None:
props = {}
for k, v in schema.get("properties", {}).items():
if not v.get("hidden", False):
props[k] = v
schema["properties"] = props
schema_extra = staticmethod(json_schema_extra)

id: Optional[PydanticObjectId] = Field(
default=None, description="MongoDB document ObjectID"
)

# State
if IS_PYDANTIC_V2:
revision_id: Optional[UUID] = Field(
default=None, json_schema_extra={"hidden": True}
)
else:
revision_id: Optional[UUID] = Field(default=None, hidden=True) # type: ignore
revision_id: Optional[UUID] = Field(default=None, exclude=True)
_previous_revision_id: Optional[UUID] = PrivateAttr(default=None)
_saved_state: Optional[Dict[str, Any]] = PrivateAttr(default=None)
_previous_saved_state: Optional[Dict[str, Any]] = PrivateAttr(default=None)
Expand All @@ -195,9 +181,6 @@ def schema_extra(
# Database
_database_major_version: ClassVar[int] = 4

# Other
_hidden_fields: ClassVar[Set[str]] = set()

def _swap_revision(self):
if self.get_settings().use_revision:
self._previous_revision_id = self.revision_id
Expand Down Expand Up @@ -1068,98 +1051,28 @@ async def inspect_collection(
return inspection_result

@classmethod
def get_hidden_fields(cls):
return set(
attribute_name
for attribute_name, model_field in get_model_fields(cls).items()
if get_extra_field_info(model_field, "hidden") is True
def check_hidden_fields(cls):
hidden_fields = [
(name, field)
for name, field in get_model_fields(cls).items()
if get_extra_field_info(field, "hidden") is True
]
if not hidden_fields:
return
warnings.warn(
f"{cls.__name__}: 'hidden=True' is deprecated, please use 'exclude=True'",
DeprecationWarning,
)

if IS_PYDANTIC_V2:

def model_dump(
self,
*,
mode="python",
include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None,
exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None,
by_alias: bool = False,
exclude_hidden: bool = True,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
round_trip: bool = False,
warnings: bool = True,
) -> "DictStrAny":
"""
Overriding of the respective method from Pydantic
Hides fields, marked as "hidden
"""
if exclude_hidden:
if isinstance(exclude, AbstractSet):
exclude = {*self._hidden_fields, *exclude}
elif isinstance(exclude, Mapping):
exclude = dict(
{k: True for k in self._hidden_fields}, **exclude
) # type: ignore
elif exclude is None:
exclude = self._hidden_fields

kwargs = {
"include": include,
"exclude": exclude,
"by_alias": by_alias,
"exclude_unset": exclude_unset,
"exclude_defaults": exclude_defaults,
"exclude_none": exclude_none,
"round_trip": round_trip,
"warnings": warnings,
}

return super().model_dump(**kwargs)

else:

def dict(
self,
*,
include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None,
exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None,
by_alias: bool = False,
skip_defaults: bool = False,
exclude_hidden: bool = True,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
) -> "DictStrAny":
"""
Overriding of the respective method from Pydantic
Hides fields, marked as "hidden
"""
if exclude_hidden:
if isinstance(exclude, AbstractSet):
exclude = {*self._hidden_fields, *exclude}
elif isinstance(exclude, Mapping):
exclude = dict(
{k: True for k in self._hidden_fields}, **exclude
) # type: ignore
elif exclude is None:
exclude = self._hidden_fields

kwargs = {
"include": include,
"exclude": exclude,
"by_alias": by_alias,
"exclude_unset": exclude_unset,
"exclude_defaults": exclude_defaults,
"exclude_none": exclude_none,
}

# TODO: Remove this check when skip_defaults are no longer supported
if skip_defaults:
kwargs["skip_defaults"] = skip_defaults

return super().dict(**kwargs)
if IS_PYDANTIC_V2:
for name, field in hidden_fields:
field.exclude = True
del field.json_schema_extra["hidden"]
cls.model_rebuild(force=True)
else:
for name, field in hidden_fields:
field.field_info.exclude = True
del field.field_info.extra["hidden"]
cls.__exclude_fields__[name] = True

@wrap_with_actions(event_type=EventTypes.VALIDATE_ON_SAVE)
async def validate_self(self, *args, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion beanie/odm/utils/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def check_nested_links(
cls._link_fields[k] = link_info
check_nested_links(link_info, prev_models=[])

cls._hidden_fields = cls.get_hidden_fields()
cls.check_hidden_fields()

@staticmethod
def init_actions(cls):
Expand Down
2 changes: 2 additions & 0 deletions tests/odm/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
DocumentWithCustomIdUUID,
DocumentWithCustomInit,
DocumentWithDecimalField,
DocumentWithDeprecatedHiddenField,
DocumentWithExtras,
DocumentWithHttpUrlField,
DocumentWithIndexMerging1,
Expand Down Expand Up @@ -199,6 +200,7 @@ async def init(db):
DocumentTestModelWithComplexIndex,
DocumentTestModelFailInspection,
DocumentWithBsonEncodersFiledsTypes,
DocumentWithDeprecatedHiddenField,
DocumentWithCustomFiledsTypes,
DocumentWithCustomIdUUID,
DocumentWithCustomIdInt,
Expand Down
2 changes: 2 additions & 0 deletions tests/odm/documents/test_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ async def test_replace(document):
new_doc = document.model_copy(update=update_data)
else:
new_doc = document.copy(update=update_data)
# pydantic v1 doesn't copy excluded fields
new_doc.test_list = document.test_list
# document.test_str = "REPLACED_VALUE"
await new_doc.replace()
new_document = await DocumentTestModel.get(document.id)
Expand Down
20 changes: 9 additions & 11 deletions tests/odm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,7 @@ class DocumentTestModel(Document):
test_int: int
test_doc: SubDocument
test_str: str

if IS_PYDANTIC_V2:
test_list: List[SubDocument] = Field(
json_schema_extra={"hidden": True}
)
else:
test_list: List[SubDocument] = Field(hidden=True)
test_list: List[SubDocument] = Field(exclude=True)

class Settings:
use_cache = True
Expand Down Expand Up @@ -242,6 +236,13 @@ class Settings:
name = "DocumentTestModel"


class DocumentWithDeprecatedHiddenField(Document):
if IS_PYDANTIC_V2:
test_hidden: List[str] = Field(json_schema_extra={"hidden": True})
else:
test_hidden: List[str] = Field(hidden=True)


class DocumentWithCustomIdUUID(Document):
id: UUID = Field(default_factory=uuid4)
name: str
Expand Down Expand Up @@ -534,10 +535,7 @@ class House(Document):
roof: Optional[Link[Roof]] = None
yards: Optional[List[Link[Yard]]] = None
height: Indexed(int) = 2
if IS_PYDANTIC_V2:
name: Indexed(str) = Field(json_schema_extra={"hidden": True})
else:
name: Indexed(str) = Field(hidden=True)
name: Indexed(str) = Field(exclude=True)

if IS_PYDANTIC_V2:
model_config = ConfigDict(
Expand Down
13 changes: 12 additions & 1 deletion tests/odm/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
DocumentTestModel,
DocumentWithBsonEncodersFiledsTypes,
DocumentWithCustomFiledsTypes,
DocumentWithDeprecatedHiddenField,
Sample,
)

Expand Down Expand Up @@ -106,14 +107,24 @@ async def test_custom_filed_types():
)


async def test_hidden(document):
async def test_excluded(document):
document = await DocumentTestModel.find_one()
if IS_PYDANTIC_V2:
assert "test_list" not in document.model_dump()
else:
assert "test_list" not in document.dict()


async def test_hidden():
document = DocumentWithDeprecatedHiddenField(test_hidden=["abc", "def"])
await document.insert()
document = await DocumentWithDeprecatedHiddenField.find_one()
if IS_PYDANTIC_V2:
assert "test_hidden" not in document.model_dump()
else:
assert "test_hidden" not in document.dict()


def test_revision_id_not_in_schema():
"""Check if there is a `revision_id` slipping into the schema."""

Expand Down