Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: keep Mapping type #70

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
50 changes: 44 additions & 6 deletions pydantic/fields.py
@@ -1,9 +1,10 @@
import warnings
from collections import deque
from collections import defaultdict, deque
from collections.abc import Iterable as CollectionsIterable
from typing import (
TYPE_CHECKING,
Any,
DefaultDict,
Deque,
Dict,
FrozenSet,
Expand Down Expand Up @@ -211,6 +212,8 @@ def Schema(default: Any, **kwargs: Any) -> Any:
SHAPE_ITERABLE = 9
SHAPE_GENERIC = 10
SHAPE_DEQUE = 11
SHAPE_DICT = 12
SHAPE_DEFAULTDICT = 13
SHAPE_NAME_LOOKUP = {
SHAPE_LIST: 'List[{}]',
SHAPE_SET: 'Set[{}]',
Expand All @@ -219,8 +222,12 @@ def Schema(default: Any, **kwargs: Any) -> Any:
SHAPE_FROZENSET: 'FrozenSet[{}]',
SHAPE_ITERABLE: 'Iterable[{}]',
SHAPE_DEQUE: 'Deque[{}]',
SHAPE_DICT: 'Dict[{}]',
SHAPE_DEFAULTDICT: 'DefaultDict[{}]',
}

MAPPING_LIKE_SHAPES: Set[int] = {SHAPE_DEFAULTDICT, SHAPE_DICT, SHAPE_MAPPING}


class ModelField(Representation):
__slots__ = (
Expand Down Expand Up @@ -492,6 +499,14 @@ def _type_analysis(self) -> None: # noqa: C901 (ignore complexity)
elif issubclass(origin, Sequence):
self.type_ = get_args(self.type_)[0]
self.shape = SHAPE_SEQUENCE
elif issubclass(origin, DefaultDict):
self.key_field = self._create_sub_type(get_args(self.type_)[0], 'key_' + self.name, for_keys=True)
self.type_ = get_args(self.type_)[1]
self.shape = SHAPE_DEFAULTDICT
elif issubclass(origin, Dict):
self.key_field = self._create_sub_type(get_args(self.type_)[0], 'key_' + self.name, for_keys=True)
self.type_ = get_args(self.type_)[1]
self.shape = SHAPE_DICT
elif issubclass(origin, Mapping):
self.key_field = self._create_sub_type(get_args(self.type_)[0], 'key_' + self.name, for_keys=True)
self.type_ = get_args(self.type_)[1]
Expand Down Expand Up @@ -608,8 +623,8 @@ def validate(

if self.shape == SHAPE_SINGLETON:
v, errors = self._validate_singleton(v, values, loc, cls)
elif self.shape == SHAPE_MAPPING:
v, errors = self._validate_mapping(v, values, loc, cls)
elif self.shape in MAPPING_LIKE_SHAPES:
v, errors = self._validate_mapping_like(v, values, loc, cls)
elif self.shape == SHAPE_TUPLE:
v, errors = self._validate_tuple(v, values, loc, cls)
elif self.shape == SHAPE_ITERABLE:
Expand Down Expand Up @@ -726,7 +741,7 @@ def _validate_tuple(
else:
return tuple(result), None

def _validate_mapping(
def _validate_mapping_like(
self, v: Any, values: Dict[str, Any], loc: 'LocStr', cls: Optional['ModelOrDc']
) -> 'ValidateReturn':
try:
Expand All @@ -752,8 +767,31 @@ def _validate_mapping(
result[key_result] = value_result
if errors:
return v, errors
else:
elif self.shape == SHAPE_DICT:
return result, None
elif self.shape == SHAPE_DEFAULTDICT:
return defaultdict(self.type_, result), None
else:
return self._get_mapping_value(v, result), None

def _get_mapping_value(self, original: T, converted: Dict[Any, Any]) -> Union[T, Dict[Any, Any]]:
"""
When type is `Mapping[KT, KV]` (or another unsupported mapping), we try to avoid
coercing to `dict` unwillingly.
"""
original_cls = original.__class__

if original_cls in {dict, Dict}:
return converted
elif original_cls in {defaultdict, DefaultDict}:
return defaultdict(self.type_, converted)
else:
try:
# Counter, OrderedDict, UserDict, ...
return original_cls(converted) # type: ignore
except TypeError:
warnings.warn(f'Could not convert dictionary to {original_cls.__name__!r}', UserWarning)
return converted

def _validate_singleton(
self, v: Any, values: Dict[str, Any], loc: 'LocStr', cls: Optional['ModelOrDc']
Expand Down Expand Up @@ -796,7 +834,7 @@ def _type_display(self) -> PyObjectStr:
t = display_as_type(self.type_)

# have to do this since display_as_type(self.outer_type_) is different (and wrong) on python 3.6
if self.shape == SHAPE_MAPPING:
if self.shape in MAPPING_LIKE_SHAPES:
t = f'Mapping[{display_as_type(self.key_field.type_)}, {t}]' # type: ignore
elif self.shape == SHAPE_TUPLE:
t = 'Tuple[{}]'.format(', '.join(display_as_type(f.type_) for f in self.sub_fields)) # type: ignore
Expand Down
5 changes: 3 additions & 2 deletions pydantic/main.py
Expand Up @@ -29,7 +29,7 @@
from .class_validators import ValidatorGroup, extract_root_validators, extract_validators, inherit_validators
from .error_wrappers import ErrorWrapper, ValidationError
from .errors import ConfigError, DictError, ExtraError, MissingError
from .fields import SHAPE_MAPPING, ModelField, ModelPrivateAttr, PrivateAttr, Undefined
from .fields import MAPPING_LIKE_SHAPES, ModelField, ModelPrivateAttr, PrivateAttr, Undefined
from .json import custom_pydantic_encoder, pydantic_encoder
from .parse import Protocol, load_file, load_str_bytes
from .schema import default_ref_template, model_schema
Expand Down Expand Up @@ -524,7 +524,8 @@ def json(
@classmethod
def parse_obj(cls: Type['Model'], obj: Any) -> 'Model':
if cls.__custom_root_type__ and (
not (isinstance(obj, dict) and obj.keys() == {ROOT_KEY}) or cls.__fields__[ROOT_KEY].shape == SHAPE_MAPPING
not (isinstance(obj, dict) and obj.keys() == {ROOT_KEY})
or cls.__fields__[ROOT_KEY].shape in MAPPING_LIKE_SHAPES
):
obj = {ROOT_KEY: obj}
elif not isinstance(obj, dict):
Expand Down
4 changes: 2 additions & 2 deletions pydantic/schema.py
Expand Up @@ -26,10 +26,10 @@
from uuid import UUID

from .fields import (
MAPPING_LIKE_SHAPES,
SHAPE_FROZENSET,
SHAPE_ITERABLE,
SHAPE_LIST,
SHAPE_MAPPING,
SHAPE_SEQUENCE,
SHAPE_SET,
SHAPE_SINGLETON,
Expand Down Expand Up @@ -446,7 +446,7 @@ def field_type_schema(
if field.shape in {SHAPE_SET, SHAPE_FROZENSET}:
f_schema['uniqueItems'] = True

elif field.shape == SHAPE_MAPPING:
elif field.shape in MAPPING_LIKE_SHAPES:
f_schema = {'type': 'object'}
key_field = cast(ModelField, field.key_field)
regex = getattr(key_field.type_, 'regex', None)
Expand Down
62 changes: 61 additions & 1 deletion tests/test_main.py
@@ -1,6 +1,7 @@
import sys
from collections import defaultdict
from enum import Enum
from typing import Any, Callable, ClassVar, Dict, List, Mapping, Optional, Type, get_type_hints
from typing import Any, Callable, ClassVar, DefaultDict, Dict, List, Mapping, Optional, Type, get_type_hints
from uuid import UUID, uuid4

import pytest
Expand Down Expand Up @@ -1473,3 +1474,62 @@ class Item(BaseModel):

assert id(image_1) == id(item.images[0])
assert id(image_2) == id(item.images[1])


def test_mapping_retains_type_subclass():
class Map(dict):
pass

class Model(BaseModel):
field: Mapping[str, Mapping[str, int]]

m = Model(field=Map(outer=Map(inner=42)))
assert isinstance(m.field, Map)
assert isinstance(m.field['outer'], Map)
assert m.field['outer']['inner'] == 42


def test_mapping_retains_type_defaultdict():
class Model(BaseModel):
field: Mapping[str, int]

d = defaultdict(int)
d[1] = '2'
d['3']

m = Model(field=d)
assert isinstance(m.field, defaultdict)
assert m.field['1'] == 2
assert m.field['3'] == 0


def test_mapping_retains_type_dict_fallback():
class Map(dict):
def __init__(self, *args, **kwargs):
if args or kwargs:
raise TypeError('test')
super().__init__(*args, **kwargs)

class Model(BaseModel):
field: Mapping[str, int]

d = Map()
d['one'] = 1
d['two'] = 2

with pytest.warns(UserWarning, match="Could not convert dictionary to 'Map'"):
m = Model(field=d)
assert isinstance(m.field, dict)
assert m.field['one'] == 1
assert m.field['two'] == 2


def test_default_dict():
class Model(BaseModel):
x: DefaultDict[int, str]

d = defaultdict(str)
d['1']
m = Model(x=d)
m.x['a']
assert repr(m) == "Model(x=defaultdict(<class 'str'>, {1: '', 'a': ''}))"