Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 33 additions & 39 deletions msrest/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,31 +113,38 @@ def _get_subtype_map(cls):
return base._subtype_map
return {}

@classmethod
def _flatten_subtype(cls, key, objects):
if not '_subtype_map' in cls.__dict__:
return {}
result = dict(cls._subtype_map[key])
for valuetype in cls._subtype_map[key].values():
result.update(objects[valuetype]._flatten_subtype(key, objects))
return result

@classmethod
def _classify(cls, response, objects):
"""Check the class _subtype_map for any child classes.
We want to ignore any inheirited _subtype_maps.
We want to ignore any inherited _subtype_maps.
Remove the polymorphic key from the initial data.
"""
try:
map = cls.__dict__.get('_subtype_map', {})
for subtype_key in cls.__dict__.get('_subtype_map', {}).keys():
subtype_value = None

for _type, _classes in map.items():
classification = response.get(_type)
try:
return objects[_classes[classification]]
except KeyError:
pass
rest_api_response_key = _decode_attribute_map_key(cls._attribute_map[subtype_key]['key'])
subtype_value = response.pop(rest_api_response_key, None) or response.pop(subtype_key, None)
if subtype_value:
flatten_mapping_type = cls._flatten_subtype(subtype_key, objects)
return objects[flatten_mapping_type[subtype_value]]
return cls

for c in _classes:
try:
_cls = objects[_classes[c]]
return _cls._classify(response, objects)
except (KeyError, TypeError):
continue
raise TypeError("Object cannot be classified futher.")
except AttributeError:
raise TypeError("Object cannot be classified futher.")
def _decode_attribute_map_key(key):
"""This decode a key in an _attribute_map to the actual key we want to look at
inside the received data.

:param str key: A key string from the generated code
"""
return key.replace('\\.', '.')

def _convert_to_datatype(data, data_type, localtypes):
if data is None:
Expand All @@ -157,6 +164,7 @@ def _convert_to_datatype(data, data_type, localtypes):
elif issubclass(data_obj, Enum):
return data
elif not isinstance(data, data_obj):
data_obj = data_obj._classify(data, localtypes)
result = {
key: _convert_to_datatype(
data[key],
Expand Down Expand Up @@ -195,7 +203,7 @@ class Serializer(object):
"unique": lambda x, y: len(x) != len(set(x)),
"multiple": lambda x, y: x % y != 0
}
flattten = re.compile(r"(?<!\\)\.")
flatten = re.compile(r"(?<!\\)\.")

def __init__(self, classes=None):
self.serialize_type = {
Expand Down Expand Up @@ -241,14 +249,12 @@ def _serialize(self, target_obj, data_type=None, **kwargs):

try:
attributes = target_obj._attribute_map
self._classify_data(target_obj, class_name, serialized)

for attr, map in attributes.items():
attr_name = attr
debug_name = "{}.{}".format(class_name, attr_name)
try:
keys = self.flattten.split(map['key'])
keys = [k.replace('\\.', '.') for k in keys]
keys = self.flatten.split(map['key'])
keys = [_decode_attribute_map_key(k) for k in keys]
attr_type = map['type']
orig_attr = getattr(target_obj, attr)
validation = target_obj._validation.get(attr_name, {})
Expand Down Expand Up @@ -278,18 +284,6 @@ def _serialize(self, target_obj, data_type=None, **kwargs):
else:
return serialized

def _classify_data(self, target_obj, class_name, serialized):
"""Check whether this object is a child and therefor needs to be
classified in the message.
"""
try:
for _type, _classes in target_obj._get_subtype_map().items():
for ref, name in _classes.items():
if name == class_name:
serialized[_type] = ref
except AttributeError:
pass # TargetObj has no _subtype_map so we don't need to classify.

def body(self, data, data_type, **kwargs):
"""Serialize data intended for a request body.

Expand Down Expand Up @@ -752,9 +746,9 @@ def __call__(self, target_obj, response_data):
while '.' in key:
dict_keys = self.flatten.split(key)
if len(dict_keys) == 1:
key = dict_keys[0].replace('\\.', '.')
key = _decode_attribute_map_key(dict_keys[0])
break
working_key = dict_keys[0].replace('\\.', '.')
working_key = _decode_attribute_map_key(dict_keys[0])
working_data = working_data.get(working_key, data)
key = '.'.join(dict_keys[1:])

Expand Down Expand Up @@ -786,8 +780,8 @@ def _classify_target(self, target, data):

try:
target = target._classify(data, self.dependencies)
except (TypeError, AttributeError):
pass # Target has no subclasses, so can't classify further.
except AttributeError:
pass # Target is not a Model, no classify
return target, target.__class__.__name__

def _unpack_content(self, raw_data):
Expand Down
Loading