diff --git a/.travis.yml b/.travis.yml index 712979c..c000fa8 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,8 +4,10 @@ python: - "3.4" - "3.5" # command to install dependencies -install: +install: + - make requirements - pip install -r requirements-test.txt + - if [[ $TRAVIS_PYTHON_VERSION == 3.3.* ]]; then pip install -r requirements-py33.txt --use-mirrors; fi - pip install coveralls # command to run tests script: diff --git a/Makefile b/Makefile index 1f3997b..ea0dd49 100644 --- a/Makefile +++ b/Makefile @@ -31,7 +31,7 @@ requirements-docs: run-tests: @echo "Running tests..." - nosetests --with-coverage -d --cover-package=dirty_models --cover-erase + nosetests --with-coverage -d --cover-package=dirty_models --cover-erase -x publish: @echo "Publishing new version on Pypi..." diff --git a/README.rst b/README.rst index 7512526..d386dad 100644 --- a/README.rst +++ b/README.rst @@ -85,6 +85,16 @@ Features Changelog --------- +Version 0.9.0 +------------- + +- New EnumField. +- Fixes on setup.py. +- Fixes on requirements. +- Fixes on formatter iters. +- Fixes on code. + + Version 0.8.1 ------------- @@ -101,10 +111,10 @@ Version 0.8.0 - Raise a RunTimeError exception if two fields use same alias in a model. - Fixed default docstrings. - Cleanup default data. Only real name fields are allowed to use as key. -- Added :meth:`~dirty_models.models.get_attrs_by_path` in order to get all values using path. -- Added :meth:`~dirty_models.models.get_1st_attr_by_path` in order to get first value using path. +- Added :meth:`~dirty_models.models.BaseModel.get_attrs_by_path` in order to get all values using path. +- Added :meth:`~dirty_models.models.BaseModel.get_1st_attr_by_path` in order to get first value using path. - Added option to access fields like in a dictionary, but using wildcards. Only for getters. - See: :meth:`~dirty_models.models.get_1st_attr_by_path`. + See: :meth:`~dirty_models.models.BaseModel.get_1st_attr_by_path`. - Added some documentation. Version 0.7.2 diff --git a/dirty_models/fields.py b/dirty_models/fields.py index c2c2260..5766efe 100644 --- a/dirty_models/fields.py +++ b/dirty_models/fields.py @@ -3,9 +3,11 @@ """ from datetime import datetime, date, time, timedelta +from enum import Enum from collections import Mapping from dateutil.parser import parse as dateutil_parse +from functools import wraps from .model_types import ListModel @@ -79,31 +81,70 @@ def delete_value(self, obj): """Removes field value from model""" obj.delete_field_value(self.name) + def _check_name(self): + if self._name is None: + raise AttributeError("Field name must be set") + def __get__(self, obj, cls=None): if obj is None: return self + + self._check_name() + if self._getter: return self._getter(self, obj, cls) - if self._name is None: - raise AttributeError("Field name must be set") + return self.get_value(obj) def __set__(self, obj, value): + self._check_name() + if self._setter: self._setter(self, obj, value) return - if self._name is None: - raise AttributeError("Field name must be set") if self.check_value(value) or self.can_use_value(value): self.set_value(obj, self.use_value(value)) def __delete__(self, obj): - if self._name is None: - raise AttributeError("Field name must be set") + self._check_name() self.delete_value(obj) +def can_use_enum(func): + """ + Decorator to use Enum value on type checks. + """ + + @wraps(func) + def inner(self, value): + if isinstance(value, Enum): + return self.check_value(value.value) or func(self, value.value) + + return func(self, value) + + return inner + + +def convert_enum(func): + """ + Decorator to use Enum value on type casts. + """ + + @wraps(func) + def inner(self, value): + try: + if self.check_value(value.value): + return value.value + return func(self, value.value) + except AttributeError: + pass + + return func(self, value) + + return inner + + class IntegerField(BaseField): """ It allows to use an integer as value in a field. @@ -114,17 +155,21 @@ class IntegerField(BaseField): * :class:`str` if all characters are digits + * :class:`~enum.Enum` if value of enum can be cast. + """ + @convert_enum def convert_value(self, value): return int(value) def check_value(self, value): return isinstance(value, int) + @can_use_enum def can_use_value(self, value): return isinstance(value, float) \ - or (isinstance(value, str) and value.isdigit()) + or (isinstance(value, str) and value.isdigit()) class FloatField(BaseField): @@ -136,18 +181,22 @@ class FloatField(BaseField): * :class:`int` * :class:`str` if all characters are digits and there is only one dot (``.``). + + * :class:`~enum.Enum` if value of enum can be cast. """ + @convert_enum def convert_value(self, value): return float(value) def check_value(self, value): return isinstance(value, float) + @can_use_enum def can_use_value(self, value): - return isinstance(value, int) or \ - (isinstance(value, str) and - value.replace('.', '', 1).isnumeric()) + return isinstance(value, int) \ + or (isinstance(value, str) and + value.replace('.', '', 1).isnumeric()) class BooleanField(BaseField): @@ -159,8 +208,11 @@ class BooleanField(BaseField): * :class:`int` ``0`` become ``False``, anything else ``True`` * :class:`str` ``true`` and ``yes`` become ``True``, anything else ``False``. It is case-insensitive. + + * :class:`~enum.Enum` if value of enum can be cast. """ + @convert_enum def convert_value(self, value): if isinstance(value, str): if value.lower().strip() in ['true', 'yes']: @@ -173,6 +225,7 @@ def convert_value(self, value): def check_value(self, value): return isinstance(value, bool) + @can_use_enum def can_use_value(self, value): return isinstance(value, (int, str)) @@ -187,14 +240,18 @@ class StringField(BaseField): * :class:`int` * :class:`float` + + * :class:`~enum.Enum` if value of enum can be cast. """ + @convert_enum def convert_value(self, value): return str(value) def check_value(self, value): return isinstance(value, str) + @can_use_enum def can_use_value(self, value): return isinstance(value, (int, float)) @@ -209,6 +266,8 @@ class StringIdField(StringField): * :class:`int` * :class:`float` + + * :class:`~enum.Enum` if value of enum can be cast. """ def set_value(self, obj, value): @@ -328,6 +387,8 @@ class TimeField(DateTimeBaseField): * :class:`int` will be used as timestamp. * :class:`~datetime.datetime` will get time part. + + * :class:`~enum.Enum` if value of enum can be cast. """ def __init__(self, parse_format=None, default_timezone=None, **kwargs): @@ -348,6 +409,7 @@ def __init__(self, parse_format=None, default_timezone=None, **kwargs): super(TimeField, self).__init__(parse_format=parse_format, **kwargs) self.default_timezone = default_timezone + @convert_enum def convert_value(self, value): if isinstance(value, list): return time(*value) @@ -370,6 +432,7 @@ def convert_value(self, value): def check_value(self, value): return isinstance(value, time) + @can_use_enum def can_use_value(self, value): return isinstance(value, (int, str, datetime, list, dict)) @@ -402,8 +465,11 @@ class DateField(DateTimeBaseField): * :class:`int` will be used as timestamp. * :class:`~datetime.datetime` will get date part. + + * :class:`~enum.Enum` if value of enum can be cast. """ + @convert_enum def convert_value(self, value): if isinstance(value, list): return date(*value) @@ -426,6 +492,7 @@ def convert_value(self, value): def check_value(self, value): return type(value) is date + @can_use_enum def can_use_value(self, value): return isinstance(value, (int, str, datetime, list, dict)) @@ -445,6 +512,8 @@ class DateTimeField(DateTimeBaseField): * :class:`int` will be used as timestamp. * :class:`~datetime.date` will set date part. + + * :class:`~enum.Enum` if value of enum can be cast. """ def __init__(self, parse_format=None, default_timezone=None, force_timezone=False, **kwargs): @@ -471,6 +540,7 @@ def __init__(self, parse_format=None, default_timezone=None, force_timezone=Fals self.default_timezone = default_timezone self.force_timezone = force_timezone + @convert_enum def convert_value(self, value): if isinstance(value, list): return datetime(*value) @@ -493,6 +563,7 @@ def convert_value(self, value): def check_value(self, value): return type(value) is datetime + @can_use_enum def can_use_value(self, value): return isinstance(value, (int, str, date, dict, list)) @@ -523,8 +594,10 @@ class TimedeltaField(BaseField): * :class:`int` as seconds. + * :class:`~enum.Enum` if value of enum can be cast. """ + @convert_enum def convert_value(self, value): if isinstance(value, (int, float)): return timedelta(seconds=value) @@ -532,6 +605,7 @@ def convert_value(self, value): def check_value(self, value): return type(value) is timedelta + @can_use_enum def can_use_value(self, value): return isinstance(value, (int, float)) @@ -607,7 +681,6 @@ def __set__(self, obj, value): class InnerFieldTypeMixin: - def __init__(self, field_type=None, **kwargs): self._field_type = None if isinstance(field_type, tuple): diff --git a/dirty_models/models.py b/dirty_models/models.py index f254bcc..0a7cb52 100644 --- a/dirty_models/models.py +++ b/dirty_models/models.py @@ -4,11 +4,12 @@ import itertools from datetime import datetime, date, time, timedelta +from enum import Enum from collections import Mapping from copy import deepcopy -from dirty_models.fields import DateField, TimeField, TimedeltaField +from dirty_models.fields import DateField, TimeField, TimedeltaField, EnumField from .base import BaseData, InnerFieldTypeMixin from .fields import IntegerField, FloatField, BooleanField, StringField, DateTimeField, \ BaseField, ModelField, ArrayField @@ -165,14 +166,15 @@ def __init__(self, data=None, flat=False, *args, **kwargs): BaseModel.__setattr__(self, '__modified_data__', {}) BaseModel.__setattr__(self, '__deleted_fields__', []) - self.unlock() - self.import_data(self.__default_data__) - if isinstance(data, (dict, Mapping)): - self.import_data(data) - self.import_data(kwargs) + from .base import Unlocker + with Unlocker(self): + self.import_data(self.__default_data__) + if isinstance(data, (dict, Mapping)): + self.import_data(data) + self.import_data(kwargs) + if flat: self.flat_data() - self.lock() def __reduce__(self): """ @@ -182,7 +184,7 @@ def __reduce__(self): return recover_model_from_data, (self.__class__, self.export_original_data(), self.export_modified_data(), self.export_deleted_fields(),) - def _get_real_name(self, name): + def get_real_name(self, name): obj = self.get_field_obj(name) try: return obj.name @@ -193,7 +195,7 @@ def set_field_value(self, name, value): """ Set the value to the field modified_data """ - name = self._get_real_name(name) + name = self.get_real_name(name) if name and self._can_write_field(name): if name in self.__deleted_fields__: @@ -214,7 +216,7 @@ def get_field_value(self, name): """ Get the field value from the modified data or the original one """ - name = self._get_real_name(name) + name = self.get_real_name(name) if not name or name in self.__deleted_fields__: return None @@ -227,7 +229,7 @@ def delete_field_value(self, name): """ Mark this field to be deleted """ - name = self._get_real_name(name) + name = self.get_real_name(name) if name and self._can_write_field(name): if name in self.__modified_data__: @@ -240,7 +242,7 @@ def reset_field_value(self, name): """ Resets value of a field """ - name = self._get_real_name(name) + name = self.get_real_name(name) if name and self._can_write_field(name): if name in self.__modified_data__: @@ -258,7 +260,7 @@ def is_modified_field(self, name): """ Returns whether a field is modified or not """ - name = self._get_real_name(name) + name = self.get_real_name(name) if name in self.__modified_data__ or name in self.__deleted_fields__: return True @@ -357,7 +359,7 @@ def get_original_field_value(self, name): """ Returns original field value or None """ - name = self._get_real_name(name) + name = self.get_real_name(name) try: value = self.__original_data__[name] @@ -515,9 +517,7 @@ def __contains__(self, item): @classmethod def get_field_obj(cls, name): obj_field = getattr(cls, name, None) - if not isinstance(obj_field, BaseField): - return None - return obj_field + return obj_field if isinstance(obj_field, BaseField) else None def _get_fields_by_path(self, field): @@ -684,6 +684,8 @@ def _get_field_type(self, key, value): return DateField(name=key) elif isinstance(value, timedelta): return TimedeltaField(name=key) + elif isinstance(value, Enum): + return EnumField(name=key, enum_class=type(value)) elif isinstance(value, (dict, BaseDynamicModel, Mapping)): return ModelField(name=key, model_class=self._dynamic_model or self.__class__) elif isinstance(value, BaseModel): @@ -803,11 +805,12 @@ def __reduce__(self): (self.get_field_type().__class__, self.get_field_type().export_definition())) - def _get_real_name(self, name): - new_name = super(HashMapModel, self)._get_real_name(name) - if not new_name: - return name - return new_name + def get_real_name(self, name): + new_name = super(HashMapModel, self).get_real_name(name) + return new_name if new_name else name + + def get_field_obj(self, name): + return super(HashMapModel, self).get_field_obj(name) or self._field_type def copy(self): """ @@ -895,11 +898,8 @@ def __init__(self, *args, **kwargs): self._dynamic_model = FastDynamicModel super(FastDynamicModel, self).__init__(*args, **kwargs) - def _get_real_name(self, name): - new_name = super(FastDynamicModel, self)._get_real_name(name) - if not new_name: - return name - return new_name + def get_real_name(self, name): + return super(FastDynamicModel, self).get_real_name(name) or name def get_validated_object(self, field_type, value): """ diff --git a/dirty_models/utils.py b/dirty_models/utils.py index 3380fab..789fe62 100644 --- a/dirty_models/utils.py +++ b/dirty_models/utils.py @@ -1,9 +1,12 @@ -import re -from json.encoder import JSONEncoder as BaseJSONEncoder from datetime import date, datetime, time, timedelta -from .fields import MultiTypeField, DateTimeBaseField +from enum import Enum +from json.encoder import JSONEncoder as BaseJSONEncoder + +import re + +from .fields import MultiTypeField from .model_types import ListModel -from .models import BaseModel, HashMapModel +from .models import BaseModel def underscore_to_camel(string): @@ -18,7 +21,6 @@ class BaseFormatterIter: class BaseFieldtypeFormatterIter(BaseFormatterIter): - def __init__(self, obj, field, parent_formatter): self.obj = obj self.field = field @@ -26,20 +28,11 @@ def __init__(self, obj, field, parent_formatter): class ListFormatterIter(BaseFieldtypeFormatterIter): - def __iter__(self): for item in self.obj: yield self.parent_formatter.format_field(self.field, item) -class HashMapFormatterIter(BaseFieldtypeFormatterIter): - - def __iter__(self): - for fieldname in self.obj.get_fields(): - value = self.obj.get_field_value(fieldname) - yield fieldname, self.parent_formatter.format_field(self.field, value) - - class BaseModelFormatterIter(BaseFormatterIter): """ Base formatter iterator for Dirty Models. @@ -52,32 +45,34 @@ def __iter__(self): fields = self.model.get_fields() for fieldname in fields: field = self.model.get_field_obj(fieldname) - yield field.name, self.format_field(field, - self.model.get_field_value(fieldname)) + name = self.model.get_real_name(fieldname) + yield name, self.format_field(field, + self.model.get_field_value(fieldname)) def format_field(self, field, value): if isinstance(field, MultiTypeField): return self.format_field(field.get_field_type_by_value(value), value) - elif isinstance(value, HashMapModel): - return HashMapFormatterIter(obj=value, field=value.get_field_type(), parent_formatter=self) elif isinstance(value, BaseModel): return self.__class__(value) elif isinstance(value, ListModel): return ListFormatterIter(obj=value, field=value.get_field_type(), parent_formatter=self) + elif isinstance(value, Enum): + return self.format_field(field, value.value) return value class ModelFormatterIter(BaseModelFormatterIter): - """ Iterate over model fields formatting them. """ def format_field(self, field, value): - if isinstance(value, (date, datetime, time)) and \ - isinstance(field, DateTimeBaseField): - return field.get_formatted_value(value) + if isinstance(value, (date, datetime, time)) and not isinstance(field, MultiTypeField): + try: + return field.get_formatted_value(value) + except AttributeError: + return str(value) elif isinstance(value, timedelta): return value.total_seconds() @@ -85,15 +80,16 @@ def format_field(self, field, value): class JSONEncoder(BaseJSONEncoder): - """ Json encoder for Dirty Models """ + default_model_iter = ModelFormatterIter + def default(self, obj): if isinstance(obj, BaseModel): - return {k: v for k, v in ModelFormatterIter(obj)} - elif isinstance(obj, (HashMapFormatterIter, ModelFormatterIter)): + return {k: v for k, v in self.default_model_iter(obj)} + elif isinstance(obj, (BaseModelFormatterIter)): return {k: v for k, v in obj} elif isinstance(obj, ListFormatterIter): return list(obj) diff --git a/docs/source/conf.py b/docs/source/conf.py index 1b36daa..d11e9ad 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -65,9 +65,9 @@ # built documents. # # The short X.Y version. -version = '0.7.0' +version = '0.9.0' # The full version, including alpha/beta/rc tags. -release = '0.7.0' +release = '0.9.0' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/requirements-docs.txt b/requirements-docs.txt index c72c2a9..3064c7e 100644 --- a/requirements-docs.txt +++ b/requirements-docs.txt @@ -1,3 +1,2 @@ sphinx -python-dateutil iso8601 \ No newline at end of file diff --git a/requirements-py33.txt b/requirements-py33.txt new file mode 100644 index 0000000..37eb288 --- /dev/null +++ b/requirements-py33.txt @@ -0,0 +1 @@ +enum34 \ No newline at end of file diff --git a/requirements-test.txt b/requirements-test.txt index 6db9492..74416d7 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,6 +1,5 @@ -pep8 < 1.6.0 +pep8 flake8 coverage nose -python-dateutil iso8601 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index e69de29..4ea05ed 100644 --- a/requirements.txt +++ b/requirements.txt @@ -0,0 +1 @@ +python-dateutil \ No newline at end of file diff --git a/setup.py b/setup.py index 8ccb5f5..d92a3c3 100644 --- a/setup.py +++ b/setup.py @@ -1,12 +1,27 @@ -from setuptools import setup +import sys import os +import re +from setuptools import setup + +install_requires = ['python-dateutil'] + +if sys.version_info < (3, 4): + install_requires.append('enum34') + +with open(os.path.join(os.path.dirname(__file__), 'README.rst')) as desc_file: + long_desc = desc_file.read() + +invalid_roles = ['meth', 'class'] + +long_desc = re.sub(r':({}):`([^`]+)`'.format('|'.join(invalid_roles)), r'``\2``', long_desc, re.M) setup( name='dirty-models', url='https://github.com/alfred82santa/dirty-models', author='alfred82santa', - version='0.8.1', + version='0.9.0', author_email='alfred82santa@gmail.com', + license='BSD', classifiers=[ 'Intended Audience :: Developers', 'Programming Language :: Python', @@ -17,9 +32,9 @@ 'Development Status :: 4 - Beta'], packages=['dirty_models'], include_package_data=False, - install_requires=['python-dateutil'], + install_requires=install_requires, description="Dirty models for python 3", - long_description=open(os.path.join(os.path.dirname(__file__), 'README.rst')).read(), + long_description=long_desc, test_suite="nose.collector", tests_require="nose", zip_safe=True, diff --git a/tests/dirty_models/tests_fields.py b/tests/dirty_models/tests_fields.py index 33ecef6..5d74f70 100644 --- a/tests/dirty_models/tests_fields.py +++ b/tests/dirty_models/tests_fields.py @@ -15,35 +15,6 @@ class TestFields(TestCase): - def test_int_field_using_int(self): - field = IntegerField() - self.assertTrue(field.check_value(3)) - self.assertEqual(field.use_value(3), 3) - - def test_int_field_desc(self): - field = IntegerField() - self.assertEqual(field.export_definition(), {'alias': None, - 'doc': 'IntegerField field', - 'name': None, - 'read_only': False}) - - def test_int_field_using_float(self): - field = IntegerField() - self.assertFalse(field.check_value(3.0)) - self.assertTrue(field.can_use_value(3.0)) - self.assertEqual(field.use_value(3.0), 3) - - def test_int_field_using_str(self): - field = IntegerField() - self.assertFalse(field.check_value("3")) - self.assertTrue(field.can_use_value("3")) - self.assertEqual(field.use_value("3"), 3) - - def test_int_field_using_dict(self): - field = IntegerField() - self.assertFalse(field.check_value({})) - self.assertFalse(field.can_use_value({})) - def test_float_field_using_int(self): field = FloatField() self.assertFalse(field.check_value(3)) @@ -1298,6 +1269,72 @@ def test_array_field_no_autolist(self): self.assertEqual(self.model.export_data(), {}) +class IntegerFieldFieldTests(TestCase): + + class TestEnum(Enum): + value_1 = 1 + value_2 = '2' + value_3 = 3.2 + value_4 = 'value' + + def test_using_int(self): + field = IntegerField() + self.assertTrue(field.check_value(3)) + self.assertEqual(field.use_value(3), 3) + + def test_desc(self): + field = IntegerField() + self.assertEqual(field.export_definition(), {'alias': None, + 'doc': 'IntegerField field', + 'name': None, + 'read_only': False}) + + def test_using_float(self): + field = IntegerField() + self.assertFalse(field.check_value(3.0)) + self.assertTrue(field.can_use_value(3.0)) + self.assertEqual(field.use_value(3.0), 3) + + def test_using_str(self): + field = IntegerField() + self.assertFalse(field.check_value("3")) + self.assertTrue(field.can_use_value("3")) + self.assertEqual(field.use_value("3"), 3) + + def test_using_dict(self): + field = IntegerField() + self.assertFalse(field.check_value({})) + self.assertFalse(field.can_use_value({})) + + def test_using_int_enum(self): + field = IntegerField() + self.assertFalse(field.check_value(self.TestEnum.value_1)) + self.assertTrue(field.can_use_value(self.TestEnum.value_1)) + self.assertEqual(field.use_value(self.TestEnum.value_1), 1) + + def test_using_str_enum(self): + field = IntegerField() + self.assertFalse(field.check_value(self.TestEnum.value_2)) + self.assertTrue(field.can_use_value(self.TestEnum.value_2)) + self.assertEqual(field.use_value(self.TestEnum.value_2), 2) + + def test_using_float_enum(self): + field = IntegerField() + self.assertFalse(field.check_value(self.TestEnum.value_3)) + self.assertTrue(field.can_use_value(self.TestEnum.value_3)) + self.assertEqual(field.use_value(self.TestEnum.value_3), 3) + + def test_using_str_enum_fail(self): + field = IntegerField() + self.assertFalse(field.check_value(self.TestEnum.value_4)) + self.assertFalse(field.can_use_value(self.TestEnum.value_4)) + + def test_using_enum_fail(self): + field = IntegerField() + self.assertFalse(field.check_value(self.TestEnum)) + self.assertFalse(field.can_use_value(self.TestEnum)) + + class MultiTypeFieldSimpleTypesTests(TestCase): def setUp(self): @@ -1407,10 +1444,10 @@ def test_get_field_type_by_value_fail(self): multi_field.get_field_type_by_value({}) -class AutoreferenceModelTests(TestCase): +class AutoreferenceModelFieldTests(TestCase): def setUp(self): - super(AutoreferenceModelTests, self).setUp() + super(AutoreferenceModelFieldTests, self).setUp() class AutoreferenceModel(BaseModel): multi_field = MultiTypeField(field_types=[IntegerField(), (ArrayField, {"field_type": ModelField()})]) diff --git a/tests/dirty_models/tests_models.py b/tests/dirty_models/tests_models.py index 612bae7..e3e85e6 100644 --- a/tests/dirty_models/tests_models.py +++ b/tests/dirty_models/tests_models.py @@ -1,12 +1,15 @@ import pickle from datetime import datetime, date, time, timedelta +from enum import Enum + from functools import partial from unittest import TestCase from dirty_models.base import Unlocker from dirty_models.fields import (BaseField, IntegerField, FloatField, StringField, DateTimeField, ModelField, - ArrayField, BooleanField, DateField, TimeField, HashMapField, TimedeltaField) + ArrayField, BooleanField, DateField, TimeField, HashMapField, TimedeltaField, + EnumField) from dirty_models.models import BaseModel, DynamicModel, HashMapModel, FastDynamicModel, CamelCaseMeta INITIAL_DATA = { @@ -917,7 +920,7 @@ def tearDown(self): def _get_field_type(self, name): try: - return self.model.__class__.__dict__[name] + return self.model.get_field_obj(name) except KeyError: return None @@ -995,6 +998,15 @@ def test_set_list_value(self): self.assertIsInstance(self._get_field_type('test1'), ArrayField) self.assertIsInstance(self._get_field_type('test1').field_type, StringField) + def test_set_enum_value(self): + class TestEnum(Enum): + value_1 = 1 + + self.model.test1 = TestEnum.value_1 + self.assertEqual(self.model.export_data(), {"test1": TestEnum.value_1}) + self.assertIsInstance(self._get_field_type('test1'), EnumField) + self.assertEqual(self._get_field_type('test1').enum_class, TestEnum) + def test_set_empty_list_value(self): self.model.test1 = [] diff --git a/tests/dirty_models/tests_utils.py b/tests/dirty_models/tests_utils.py index 2342698..467525b 100644 --- a/tests/dirty_models/tests_utils.py +++ b/tests/dirty_models/tests_utils.py @@ -1,16 +1,15 @@ from datetime import datetime, date, timedelta +from enum import Enum from json import dumps, loads from unittest.case import TestCase from dirty_models.fields import StringIdField, IntegerField, DateTimeField, ArrayField, MultiTypeField, ModelField, \ - HashMapField, DateField, TimedeltaField + HashMapField, DateField, TimedeltaField, EnumField from dirty_models.models import BaseModel, DynamicModel, FastDynamicModel -from dirty_models.utils import underscore_to_camel, ModelFormatterIter, ListFormatterIter, HashMapFormatterIter, \ - JSONEncoder +from dirty_models.utils import underscore_to_camel, ModelFormatterIter, ListFormatterIter, JSONEncoder class UnderscoreToCamelTests(TestCase): - def test_no_underscore(self): self.assertEqual(underscore_to_camel('foobar'), 'foobar') @@ -28,6 +27,10 @@ def test_underscore_multi_number(self): class TestModel(BaseModel): + class TestEnum(Enum): + value_1 = 1 + value_2 = '2' + value_3 = date(year=2015, month=7, day=30) test_string_field_1 = StringIdField(name='other_field') test_int_field_1 = IntegerField() @@ -36,16 +39,17 @@ class TestModel(BaseModel): test_array_multitype = ArrayField(field_type=MultiTypeField(field_types=[IntegerField(), DateTimeField( parse_format="%Y-%m-%dT%H:%M:%S" - )])) + )])) test_model_field_1 = ArrayField(field_type=ArrayField(field_type=ModelField())) test_hash_map = HashMapField(field_type=DateField(parse_format="%Y-%m-%d date")) test_timedelta = TimedeltaField() + test_enum = EnumField(enum_class=TestEnum) + test_multi_field = MultiTypeField(field_types=[IntegerField(), + DateField(parse_format="%Y-%m-%d multi date")]) class ModelFormatterIterTests(TestCase): - def test_model_formatter(self): - model = TestModel(data={'test_string_field_1': 'foo', 'test_int_field_1': 4, 'test_datetime': datetime(year=2016, month=5, day=30, @@ -60,7 +64,9 @@ def test_model_formatter(self): 'test_model_field_1': [[{'test_datetime': datetime(year=2015, month=7, day=30, hour=22, minute=22, second=22)}]], 'test_hash_map': {'foo': date(year=2015, month=7, day=30)}, - 'test_timedelta': timedelta(seconds=32.1122)}) + 'test_timedelta': timedelta(seconds=32.1122), + 'test_enum': TestModel.TestEnum.value_3, + 'test_multi_field': date(year=2015, month=7, day=30)}) formatter = ModelFormatterIter(model) data = {k: v for k, v in formatter} @@ -75,9 +81,11 @@ def test_model_formatter(self): self.assertIsInstance(list(data['test_model_field_1'])[0], ListFormatterIter) self.assertEqual({k: v for k, v in list(list(data['test_model_field_1'])[0])[0]}, {'test_datetime': '2015-07-30T22:22:22'}) - self.assertIsInstance(data['test_hash_map'], HashMapFormatterIter) + self.assertIsInstance(data['test_hash_map'], ModelFormatterIter) self.assertEqual({k: v for k, v in data['test_hash_map']}, {'foo': '2015-07-30 date'}) self.assertEqual(data['test_timedelta'], 32.1122) + self.assertEqual(data['test_enum'], str(date(year=2015, month=7, day=30))) + self.assertEqual(data['test_multi_field'], '2015-07-30 multi date') def test_dynamic_model_formatter(self): model = DynamicModel(data={'test_string_field_1': 'foo', @@ -91,7 +99,9 @@ def test_dynamic_model_formatter(self): 'test_model_field_1': [[{'test_datetime': datetime(year=2015, month=7, day=30, hour=22, minute=22, second=22)}]], 'test_hash_map': {'foo': date(year=2015, month=7, day=30)}, - 'test_timedelta': timedelta(seconds=32.1122)}) + 'test_timedelta': timedelta(seconds=32.1122), + 'test_enum': TestModel.TestEnum.value_1, + 'test_multi_field': date(year=2015, month=7, day=30)}) formatter = ModelFormatterIter(model) data = {k: v for k, v in formatter} @@ -121,7 +131,9 @@ def test_fast_dynamic_model_formatter(self): hour=22, minute=22, second=22)}]], 'test_hash_map': {'foo': date(year=2015, month=7, day=30)}, - 'test_timedelta': timedelta(seconds=32.1122)}) + 'test_timedelta': timedelta(seconds=32.1122), + 'test_enum': TestModel.TestEnum.value_1, + 'test_multi_field': date(year=2015, month=7, day=30)}) formatter = ModelFormatterIter(model) data = {k: v for k, v in formatter} @@ -136,12 +148,11 @@ def test_fast_dynamic_model_formatter(self): self.assertIsInstance(data['test_hash_map'], ModelFormatterIter) self.assertEqual({k: v for k, v in data['test_hash_map']}, {'foo': '2015-07-30'}) self.assertEqual(data['test_timedelta'], 32.1122) + self.assertEqual(data['test_multi_field'], '2015-07-30') class JSONEncoderTests(TestCase): - def test_model_json(self): - model = TestModel(data={'test_string_field_1': 'foo', 'test_int_field_1': 4, 'test_datetime': datetime(year=2016, month=5, day=30, @@ -156,7 +167,9 @@ def test_model_json(self): 'test_model_field_1': [[{'test_datetime': datetime(year=2015, month=7, day=30, hour=22, minute=22, second=22)}]], 'test_hash_map': {'foo': date(year=2015, month=7, day=30)}, - 'test_timedelta': timedelta(seconds=32.1122)}) + 'test_timedelta': timedelta(seconds=32.1122), + 'test_enum': TestModel.TestEnum.value_1, + 'test_multi_field': date(year=2015, month=7, day=30)}) json_str = dumps(model, cls=JSONEncoder) @@ -168,12 +181,13 @@ def test_model_json(self): 'test_array_multitype': ['2015-05-30T22:22:22', 4, 5], 'test_model_field_1': [[{'test_datetime': '2015-07-30T22:22:22'}]], 'test_hash_map': {'foo': '2015-07-30 date'}, - 'test_timedelta': 32.1122} + 'test_timedelta': 32.1122, + 'test_enum': 1, + 'test_multi_field': '2015-07-30 multi date'} self.assertEqual(loads(json_str), data) def test_general_use_json(self): - data = {'foo': 3, 'bar': 'str'} json_str = dumps(data, cls=JSONEncoder) self.assertEqual(loads(json_str), data)