diff --git a/src/ansys/openapi/common/_api_client.py b/src/ansys/openapi/common/_api_client.py index d62c11d6..05627f07 100644 --- a/src/ansys/openapi/common/_api_client.py +++ b/src/ansys/openapi/common/_api_client.py @@ -373,7 +373,7 @@ def __deserialize(self, data: SerializedType, klass_name: str) -> DeserializedTy assert isinstance(data, str) return self.__deserialize_datetime(data) else: - assert isinstance(data, dict) + assert isinstance(data, (dict, str)) return self.__deserialize_model(data, klass) def call_api( @@ -794,20 +794,28 @@ def __deserialize_datetime(value: str) -> datetime.datetime: ) def __deserialize_model( - self, data: Dict, klass: Type[ModelBase] - ) -> Union[ModelBase, Dict]: - """Deserialize ``dict`` to model. + self, data: Union[Dict, str], klass: Type[ModelBase] + ) -> Union[ModelBase, Dict, str]: + """Deserialize model representation to model. Given a model type and the serialized data, deserialize into an instance of the model class. Parameters ---------- - data : Dict + data : Union[Dict, str] Serialized representation of the model object. klass : ModelType Type of the model to deserialize. """ + if not klass.swagger_types: + try: + klass.get_real_child_model(klass(), {}) + except NotImplementedError: + return data + except BaseException: + pass + kwargs = {} if klass.swagger_types is not None: for attr, attr_type in klass.swagger_types.items(): diff --git a/src/ansys/openapi/common/_base/_types.py b/src/ansys/openapi/common/_base/_types.py index 29bbeaef..30b9928a 100644 --- a/src/ansys/openapi/common/_base/_types.py +++ b/src/ansys/openapi/common/_base/_types.py @@ -37,7 +37,8 @@ def to_dict(self) -> Dict[str, DeserializedType]: def to_str(self) -> str: ... - def get_real_child_model(self, data: Dict) -> str: + def get_real_child_model(self, data: Union[Dict, str]) -> str: + """Classes with discriminators will override this method and may change the method signature.""" raise NotImplementedError() diff --git a/tests/models/__init__.py b/tests/models/__init__.py index 2453f790..64f11212 100644 --- a/tests/models/__init__.py +++ b/tests/models/__init__.py @@ -1,2 +1,3 @@ from .example_model import ExampleModel from .example_base_model import ExampleBaseModel +from .example_model_with_enum import ExampleModelWithEnum diff --git a/tests/models/example_model_with_enum.py b/tests/models/example_model_with_enum.py new file mode 100644 index 00000000..5691ad8e --- /dev/null +++ b/tests/models/example_model_with_enum.py @@ -0,0 +1,81 @@ +import pprint +from ansys.openapi.common import ModelBase + + +class ExampleModelWithEnum(ModelBase): + """NOTE: This class is auto generated by the swagger code generator program. + + Do not edit the class manually. + """ + + """ + allowed enum values + """ + EXCELLENT = "Excellent" + GOOD = "Good" + ACCEPTABLE = "Acceptable" + POOR = "Poor" + """ + Attributes: + swagger_types (dict): The key is attribute name + and the value is attribute type. + attribute_map (dict): The key is attribute name + and the value is json key in definition. + """ + swagger_types = {} + + attribute_map = {} + + subtype_mapping = {} + + def __init__(self): # noqa: E501 + """ExampleModelWithEnum - a model defined in Swagger""" # noqa: E501 + self.discriminator = None + + def to_dict(self): + """Returns the model properties as a dict""" + result = {} + + for attr in self.swagger_types.keys(): + value = getattr(self, attr) + if isinstance(value, list): + result[attr] = list( + map(lambda x: x.to_dict() if hasattr(x, "to_dict") else x, value) + ) + elif hasattr(value, "to_dict"): + result[attr] = value.to_dict() + elif isinstance(value, dict): + result[attr] = dict( + map( + lambda item: (item[0], item[1].to_dict()) + if hasattr(item[1], "to_dict") + else item, + value.items(), + ) + ) + else: + result[attr] = value + if issubclass(ExampleModelWithEnum, dict): + for key, value in self.items(): + result[key] = value + + return result + + def to_str(self): + """Returns the string representation of the model""" + return pprint.pformat(self.to_dict()) + + def __repr__(self): + """For `print` and `pprint`""" + return self.to_str() + + def __eq__(self, other): + """Returns true if both objects are equal""" + if not isinstance(other, ExampleModelWithEnum): + return False + + return self.__dict__ == other.__dict__ + + def __ne__(self, other): + """Returns true if both objects are not equal""" + return not self == other diff --git a/tests/test_api_client.py b/tests/test_api_client.py index 1c0d3b9f..a1a582c5 100644 --- a/tests/test_api_client.py +++ b/tests/test_api_client.py @@ -184,6 +184,15 @@ def test_serialize_model(self): serialized_model = self._client.sanitize_for_serialization(model_instance) assert serialized_model == model_dict + def test_serialize_enum_model(self): + from . import models + + self._client.setup_client(models) + model_instance = models.ExampleModelWithEnum().GOOD + model_value = "Good" + serialized_model = self._client.sanitize_for_serialization(model_instance) + assert serialized_model == model_value + class TestDeserialization: _test_value_list = ["foo", int(2), 2.0, True] @@ -309,6 +318,16 @@ def test_deserialize_model_with_discriminator(self): assert isinstance(deserialized_model, models.ExampleModel) assert deserialized_model == model_instance + def test_deserialize_enum_model(self): + from . import models + + self._client.setup_client(models) + model_instance = models.ExampleModelWithEnum().GOOD + model_value = "Good" + type_ref = "ExampleModelWithEnum" + serialized_model = self._client._ApiClient__deserialize(model_value, type_ref) + assert serialized_model == model_instance + @pytest.mark.parametrize( ("data", "target_type"), (