Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix EmbeddedModel generics definition with a custom key_name #269

Merged
merged 7 commits into from
Sep 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .codecov.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
codecov:
require_ci_to_pass: yes
notify:
after_n_builds: 12
after_n_builds: 14
31 changes: 16 additions & 15 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@ jobs:

compatibility-tests:
runs-on: ubuntu-latest
continue-on-error: true
strategy:
matrix:
python-version:
- 3.7
- 3.8
- 3.9
- "3.7"
- "3.8"
- "3.9"
- "3.10"
#- "3.11" # TODO: Enable once 3.11 is released

steps:
- uses: actions/checkout@v2
Expand Down Expand Up @@ -59,17 +61,19 @@ jobs:

tests:
runs-on: ubuntu-latest
continue-on-error: true
strategy:
matrix:
python-version:
- 3.7
- 3.8
- 3.9
- "3.7"
- "3.8"
- "3.9"
- "3.10"
#- "3.11" # TODO: Enable once 3.11 is released
mongo-version:
- 4.4
- 5.0
- 6.0
- "4.4"
- "5"
- "6"
mongo-mode:
- standalone
include:
Expand Down Expand Up @@ -117,6 +121,7 @@ jobs:

integrated-realworld-test:
runs-on: ubuntu-latest
timeout-minutes: 5
steps:
- uses: actions/checkout@v3
with:
Expand All @@ -135,18 +140,14 @@ jobs:
with:
python-version: "3.10"
cache: "poetry"
- name: Install dependencies
- name: Install dependencies (w/o ODMantic)
working-directory: fastapi-odmantic-realworld-example
run: |
echo "$(grep -v 'odmantic =' ./pyproject.toml)" > pyproject.toml
poetry install
- name: Build current ODMantic version
working-directory: odmantic-current
run: |
flit build
- name: Install current ODMantic version
working-directory: fastapi-odmantic-realworld-example
run: poetry add ../odmantic-current/dist/*.tar.gz
run: poetry run pip install ../odmantic-current/
- name: Start the MongoDB instance
uses: art049/mongodb-cluster-action@v0
id: mongodb-cluster-action
Expand Down
8 changes: 7 additions & 1 deletion Taskfile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ tasks:
deps:
- task: "mongodb:check"
cmds:
- pytest -rs
- pytest -rs -n auto

default:
desc: |
Expand Down Expand Up @@ -95,3 +95,9 @@ tasks:
- flit install --deps=all
sources:
- pyproject.toml

clean:
cmds:
- rm -rf dist/
- rm -rf htmlcov/ ./.coverage ./coverage.xml
- rm -rf .task/ ./__release_notes__.md ./__version__.txt
6 changes: 6 additions & 0 deletions odmantic/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ def __init__(self, key_name: str):
self.foreign_key_name = f"'{key_name}'"


class IncorrectGenericEmbeddedModelValue(ValueError):
def __init__(self, value: Any):
super().__init__("incorrect generic embedded model value")
self.value = value


class DocumentParsingError(ValidationError):
"""Unable to parse the document into an instance.

Expand Down
25 changes: 25 additions & 0 deletions odmantic/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,31 @@ def __init__(
self.model = model


class ODMEmbeddedGeneric(ODMField):
# Only dict,set and list are "officially" supported for now
__slots__ = ("model", "generic_origin")
__allowed_operators__ = set(("eq", "ne"))

def __init__(
self,
key_name: str,
model_config: Type[BaseODMConfig],
model: Type["EmbeddedModel"],
generic_origin: Any,
index: bool = False,
unique: bool = False,
):
super().__init__(
primary_field=False,
key_name=key_name,
model_config=model_config,
index=index,
unique=unique,
)
self.model = model
self.generic_origin = generic_origin


class KeyNameProxy(str):
"""Used to provide the `++` operator enabling reference key name creation"""

Expand Down
105 changes: 104 additions & 1 deletion odmantic/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from pydantic.tools import parse_obj_as
from pydantic.typing import is_classvar, resolve_annotations
from pydantic.utils import lenient_issubclass
from typing_extensions import dataclass_transform

from odmantic.bson import (
_BSON_SUBSTITUTED_FIELDS,
Expand All @@ -49,6 +48,7 @@
from odmantic.exceptions import (
DocumentParsingError,
ErrorList,
IncorrectGenericEmbeddedModelValue,
KeyNotFoundInDocumentError,
ReferencedDocumentNotFoundError,
)
Expand All @@ -58,12 +58,19 @@
ODMBaseField,
ODMBaseIndexableField,
ODMEmbedded,
ODMEmbeddedGeneric,
ODMField,
ODMFieldInfo,
ODMReference,
)
from odmantic.index import Index, ODMBaseIndex, ODMSingleFieldIndex
from odmantic.reference import ODMReferenceInfo
from odmantic.typing import (
dataclass_transform,
get_first_type_argument_subclassing,
get_origin,
is_type_argument_subclass,
)
from odmantic.utils import (
is_dunder,
raise_on_invalid_collection_name,
Expand Down Expand Up @@ -256,6 +263,42 @@ def __validate_cls_namespace__(name: str, namespace: Dict) -> None: # noqa C901
index=index,
unique=unique,
)

elif is_type_argument_subclass(field_type, EmbeddedModel):
if isinstance(value, ODMFieldInfo):
if value.primary_field:
raise TypeError(
"Declaring a generic type of embedded models as a primary "
f"field is not allowed: {field_name} in {name}"
)
namespace[field_name] = value.pydantic_field_info
key_name = (
value.key_name if value.key_name is not None else field_name
)
index = value.index
unique = value.unique
else:
key_name = field_name
index = False
unique = False
model = get_first_type_argument_subclassing(field_type, EmbeddedModel)
assert model is not None
if len(model.__references__) > 0:
raise TypeError(
"Declaring a generic type of embedded models containing "
f"references is not allowed: {field_name} in {name}"
)
generic_origin = get_origin(field_type)
assert generic_origin is not None
odm_fields[field_name] = ODMEmbeddedGeneric(
model=model,
generic_origin=generic_origin,
key_name=key_name,
model_config=config,
index=index,
unique=unique,
)

elif lenient_issubclass(field_type, Model):
if not isinstance(value, ODMReferenceInfo):
raise TypeError(
Expand Down Expand Up @@ -665,6 +708,16 @@ def __doc(
doc[field.key_name] = raw_doc[field_name][field.model.__primary_field__]
elif isinstance(field, ODMEmbedded):
doc[field.key_name] = self.__doc(raw_doc[field_name], field.model, None)
elif isinstance(field, ODMEmbeddedGeneric):
if field.generic_origin is dict:
doc[field.key_name] = {
item_key: self.__doc(item_value, field.model)
for item_key, item_value in raw_doc[field_name].items()
}
else:
doc[field.key_name] = [
self.__doc(item, field.model) for item in raw_doc[field_name]
]
elif field_name in model.__bson_serialized_fields__:
doc[field.key_name] = model.__fields__[field_name].type_.__bson__(
raw_doc[field_name]
Expand Down Expand Up @@ -766,6 +819,56 @@ def _parse_doc_to_obj( # noqa C901 # TODO: refactor document parsing
)
)
obj[field_name] = value
elif isinstance(field, ODMEmbeddedGeneric):
value = Undefined
raw_value = raw_doc.get(field.key_name, Undefined)
if raw_value is not Undefined:
if isinstance(raw_value, list) and (
field.generic_origin is list
or field.generic_origin is tuple
or field.generic_origin is set
):
value = []
for i, item in enumerate(raw_value):
sub_errors, item = field.model._parse_doc_to_obj(
item, base_loc=base_loc + (field_name, f"[{i}]")
)
if len(sub_errors) > 0:
errors.extend(sub_errors)
else:
value.append(item)
obj[field_name] = value
elif isinstance(raw_value, dict) and field.generic_origin is dict:
value = {}
for item_key, item_value in raw_value.items():
sub_errors, item_value = field.model._parse_doc_to_obj(
item_value,
base_loc=base_loc + (field_name, f'["{item_key}"]'),
)
if len(sub_errors) > 0:
errors.extend(sub_errors)
else:
value[item_key] = item_value
obj[field_name] = value
else:
errors.append(
ErrorWrapper(
exc=IncorrectGenericEmbeddedModelValue(raw_value),
loc=base_loc + (field_name,),
)
)
else:
if not field.is_required_in_doc():
value = field.get_default_importing_value()
if value is Undefined:
errors.append(
ErrorWrapper(
exc=KeyNotFoundInDocumentError(field.key_name),
loc=base_loc + (field_name,),
)
)
else:
obj[field_name] = value
else:
field = cast(ODMField, field)
value = raw_doc.get(field.key_name, Undefined)
Expand Down
34 changes: 32 additions & 2 deletions odmantic/typing.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,40 @@
import sys
from typing import Any
from typing import Callable as TypingCallable
from typing import Tuple, Type, TypeVar, Union

from pydantic.utils import lenient_issubclass

NoArgAnyCallable = TypingCallable[[], Any]

# Handles globally the typing imports from typing or the typing_extensions backport
if sys.version_info < (3, 8):
from typing_extensions import Literal
from typing_extensions import Literal, get_args, get_origin
else:
from typing import Literal, get_args, get_origin # noqa: F401

if sys.version_info < (3, 11):
from typing_extensions import dataclass_transform
else:
from typing import Literal # noqa: F401
# FIXME: add this back to coverage once 3.11 is released
from typing import dataclass_transform # noqa: F401 # pragma: no cover


def is_type_argument_subclass(
type_: Type, class_or_tuple: Union[Type[Any], Tuple[Type[Any], ...], None]
) -> bool:
args = get_args(type_)
return any(lenient_issubclass(arg, class_or_tuple) for arg in args)


T = TypeVar("T")


def get_first_type_argument_subclassing(
type_: Type, cls: Type[T]
) -> Union[Type[T], None]:
args: Tuple[Type, ...] = get_args(type_)
for arg in args:
if lenient_issubclass(arg, cls):
return arg
return None
8 changes: 7 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ test = [
"pytest-xdist ~= 2.1.0",
"pytest-asyncio ~= 0.16.0",
# "pytest-testmon ~= 1.3.1",
"pytest-sugar ~= 0.9.5",
"async-asgi-testclient ~= 1.4.4",
"asyncmock ~= 0.4.2",
"coverage[toml] ~= 6.2",
Expand Down Expand Up @@ -101,4 +102,9 @@ branch = true
[tool.coverage.report]
include = ["odmantic/*", "tests/*"]
omit = ["**/conftest.py"]
exclude_lines = ["pragma: no cover", "if TYPE_CHECKING:", "@pytest.mark.skip", "@abstractmethod"]
exclude_lines = [
"pragma: no cover",
"if TYPE_CHECKING:",
"@pytest.mark.skip",
"@abstractmethod",
]
Loading