Skip to content

Commit

Permalink
Replace custom 'hidden=True' field attribute with builtin 'exclude=True'
Browse files Browse the repository at this point in the history
  • Loading branch information
gsakkis committed Oct 13, 2023
1 parent d9eb71d commit c87b226
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 141 deletions.
141 changes: 14 additions & 127 deletions beanie/odm/documents.py
@@ -1,15 +1,12 @@
import asyncio
from typing import (
TYPE_CHECKING,
AbstractSet,
Any,
ClassVar,
Dict,
Iterable,
List,
Mapping,
Optional,
Set,
Type,
TypeVar,
Union,
Expand Down Expand Up @@ -86,7 +83,6 @@
from beanie.odm.utils.parsing import merge_models
from beanie.odm.utils.pydantic import (
IS_PYDANTIC_V2,
get_extra_field_info,
get_field_type,
get_model_dump,
get_model_fields,
Expand All @@ -105,19 +101,23 @@
if IS_PYDANTIC_V2:
from pydantic import model_validator

if TYPE_CHECKING:
from pydantic.typing import AbstractSetIntStr, DictStrAny, MappingIntStrAny

DocType = TypeVar("DocType", bound="Document")
DocumentProjectionType = TypeVar("DocumentProjectionType", bound=BaseModel)


def json_schema_extra(schema: Dict[str, Any], model: Type["Document"]) -> None:
props = {}
for k, v in schema.get("properties", {}).items():
if not v.get("hidden", False):
props[k] = v
schema["properties"] = props
# remove excluded fields from the json schema
properties = schema.get("properties")
if not properties:
return
for k, field in get_model_fields(model).items():
k = field.alias or k
if k not in properties:
continue
field_info = field if IS_PYDANTIC_V2 else field.field_info
if field_info.exclude:
del properties[k]


def document_alias_generator(s: str) -> str:
Expand Down Expand Up @@ -152,33 +152,17 @@ class Document(
else:

class Config:
json_encoders = {
ObjectId: lambda v: str(v),
}
json_encoders = {ObjectId: str}
allow_population_by_field_name = True
fields = {"id": "_id"}

@staticmethod
def schema_extra(
schema: Dict[str, Any], model: Type["Document"]
) -> None:
props = {}
for k, v in schema.get("properties", {}).items():
if not v.get("hidden", False):
props[k] = v
schema["properties"] = props
schema_extra = staticmethod(json_schema_extra)

id: Optional[PydanticObjectId] = Field(
default=None, description="MongoDB document ObjectID"
)

# State
if IS_PYDANTIC_V2:
revision_id: Optional[UUID] = Field(
default=None, json_schema_extra={"hidden": True}
)
else:
revision_id: Optional[UUID] = Field(default=None, hidden=True) # type: ignore
revision_id: Optional[UUID] = Field(default=None, exclude=True)
_previous_revision_id: Optional[UUID] = PrivateAttr(default=None)
_saved_state: Optional[Dict[str, Any]] = PrivateAttr(default=None)
_previous_saved_state: Optional[Dict[str, Any]] = PrivateAttr(default=None)
Expand All @@ -195,9 +179,6 @@ def schema_extra(
# Database
_database_major_version: ClassVar[int] = 4

# Other
_hidden_fields: ClassVar[Set[str]] = set()

def _swap_revision(self):
if self.get_settings().use_revision:
self._previous_revision_id = self.revision_id
Expand Down Expand Up @@ -1062,100 +1043,6 @@ async def inspect_collection(
)
return inspection_result

@classmethod
def get_hidden_fields(cls):
return set(
attribute_name
for attribute_name, model_field in get_model_fields(cls).items()
if get_extra_field_info(model_field, "hidden") is True
)

if IS_PYDANTIC_V2:

def model_dump(
self,
*,
mode="python",
include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None,
exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None,
by_alias: bool = False,
exclude_hidden: bool = True,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
round_trip: bool = False,
warnings: bool = True,
) -> "DictStrAny":
"""
Overriding of the respective method from Pydantic
Hides fields, marked as "hidden
"""
if exclude_hidden:
if isinstance(exclude, AbstractSet):
exclude = {*self._hidden_fields, *exclude}
elif isinstance(exclude, Mapping):
exclude = dict(
{k: True for k in self._hidden_fields}, **exclude
) # type: ignore
elif exclude is None:
exclude = self._hidden_fields

kwargs = {
"include": include,
"exclude": exclude,
"by_alias": by_alias,
"exclude_unset": exclude_unset,
"exclude_defaults": exclude_defaults,
"exclude_none": exclude_none,
"round_trip": round_trip,
"warnings": warnings,
}

return super().model_dump(**kwargs)

else:

def dict(
self,
*,
include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None,
exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None,
by_alias: bool = False,
skip_defaults: bool = False,
exclude_hidden: bool = True,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
) -> "DictStrAny":
"""
Overriding of the respective method from Pydantic
Hides fields, marked as "hidden
"""
if exclude_hidden:
if isinstance(exclude, AbstractSet):
exclude = {*self._hidden_fields, *exclude}
elif isinstance(exclude, Mapping):
exclude = dict(
{k: True for k in self._hidden_fields}, **exclude
) # type: ignore
elif exclude is None:
exclude = self._hidden_fields

kwargs = {
"include": include,
"exclude": exclude,
"by_alias": by_alias,
"exclude_unset": exclude_unset,
"exclude_defaults": exclude_defaults,
"exclude_none": exclude_none,
}

# TODO: Remove this check when skip_defaults are no longer supported
if skip_defaults:
kwargs["skip_defaults"] = skip_defaults

return super().dict(**kwargs)

@wrap_with_actions(event_type=EventTypes.VALIDATE_ON_SAVE)
async def validate_self(self, *args, **kwargs):
# TODO: it can be sync, but needs some actions controller improvements
Expand Down
2 changes: 0 additions & 2 deletions beanie/odm/utils/init.py
Expand Up @@ -373,8 +373,6 @@ def check_nested_links(
cls._link_fields[k] = link_info
check_nested_links(link_info, prev_models=[])

cls._hidden_fields = cls.get_hidden_fields()

@staticmethod
def init_actions(cls):
"""
Expand Down
2 changes: 2 additions & 0 deletions tests/odm/documents/test_update.py
Expand Up @@ -63,6 +63,8 @@ async def test_replace(document):
new_doc = document.model_copy(update=update_data)
else:
new_doc = document.copy(update=update_data)
# pydantic v1 doesn't copy excluded fields
new_doc.test_list = document.test_list
# document.test_str = "REPLACED_VALUE"
await new_doc.replace()
new_document = await DocumentTestModel.get(document.id)
Expand Down
13 changes: 2 additions & 11 deletions tests/odm/models.py
Expand Up @@ -145,13 +145,7 @@ class DocumentTestModel(Document):
test_int: int
test_doc: SubDocument
test_str: str

if IS_PYDANTIC_V2:
test_list: List[SubDocument] = Field(
json_schema_extra={"hidden": True}
)
else:
test_list: List[SubDocument] = Field(hidden=True)
test_list: List[SubDocument] = Field(exclude=True)

class Settings:
use_cache = True
Expand Down Expand Up @@ -533,10 +527,7 @@ class House(Document):
roof: Optional[Link[Roof]] = None
yards: Optional[List[Link[Yard]]] = None
height: Indexed(int) = 2
if IS_PYDANTIC_V2:
name: Indexed(str) = Field(json_schema_extra={"hidden": True})
else:
name: Indexed(str) = Field(hidden=True)
name: Indexed(str) = Field(exclude=True)

if IS_PYDANTIC_V2:
model_config = ConfigDict(
Expand Down
2 changes: 1 addition & 1 deletion tests/odm/test_fields.py
Expand Up @@ -106,7 +106,7 @@ async def test_custom_filed_types():
)


async def test_hidden(document):
async def test_excluded(document):
document = await DocumentTestModel.find_one()
if IS_PYDANTIC_V2:
assert "test_list" not in document.model_dump()
Expand Down

0 comments on commit c87b226

Please sign in to comment.