diff --git a/aiorest_ws/conf/global_settings.py b/aiorest_ws/conf/global_settings.py index bd32280..32c7886 100644 --- a/aiorest_ws/conf/global_settings.py +++ b/aiorest_ws/conf/global_settings.py @@ -32,6 +32,7 @@ UNICODE_JSON = True COMPACT_JSON = True COERCE_DECIMAL_TO_STRING = True +UPLOADED_FILES_USE_URL = True # ----------------------------------------------- # Database diff --git a/aiorest_ws/db/orm/django/compat.py b/aiorest_ws/db/orm/django/compat.py new file mode 100644 index 0000000..f07ced3 --- /dev/null +++ b/aiorest_ws/db/orm/django/compat.py @@ -0,0 +1,77 @@ +# -*- coding: utf-8 -*- +""" +The `compat` module provides support for backwards compatibility with older +versions of Django, and compatibility wrappers around optional packages. +""" +import inspect + +import django + +from django.apps import apps +from django.core.exceptions import ImproperlyConfigured +from django.db import models + + +__all__ = [ + 'DecimalValidator', 'postgres_fields', 'JSONField', + '_resolve_model', 'get_related_model', 'get_remote_field', + 'value_from_object' +] + + +try: + from django.core.validators import DecimalValidator +except ImportError: + DecimalValidator = None + + +try: + from django.contrib.postgres import fields as postgres_fields +except ImportError: + postgres_fields = None + + +try: + from django.contrib.postgres.fields import JSONField +except ImportError: + JSONField = None + + +def _resolve_model(obj): + """ + Resolve supplied `obj` to a Django model class. + `obj` must be a Django model class itself, or a string + representation of one. Useful in situations like GH #1225 where + Django may not have resolved a string-based reference to a model in + another model's foreign key definition. + String representations should have the format: + 'appname.ModelName' + """ + if isinstance(obj, str) and len(obj.split('.')) == 2: + app_name, model_name = obj.split('.') + resolved_model = apps.get_model(app_name, model_name) + if resolved_model is None: + msg = "Django did not return a model for {0}.{1}" + raise ImproperlyConfigured(msg.format(app_name, model_name)) + return resolved_model + elif inspect.isclass(obj) and issubclass(obj, models.Model): + return obj + raise ValueError("{0} is not a Django model".format(obj)) + + +def get_related_model(field): + if django.VERSION < (1, 9): + return _resolve_model(field.rel.to) + return field.remote_field.model + + +# field.rel is deprecated from 1.9 onwards +def get_remote_field(field, **kwargs): + if 'default' in kwargs: + return getattr(field, 'remote_field', kwargs['default']) + + return field.remote_field + + +def value_from_object(field, obj): + return field.value_from_object(obj) diff --git a/aiorest_ws/db/orm/django/field_mapping.py b/aiorest_ws/db/orm/django/field_mapping.py new file mode 100644 index 0000000..0e0f53f --- /dev/null +++ b/aiorest_ws/db/orm/django/field_mapping.py @@ -0,0 +1,270 @@ +# -*- coding: utf-8 -*- +""" +Helper functions for mapping model fields to a dictionary of default +keyword arguments that should be used for their equivalent serializer fields. +""" +from django.core import validators +from django.db import models +from django.utils.text import capfirst + +from aiorest_ws.db.orm.django.compat import DecimalValidator +from aiorest_ws.db.orm.django.validators import UniqueValidator +from aiorest_ws.utils.field_mapping import needs_label + +__all__ = [ + 'NUMERIC_FIELD_TYPES', 'get_detail_view_name', 'get_field_kwargs', + 'get_relation_kwargs', 'get_nested_relation_kwargs', 'get_url_kwargs' +] + + +NUMERIC_FIELD_TYPES = ( + models.IntegerField, models.FloatField, models.DecimalField +) + + +def get_detail_view_name(model): + """ + Given a model class, return the view name to use for URL relationships + that refer to instances of the model. + """ + return '%(model_name)s-detail' % { + 'app_label': model._meta.app_label, + 'model_name': model._meta.object_name.lower() + } + + +def get_field_kwargs(field_name, model_field): + """ + Creates a default instance of a basic non-relational field. + """ + kwargs = {} + validator_kwarg = list(model_field.validators) + + # The following will only be used by ModelField classes. + # Gets removed for everything else. + kwargs['model_field'] = model_field + + if model_field.verbose_name and needs_label(model_field.verbose_name, field_name): # NOQA + kwargs['label'] = capfirst(model_field.verbose_name) + + if model_field.help_text: + kwargs['help_text'] = model_field.help_text + + max_digits = getattr(model_field, 'max_digits', None) + if max_digits is not None: + kwargs['max_digits'] = max_digits + + decimal_places = getattr(model_field, 'decimal_places', None) + if decimal_places is not None: + kwargs['decimal_places'] = decimal_places + + if isinstance(model_field, models.AutoField) or not model_field.editable: + # If this field is read-only, then return early. + # Further keyword arguments are not valid. + kwargs['read_only'] = True + return kwargs + + if model_field.has_default() or model_field.blank or model_field.null: + kwargs['required'] = False + + is_nullable_field = not isinstance(model_field, models.NullBooleanField) + if model_field.null and is_nullable_field: + kwargs['allow_null'] = True + + if model_field.blank and (isinstance(model_field, models.CharField) or + isinstance(model_field, models.TextField)): + kwargs['allow_blank'] = True + + if isinstance(model_field, models.FilePathField): + kwargs['path'] = model_field.path + + if model_field.match is not None: + kwargs['match'] = model_field.match + + if model_field.recursive is not False: + kwargs['recursive'] = model_field.recursive + + if model_field.allow_files is not True: + kwargs['allow_files'] = model_field.allow_files + + if model_field.allow_folders is not False: + kwargs['allow_folders'] = model_field.allow_folders + + if model_field.choices: + # If this model field contains choices, then return early. + # Further keyword arguments are not valid. + kwargs['choices'] = model_field.choices + return kwargs + + # Our decimal validation is handled in the field code, not validator code. + # (In Django 1.9+ this differs from previous style) + if isinstance(model_field, models.DecimalField) and DecimalValidator: + validator_kwarg = [ + validator for validator in validator_kwarg + if not isinstance(validator, DecimalValidator) + ] + + # Ensure that max_length is passed explicitly as a keyword arg, + # rather than as a validator. + max_length = getattr(model_field, 'max_length', None) + if max_length is not None and (isinstance(model_field, models.CharField) or + isinstance(model_field, models.TextField)): + kwargs['max_length'] = max_length + validator_kwarg = [ + validator for validator in validator_kwarg + if not isinstance(validator, validators.MaxLengthValidator) + ] + + # Ensure that min_length is passed explicitly as a keyword arg, + # rather than as a validator. + min_length = next(( + validator.limit_value for validator in validator_kwarg + if isinstance(validator, validators.MinLengthValidator) + ), None) + if min_length is not None and isinstance(model_field, models.CharField): + kwargs['min_length'] = min_length + validator_kwarg = [ + validator for validator in validator_kwarg + if not isinstance(validator, validators.MinLengthValidator) + ] + + # Ensure that max_value is passed explicitly as a keyword arg, + # rather than as a validator. + max_value = next(( + validator.limit_value for validator in validator_kwarg + if isinstance(validator, validators.MaxValueValidator) + ), None) + if max_value is not None and isinstance(model_field, NUMERIC_FIELD_TYPES): + kwargs['max_value'] = max_value + validator_kwarg = [ + validator for validator in validator_kwarg + if not isinstance(validator, validators.MaxValueValidator) + ] + + # Ensure that max_value is passed explicitly as a keyword arg, + # rather than as a validator. + min_value = next(( + validator.limit_value for validator in validator_kwarg + if isinstance(validator, validators.MinValueValidator) + ), None) + if min_value is not None and isinstance(model_field, NUMERIC_FIELD_TYPES): + kwargs['min_value'] = min_value + validator_kwarg = [ + validator for validator in validator_kwarg + if not isinstance(validator, validators.MinValueValidator) + ] + + # URLField does not need to include the URLValidator argument, + # as it is explicitly added in. + if isinstance(model_field, models.URLField): + validator_kwarg = [ + validator for validator in validator_kwarg + if not isinstance(validator, validators.URLValidator) + ] + + # EmailField does not need to include the validate_email argument, + # as it is explicitly added in. + if isinstance(model_field, models.EmailField): + validator_kwarg = [ + validator for validator in validator_kwarg + if validator is not validators.validate_email + ] + + # SlugField do not need to include the 'validate_slug' argument + if isinstance(model_field, models.SlugField): + validator_kwarg = [ + validator for validator in validator_kwarg + if validator is not validators.validate_slug + ] + + # for IPAddressField exclude the 'validate_ipv46_address' argument + if isinstance(model_field, models.GenericIPAddressField): + validator_kwarg = [ + validator for validator in validator_kwarg + if validator is not validators.validate_ipv46_address + ] + + if getattr(model_field, 'unique', False): + unique_error_message = model_field.error_messages.get('unique', None) + if unique_error_message: + unique_error_message = unique_error_message % { + 'model_name': model_field.model._meta.verbose_name, + 'field_label': model_field.verbose_name + } + validator = UniqueValidator( + queryset=model_field.model._default_manager, + message=unique_error_message) + validator_kwarg.append(validator) + + if validator_kwarg: + kwargs['validators'] = validator_kwarg + + return kwargs + + +def get_relation_kwargs(field_name, relation_info): + """ + Creates a default instance of a flat relational field. + """ + model_field = relation_info.model_field + related_model = relation_info.related_model + to_many = relation_info.to_many + to_field = relation_info.to_field + has_through_model = relation_info.has_through_model + + kwargs = { + 'queryset': related_model._default_manager, + 'view_name': get_detail_view_name(related_model) + } + + if to_many: + kwargs['many'] = True + + if to_field: + kwargs['to_field'] = to_field + + if has_through_model: + kwargs['read_only'] = True + kwargs.pop('queryset', None) + + if model_field: + if model_field.verbose_name and needs_label(model_field, field_name): + kwargs['label'] = capfirst(model_field.verbose_name) + help_text = model_field.help_text + if help_text: + kwargs['help_text'] = help_text + if not model_field.editable: + kwargs['read_only'] = True + kwargs.pop('queryset', None) + if kwargs.get('read_only', False): + # If this field is read-only, then return early. + # No further keyword arguments are valid. + return kwargs + + if model_field.has_default() or model_field.blank or model_field.null: + kwargs['required'] = False + if model_field.null: + kwargs['allow_null'] = True + if model_field.validators: + kwargs['validators'] = model_field.validators + if getattr(model_field, 'unique', False): + queryset = model_field.model._default_manager + validator = UniqueValidator(queryset=queryset) + kwargs['validators'] = kwargs.get('validators', []) + [validator] + if to_many and not model_field.blank: + kwargs['allow_empty'] = False + + return kwargs + + +def get_nested_relation_kwargs(relation_info): + kwargs = {'read_only': True} + if relation_info.to_many: + kwargs['many'] = True + return kwargs + + +def get_url_kwargs(model_field): + return { + 'view_name': get_detail_view_name(model_field) + } diff --git a/aiorest_ws/db/orm/django/fields.py b/aiorest_ws/db/orm/django/fields.py new file mode 100644 index 0000000..1edac0d --- /dev/null +++ b/aiorest_ws/db/orm/django/fields.py @@ -0,0 +1,400 @@ +# -*- coding: utf-8 -*- +""" +Field entities, implemented for support Django ORM. + +Every class, represented here, is associated with one certain field type of +table relatively to Django ORM. Each of them field also used later for +serializing/deserializing object of ORM. +""" +import datetime +import re +import uuid + +from aiorest_ws.conf import settings +from aiorest_ws.db.orm import fields +from aiorest_ws.db.orm.django.compat import get_remote_field, \ + value_from_object +from aiorest_ws.db.orm.fields import empty +from aiorest_ws.db.orm.validators import MaxLengthValidator +from aiorest_ws.utils.date.dateparse import parse_duration + +from django.forms import FilePathField as DjangoFilePathField +from django.forms import ImageField as DjangoImageField +from django.core.exceptions import ValidationError as DjangoValidationError +from django.core.validators import EmailValidator, RegexValidator, \ + URLValidator, ip_address_validators +from django.utils import six +from django.utils.duration import duration_string +from django.utils.encoding import is_protected_type +from django.utils.ipv6 import clean_ipv6_address + +__all__ = ( + 'IntegerField', 'BooleanField', 'CharField', 'ChoiceField', + 'MultipleChoiceField', 'FloatField', 'NullBooleanField', 'DecimalField', + 'TimeField', 'DateField', 'DateTimeField', 'DurationField', 'ListField', + 'DictField', 'HStoreField', 'JSONField', 'ModelField', 'ReadOnlyField', + 'SerializerMethodField', 'EmailField', 'RegexField', 'SlugField', + 'URLField', 'UUIDField', 'IPAddressField', 'FilePathField', 'FileField', + 'ImageField', 'CreateOnlyDefault' +) + + +class IntegerField(fields.IntegerField): + pass + + +class BooleanField(fields.BooleanField): + pass + + +class CharField(fields.CharField): + pass + + +class ChoiceField(fields.ChoiceField): + pass + + +class MultipleChoiceField(ChoiceField): + default_error_messages = { + 'invalid_choice': u'"{input}" is not a valid choice.', + 'not_a_list': u'Expected a list of items but got type "{input_type}".', + 'empty': u'This selection may not be empty.' + } + + def __init__(self, *args, **kwargs): + self.allow_empty = kwargs.pop('allow_empty', True) + super(MultipleChoiceField, self).__init__(*args, **kwargs) + + def get_value(self, dictionary): + if self.field_name not in dictionary: + if getattr(self.root, 'partial', False): + return empty + return dictionary.get(self.field_name, empty) + + def to_internal_value(self, data): + if isinstance(data, type('')) or not hasattr(data, '__iter__'): + self.raise_error('not_a_list', input_type=type(data).__name__) + if not self.allow_empty and len(data) == 0: + self.raise_error('empty') + + return { + super(MultipleChoiceField, self).to_internal_value(item) + for item in data + } + + def to_representation(self, value): + return { + self.choice_strings_to_values.get(str(item), item) + for item in value + } + + +class FloatField(fields.FloatField): + pass + + +class NullBooleanField(fields.NullBooleanField): + pass + + +class DecimalField(fields.DecimalField): + pass + + +class TimeField(fields.TimeField): + pass + + +class DateField(fields.DateField): + pass + + +class DateTimeField(fields.DateTimeField): + pass + + +class DurationField(fields.AbstractField): + default_error_messages = { + 'invalid': u"Duration has wrong format. Use one of these formats " + u"instead: {format}.", + } + + def to_internal_value(self, value): + if isinstance(value, datetime.timedelta): + return value + parsed = parse_duration(str(value)) + if parsed is not None: + return parsed + self.raise_error('invalid', format='[DD] [HH:[MM:]]ss[.uuuuuu]') + + def to_representation(self, value): + return duration_string(value) + + +class ListField(fields.ListField): + pass + + +class DictField(fields.DictField): + pass + + +class HStoreField(fields.HStoreField): + pass + + +class JSONField(fields.JSONField): + pass + + +class ModelField(fields.ModelField): + default_error_messages = { + 'max_length': u'Ensure this field has no more than {max_length} ' + u'characters.' + } + + def __init__(self, model_field, **kwargs): + # The `max_length` option is supported by Django's base `Field` class, + # so we'd better support it here. + max_length = kwargs.pop('max_length', None) + super(ModelField, self).__init__(model_field, **kwargs) + if max_length is not None: + message = self.error_messages['max_length'].format( + max_length=max_length + ) + self.validators.append( + MaxLengthValidator(max_length, message=message) + ) + + def to_internal_value(self, data): + rel = get_remote_field(self.model_field, default=None) + if rel is not None: + return rel.to._meta.get_field(rel.field_name).to_python(data) + return self.model_field.to_python(data) + + def to_representation(self, obj): + value = value_from_object(self.model_field, obj) + if is_protected_type(value): + return value + return self.model_field.value_to_string(obj) + + +class ReadOnlyField(fields.ReadOnlyField): + pass + + +class SerializerMethodField(fields.SerializerMethodField): + pass + + +class CreateOnlyDefault(fields.CreateOnlyDefault): + pass + + +class EmailField(CharField): + default_error_messages = { + "invalid": u"Enter a valid email address." + } + + def __init__(self, **kwargs): + super(EmailField, self).__init__(**kwargs) + validator = EmailValidator(message=self.error_messages['invalid']) + self.validators.append(validator) + + +class RegexField(CharField): + default_error_messages = { + 'invalid': u"This value does not match the required pattern." + } + + def __init__(self, regex, **kwargs): + super(RegexField, self).__init__(**kwargs) + validator = RegexValidator( + regex, message=self.error_messages['invalid'] + ) + self.validators.append(validator) + + +class SlugField(CharField): + default_error_messages = { + 'invalid': u'Enter a valid "slug" consisting of letters, numbers, ' + u'underscores or hyphens.' + } + + def __init__(self, **kwargs): + super(SlugField, self).__init__(**kwargs) + slug_regex = re.compile(r'^[-a-zA-Z0-9_]+$') + validator = RegexValidator( + slug_regex, message=self.error_messages['invalid'] + ) + self.validators.append(validator) + + +class URLField(CharField): + default_error_messages = { + 'invalid': u"Enter a valid URL." + } + + def __init__(self, **kwargs): + super(URLField, self).__init__(**kwargs) + validator = URLValidator(message=self.error_messages['invalid']) + self.validators.append(validator) + + +class UUIDField(fields.AbstractField): + valid_formats = ('hex_verbose', 'hex', 'int', 'urn') + + default_error_messages = { + 'invalid': u'"{value}" is not a valid UUID.' + } + + def __init__(self, **kwargs): + self.uuid_format = kwargs.pop('format', 'hex_verbose') + if self.uuid_format not in self.valid_formats: + raise ValueError( + 'Invalid format for uuid representation. ' + 'Must be one of "{0}"'.format('", "'.join(self.valid_formats)) + ) + super(UUIDField, self).__init__(**kwargs) + + def to_internal_value(self, data): + if not isinstance(data, uuid.UUID): + try: + if isinstance(data, int): + return uuid.UUID(int=data) + elif isinstance(data, str): + return uuid.UUID(hex=data) + else: + self.raise_error('invalid', value=data) + except ValueError: + self.raise_error('invalid', value=data) + return data + + def to_representation(self, value): + if self.uuid_format == 'hex_verbose': + return str(value) + else: + return getattr(value, self.uuid_format) + + +class IPAddressField(CharField): + """Support both IPAddressField and GenericIPAddressField""" + + default_error_messages = { + 'invalid': u"Enter a valid IPv4 or IPv6 address." + } + + def __init__(self, protocol='both', **kwargs): + self.protocol = protocol.lower() + self.unpack_ipv4 = (self.protocol == 'both') + super(IPAddressField, self).__init__(**kwargs) + validators, error_message = ip_address_validators( + protocol, self.unpack_ipv4 + ) + self.validators.extend(validators) + + def to_internal_value(self, data): + if not isinstance(data, six.string_types): + self.raise_error('invalid', value=data) + + if ':' in data: + try: + if self.protocol in ('both', 'ipv6'): + return clean_ipv6_address(data, self.unpack_ipv4) + except DjangoValidationError: + self.raise_error('invalid', value=data) + + return super(IPAddressField, self).to_internal_value(data) + + +class FilePathField(ChoiceField): + default_error_messages = { + 'invalid_choice': u'"{input}" is not a valid path choice.' + } + + def __init__(self, path, match=None, recursive=False, allow_files=True, + allow_folders=False, required=None, **kwargs): + # Defer to Django's FilePathField implementation to get the + # valid set of choices. + field = DjangoFilePathField( + path, match=match, recursive=recursive, allow_files=allow_files, + allow_folders=allow_folders, required=required + ) + kwargs['choices'] = field.choices + super(FilePathField, self).__init__(**kwargs) + + +class FileField(fields.AbstractField): + default_error_messages = { + 'required': u'No file was submitted.', + 'invalid': u'The submitted data was not a file. Check the encoding ' + u'type on the form.', + 'no_name': u'No filename could be determined.', + 'empty': u'The submitted file is empty.', + 'max_length': u'Ensure this filename has at most {max_length} ' + u'characters (it has {length}).', + } + + def __init__(self, *args, **kwargs): + self.max_length = kwargs.pop('max_length', None) + self.allow_empty_file = kwargs.pop('allow_empty_file', False) + if 'use_url' in kwargs: + self.use_url = kwargs.pop('use_url') + super(FileField, self).__init__(*args, **kwargs) + + def to_internal_value(self, data): + try: + # `UploadedFile` objects should have name and size attributes. + file_name = data.name + file_size = data.size + except AttributeError: + self.raise_error('invalid') + + if not file_name: + self.raise_error('no_name') + if not self.allow_empty_file and not file_size: + self.raise_error('empty') + if self.max_length and len(file_name) > self.max_length: + self.raise_error( + 'max_length', max_length=self.max_length, length=len(file_name) + ) + + return data + + def to_representation(self, value): + use_url = getattr(self, 'use_url', settings.UPLOADED_FILES_USE_URL) + + if not value: + return None + + if use_url: + if not getattr(value, 'url', None): + # If the file has not been saved it may not have a URL. + return None + url = value.url + return url + return value.name + + +class ImageField(FileField): + default_error_messages = { + 'invalid_image': u'Upload a valid image. The file you uploaded was ' + u'either not an image or a corrupted image.' + } + + def __init__(self, *args, **kwargs): + self._DjangoImageField = kwargs.pop( + '_DjangoImageField', DjangoImageField + ) + super(ImageField, self).__init__(*args, **kwargs) + + def to_internal_value(self, data): + # Image validation is a bit grungy, so we'll just outright + # defer to Django's implementation so we don't need to + # consider it, or treat PIL as a test dependency. + file_object = super(ImageField, self).to_internal_value(data) + django_field = self._DjangoImageField() + django_field.error_messages = self.error_messages + django_field.to_python(file_object) + return file_object diff --git a/aiorest_ws/db/orm/django/model_meta.py b/aiorest_ws/db/orm/django/model_meta.py new file mode 100644 index 0000000..00c9145 --- /dev/null +++ b/aiorest_ws/db/orm/django/model_meta.py @@ -0,0 +1,179 @@ +# -*- coding: utf-8 -*- +""" +Helper function for returning the field information that is associated +with a model class. This includes returning all the forward and reverse +relationships and their associated metadata. +""" +from collections import OrderedDict + +from aiorest_ws.utils.structures import FieldInfo, RelationInfo +from aiorest_ws.db.orm.django.compat import get_related_model, \ + get_remote_field + +__all__ = ( + '_get_pk', '_get_fields', '_get_to_field', '_get_forward_relationships', + '_get_reverse_relationships', '_merge_fields_and_pk', + '_merge_relationships', 'get_field_info', 'is_abstract_model' +) + + +def _get_pk(opts): + pk = opts.pk + rel = get_remote_field(pk) + + while rel and rel.parent_link: + # If model is a child via multi-table inheritance, use parent's pk. + pk = get_related_model(pk)._meta.pk + rel = get_remote_field(pk) + + return pk + + +def _get_fields(opts): + fields = OrderedDict() + opts_fields = [ + field for field in opts.fields + if field.serialize and not get_remote_field(field) + ] + for field in opts_fields: + fields[field.name] = field + + return fields + + +def _get_to_field(field): + return getattr(field, 'to_fields', None) and field.to_fields[0] + + +def _get_forward_relationships(opts): + """ + Returns an `OrderedDict` of field names to `RelationInfo`. + """ + forward_relations = OrderedDict() + forwards_fields = [ + field for field in opts.fields + if field.serialize and get_remote_field(field) + ] + for field in forwards_fields: + forward_relations[field.name] = RelationInfo( + model_field=field, + related_model=get_related_model(field), + to_many=False, + to_field=_get_to_field(field), + has_through_model=False + ) + + # Deal with forward many-to-many relationships. + many_to_many_fields = [ + field for field in opts.many_to_many + if field.serialize + ] + for field in many_to_many_fields: + forward_relations[field.name] = RelationInfo( + model_field=field, + related_model=get_related_model(field), + to_many=True, + # many-to-many do not have to_fields + to_field=None, + has_through_model=( + not get_remote_field(field).through._meta.auto_created + ) + ) + + return forward_relations + + +def _get_reverse_relationships(opts): + """ + Returns an `OrderedDict` of field names to `RelationInfo`. + """ + # Note that we have a hack here to handle internal API differences for + # this internal API across Django 1.7 -> Django 1.8. + # See: https://code.djangoproject.com/ticket/24208 + + reverse_relations = OrderedDict() + all_related_objects = [ + r for r in opts.related_objects + if not r.field.many_to_many + ] + for relation in all_related_objects: + accessor_name = relation.get_accessor_name() + related = getattr(relation, 'related_model', relation.model) + reverse_relations[accessor_name] = RelationInfo( + model_field=None, + related_model=related, + to_many=get_remote_field(relation.field).multiple, + to_field=_get_to_field(relation.field), + has_through_model=False + ) + + # Deal with reverse many-to-many relationships. + all_related_many_to_many_objects = [ + r for r in opts.related_objects + if r.field.many_to_many + ] + for relation in all_related_many_to_many_objects: + + has_through_model = False + through = getattr(get_remote_field(relation.field), 'through', None) + if through is not None: + remote_field = get_remote_field(relation.field) + has_through_model = not remote_field.through._meta.auto_created + + accessor_name = relation.get_accessor_name() + related = getattr(relation, 'related_model', relation.model) + reverse_relations[accessor_name] = RelationInfo( + model_field=None, + related_model=related, + to_many=True, + # manytomany do not have to_fields + to_field=None, + has_through_model=has_through_model + ) + + return reverse_relations + + +def _merge_fields_and_pk(pk, fields): + fields_and_pk = OrderedDict() + fields_and_pk['pk'] = pk + fields_and_pk[pk.name] = pk + fields_and_pk.update(fields) + + return fields_and_pk + + +def _merge_relationships(forward_relations, reverse_relations): + return OrderedDict( + list(forward_relations.items()) + + list(reverse_relations.items()) + ) + + +def get_field_info(model): + """ + Given a model class, returns a `FieldInfo` instance, which is a + `namedtuple`, containing metadata about the various field types on the + model including information about their relationships. + """ + opts = model._meta.concrete_model._meta + + pk = _get_pk(opts) + fields = _get_fields(opts) + forward_relations = _get_forward_relationships(opts) + reverse_relations = _get_reverse_relationships(opts) + fields_and_pk = _merge_fields_and_pk(pk, fields) + relationships = _merge_relationships(forward_relations, reverse_relations) + + return FieldInfo(pk, fields, forward_relations, reverse_relations, + fields_and_pk, relationships) + + +def is_abstract_model(model): + """ + Given a model class, returns a boolean True if it is abstract and False + if it is not. + """ + has_meta_attribute = hasattr(model, '_meta') + is_abstract = hasattr(model._meta, 'abstract') and model._meta.abstract + return has_meta_attribute and is_abstract diff --git a/aiorest_ws/db/orm/django/relations.py b/aiorest_ws/db/orm/django/relations.py new file mode 100644 index 0000000..a55e73f --- /dev/null +++ b/aiorest_ws/db/orm/django/relations.py @@ -0,0 +1,193 @@ +# -*- coding: utf-8 -*- +""" +Module, which provide classes and function for related and nested field. +""" +from collections import OrderedDict + +from aiorest_ws.db.orm import relations +from aiorest_ws.exceptions import ImproperlyConfigured +from aiorest_ws.utils.fields import get_attribute, is_simple_callable + +from django.core.exceptions import ObjectDoesNotExist +from django.db.models import Manager +from django.db.models.query import QuerySet +from django.utils.encoding import smart_text + + +__all__ = ( + 'ManyRelatedField', 'RelatedField', 'StringRelatedField', + 'PrimaryKeyRelatedField', 'HyperlinkedRelatedField', + 'HyperlinkedIdentityField', 'SlugRelatedField' +) + + +class ManyRelatedField(relations.ManyRelatedField): + """ + Relationships with `many=True` transparently get coerced into instead being + a ManyRelatedField with a child relationship. + The `ManyRelatedField` class is responsible for handling iterating through + the values and passing each one to the child relationship. + This class is treated as private API. + You shouldn't generally need to be using this class directly yourself, + and should instead simply set 'many=True' on the relationship. + """ + + def get_attribute(self, instance): + # Can't have any relationships if not created + if hasattr(instance, 'pk') and instance.pk is None: + return [] + + relationship = get_attribute(instance, self.source_attrs) + if hasattr(relationship, 'all'): + relationship.all() + return relationship + + +class RelatedField(relations.RelatedField): + many_related_field = ManyRelatedField + + def get_queryset(self): + queryset = self.queryset + if isinstance(queryset, (QuerySet, Manager)): + # Ensure queryset is re-evaluated whenever used. + # Note that actually a `Manager` class may also be used as the + # queryset argument. This occurs on ModelSerializer fields, + # as it allows us to generate a more expressive 'repr' output + # for the field. + # Eg: 'MyRelationship(queryset=ExampleModel.objects.all())' + queryset = queryset.all() + return queryset + + def get_attribute(self, instance): + if self.use_pk_only_optimization() and self.source_attrs: + # Optimized case, return a mock object only containing the + # pk attribute. + try: + instance = get_attribute(instance, self.source_attrs[:-1]) + value = instance.serializable_value(self.source_attrs[-1]) + if is_simple_callable(value): + # Handle edge case where the relationship `source` argument + # points to a `get_relationship()` method on the model + value = value().pk + return relations.PKOnlyObject(pk=value) + except AttributeError: + pass + + # Standard case, return the object instance. + return get_attribute(instance, self.source_attrs) + + def get_choices(self, cutoff=None): + queryset = self.get_queryset() + if queryset is None: + # Ensure that field.choices returns something sensible + # even when accessed with a read-only field. + return {} + + if cutoff is not None: + queryset = queryset[:cutoff] + + return OrderedDict([ + (self.to_representation(item), self.display_value(item)) + for item in queryset + ]) + + @property + def choices(self): + return self.get_choices() + + @property + def grouped_choices(self): + return self.choices + + def display_value(self, instance): + return str(instance) + + +class StringRelatedField(relations.StringRelatedField, RelatedField): + pass + + +class PrimaryKeyRelatedField(relations.PrimaryKeyRelatedField, + RelatedField): + + def to_internal_value(self, data): + if self.pk_field is not None: + data = self.pk_field.to_internal_value(data) + try: + return self.get_queryset().get(pk=data) + except ObjectDoesNotExist: + self.raise_error('does_not_exist', pk_value=data) + except (TypeError, ValueError): + self.raise_error('incorrect_type', data_type=type(data).__name__) + + def to_representation(self, value): + if self.pk_field is not None: + return self.pk_field.to_representation(value.pk) + return value.pk + + +class HyperlinkedRelatedField(relations.HyperlinkedRelatedField, + RelatedField): + lookup_field = 'pk' + + def use_pk_only_optimization(self): + return self.lookup_field == 'pk' + + def get_object(self, view_name, view_args, view_kwargs): + """ + Return the object corresponding to a matched URL. + Takes the matched URL conf arguments, and should return an + object instance, or raise an `ObjectDoesNotExist` exception. + """ + try: + lookup_value = view_kwargs[self.lookup_url_kwarg] + lookup_kwargs = {self.lookup_field: lookup_value} + return self.get_queryset().get(**lookup_kwargs) + except ObjectDoesNotExist: + self.raise_error('does_not_exist') + except KeyError: + raise ImproperlyConfigured( + "Missing primary key in the endpoint path. For fixing it just " + "specify for a requested endpoint URL with included " + "`{field_name}` parameter in the path or override" + "`lookup_url_kwarg` in the constructor for the concrete field." + .format(field_name=self.lookup_url_kwarg) + ) + except (TypeError, ValueError): + self.raise_error( + 'incorrect_type', data_type=type(view_kwargs).__name__ + ) + + def is_saved_in_database(self, obj): + if not obj or not obj.pk: + return False + return True + + def get_lookup_value(self, obj): + pk = getattr(obj, self.lookup_field) + return pk if isinstance(pk, (tuple, list)) else (pk, ) + + +class HyperlinkedIdentityField(relations.HyperlinkedIdentityField, + HyperlinkedRelatedField): + pass + + +class SlugRelatedField(relations.SlugRelatedField, RelatedField): + """ + A read-write field that represents the target of the relationship + by a unique 'slug' attribute. + """ + def to_internal_value(self, data): + try: + return self.get_queryset().get(**{self.slug_field: data}) + except ObjectDoesNotExist: + self.raise_error( + 'does_not_exist', slug_name=self.slug_field, + value=smart_text(data) + ) + except (TypeError, ValueError): + self.raise_error('invalid') + + def to_representation(self, obj): + return getattr(obj, self.slug_field) diff --git a/aiorest_ws/db/orm/django/serializers.py b/aiorest_ws/db/orm/django/serializers.py new file mode 100644 index 0000000..684d1ca --- /dev/null +++ b/aiorest_ws/db/orm/django/serializers.py @@ -0,0 +1,476 @@ +# -*- coding: utf-8 -*- +""" +List, model and hyperlinked serializer classes for Django ORM. + +As you can see there, we are inherited from all base classes and implement +logic according to the work of SQLAlchemy ORM. For the most situations using +ModelSerializer class will be enough. +""" +import traceback + +from aiorest_ws.conf import settings +from aiorest_ws.db.orm.abstract import empty +from aiorest_ws.db.orm.exceptions import ValidationError +from aiorest_ws.db.orm.serializers import \ + ListSerializer as BaseListSerializer, \ + ModelSerializer as BaseModelSerializer, HyperlinkedModelSerializerMixin, \ + raise_errors_on_nested_writes +from aiorest_ws.db.orm.django import model_meta +from aiorest_ws.db.orm.django.compat import postgres_fields, \ + JSONField as ModelJSONField +from aiorest_ws.db.orm.django.field_mapping import get_field_kwargs, \ + get_nested_relation_kwargs, get_relation_kwargs, get_url_kwargs +from aiorest_ws.utils.field_mapping import ClassLookupDict + +from django.core.exceptions import ValidationError as DjangoValidationError +from django.db import models +from django.db.models import DurationField as ModelDurationField +from django.db.models.fields import Field as DjangoModelField +from django.db.models.fields import FieldDoesNotExist +from django.utils import timezone + +from aiorest_ws.db.orm.django.fields import * # NOQA +from aiorest_ws.db.orm.django.relations import * # NOQA + +__all__ = ( + 'get_validation_error_detail', 'ListSerializer', 'ModelSerializer', + 'HyperlinkedModelSerializer', +) + + +def get_validation_error_detail(exc): + assert isinstance(exc, (ValidationError, DjangoValidationError)) + + if isinstance(exc, DjangoValidationError): + # Normally you should raise `serializers.ValidationError` + # inside your codebase, but we handle Django's validation + # exception class as well for simpler compat. + # Eg. Calling Model.clean() explicitly inside Serializer.validate() + return {settings.REST_CONFIG['NON_FIELD_ERRORS_KEY']: exc.args} + elif isinstance(exc.detail, dict): + # If errors may be a dict we use the standard {key: list of values}. + # Here we ensure that all the values are *lists* of errors. + return { + key: value if isinstance(value, (list, dict)) else [value] + for key, value in exc.detail.items() + } + elif isinstance(exc.detail, list): + # Errors raised as a list are non-field errors. + return {settings.REST_CONFIG['NON_FIELD_ERRORS_KEY']: exc.detail} + # Errors raised as a string are non-field errors. + return {settings.REST_CONFIG['NON_FIELD_ERRORS_KEY']: [exc.detail]} + + +class ListSerializer(BaseListSerializer): + + def run_validation(self, data=empty): + """ + We override the default `run_validation`, because the validation + performed by validators and the `.validate()` method should + be coerced into an error dictionary with a 'non_fields_error' key. + """ + (is_empty_value, data) = self.validate_empty_values(data) + if is_empty_value: + return data + + value = self.to_internal_value(data) + try: + self.run_validators(value) + value = self.validate(value) + assert value is not None, '.validate() should return the ' \ + 'validated data' + except (ValidationError, DjangoValidationError) as exc: + raise ValidationError(detail=get_validation_error_detail(exc)) + + return value + + +class ModelSerializer(BaseModelSerializer): + """ + Base serializer for Django models. + + This class automatically generate "scheme" for further serializing and + de-serializing data, according to used model, which is specified by user. + For indicated fields by programmer, ModelSerializer will not apply "scheme + construction rule". They are will have skipped while processing. Otherwise, + for every field which is not pre-defined by user, will be found the most + suitable and compatible field class. + """ + serializer_field_mapping = { + models.AutoField: IntegerField, + models.BigIntegerField: IntegerField, + models.BooleanField: BooleanField, + models.CharField: CharField, + models.CommaSeparatedIntegerField: CharField, + models.DateField: DateField, + models.DateTimeField: DateTimeField, + models.DecimalField: DecimalField, + models.EmailField: EmailField, + models.Field: ModelField, + models.FileField: FileField, + models.FloatField: FloatField, + models.ImageField: ImageField, + models.IntegerField: IntegerField, + models.NullBooleanField: NullBooleanField, + models.PositiveIntegerField: IntegerField, + models.PositiveSmallIntegerField: IntegerField, + models.SlugField: SlugField, + models.SmallIntegerField: IntegerField, + models.TextField: CharField, + models.TimeField: TimeField, + models.URLField: URLField, + models.GenericIPAddressField: IPAddressField, + models.FilePathField: FilePathField, + } + if ModelDurationField is not None: + serializer_field_mapping[ModelDurationField] = DurationField + if ModelJSONField is not None: + serializer_field_mapping[ModelJSONField] = JSONField + serializer_related_field = PrimaryKeyRelatedField + serializer_related_to_field = SlugRelatedField + serializer_url_field = HyperlinkedIdentityField + serializer_choice_field = ChoiceField + default_list_serializer = ListSerializer + + def is_abstract_model(self, model): + """ + Check the passed model is abstract. + """ + return model_meta.is_abstract_model(model) + + def get_field_info(self, model): + """ + Get metadata about field in the passed model. + """ + return model_meta.get_field_info(model) + + # Default `create` and `update` behavior... + def create(self, validated_data): + """ + We have a bit of extra checking around this in order to provide + descriptive messages when something goes wrong, but this method is + essentially just: + + return ExampleModel.objects.create(**validated_data) + + If there are many to many fields present on the instance then they + cannot be set until the model is instantiated, in which case the + implementation is like so: + + example_relationship = validated_data.pop('example_relationship') + instance = ExampleModel.objects.create(**validated_data) + instance.example_relationship = example_relationship + return instance + + The default implementation also does not handle nested relationships. + If you want to support writable nested relationships you'll need + to write an explicit `.create()` method. + """ + raise_errors_on_nested_writes('create', self, validated_data) + + ModelClass = self.Meta.model + + # Remove many-to-many relationships from validated_data. + # They are not valid arguments to the default `.create()` method, + # as they require that the instance has already been saved. + info = model_meta.get_field_info(ModelClass) + many_to_many = {} + for field_name, relation_info in info.relations.items(): + if relation_info.to_many and (field_name in validated_data): + many_to_many[field_name] = validated_data.pop(field_name) + + try: + instance = ModelClass.objects.create(**validated_data) + except TypeError: + tb = traceback.format_exc() + msg = ( + 'Got a `TypeError` when calling `%s.objects.create()`. ' + 'This may be because you have a writable field on the ' + 'serializer class that is not a valid argument to ' + '`%s.objects.create()`. You may need to make the field ' + 'read-only, or override the %s.create() method to handle ' + 'this correctly.\nOriginal exception was:\n %s' % + ( + ModelClass.__name__, + ModelClass.__name__, + self.__class__.__name__, + tb + ) + ) + raise TypeError(msg) + + # Save many-to-many relationships after the instance is created. + if many_to_many: + for field_name, value in many_to_many.items(): + setattr(instance, field_name, value) + + return instance + + def update(self, instance, validated_data): + raise_errors_on_nested_writes('update', self, validated_data) + + # Simply set each attribute on the instance, and then save it. + # Note that unlike `.create()` we don't need to treat many-to-many + # relationships as being a special case. During updates we already + # have an instance pk for the relationships to be associated with. + for attr, value in validated_data.items(): + setattr(instance, attr, value) + instance.save() + + return instance + + def run_validation(self, data=empty): + """ + We override the default `run_validation`, because the validation + performed by validators and the `.validate()` method should + be coerced into an error dictionary with a 'non_fields_error' key. + """ + (is_empty_value, data) = self.validate_empty_values(data) + if is_empty_value: + return data + + for field in self.fields.values(): + for validator in field.validators: + if hasattr(validator, 'set_context'): + validator.set_context(field) + + value = self.to_internal_value(data) + try: + self.run_validators(value) + value = self.validate(value) + assert value is not None, '.validate() should return the ' \ + 'validated data' + except (ValidationError, DjangoValidationError) as exc: + raise ValidationError(detail=get_validation_error_detail(exc)) + + return value + + def get_default_field_names(self, declared_fields, model_info): + """ + Return the default list of field names that will be used if the + `Meta.fields` option is not specified. + """ + return ( + [model_info.pk.name] + + list(declared_fields.keys()) + + list(model_info.fields.keys()) + + list(model_info.forward_relations.keys()) + ) + + def _get_unique_constraint_names(self, model, model_fields, field_names): + """ + Return a set of field names, for each column unique constraint. + Used internally by `get_uniqueness_extra_kwargs`. + """ + unique_constraint_names = set() + + for model_field in model_fields.values(): + # Include each of the `unique_for_*` field names. + unique_constraint_names |= { + model_field.unique_for_date, model_field.unique_for_month, + model_field.unique_for_year + } + + unique_constraint_names -= {None} + return unique_constraint_names + + def _get_unique_together_constraints(self, model, model_fields, field_names): # NOQA + """ + Return a set of field names for a multiple unique constraints. + Used internally by `get_uniqueness_extra_kwargs`. + """ + unique_constraint_names = set() + + for parent_class in [model] + list(model._meta.parents.keys()): + for unique_together_list in parent_class._meta.unique_together: + if set(field_names).issuperset(set(unique_together_list)): + unique_constraint_names |= set(unique_together_list) + + return unique_constraint_names + + def _get_unique_field(self, model, unique_field_name): + """ + Return a field by his name from a model. + Used internally by `get_uniqueness_extra_kwargs`. + """ + return model._meta.get_field(unique_field_name) + + def _get_default_field_value(self, unique_constraint_field): + """ + Return a default value for a passed field. + Used internally by `get_uniqueness_extra_kwargs`. + """ + default = empty + + if getattr(unique_constraint_field, 'auto_now_add', None): + default = CreateOnlyDefault(timezone.now) + elif getattr(unique_constraint_field, 'auto_now', None): + default = timezone.now + elif unique_constraint_field.has_default(): + default = unique_constraint_field.default + + return default + + def build_field(self, field_name, info, model_class, nested_depth): + """ + Create regular model fields. + """ + if field_name in info.fields_and_pk: + model_field = info.fields_and_pk[field_name] + return self.build_standard_field(field_name, model_field) + + elif field_name in info.relations: + relation_info = info.relations[field_name] + if not nested_depth: + return self.build_relational_field(field_name, relation_info) + else: + return self.build_nested_field( + field_name, relation_info, nested_depth + ) + + elif hasattr(model_class, field_name): + return self.build_property_field(field_name, model_class) + + elif field_name == self.url_field_name: + return self.build_url_field(field_name, model_class) + + return self.build_unknown_field(field_name, model_class) + + def build_standard_field(self, *args, **kwargs): + """ + Create regular model fields. + """ + field_name, model_field = args + field_mapping = ClassLookupDict(self.serializer_field_mapping) + + field_class = field_mapping[model_field] + field_kwargs = get_field_kwargs(field_name, model_field) + + if 'choices' in field_kwargs: + # Fields with choices get coerced into `ChoiceField` + # instead of using their regular typed field. + field_class = self.serializer_choice_field + # Some model fields may introduce kwargs that would not be valid + # for the choice field. We need to strip these out. + valid_kwargs = { + 'read_only', 'write_only', + 'required', 'default', 'initial', 'source', + 'label', 'help_text', 'style', + 'error_messages', 'validators', 'allow_null', 'allow_blank', + 'choices' + } + for key in list(field_kwargs.keys()): + if key not in valid_kwargs: + field_kwargs.pop(key) + + if not issubclass(field_class, ModelField): + # `model_field` is only valid for the fallback case of + # `ModelField`, which is used when no other typed field + # matched to the model field. + field_kwargs.pop('model_field', None) + + if not issubclass(field_class, CharField) and not issubclass(field_class, ChoiceField): # NOQA + # `allow_blank` is only valid for textual fields. + field_kwargs.pop('allow_blank', None) + + if postgres_fields and isinstance(model_field, postgres_fields.ArrayField): # NOQA + # Populate the `child` argument on `ListField` instances generated + # for the PostgreSQL specific `ArrayField`. + child_model_field = model_field.base_field + child_field_class, child_field_kwargs = self.build_standard_field( + 'child', child_model_field + ) + field_kwargs['child'] = child_field_class(**child_field_kwargs) + + return field_class, field_kwargs + + def build_relational_field(self, *args, **kwargs): + """ + Create fields for forward and reverse relationships. + """ + field_name, relation_info = args + field_class = self.serializer_related_field + field_kwargs = get_relation_kwargs(field_name, relation_info) + + to_field = field_kwargs.pop('to_field', None) + if to_field and not relation_info.related_model._meta.get_field(to_field).primary_key: # NOQA + field_kwargs['slug_field'] = to_field + field_class = self.serializer_related_to_field + + # `view_name` is only valid for hyperlinked relationships. + if not issubclass(field_class, HyperlinkedRelatedField): + field_kwargs.pop('view_name', None) + + return field_class, field_kwargs + + def build_nested_field(self, *args, **kwargs): + """ + Create nested fields for forward and reverse relationships. + """ + field_name, relation_info, nested_depth = args + + class NestedSerializer(ModelSerializer): + + class Meta: + model = relation_info.related_model + depth = nested_depth - 1 + + field_class = NestedSerializer + field_kwargs = get_nested_relation_kwargs(relation_info) + + return field_class, field_kwargs + + def build_property_field(self, *args, **kwargs): + """ + Create a read only field for model methods and properties. + """ + field_class = ReadOnlyField + field_kwargs = {} + + return field_class, field_kwargs + + def build_url_field(self, field_name, model_class): + """ + Create a field representing the object's own URL. + """ + field_class = self.serializer_url_field + field_kwargs = get_url_kwargs(model_class) + + return field_class, field_kwargs + + def _bind_field(self, model, source, model_fields): + """ + Bind passed field to model serializer. + Used internally by `_get_model_fields`. + """ + try: + field = model._meta.get_field(source) + if isinstance(field, DjangoModelField): + model_fields[source] = field + except FieldDoesNotExist: + pass + + +class HyperlinkedModelSerializer(HyperlinkedModelSerializerMixin, + ModelSerializer): + """ + A type of `ModelSerializer` that uses hyperlinked relationships instead + of primary key relationships. Specifically: + + * A 'url' field is included instead of the 'id' field. + * Relationships to other instances are hyperlinks, instead of primary keys. + """ + serializer_related_field = HyperlinkedRelatedField + + def build_nested_field(self, field_name, relation_info, nested_depth): + + class NestedSerializer(HyperlinkedModelSerializer): + + class Meta: + model = relation_info.related_model + depth = nested_depth - 1 + + field_class = NestedSerializer + field_kwargs = get_nested_relation_kwargs(relation_info) + + return field_class, field_kwargs diff --git a/aiorest_ws/db/orm/django/validators.py b/aiorest_ws/db/orm/django/validators.py new file mode 100644 index 0000000..9a3d3e0 --- /dev/null +++ b/aiorest_ws/db/orm/django/validators.py @@ -0,0 +1,89 @@ +# -*- coding: utf-8 -*- +""" +We perform uniqueness checks explicitly on the serializer class, rather +the using Django's `.full_clean()`. + +This gives us better separation of concerns, allows us to use single-step +object creation, and makes it possible to switch between using the implicit +`ModelSerializer` class and an equivalent explicit `Serializer` class. +""" +from django.db import DataError +from django.utils.translation import ugettext_lazy as _ + +from aiorest_ws.db.orm.exceptions import ValidationError +from aiorest_ws.utils.representation import smart_repr + +__all__ = [ + 'qs_exists', 'qs_filter', 'UniqueValidator' +] + + +# Robust filter and exist implementations. Ensures that queryset.exists() for +# an invalid value returns `False`, rather than raising an error. +# Refs https://github.com/tomchristie/django-rest-framework/issues/3381 + +def qs_exists(queryset): + try: + return queryset.exists() + except (TypeError, ValueError, DataError): + return False + + +def qs_filter(queryset, **kwargs): + try: + return queryset.filter(**kwargs) + except (TypeError, ValueError, DataError): + return queryset.none() + + +class UniqueValidator(object): + """ + Validator that corresponds to `unique=True` on a model field. + Should be applied to an individual field on the serializer. + """ + message = _('This field must be unique.') + + def __init__(self, queryset, message=None): + self.queryset = queryset + self.serializer_field = None + self.message = message or self.message + + def set_context(self, serializer_field): + """ + This hook is called by the serializer instance, + prior to the validation call being made. + """ + # Determine the underlying model field name. This may not be the + # same as the serializer field name if `source=<>` is set. + self.field_name = serializer_field.source_attrs[-1] + # Determine the existing instance, if this is an update operation. + self.instance = getattr(serializer_field.parent, 'instance', None) + + def filter_queryset(self, value, queryset): + """ + Filter the queryset to all instances matching the given attribute. + """ + filter_kwargs = {self.field_name: value} + return qs_filter(queryset, **filter_kwargs) + + def exclude_current_instance(self, queryset): + """ + If an instance is being updated, then do not include + that instance itself as a uniqueness conflict. + """ + if self.instance is not None: + return queryset.exclude(pk=self.instance.pk) + return queryset + + def __call__(self, value): + queryset = self.queryset + queryset = self.filter_queryset(value, queryset) + queryset = self.exclude_current_instance(queryset) + if qs_exists(queryset): + raise ValidationError(self.message) + + def __repr__(self): + return repr('<%s(queryset=%s)>' % ( + self.__class__.__name__, + smart_repr(self.queryset) + )) diff --git a/aiorest_ws/db/orm/fields.py b/aiorest_ws/db/orm/fields.py index 8078287..2332083 100644 --- a/aiorest_ws/db/orm/fields.py +++ b/aiorest_ws/db/orm/fields.py @@ -221,7 +221,7 @@ def __init__(self, choices, **kwargs): # Allows us to deal with eg. integer choices while supporting either # integer or string input, but still get the correct datatype out. self.choice_strings_to_values = { - key: key for key in self.choices.keys() + str(key): key for key in self.choices.keys() } self.allow_blank = kwargs.pop('allow_blank', False) @@ -232,14 +232,14 @@ def to_internal_value(self, data): return '' try: - return self.choice_strings_to_values[data] + return self.choice_strings_to_values[str(data)] except KeyError: self.raise_error('invalid_choice', input=data) def to_representation(self, value): if value in ('', None): return value - return self.choice_strings_to_values.get(value, value) + return self.choice_strings_to_values.get(str(value), value) class FloatField(AbstractField): @@ -836,7 +836,10 @@ def to_representation(self, data): """ List of object instances -> List of dicts of primitive datatypes. """ - return [self.child.to_representation(item) for item in data] + return [ + self.child.to_representation(item) if item is not None else None + for item in data + ] class DictField(AbstractField): @@ -880,7 +883,7 @@ def to_representation(self, value): List of object instances -> List of dicts of primitive datatypes. """ return { - str(key): self.child.to_representation(val) + str(key): self.child.to_representation(val) if val is not None else None # NOQA for key, val in value.items() } @@ -927,11 +930,6 @@ class ModelField(AbstractField): This is used by `ModelSerializer` when dealing with custom model fields, that do not have a serializer field to be mapped to. """ - default_error_messages = { - 'max_length': "Ensure this field has no more than {max_length} " - "characters." - } - def __init__(self, model_field, **kwargs): self.model_field = model_field super(ModelField, self).__init__(**kwargs) @@ -981,6 +979,7 @@ def bind(self, field_name, parent): # The method name should default to `get_{field_name}`. if self.method_name is None: self.method_name = default_method_name + super(SerializerMethodField, self).bind(field_name, parent) def to_representation(self, value): diff --git a/aiorest_ws/db/orm/relations.py b/aiorest_ws/db/orm/relations.py index fa18661..ce32f33 100644 --- a/aiorest_ws/db/orm/relations.py +++ b/aiorest_ws/db/orm/relations.py @@ -48,6 +48,8 @@ class Hyperlink(str): We use this for hyperlinked URLs that may render as a named link in some contexts, or render as a plain URL in others. """ + is_hyperlink = True + def __new__(self, url, name): ret = str.__new__(self, url) ret.name = name @@ -56,8 +58,6 @@ def __new__(self, url, name): def __getnewargs__(self): return str(self), self.name - is_hyperlink = True - class PKOnlyObject(object): """ diff --git a/aiorest_ws/db/orm/serializers.py b/aiorest_ws/db/orm/serializers.py index fa2b63a..d0f4a93 100644 --- a/aiorest_ws/db/orm/serializers.py +++ b/aiorest_ws/db/orm/serializers.py @@ -962,12 +962,14 @@ def get_uniqueness_extra_kwargs(self, field_names, declared_fields, # arguments to deal with `unique_for` dates that are required to # be in the input data in order to validate it. unique_constraint_names = self._get_unique_constraint_names( - model, model_fields + model, model_fields, field_names ) # Include each of "unique multiple columns" field names, # so long as all the field names are included on the serializer. - unique_constraint_names |= self._get_unique_together_constraints(model) + unique_constraint_names |= self._get_unique_together_constraints( + model, model_fields, field_names + ) # Now we have all the field names that have uniqueness constraints # applied, we can add the extra 'required=...' or 'default=...' @@ -1041,7 +1043,7 @@ def _get_model_fields(self, field_names, declared_fields, extra_kwargs): return model_fields - def _get_unique_constraint_names(self, model, model_fields): + def _get_unique_constraint_names(self, model, model_fields, field_names): """ Return a set of field names, for each column unique constraint. Used internally by `get_uniqueness_extra_kwargs`. @@ -1049,7 +1051,7 @@ def _get_unique_constraint_names(self, model, model_fields): raise NotImplementedError('`_get_unique_constraint_names()` ' 'must be implemented.') - def _get_unique_together_constraints(self, model): + def _get_unique_together_constraints(self, model, model_fields, field_names): # NOQA """ Return a set of field names for a multiple unique constraints. Used internally by `get_uniqueness_extra_kwargs`. diff --git a/aiorest_ws/db/orm/sqlalchemy/serializers.py b/aiorest_ws/db/orm/sqlalchemy/serializers.py index f7d697d..a22c957 100644 --- a/aiorest_ws/db/orm/sqlalchemy/serializers.py +++ b/aiorest_ws/db/orm/sqlalchemy/serializers.py @@ -37,9 +37,7 @@ def get_validation_error_detail(exc): assert isinstance(exc, (ValidationError, AssertionError)) if isinstance(exc, AssertionError): - return { - settings.REST_CONFIG['NON_FIELD_ERRORS_KEY']: exc.args - } + return {settings.REST_CONFIG['NON_FIELD_ERRORS_KEY']: exc.args} elif isinstance(exc.detail, dict): # If errors may be a dict we use the standard {key: list of values}. # Here we ensure that all the values are *lists* of errors. @@ -295,7 +293,37 @@ def get_default_field_names(self, declared_fields, model_info): list(model_info.forward_relations.keys()) )) - def _get_unique_constraint_names(self, model, model_fields): + def _get_model_fields(self, field_names, declared_fields, extra_kwargs): + """ + Returns all the model fields that are being mapped to by fields + on the serializer class. + Returned as a dict of 'model field name' -> 'model field'. + Used internally by `get_uniqueness_field_options`. + """ + model = getattr(self.Meta, 'model') + model_fields = {} + + for field_name in field_names: + if field_name in declared_fields: + # If the field is declared on the serializer + field = declared_fields[field_name] + source = field.source or field_name + else: + try: + source = extra_kwargs[field_name]['source'] + except KeyError: + source = field_name + + if '.' in source or source == '*': + # Model fields will always have a simple source mapping, + # they can't be nested attribute lookups. + continue + + self._bind_field(model, source, model_fields) + + return model_fields + + def _get_unique_constraint_names(self, model, model_fields, field_names): """ Return a set of field names, for each column unique constraint. Used internally by `get_uniqueness_extra_kwargs`. @@ -315,7 +343,7 @@ def _get_unique_constraint_names(self, model, model_fields): unique_constraint_names -= {None} return unique_constraint_names - def _get_unique_together_constraints(self, model): + def _get_unique_together_constraints(self, model, model_fields, field_names): # NOQA """ Return a set of field names for a multiple unique constraints. Used internally by `get_uniqueness_extra_kwargs`. diff --git a/aiorest_ws/utils/date/dateparse.py b/aiorest_ws/utils/date/dateparse.py index c6f55ac..7e3743a 100644 --- a/aiorest_ws/utils/date/dateparse.py +++ b/aiorest_ws/utils/date/dateparse.py @@ -12,8 +12,9 @@ from aiorest_ws.utils.date.timezone import utc, get_fixed_timezone __all__ = [ - 'date_re', 'time_re', 'datetime_re', - 'parse_date', 'parse_time', 'parse_datetime', 'parse_timedelta' + 'date_re', 'time_re', 'datetime_re', 'standard_duration_re', + 'iso8601_duration_re', 'parse_date', 'parse_time', 'parse_datetime', + 'parse_timedelta' ] date_re = re.compile( @@ -32,6 +33,30 @@ r'(?PZ|[+-]\d{2}(?::?\d{2})?)?$' ) +standard_duration_re = re.compile( + r'^' + r'(?:(?P-?\d+) (days?, )?)?' + r'((?:(?P\d+):)(?=\d+:\d+))?' + r'(?:(?P\d+):)?' + r'(?P\d+)' + r'(?:\.(?P\d{1,6})\d{0,6})?' + r'$' +) + +# Support the sections of ISO 8601 date representation that are accepted by +# timedelta +iso8601_duration_re = re.compile( + r'^(?P[-+]?)' + r'P' + r'(?:(?P\d+(.\d+)?)D)?' + r'(?:T' + r'(?:(?P\d+(.\d+)?)H)?' + r'(?:(?P\d+(.\d+)?)M)?' + r'(?:(?P\d+(.\d+)?)S)?' + r')?' + r'$' +) + def parse_date(value): """Parses a string and return a datetime.date. @@ -127,3 +152,20 @@ def parse_timedelta(value): d = d.groupdict(0) return datetime.timedelta(**dict(((k, float(v)) for k, v in d.items()))) + + +def parse_duration(value): + """Parses a duration string and returns a datetime.timedelta. + The preferred format for durations in Django is '%d %H:%M:%S.%f'. + Also supports ISO 8601 representation. + """ + match = standard_duration_re.match(value) + if not match: + match = iso8601_duration_re.match(value) + if match: + kw = match.groupdict() + sign = -1 if kw.pop('sign', '+') == '-' else 1 + if kw.get('microseconds'): + kw['microseconds'] = kw['microseconds'].ljust(6, '0') + kw = {k: float(v) for k, v in kw.items() if v is not None} + return sign * datetime.timedelta(**kw) diff --git a/examples/django_orm/__init__.py b/examples/django_orm/__init__.py new file mode 100644 index 0000000..40a96af --- /dev/null +++ b/examples/django_orm/__init__.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- diff --git a/examples/django_orm/client.py b/examples/django_orm/client.py new file mode 100644 index 0000000..b3aeca7 --- /dev/null +++ b/examples/django_orm/client.py @@ -0,0 +1,111 @@ +# -*- coding: utf-8 -*- +import asyncio +import json + +from hashlib import sha256 +from autobahn.asyncio.websocket import WebSocketClientProtocol, \ + WebSocketClientFactory + + +def hash_password(password): + return sha256(password.encode('utf-8')).hexdigest() + + +class HelloClientProtocol(WebSocketClientProtocol): + + def onOpen(self): + # Create new manufacturer + request = { + 'method': 'POST', + 'url': '/manufacturer/', + 'event_name': 'create-manufacturer', + 'args': { + "name": 'Ford' + } + } + self.sendMessage(json.dumps(request).encode('utf8')) + + # Get information about Audi + request = { + 'method': 'GET', + 'url': '/manufacturer/Audi/', + 'event_name': 'get-manufacturer-detail' + } + self.sendMessage(json.dumps(request).encode('utf8')) + + # Get cars list + request = { + 'method': 'GET', + 'url': '/cars/', + 'event_name': 'get-cars-list' + } + self.sendMessage(json.dumps(request).encode('utf8')) + + # Create new car + request = { + 'method': 'POST', + 'url': '/cars/', + 'event_name': 'create-car', + 'args': { + 'name': 'M5', + 'manufacturer': 2 + } + } + self.sendMessage(json.dumps(request).encode('utf8')) + + # Trying to create new car with same info, but we have taken an error + self.sendMessage(json.dumps(request).encode('utf8')) + + # # Update existing object + request = { + 'method': 'PUT', + 'url': '/cars/Q5/', + 'event_name': 'partial-update-car', + 'args': { + 'name': 'Q7' + } + } + self.sendMessage(json.dumps(request).encode('utf8')) + + # Get the list of manufacturers + request = { + 'method': 'GET', + 'url': '/manufacturer/', + 'event_name': 'get-manufacturer-list', + 'args': {} + } + self.sendMessage(json.dumps(request).encode('utf8')) + + # Update manufacturer + request = { + 'method': 'PUT', + 'url': '/manufacturer/Audi/', + 'event_name': 'update-manufacturer', + 'args': { + 'name': 'Not Audi' + } + } + self.sendMessage(json.dumps(request).encode('utf8')) + + # Get car by name + request = { + 'method': 'GET', + 'url': '/cars/TT/', + 'event_name': 'get-car-detail', + 'args': {} + } + self.sendMessage(json.dumps(request).encode('utf8')) + + def onMessage(self, payload, isBinary): + print("Result: {0}".format(payload.decode('utf8'))) + + +if __name__ == '__main__': + factory = WebSocketClientFactory("ws://localhost:8080", debug=False) + factory.protocol = HelloClientProtocol + + loop = asyncio.get_event_loop() + coro = loop.create_connection(factory, '127.0.0.1', 8080) + loop.run_until_complete(coro) + loop.run_forever() + loop.close() diff --git a/examples/django_orm/server/__init__.py b/examples/django_orm/server/__init__.py new file mode 100644 index 0000000..40a96af --- /dev/null +++ b/examples/django_orm/server/__init__.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- diff --git a/examples/django_orm/server/app/__init__.py b/examples/django_orm/server/app/__init__.py new file mode 100644 index 0000000..5a6eafd --- /dev/null +++ b/examples/django_orm/server/app/__init__.py @@ -0,0 +1,17 @@ +# -*- coding: utf-8 -*- +import sys +from os import path + +from django import setup as django_setup +from django.conf import settings as django_settings + +from aiorest_ws.conf import settings + + +sys.path.insert(0, path.abspath(path.join(path.dirname(__file__), '..'))) +django_settings.configure( + DATABASES=settings.DATABASES, + INSTALLED_APPS=settings.INSTALLED_APPS +) +django_setup() + diff --git a/examples/django_orm/server/app/db.py b/examples/django_orm/server/app/db.py new file mode 100644 index 0000000..fcf86fc --- /dev/null +++ b/examples/django_orm/server/app/db.py @@ -0,0 +1,23 @@ +# -*- coding: utf-8 -*- +from django.db import models + + +class Manufacturer(models.Model): + name = models.CharField(max_length=30) + + class Meta: + app_label = 'django_orm_example' + + def __str__(self): + return '' % (self.id, self.name) + + +class Car(models.Model): + name = models.CharField(max_length=30, unique=True) + manufacturer = models.ForeignKey(Manufacturer, related_name='cars') + + class Meta: + app_label = 'django_orm_example' + + def __str__(self): + return '' % (self.id, self.name, self.manufacturer) diff --git a/examples/django_orm/server/app/serializers.py b/examples/django_orm/server/app/serializers.py new file mode 100644 index 0000000..9a9a862 --- /dev/null +++ b/examples/django_orm/server/app/serializers.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- +from aiorest_ws.db.orm.django import serializers + +from app.db import Manufacturer, Car + + +class ManufacturerSerializer(serializers.ModelSerializer): + + class Meta: + model = Manufacturer + + +class CarSerializer(serializers.ModelSerializer): + + class Meta: + model = Car diff --git a/examples/django_orm/server/app/urls.py b/examples/django_orm/server/app/urls.py new file mode 100644 index 0000000..e1121a7 --- /dev/null +++ b/examples/django_orm/server/app/urls.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +from aiorest_ws.routers import SimpleRouter + +from app.views import ManufacturerListView, ManufacturerView, CarListView, \ + CarView + + +router = SimpleRouter() +router.register('/manufacturer/{name}', ManufacturerView, ['GET', 'PUT'], + name='manufacturer-detail') +router.register('/manufacturer/', ManufacturerListView, ['POST']) +router.register('/cars/{name}', CarView, ['GET', 'PUT'], + name='car-detail') +router.register('/cars/', CarListView, ['POST']) diff --git a/examples/django_orm/server/app/views.py b/examples/django_orm/server/app/views.py new file mode 100644 index 0000000..ca4b460 --- /dev/null +++ b/examples/django_orm/server/app/views.py @@ -0,0 +1,98 @@ +# -*- coding: utf-8 -*- +from aiorest_ws.db.orm.exceptions import ValidationError +from aiorest_ws.views import MethodBasedView + +from django.core.exceptions import ObjectDoesNotExist + +from app.db import Manufacturer, Car +from app.serializers import ManufacturerSerializer, CarSerializer + + +class ManufacturerListView(MethodBasedView): + + def get(self, request, *args, **kwargs): + instances = Manufacturer.objects.all() + serializer = ManufacturerSerializer(instances, many=True) + return serializer.data + + def post(self, request, *args, **kwargs): + data = kwargs.get('params', None) + if not data: + raise ValidationError('You must provide arguments for create.') + + serializer = ManufacturerSerializer(data=data) + serializer.is_valid(raise_exception=True) + serializer.save() + return serializer.data + + +class ManufacturerView(MethodBasedView): + + def get_manufacturer(self, name): + try: + manufacturer = Manufacturer.objects.get(name__iexact=name) + except ObjectDoesNotExist: + raise ValidationError("The requested object does not exist") + + return manufacturer + + def get(self, request, name, *args, **kwargs): + manufacturer = self.get_manufacturer(name) + serializer = ManufacturerSerializer(manufacturer) + return serializer.data + + def put(self, request, name, *args, **kwargs): + data = kwargs.get('params', None) + if not data: + raise ValidationError('You must provide arguments for create.') + + instance = self.get_manufacturer(name) + serializer = ManufacturerSerializer(instance, data=data, partial=True) + serializer.is_valid(raise_exception=True) + serializer.save() + return serializer.data + + +class CarListView(MethodBasedView): + + def get(self, request, *args, **kwargs): + data = Car.objects.all() + serializer = CarSerializer(data, many=True) + return serializer.data + + def post(self, request, *args, **kwargs): + data = kwargs.get('params', None) + if not data: + raise ValidationError('You must provide arguments for create.') + + serializer = CarSerializer(data=data) + serializer.is_valid(raise_exception=True) + serializer.save() + return serializer.data + + +class CarView(MethodBasedView): + + def get_car(self, name): + try: + car = Car.objects.get(name__iexact=name) + except ObjectDoesNotExist: + raise ValidationError("The requested object does not exist") + + return car + + def get(self, request, name, *args, **kwargs): + instance = self.get_car(name) + serializer = CarSerializer(instance) + return serializer.data + + def put(self, request, name, *args, **kwargs): + data = kwargs.get('params', None) + if not data: + raise ValidationError('You must provide data for update.') + + instance = self.get_car(name) + serializer = CarSerializer(instance, data=data, partial=True) + serializer.is_valid(raise_exception=True) + serializer.save() + return serializer.data diff --git a/examples/django_orm/server/create_db.py b/examples/django_orm/server/create_db.py new file mode 100644 index 0000000..c8a205b --- /dev/null +++ b/examples/django_orm/server/create_db.py @@ -0,0 +1,25 @@ +# -*- coding: utf-8 -*- +from app.db import Manufacturer, Car + +from django.db import connections + + +if __name__ == '__main__': + # Create tables in memory + conn = connections['default'] + with conn.schema_editor() as editor: + editor.create_model(Manufacturer) + editor.create_model(Car) + + # Initialize the database + data = { + 'Audi': ['A8', 'Q5', 'TT'], + 'BMW': ['M3', 'i8'], + 'Mercedes-Benz': ['C43 AMG W202', 'C450 AMG 4MATIC'] + } + + for name, models in data.items(): + manufacturer = Manufacturer.objects.create(name=name) + + for model in models: + Car.objects.create(name=model, manufacturer=manufacturer) diff --git a/examples/django_orm/server/run_server.py b/examples/django_orm/server/run_server.py new file mode 100644 index 0000000..ddef191 --- /dev/null +++ b/examples/django_orm/server/run_server.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- +import os + + +def get_module_content(filename): + with open(filename, 'r') as module: + return module.read() + + +if __name__ == __name__: + os.environ.setdefault("AIORESTWS_SETTINGS_MODULE", "settings") + exec(get_module_content('./create_db.py')) + exec(get_module_content('./server.py')) diff --git a/examples/django_orm/server/server.py b/examples/django_orm/server/server.py new file mode 100644 index 0000000..ceb8349 --- /dev/null +++ b/examples/django_orm/server/server.py @@ -0,0 +1,19 @@ +# -*- coding: utf-8 -*- +from aiorest_ws.app import Application +from aiorest_ws.command_line import CommandLine +from aiorest_ws.routers import SimpleRouter + +from app.urls import router + +main_router = SimpleRouter() +main_router.include(router) + + +if __name__ == '__main__': + cmd = CommandLine() + cmd.define('-ip', default='127.0.0.1', help='used ip', type=str) + cmd.define('-port', default=8080, help='listened port', type=int) + args = cmd.parse_command_line() + + app = Application() + app.run(host=args.ip, port=args.port, router=main_router) diff --git a/examples/django_orm/server/settings.py b/examples/django_orm/server/settings.py new file mode 100644 index 0000000..2404649 --- /dev/null +++ b/examples/django_orm/server/settings.py @@ -0,0 +1,10 @@ +# -*- coding: utf-8 -*- + +USE_ORM_ENGINE = True +DATABASES = { + 'default': { + 'ENGINE': 'django.db.backends.sqlite3', + 'NAME': ':memory:', + } +} +INSTALLED_APPS = ("app", ) diff --git a/examples/sqlalchemy_orm/server/app/views.py b/examples/sqlalchemy_orm/server/app/views.py index aece86e..0ecd857 100644 --- a/examples/sqlalchemy_orm/server/app/views.py +++ b/examples/sqlalchemy_orm/server/app/views.py @@ -19,7 +19,7 @@ def post(self, *args, **kwargs): if not data: raise ValidationError('You must provide arguments for create.') - created_obj_data = data.get('list' , []) + created_obj_data = data.get('list', []) if not data: raise ValidationError('You must provide a list of objects.') diff --git a/tests/db/orm/django/__init__.py b/tests/db/orm/django/__init__.py new file mode 100644 index 0000000..eb9e0b9 --- /dev/null +++ b/tests/db/orm/django/__init__.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- +from django import setup +from django.conf import settings + + +if not settings.configured: + settings.configure( + DATABASES={ + 'default': { + 'ENGINE': 'django.db.backends.sqlite3', + 'NAME': ':memory:', + } + }, + ) + setup() diff --git a/tests/db/orm/django/base.py b/tests/db/orm/django/base.py new file mode 100644 index 0000000..58e7f53 --- /dev/null +++ b/tests/db/orm/django/base.py @@ -0,0 +1,31 @@ +# -*- coding: utf-8 -*- +import unittest + +from django.apps import apps +from django.db import connections + + +class DjangoUnitTest(unittest.TestCase): + + models = [] + apps = () + + @classmethod + def setUpClass(cls): + super(DjangoUnitTest, cls).setUpClass() + if apps: + apps.populate(cls.apps) + + conn = connections['default'] + with conn.schema_editor() as editor: + for model in cls.models: + editor.create_model(model) + + @classmethod + def tearDownClass(cls): + super(DjangoUnitTest, cls).tearDownClass() + + conn = connections['default'] + with conn.schema_editor() as editor: + for model in cls.models: + editor.delete_model(model) diff --git a/tests/db/orm/django/test_fields.py b/tests/db/orm/django/test_fields.py new file mode 100644 index 0000000..1b49919 --- /dev/null +++ b/tests/db/orm/django/test_fields.py @@ -0,0 +1,1558 @@ +# -*- coding: utf-8 -*- +import datetime +import os +import re +import uuid +from decimal import Decimal + +from aiorest_ws.conf import settings +from aiorest_ws.db.orm.abstract import empty, SkipField +from aiorest_ws.db.orm.django import fields +from aiorest_ws.db.orm.exceptions import ValidationError +from aiorest_ws.utils.date import timezone + +from django.db import models +from django.core.exceptions import ValidationError as DjangoValidationError + +from tests.db.orm.django.base import DjangoUnitTest + + +class TestIntegerField(DjangoUnitTest): + + def test_init_without_borders(self): + instance = fields.IntegerField() + self.assertIsNone(instance.max_value) + self.assertIsNone(instance.min_value) + + def test_init_with_min_value(self): + instance = fields.IntegerField(min_value=10) + self.assertEqual(instance.min_value, 10) + self.assertIsNone(instance.max_value) + + def test_init_with_max_value(self): + instance = fields.IntegerField(max_value=10) + self.assertIsNone(instance.min_value) + self.assertEqual(instance.max_value, 10) + + def test_init_with_min_and_max_values(self): + instance = fields.IntegerField(min_value=1, max_value=10) + self.assertEqual(instance.min_value, 1) + self.assertEqual(instance.max_value, 10) + + def test_to_internal_value(self): + instance = fields.IntegerField() + self.assertEqual(instance.to_internal_value(1), 1) + + def test_to_internal_value_raises_max_string_length_exception(self): + instance = fields.IntegerField() + + with self.assertRaises(ValidationError): + data = 'value' * 250 + instance.to_internal_value(data) + + def test_to_internal_value_raises_validate_exception(self): + instance = fields.IntegerField() + + with self.assertRaises(ValidationError): + instance.to_internal_value('object') + + def test_to_to_representation(self): + instance = fields.IntegerField() + self.assertEqual(instance.to_representation('1'), 1) + + +class TestBooleanField(DjangoUnitTest): + + def test_init_raises_assertion_error(self): + with self.assertRaises(AssertionError): + fields.BooleanField(allow_null=True) + + def test_to_internal_value_returns_true_value(self): + instance = fields.BooleanField() + + for value in instance.TRUE_VALUES: + self.assertTrue(instance.to_internal_value(value)) + + def test_to_internal_value_returns_false_value(self): + instance = fields.BooleanField() + + for value in instance.FALSE_VALUES: + self.assertFalse(instance.to_internal_value(value)) + + def test_to_internal_value_raises_validate_exception(self): + instance = fields.BooleanField() + + with self.assertRaises(ValidationError): + instance.to_internal_value(object()) + + def test_to_internal_value_with_unhashable_value(self): + instance = fields.BooleanField() + + with self.assertRaises(ValidationError): + instance.to_internal_value({}) + + def test_to_representation_returns_true_value(self): + instance = fields.BooleanField() + + for value in instance.TRUE_VALUES: + self.assertTrue(instance.to_representation(value)) + + def test_to_representation_returns_false_value(self): + instance = fields.BooleanField() + + for value in instance.FALSE_VALUES: + self.assertFalse(instance.to_representation(value)) + + def test_to_representation_return_true_for_not_defined_values(self): + instance = fields.BooleanField() + self.assertTrue(instance.to_representation(object())) + self.assertTrue(instance.to_representation("value")) + + def test_to_representation_return_false_for_not_defined_values(self): + instance = fields.BooleanField() + self.assertFalse(instance.to_representation(())) + self.assertFalse(instance.to_representation("")) + + +class TestNullBooleanField(DjangoUnitTest): + + def test_to_internal_value_returns_true_value(self): + instance = fields.NullBooleanField() + + for value in instance.TRUE_VALUES: + self.assertTrue(instance.to_internal_value(value)) + + def test_to_internal_value_returns_false_value(self): + instance = fields.NullBooleanField() + + for value in instance.FALSE_VALUES: + self.assertFalse(instance.to_internal_value(value)) + + def test_to_internal_value_returns_none_value(self): + instance = fields.NullBooleanField() + + for value in instance.NULL_VALUES: + self.assertIsNone(instance.to_internal_value(value)) + + def test_to_internal_value_raises_validate_exception(self): + instance = fields.NullBooleanField() + + with self.assertRaises(ValidationError): + instance.to_internal_value(object()) + + def test_to_representation_returns_true_value(self): + instance = fields.NullBooleanField() + + for value in instance.TRUE_VALUES: + self.assertTrue(instance.to_representation(value)) + + def test_to_representation_returns_false_value(self): + instance = fields.NullBooleanField() + + for value in instance.FALSE_VALUES: + self.assertFalse(instance.to_representation(value)) + + def test_to_representation_returns_none_value(self): + instance = fields.NullBooleanField() + + for value in instance.NULL_VALUES: + self.assertIsNone(instance.to_representation(value)) + + def test_to_representation_return_true_for_not_defined_values(self): + instance = fields.NullBooleanField() + self.assertTrue(instance.to_representation(object())) + self.assertTrue(instance.to_representation("value")) + + def test_to_representation_return_false_for_not_defined_values(self): + instance = fields.NullBooleanField() + self.assertFalse(instance.to_representation(())) + self.assertFalse(instance.to_representation("")) + + +class TestCharField(DjangoUnitTest): + + def test_init_default(self): + instance = fields.CharField() + self.assertFalse(instance.allow_blank) + self.assertTrue(instance.trim_whitespace) + self.assertIsNone(instance.min_length) + self.assertIsNone(instance.max_length) + + def test_run_validation(self): + instance = fields.CharField() + self.assertEqual(instance.to_internal_value('test'), 'test') + + def test_run_validation_raise_validation_error_for_too_short_string(self): + instance = fields.CharField(min_length=5) + + with self.assertRaises(ValidationError): + instance.run_validation('test') + + def test_run_validation_raise_validation_error_for_too_long_string(self): + instance = fields.CharField(max_length=3) + + with self.assertRaises(ValidationError): + instance.run_validation('test') + + def test_run_validation_raise_validation_error_for_blank_field(self): + instance = fields.CharField(allow_blank=False) + + with self.assertRaises(ValidationError): + instance.run_validation('') + + def test_run_validation_returns_empty_string(self): + instance = fields.CharField(allow_blank=True) + self.assertEqual(instance.run_validation(''), '') + + def test_to_internal_value(self): + instance = fields.CharField() + self.assertEqual(instance.to_internal_value(' value '), 'value') + + def test_to_internal_value_without_trim_whitespace(self): + instance = fields.CharField(trim_whitespace=False) + self.assertEqual(instance.to_internal_value(' value '), ' value ') + + def test_disallow_blank_with_trim_whitespace(self): + instance = fields.CharField(allow_blank=False, trim_whitespace=True) + + with self.assertRaises(ValidationError): + instance.run_validation(' ') + + def test_to_representation(self): + instance = fields.CharField() + self.assertEqual(instance.to_representation('test'), 'test') + + +class TestChoiceField(DjangoUnitTest): + + choices = ( + (1, 'one'), + (2, 'two'), + (3, 'three') + ) + + def test_to_internal_value(self): + instance = fields.ChoiceField(choices=self.choices) + self.assertEqual(instance.to_internal_value(1), 1) + + def test_to_internal_value_for_empty_string(self): + instance = fields.ChoiceField(self.choices, allow_blank=True) + self.assertEqual(instance.to_internal_value(''), '') + + def test_to_internal_value_raise_validation_error(self): + instance = fields.ChoiceField(self.choices) + + with self.assertRaises(ValidationError): + instance.to_internal_value(4) + + def test_to_representation(self): + instance = fields.ChoiceField(self.choices) + self.assertEqual(instance.to_representation(2), 2) + + def test_to_representation_empty_string(self): + instance = fields.ChoiceField(self.choices) + self.assertEqual(instance.to_representation(''), '') + + def test_to_representation_none_value(self): + instance = fields.ChoiceField(self.choices) + self.assertEqual(instance.to_representation(None), None) + + def test_to_representation_not_found_key(self): + instance = fields.ChoiceField(self.choices) + self.assertEqual(instance.to_representation(4), 4) + + +class TestMultipleChoiceField(DjangoUnitTest): + + choices = ( + (1, 'one'), + (2, 'two'), + (3, 'three') + ) + + def test_get_value_returns_empty(self): + + class FakeSerializer(): + parent = None + partial = True + + instance = fields.MultipleChoiceField(self.choices) + instance.bind('choice', self) + instance.parent = FakeSerializer() + self.assertEqual(instance.get_value({}), empty) + + def test_get_value_returns_value_from_dictionary(self): + instance = fields.MultipleChoiceField(self.choices) + instance.bind('choice', self) + self.assertEqual(instance.get_value({'choice': 1}), 1) + + def test_to_internal_value(self): + instance = fields.MultipleChoiceField(self.choices) + instance.bind('choice', self) + self.assertEqual(instance.to_internal_value((1, )), {1}) + + def test_to_internal_value_raises_not_a_list_error(self): + instance = fields.MultipleChoiceField(self.choices) + instance.bind('choice', self) + self.assertRaises( + ValidationError, + instance.to_internal_value, '' + ) + + def test_to_internal_value_raises_empty_error(self): + instance = fields.MultipleChoiceField(self.choices, allow_empty=False) + instance.bind('choice', self) + self.assertRaises( + ValidationError, + instance.to_internal_value, () + ) + + def test_to_representation_returns_dictionary(self): + instance = fields.MultipleChoiceField(self.choices) + instance.bind('choice', self) + self.assertEqual(instance.to_representation((1, )), {1}) + + +class TestFloatField(DjangoUnitTest): + + def test_init_default(self): + instance = fields.FloatField() + self.assertIsNone(instance.min_value) + self.assertIsNone(instance.max_value) + + def test_run_validation_without_borders(self): + instance = fields.FloatField() + self.assertEqual(instance.run_validation(5.0), 5.0) + + def test_run_validation_with_defined_min_value(self): + instance = fields.FloatField(min_value=10.0) + + with self.assertRaises(ValidationError): + instance.run_validation(5.0) + + def test_run_validation_with_defined_max_value(self): + instance = fields.FloatField(max_value=10.0) + + with self.assertRaises(ValidationError): + instance.run_validation(11.0) + + def test_to_internal_value(self): + instance = fields.FloatField() + self.assertEqual(instance.to_internal_value(5), 5) + + def test_to_internal_value_raises_validation_error(self): + instance = fields.FloatField() + + with self.assertRaises(ValidationError): + instance.to_internal_value(None) + + def test_to_internal_value_raises_validation_error_for_too_long_str(self): + instance = fields.FloatField(max_value=10) + + with self.assertRaises(ValidationError): + instance.to_internal_value('test' * 255) + + def test_to_representation(self): + instance = fields.FloatField() + self.assertEqual(instance.to_representation(5), 5.0) + + +class TestDecimalField(DjangoUnitTest): + + def test_init_default(self): + instance = fields.DecimalField(max_digits=5, decimal_places=2) + self.assertEqual(instance.max_digits, 5) + self.assertEqual(instance.decimal_places, 2) + self.assertEqual(instance.max_whole_digits, 3) + + def test_init_with_not_defined_max_whole_digits(self): + instance = fields.DecimalField(max_digits=None, decimal_places=None) + self.assertIsNone(instance.max_digits) + self.assertIsNone(instance.decimal_places) + self.assertIsNone(instance.max_whole_digits) + + def test_run_validation(self): + instance = fields.DecimalField(max_digits=5, decimal_places=2) + self.assertEqual(instance.run_validation(99), 99) + + def test_run_validation_raises_validation_error_for_gt_max_value(self): + instance = fields.DecimalField( + max_digits=5, decimal_places=2, max_value=90 + ) + + with self.assertRaises(ValidationError): + instance.run_validation(99) + + def test_run_validation_raises_validation_error_for_lt_min_value(self): + instance = fields.DecimalField( + max_digits=5, decimal_places=2, min_value=10, + ) + + with self.assertRaises(ValidationError): + instance.run_validation(9) + + def test_validate_precision_with_exponent(self): + instance = fields.DecimalField(max_digits=5, decimal_places=0) + + value = Decimal('12345') + self.assertEqual(instance.validate_precision(value), value) + + def test_validate_precision_with_digittuple(self): + instance = fields.DecimalField(max_digits=7, decimal_places=2) + + value = Decimal('12345.0') + self.assertEqual(instance.validate_precision(value), value) + + def test_validate_precision_with_fraction(self): + instance = fields.DecimalField(max_digits=7, decimal_places=5) + + value = Decimal('0.01234') + self.assertEqual(instance.validate_precision(value), value) + + def test_validate_precision_raise_validation_exc_max_digits(self): + instance = fields.DecimalField(max_digits=5, decimal_places=2) + + with self.assertRaises(ValidationError): + instance.validate_precision(Decimal('1234500.0')) + + def test_validate_precision_raise_validation_exc_max_decimal_places(self): + instance = fields.DecimalField(max_digits=9, decimal_places=0) + + with self.assertRaises(ValidationError): + instance.validate_precision(Decimal('1234500.0')) + + def test_validate_precision_raise_validation_exc_max_whole_digits(self): + instance = fields.DecimalField(max_digits=9, decimal_places=7) + + with self.assertRaises(ValidationError): + instance.validate_precision(Decimal('1234500.0')) + + def test_to_internal_value(self): + instance = fields.DecimalField(max_digits=10, decimal_places=5) + self.assertEqual( + instance.to_internal_value(12345.0), Decimal('12345.0') + ) + + def test_to_internal_value_raises_validation_error_for_max_length(self): + instance = fields.DecimalField(max_digits=10, decimal_places=5) + + with self.assertRaises(ValidationError): + instance.to_internal_value('test' * 255) + + def test_to_internal_value_raises_validation_error_for_not_decimal(self): + instance = fields.DecimalField(max_digits=10, decimal_places=5) + + with self.assertRaises(ValidationError): + instance.to_internal_value('None') + + def test_to_internal_value_raises_validation_error_for_NaN(self): + instance = fields.DecimalField(max_digits=10, decimal_places=5) + + with self.assertRaises(ValidationError): + instance.to_internal_value('NaN') + + def test_to_internal_value_raises_validation_error_for_infinity(self): + instance = fields.DecimalField(max_digits=10, decimal_places=5) + + with self.assertRaises(ValidationError): + instance.to_internal_value(float('inf')) # positive infinite + instance.to_internal_value(-float('inf')) # negative infinite + + def test_to_representation_with_decimal_as_a_string(self): + instance = fields.DecimalField(max_digits=10, decimal_places=5) + self.assertEqual(instance.to_representation('12345.0'), '12345.00000') + + def test_to_representation_without_coerce_to_string(self): + instance = fields.DecimalField( + max_digits=10, decimal_places=5, coerce_to_string=False + ) + value = Decimal('12345.0') + self.assertEqual(instance.to_representation(value), value) + + def test_quantize(self): + instance = fields.DecimalField(max_digits=10, decimal_places=5) + + value = Decimal('12345.0') + self.assertEqual(instance.quantize(value), value) + + +class TestTimeField(DjangoUnitTest): + + def test_run_validation_raise_validation_error(self): + instance = fields.TimeField() + + with self.assertRaises(ValidationError): + instance.run_validation('value') + + def test_to_internal_value_string(self): + instance = fields.TimeField() + self.assertEqual( + instance.to_internal_value('03:00'), + datetime.time(3, 0) + ) + + def test_to_internal_value_time(self): + instance = fields.TimeField() + self.assertEqual( + instance.to_internal_value(datetime.time(3, 0)), + datetime.time(3, 0) + ) + + def test_to_internal_value_for_non_iso8601(self): + instance = fields.TimeField(input_formats=('%H:%M', )) + self.assertEqual( + instance.to_internal_value('10:00'), + datetime.time(10, 0) + ) + + def test_to_internal_value_raises_validation_error_with_empty_format(self): + instance = fields.TimeField(input_formats=()) + + with self.assertRaises(ValidationError): + instance.to_internal_value('99:99') + + def test_to_internal_value_raises_validation_error_for_a_wrong_type(self): + instance = fields.TimeField() + + with self.assertRaises(ValidationError): + instance.to_internal_value(None) + + def test_to_internal_value_raises_error_for_wrong_value_and_format(self): + instance = fields.TimeField(input_formats=('%H:%M',)) + + with self.assertRaises(ValidationError): + instance.to_internal_value('99:99') + + def test_to_internal_value_raises_error_for_none_value_and_format(self): + instance = fields.TimeField(input_formats=('%H:%M',)) + + with self.assertRaises(ValidationError): + instance.to_internal_value(None) + + def test_to_internal_value_raises_validation_error_for_wrong_value(self): + instance = fields.TimeField() + + with self.assertRaises(ValidationError): + instance.to_internal_value('99:99') + + def test_to_representation(self): + instance = fields.TimeField(format='%H:%M:%S') + timestamp = datetime.time(3, 0) + self.assertEqual( + instance.to_representation(timestamp), '03:00:00' + ) + + def test_to_representation_returns_none_for_empty_string(self): + instance = fields.TimeField() + self.assertIsNone(instance.to_representation('')) + + def test_to_representation_returns_none(self): + instance = fields.TimeField() + self.assertIsNone(instance.to_representation(None)) + + def test_to_representation_return_value(self): + instance = fields.TimeField(format=None) + timestamp = datetime.time(13, 0) + self.assertEqual(instance.to_representation(timestamp), timestamp) + + def test_to_representation_parse_string_into_iso8601_string(self): + instance = fields.TimeField() + self.assertEqual( + instance.to_representation('10:00:00'), '10:00:00' + ) + + def test_to_representation_parse_time_into_iso8601_string(self): + instance = fields.TimeField() + self.assertEqual( + instance.to_representation(datetime.time(10, 0)), '10:00:00' + ) + + def test_to_representation_raise_assertion_error(self): + instance = fields.TimeField(format=settings.ISO_8601) + with self.assertRaises(AssertionError): + instance.to_representation(datetime.datetime(2000, 1, 1, 10, 00)) + + +class TestDateFields(DjangoUnitTest): + + def test_run_validation_raises_validation_error_for_wrong_value(self): + instance = fields.DateField() + + with self.assertRaises(ValidationError): + instance.run_validation('value') + + def test_run_validation_raises_validation_error_for_wrong_datetime(self): + instance = fields.DateField() + + with self.assertRaises(ValidationError): + instance.run_validation('2001-99-99') + + def test_run_validation_raises_validation_error_for_a_wrong_type(self): + instance = fields.DateField() + + with self.assertRaises(ValidationError): + instance.run_validation(datetime.datetime(2000, 1, 1, 1, 0)) + + def test_to_internal_value_raises_validation_error_for_wrong_type(self): + instance = fields.DateField() + + with self.assertRaises(ValidationError): + instance.to_internal_value(datetime.datetime(2000, 1, 1, 1, 0)) + + def test_to_internal_value_returns_date_instance(self): + instance = fields.DateField() + value = datetime.date(2000, 1, 1) + self.assertEqual(instance.to_internal_value(value), value) + + def test_to_internal_value_returns_parsed_string_for_iso8601(self): + instance = fields.DateField() + self.assertEqual( + instance.to_internal_value('2000-01-01'), + datetime.date(2000, 1, 1) + ) + + def test_to_internal_value_raises_validation_error_for_wrong_date(self): + instance = fields.DateField() + + with self.assertRaises(ValidationError): + instance.to_internal_value('2000-99-99') + + def test_to_internal_value_raises_validation_error_for_wrong_value(self): + instance = fields.DateField() + + with self.assertRaises(ValidationError): + instance.to_internal_value(None) + + def test_to_internal_value_with_user_format(self): + instance = fields.DateField(input_formats=('%Y-%m-%d', )) + self.assertEqual( + instance.to_internal_value('2000-01-01'), + datetime.date(2000, 1, 1) + ) + + def test_to_internal_value_with_format_raises_error_for_wrong_value(self): + instance = fields.DateField(input_formats=('%Y-%m-%d', )) + + with self.assertRaises(ValidationError): + instance.to_internal_value('2000-99-99') + + def test_to_internal_value_with_format_raises_error_for_wrong_type(self): + instance = fields.DateField(input_formats=('%Y-%m-%d', )) + + with self.assertRaises(ValidationError): + instance.to_internal_value(None) + + def test_to_representation_returns_none(self): + instance = fields.DateField() + self.assertIsNone(instance.to_representation(None)) + + def test_to_representation_returns_empty_string(self): + instance = fields.DateField() + self.assertEqual(instance.to_representation(''), None) + + def test_to_representation_with_none_output_format(self): + instance = fields.DateField(format=None) + self.assertEqual( + instance.to_representation('2000-01-01'), '2000-01-01' + ) + + def test_to_representation_raises_assertion_error_for_a_wrong_type(self): + instance = fields.DateField() + + with self.assertRaises(AssertionError): + instance.to_representation(datetime.datetime(2000, 1, 1)) + + def test_to_representation_returns_value_in_uso8601_for_string(self): + instance = fields.DateField() + self.assertEqual( + instance.to_representation('2000-01-01'), '2000-01-01' + ) + + def test_to_representation_returns_value_in_uso8601_for_date(self): + instance = fields.DateField() + self.assertEqual( + instance.to_representation(datetime.date(2000, 1, 1)), '2000-01-01' + ) + + def test_to_representation_with_custom_date_format(self): + instance = fields.DateField(format="%Y-%m-%d") + self.assertEqual( + instance.to_representation(datetime.date(2000, 1, 1)), '2000-01-01' + ) + + +class TestDateTimeField(DjangoUnitTest): + + def test_run_validation_raises_validation_error_for_a_wrong_type(self): + instance = fields.DateTimeField() + + with self.assertRaises(ValidationError): + instance.run_validation(None) + + def test_run_validation_raises_validation_error_for_a_wrong_value(self): + instance = fields.DateTimeField() + + with self.assertRaises(ValidationError): + instance.run_validation('value') + + def test_enforce_timezone_returns_naive_datetime(self): + instance = fields.DateTimeField() + value = datetime.datetime(2000, 1, 1, 10, 0) + self.assertEqual(instance.enforce_timezone(value), value) + + def test_enforce_timezone_returns_aware_datetime_with_utc_timezone(self): + instance = fields.DateTimeField(default_timezone=timezone.UTC()) + self.assertEqual( + instance.enforce_timezone(datetime.datetime(2000, 1, 1, 10, 0)), + datetime.datetime(2000, 1, 1, 10, 0, tzinfo=timezone.UTC()) + ) + + def test_enforce_timezone_returns_naive_datetime_with_utc_timezone(self): + instance = fields.DateTimeField() + value = datetime.datetime(2000, 1, 1, 10, 0, tzinfo=timezone.UTC()) + self.assertEqual( + instance.enforce_timezone(value), + datetime.datetime(2000, 1, 1, 10, 0) + ) + + def test_to_internal_value_returns_datetime_with_enforce_datetime(self): + instance = fields.DateTimeField() + self.assertEqual( + instance.to_internal_value(datetime.datetime(2000, 1, 1)), + datetime.datetime(2000, 1, 1) + ) + + def test_to_internal_value_raises_validation_error_for_date_type(self): + instance = fields.DateTimeField() + + with self.assertRaises(ValidationError): + instance.to_internal_value(datetime.date(2000, 1, 1)) + + def test_to_internal_value_returns_datetime_in_iso8601(self): + instance = fields.DateTimeField() + + self.assertEqual( + instance.to_internal_value('2000-01-01 10:00'), + datetime.datetime(2000, 1, 1, 10, 0) + ) + + def test_to_internal_value_raises_validation_error_with_a_wrong_type(self): + instance = fields.DateTimeField() + + with self.assertRaises(ValidationError): + instance.to_internal_value(None) + + def test_to_internal_value_raises_validation_error_for_invalid_value(self): + instance = fields.DateTimeField() + + with self.assertRaises(ValidationError): + instance.to_internal_value('2000-99-99 10:00') + + def test_to_internal_value_with_format_returns_datetime(self): + instance = fields.DateTimeField(input_formats=("%Y-%m-%d %H:%M", )) + self.assertEqual( + instance.to_internal_value('2000-01-01 10:00'), + datetime.datetime(2000, 1, 1, 10, 0) + ) + + def test_to_internal_value_with_format_raises_exc_for_a_wrong_type(self): + instance = fields.DateTimeField(input_formats=("%Y-%m-%d %H:%M",)) + + with self.assertRaises(ValidationError): + instance.to_internal_value(None) + + def test_to_internal_value_with_format_raises_exc_for_invalid_value(self): + instance = fields.DateTimeField(input_formats=("%Y-%m-%d %H:%M",)) + + with self.assertRaises(ValidationError): + instance.to_internal_value('2000-99-99 10:00') + + def test_to_representation_returns_none_for_empty_string(self): + instance = fields.DateTimeField() + self.assertIsNone(instance.to_representation('')) + + def test_to_representation_returns_none_for_none_type(self): + instance = fields.DateTimeField() + self.assertIsNone(instance.to_representation(None)) + + def test_to_representation_with_none_output_format_returns_value(self): + instance = fields.DateTimeField(format=None) + self.assertEqual( + instance.to_representation('2000-01-01T10:00:00Z'), + '2000-01-01T10:00:00Z' + ) + + def test_to_representation_returns_value_in_iso8601(self): + instance = fields.DateTimeField() + value = datetime.datetime(2000, 1, 1, 10, 0, tzinfo=timezone.UTC()) + self.assertEqual( + instance.to_representation(value), + '2000-01-01T10:00:00Z' + ) + + def test_to_representation_returns_value_in_custom_format(self): + instance = fields.DateTimeField(format="%Y-%m-%d %H:%M") + self.assertEqual( + instance.to_representation(datetime.datetime(2000, 1, 1, 10, 0)), + '2000-01-01 10:00' + ) + + +class TestDurationField(DjangoUnitTest): + + def test_to_internal_value(self): + instance = fields.DurationField() + timedelta = datetime.timedelta(days=3) + self.assertEqual(instance.to_internal_value(timedelta), timedelta) + + def test_to_internal_value_from_string(self): + instance = fields.DurationField() + value = '3 10:00:00.123456' + timedelta = datetime.timedelta(days=3, hours=10, microseconds=123456) + self.assertEqual(instance.to_internal_value(value), timedelta) + + def test_to_internal_value_raises_invalid_error(self): + instance = fields.DurationField() + value = 'invalid_value' + self.assertRaises( + ValidationError, + instance.to_internal_value, value + ) + + def test_to_representation(self): + instance = fields.DurationField() + data = datetime.timedelta(days=3, hours=10, microseconds=123456) + output = '3 10:00:00.123456' + self.assertEqual(instance.to_representation(data), output) + + +class CustomStringField(models.CharField): + pass + + +class TestModelField(DjangoUnitTest): + + class User(models.Model): + name = CustomStringField(max_length=30) + + class Meta: + app_label = 'test_django_model_field' + + def __str__(self): + return '' % (self.name) + + class Car(models.Model): + name = models.CharField(max_length=30) + max_speed = models.FloatField(null=True, blank=True) + manufacturer = models.ForeignKey( + "test_django_model_field.User", related_name='cars' + ) + + class Meta: + app_label = 'test_django_model_field' + + def __str__(self): + return '' % (self.name, self.manufacturer) + + models = (User, Car) + apps = ('test_django_model_field', ) + + def test_to_internal_value(self): + model_field = self.User._meta.get_field('name') + instance = fields.ModelField(model_field, max_length=30) + self.assertEqual(instance.to_internal_value('data'), 'data') + + def test_to_internal_value_with_related_field(self): + model_field = self.Car._meta.get_field('manufacturer') + instance = fields.ModelField(model_field) + self.assertEqual(instance.to_internal_value(None), None) + + def test_to_representation(self): + + class FakeObject(): + name = 'i535' + + model_field = self.Car._meta.get_field('name') + instance = fields.ModelField(model_field) + self.assertEqual(instance.to_representation(FakeObject), 'i535') + + def test_to_representation_with_protected_field(self): + + class FakeObject(): + max_speed = 100 + + model_field = self.Car._meta.get_field('max_speed') + instance = fields.ModelField(model_field) + self.assertEqual(instance.to_representation(FakeObject), 100.0) + + +class TestListField(DjangoUnitTest): + + def test_init_raises_assertion_error_for_defined_child_as_a_class(self): + + with self.assertRaises(AssertionError): + fields.ListField(child=fields.IntegerField) + + def test_init_raises_assertion_error_for_child_without_source(self): + + with self.assertRaises(AssertionError): + fields.ListField(child=fields.IntegerField(source='value')) + + def test_get_value(self): + instance = fields.ListField(child=fields.IntegerField()) + instance.bind('test', None) + self.assertEqual(instance.get_value({'test': [1, 2, 3]}), [1, 2, 3]) + + def test_get_value_returns_empty_value(self): + + class FakeModelSerializer(object): + partial = True + parent = None + + instance = fields.ListField(child=fields.IntegerField()) + instance.bind('test', FakeModelSerializer()) + self.assertEqual(instance.get_value({'key': [1, 2, 3]}), empty) + + def test_to_internal_value_empty_list(self): + instance = fields.ListField(child=fields.IntegerField()) + self.assertEqual(instance.to_internal_value([]), []) + + def test_to_internal_value_for_integer_list(self): + instance = fields.ListField(child=fields.IntegerField()) + self.assertEqual(instance.to_internal_value([1, 2, 3]), [1, 2, 3]) + + def test_to_internal_value_for_list_with_integer_as_a_string(self): + instance = fields.ListField(child=fields.IntegerField()) + self.assertEqual( + instance.to_internal_value(['1', '2', '3']), + [1, 2, 3] + ) + + def test_to_internal_value_raises_validation_error_for_wrong_type(self): + instance = fields.ListField(child=fields.IntegerField()) + + with self.assertRaises(ValidationError): + instance.to_internal_value({"key": "value"}) + + def test_to_internal_value_raises_validation_error_for_wrong_value(self): + instance = fields.ListField(child=fields.IntegerField()) + + with self.assertRaises(ValidationError): + instance.to_internal_value([1, 2, 'error']) + + def test_to_internal_value_raises_validation_error_for_empty_list(self): + instance = fields.ListField( + child=fields.IntegerField(), allow_empty=False + ) + + with self.assertRaises(ValidationError): + instance.to_internal_value([]) + + def test_to_representation(self): + instance = fields.ListField(child=fields.IntegerField()) + self.assertEqual(instance.to_representation([1, 2, 3]), [1, 2, 3]) + + def test_to_representation_for_string_list(self): + instance = fields.ListField(child=fields.IntegerField()) + self.assertEqual( + instance.to_representation(['1', '2', '3']), + [1, 2, 3] + ) + + def test_to_iternal_value_without_child_instance(self): + instance = fields.ListField() + self.assertEqual( + instance.to_internal_value([1, '2', True, [4, 5, 6]]), + [1, '2', True, [4, 5, 6]] + ) + + def test_to_iternal_value_without_child_instance_raises_an_error(self): + instance = fields.ListField() + + with self.assertRaises(ValidationError): + instance.to_internal_value('value') + + def test_to_representation_without_child_instance(self): + instance = fields.ListField() + self.assertEqual( + instance.to_representation([1, '2', True, [4, 5, 6]]), + [1, '2', True, [4, 5, 6]] + ) + + +class TestDictField(DjangoUnitTest): + + def test_init_raises_assertion_error_for_defined_child_as_a_class(self): + + with self.assertRaises(AssertionError): + fields.DictField(child=fields.CharField) + + def test_init_raises_assertion_error_for_child_without_source(self): + + with self.assertRaises(AssertionError): + fields.DictField(child=fields.CharField(source='value')) + + def test_get_value(self): + instance = fields.DictField(child=fields.CharField()) + instance.bind('test', None) + self.assertEqual(instance.get_value({'test': 'value'}), 'value') + + def test_get_value_returns_empty_value(self): + instance = fields.DictField(child=fields.CharField()) + instance.bind('test', None) + self.assertEqual(instance.get_value({'key': 'value'}), empty) + + def test_to_internal_value(self): + instance = fields.DictField(child=fields.CharField()) + self.assertEqual( + instance.to_internal_value({'a': 1, 'b': '2'}), + {'a': '1', 'b': '2'} + ) + + def test_to_internal_value_raises_validation_error_for_a_wrong_type(self): + instance = fields.DictField(child=fields.CharField()) + + with self.assertRaises(ValidationError): + instance.to_internal_value('value') + + def test_to_internal_value_raises_validation_error_for_a_wrong_value(self): + instance = fields.DictField(child=fields.CharField()) + + with self.assertRaises(ValidationError): + instance.to_internal_value({'key': None}) + + def test_to_representation(self): + instance = fields.DictField(child=fields.CharField()) + self.assertEqual( + instance.to_representation({'a': 1, 'b': '2'}), + {'a': '1', 'b': '2'} + ) + + def test_to_iternal_value_without_child_instance(self): + instance = fields.DictField() + self.assertEqual( + instance.to_internal_value({'a': 1, 'b': [1, 2], 'c': 'c'}), + {'a': 1, 'b': [1, 2], 'c': 'c'} + ) + + def test_to_iternal_value_without_child_instance_raises_an_error(self): + instance = fields.DictField() + + with self.assertRaises(ValidationError): + instance.to_internal_value('value') + + def test_to_representation_without_child_instance(self): + instance = fields.DictField() + self.assertEqual( + instance.to_representation({'a': 1, 'b': [1, 2], 'c': 'c'}), + {'a': 1, 'b': [1, 2], 'c': 'c'} + ) + + +class TestHStoreField(DjangoUnitTest): + + def test_init_raises_assertion_error_for_defined_child_as_a_class(self): + with self.assertRaises(AssertionError): + fields.HStoreField(child=fields.CharField) + + def test_init_raises_assertion_error_for_child_without_source(self): + with self.assertRaises(AssertionError): + fields.HStoreField(child=fields.CharField(source='value')) + + def test_get_value(self): + instance = fields.HStoreField() + instance.bind('test', None) + self.assertEqual(instance.get_value({'test': 'value'}), 'value') + + def test_get_value_returns_empty_value(self): + instance = fields.HStoreField(child=fields.CharField()) + instance.bind('test', None) + self.assertEqual(instance.get_value({'key': 'value'}), empty) + + def test_to_internal_value(self): + instance = fields.HStoreField(child=fields.CharField()) + self.assertEqual( + instance.to_internal_value({'a': 1, 'b': '2'}), + {'a': '1', 'b': '2'} + ) + + def test_to_internal_value_raises_validation_error_for_a_wrong_type(self): + instance = fields.HStoreField(child=fields.CharField()) + + with self.assertRaises(ValidationError): + instance.to_internal_value('value') + + def test_to_internal_value_raises_validation_error_for_a_wrong_value(self): + instance = fields.HStoreField(child=fields.CharField()) + + with self.assertRaises(ValidationError): + instance.to_internal_value({'key': None}) + + def test_to_representation(self): + instance = fields.HStoreField(child=fields.CharField()) + self.assertEqual( + instance.to_representation({'a': 1, 'b': '2'}), + {'a': '1', 'b': '2'} + ) + + def test_to_iternal_value_without_child_instance(self): + instance = fields.HStoreField() + self.assertEqual( + instance.to_internal_value({'a': 1, 'b': [1, 2], 'c': 'c'}), + {'a': '1', 'b': '[1, 2]', 'c': 'c'} + ) + + def test_to_iternal_value_without_child_instance_raises_an_error(self): + instance = fields.HStoreField() + + with self.assertRaises(ValidationError): + instance.to_internal_value('value') + + def test_to_representation_without_child_instance(self): + instance = fields.HStoreField() + self.assertEqual( + instance.to_representation({'a': 1, 'b': [1, 2], 'c': 'c'}), + {'a': '1', 'b': '[1, 2]', 'c': 'c'} + ) + + +class TestJSONField(DjangoUnitTest): + + # simple JSON + + def test_run_validation_raises_validation_error(self): + instance = fields.JSONField() + + with self.assertRaises(ValidationError): + instance.run_validation({'key': [1, '2', [3, set()]]}) + + def test_to_internal_value(self): + instance = fields.JSONField() + self.assertEqual( + instance.to_internal_value({'key': [0.0, 1, [2, 'nested', {}]]}), + {'key': [0.0, 1, [2, 'nested', {}]]} + ) + + def test_to_internal_value_raises_validation_error_for_a_wrong_type(self): + instance = fields.JSONField() + + with self.assertRaises(ValidationError): + instance.to_internal_value(set()) + + def test_to_representation(self): + instance = fields.JSONField() + self.assertEqual( + instance.to_representation({'key': [0.0, 1, [2, 'nested', {}]]}), + {'key': [0.0, 1, [2, 'nested', {}]]} + ) + + # binary JSON + + def test_run_validation_raises_validation_error_for_binary_json(self): + instance = fields.JSONField(binary=True) + + with self.assertRaises(ValidationError): + instance.run_validation("{'key': \"str)") + + def test_to_internal_value_for_binary_json(self): + instance = fields.JSONField(binary=True) + self.assertEqual( + instance.to_internal_value(b'{"key": [0.0, 1, [2, {}]]}'), + {"key": [0.0, 1, [2, {}]]} + ) + + def test_to_representation_for_binary_json(self): + instance = fields.JSONField(binary=True) + self.assertEqual( + instance.to_representation({'key': [0.0, 1, [2, {}]]}), + b'{"key": [0.0, 1, [2, {}]]}' + ) + + +class TestReadOnlyField(DjangoUnitTest): + + def test_to_representation(self): + instance = fields.ReadOnlyField() + self.assertEqual(instance.to_representation('value'), 'value') + + +class TestSerializerMethodField(DjangoUnitTest): + + class FakeModelSerializer(object): + + def get_none(self, obj): # NOQA + return {"key": "value"} + + def test_bind(self): + model_serializer = self.FakeModelSerializer() + instance = fields.SerializerMethodField() + instance.bind('none', model_serializer) + self.assertEqual(instance.parent, model_serializer) + self.assertEqual(instance.method_name, 'get_none') + + def test_bind_raises_assertion_error_for_method_name_argument(self): + model_serializer = self.FakeModelSerializer() + instance = fields.SerializerMethodField(method_name='get_none') + + with self.assertRaises(AssertionError): + instance.bind('none', model_serializer) + + def test_to_internal_value_raises_not_implemented_error(self): + instance = fields.SerializerMethodField() + + with self.assertRaises(NotImplementedError): + instance.to_internal_value({"key": "value"}) + + def test_to_representation(self): + model_serializer = self.FakeModelSerializer() + instance = fields.SerializerMethodField() + instance.bind('none', model_serializer) + self.assertEqual( + instance.to_representation(object()), + {"key": "value"} + ) + + +class TestCreateOnlyField(DjangoUnitTest): + + class FakeDefault(object): + + @staticmethod + def set_context(field): + field.test_attr = 'value' + + def test_can_set_context_returns_false(self): + instance = fields.CreateOnlyDefault('default') + instance.is_update = False + self.assertFalse(instance._can_set_context()) + + def test_can_set_context_returns_true_for_instance_class(self): + instance = fields.CreateOnlyDefault(self.FakeDefault) + instance.is_update = False + self.assertTrue(instance._can_set_context()) + + def test_can_set_context_returns_false_for_instance_class(self): + instance = fields.CreateOnlyDefault(self.FakeDefault) + instance.is_update = True + self.assertFalse(instance._can_set_context()) + + def test_can_set_context_returns_false_for_not_instance_class(self): + instance = fields.CreateOnlyDefault('default') + instance.is_update = True + self.assertFalse(instance._can_set_context()) + + def test_set_context_function_add_context_for_default(self): + instance = fields.CreateOnlyDefault(self.FakeDefault) + fake_parent = type('FakeParent', (), {'instance': None}) + serializer_field = fields.IntegerField() + serializer_field.bind('pk', fake_parent) + instance.set_context(serializer_field) + self.assertFalse(instance.is_update) + self.assertTrue(hasattr(serializer_field, 'test_attr')) + self.assertEqual(serializer_field.test_attr, 'value') + + def test_set_context_function_not_add_context_for_default(self): + instance = fields.CreateOnlyDefault(self.FakeDefault) + fake_parent = type('FakeParent', (), {'instance': object()}) + serializer_field = fields.IntegerField() + serializer_field.bind('pk', fake_parent) + instance.set_context(serializer_field) + self.assertTrue(instance.is_update) + self.assertFalse(hasattr(serializer_field, 'test_attr')) + + def test_call_returns_default(self): + instance = fields.CreateOnlyDefault('value') + instance.is_update = False + self.assertEqual(instance(), instance.default) + + def test_call_returns_default_from_callable(self): + instance = fields.CreateOnlyDefault(self.FakeDefault) + instance.is_update = False + self.assertIsInstance(instance(), self.FakeDefault) + + def test_call_raise_skip_field_exception(self): + instance = fields.CreateOnlyDefault(self.FakeDefault) + instance.is_update = True + + with self.assertRaises(SkipField): + instance() + + def test_repr(self): + instance = fields.CreateOnlyDefault('value') + self.assertEqual(instance.__repr__(), 'CreateOnlyDefault(value)') + + +class TestEmailField(DjangoUnitTest): + + def test_run_validation(self): + instance = fields.EmailField() + email = 'admin@email.com' + self.assertEqual(instance.run_validation(email), email) + + def test_run_validation_raises_invalid_email_error(self): + instance = fields.EmailField() + email = 'invalid_email' + self.assertRaises( + DjangoValidationError, + instance.run_validation, email + ) + + +class TestRegexField(DjangoUnitTest): + + regex = re.compile(r'[0-9]+') + + def test_run_validation(self): + instance = fields.RegexField(self.regex) + value = '123456789' + self.assertEqual(instance.run_validation(value), value) + + def test_run_validation_raises_invalid_value_error(self): + instance = fields.RegexField(self.regex) + value = 'invalid_value' + self.assertRaises( + DjangoValidationError, + instance.run_validation, value + ) + + +class TestSlugField(DjangoUnitTest): + + def test_run_validation(self): + instance = fields.SlugField() + value = 'valid-slug' + self.assertEqual(instance.run_validation(value), value) + + def test_run_validation_raises_invalid_value_error(self): + instance = fields.SlugField() + value = 'invalid-slug-with-$' + self.assertRaises( + DjangoValidationError, + instance.run_validation, value + ) + + +class TestURLField(DjangoUnitTest): + + def test_run_validation(self): + instance = fields.URLField() + url = 'http://my-test-website.com' + self.assertEqual(instance.run_validation(url), url) + + def test_run_validation_raises_invalid_value_error(self): + instance = fields.URLField() + url = 'not_http://definitelynotwebsite.com' + self.assertRaises( + DjangoValidationError, + instance.run_validation, url + ) + + +class TestUUIDField(DjangoUnitTest): + + def test_init_with_invalid_uuid_format(self): + self.assertRaises( + ValueError, + fields.UUIDField, format='unknown_format' + ) + + def test_run_validation(self): + instance = fields.UUIDField() + value = uuid.UUID('56c48c7e-a2b3-11e6-8bbf-0c4de9c846b0') + self.assertEqual(instance.run_validation(value), value) + + def test_run_validation_raises_invalid_value_error(self): + instance = fields.UUIDField() + value = 'invalid-uuid4' + self.assertRaises( + ValidationError, + instance.run_validation, value + ) + + def test_to_internal_value(self): + instance = fields.UUIDField() + value = uuid.UUID('56c48c7e-a2b3-11e6-8bbf-0c4de9c846b0') + self.assertEqual(instance.to_internal_value(value), value) + + def test_to_internal_value_with_integer_data(self): + instance = fields.UUIDField() + value = 115334147392221633161905320179804948144 + self.assertEqual( + instance.to_internal_value(value), + uuid.UUID('56c48c7e-a2b3-11e6-8bbf-0c4de9c846b0') + ) + + def test_to_internal_value_with_string_data(self): + instance = fields.UUIDField() + value = '56c48c7e-a2b3-11e6-8bbf-0c4de9c846b0' + self.assertEqual( + instance.to_internal_value(value), + uuid.UUID('56c48c7e-a2b3-11e6-8bbf-0c4de9c846b0') + ) + + def test_to_internal_value_raises_value_error_for_a_wrong_input(self): + instance = fields.UUIDField() + value = '56c48c7e-a2b3-11e6-wrong' + self.assertRaises(ValidationError, instance.to_internal_value, value) + + def test_to_internal_value_raises_value_error_for_invalid_type(self): + instance = fields.UUIDField() + value = ('not_uuid', ) + self.assertRaises(ValidationError, instance.to_internal_value, value) + + def test_to_representation(self): + instance = fields.UUIDField() + value = uuid.UUID('56c48c7e-a2b3-11e6-8bbf-0c4de9c846b0') + self.assertEqual(instance.to_representation(value), str(value)) + + def test_to_representation_with_custom_format(self): + instance = fields.UUIDField(format='int') + value = uuid.UUID('56c48c7e-a2b3-11e6-8bbf-0c4de9c846b0') + self.assertEqual(instance.to_representation(value), int(value)) + + +class TestIPAddressField(DjangoUnitTest): + + def test_run_validation(self): + instance = fields.IPAddressField() + ip = '127.0.0.1' + self.assertEqual(instance.run_validation(ip), ip) + + def test_run_validation_raises_invalid_value_error(self): + instance = fields.IPAddressField() + ip = (127, 0, 0, 1) + self.assertRaises(ValidationError, instance.run_validation, ip) + + def test_to_internal_value(self): + instance = fields.IPAddressField() + ip = '2931:dba:85a3:42:127a:1a2f:552:7011' + self.assertEqual(instance.to_internal_value(ip), ip) + + def test_to_internal_value_raises_invalid_value_error(self): + instance = fields.IPAddressField() + ip = 1270001 + self.assertRaises(ValidationError, instance.to_internal_value, ip) + + def test_to_internal_value_raises_invalid_value_error_for_ipv6(self): + instance = fields.IPAddressField() + ip = '1234:1234:1234:1234:1234:1234:1234:12345' + self.assertRaises(ValidationError, instance.to_internal_value, ip) + + +class TestFilePathField(DjangoUnitTest): + + path = os.path.abspath(os.path.dirname(__file__)) + + def test_run_validation(self): + instance = fields.FilePathField(self.path) + self.assertEqual(instance.run_validation(__file__), __file__) + + def test_run_validation_raise_invalid_error(self): + instance = fields.FilePathField(self.path) + self.assertRaises(ValidationError, instance.run_validation, 'path') + + +class MockFile: + + def __init__(self, name='', size=0, url=''): + self.name = name + self.size = size + self.url = url + + def __eq__(self, other): + return ( + isinstance(other, MockFile) and + self.name == other.name and + self.size == other.size and + self.url == other.url + ) + + +class TestFileField(DjangoUnitTest): + + def test_use_url(self): + instance = fields.FileField(max_length=10, use_url=True) + self.assertTrue(instance.use_url) + + def test_to_internal_value(self): + instance = fields.FileField(max_length=10) + test_file = MockFile(name='file.txt', size=10) + self.assertEqual(instance.to_internal_value(test_file), test_file) + + def test_to_internal_value_raises_error_for_invalid_attribute(self): + instance = fields.FileField(max_length=10) + test_file = 'not_file' + self.assertRaises( + ValidationError, + instance.to_internal_value, test_file + ) + + def test_to_internal_value_raises_error_for_file_without_name(self): + instance = fields.FileField(max_length=10) + test_file = MockFile(name='', size=10) + self.assertRaises( + ValidationError, + instance.to_internal_value, test_file + ) + + def test_to_internal_value_raises_error_for_empty_file(self): + instance = fields.FileField(max_length=10) + test_file = MockFile(name='file.txt', size=0) + self.assertRaises( + ValidationError, + instance.to_internal_value, test_file + ) + + def test_to_internal_value_raises_error_for_limit_file_size(self): + instance = fields.FileField(max_length=10) + test_file = MockFile(name='_' * 15, size=10) + self.assertRaises( + ValidationError, + instance.to_internal_value, test_file + ) + + def test_to_representation(self): + instance = fields.FileField(max_length=10) + test_file = MockFile(name='file.txt', url='/file.txt') + self.assertEqual(instance.to_representation(test_file), '/file.txt') + + def test_to_representation_without_using_url(self): + instance = fields.FileField(max_length=10, use_url=False) + test_file = MockFile(name='file.txt', url='/file.txt') + self.assertEqual(instance.to_representation(test_file), 'file.txt') + + def test_to_representation_for_not_saved_file(self): + instance = fields.FileField(max_length=10) + test_file = MockFile(name='file.txt') + self.assertIsNone(instance.to_representation(test_file)) + + def test_to_representation_for_empty_string(self): + instance = fields.FileField(max_length=10) + test_file = '' + self.assertIsNone(instance.to_representation(test_file)) + + +class TestImageField(DjangoUnitTest): + + def test_to_internal_value_with_passed_validation(self): + + class PassImageValidation(object): + def to_python(self, value): + return value + + instance = fields.ImageField(_DjangoImageField=PassImageValidation) + test_file = MockFile(name='file.txt', size=10) + self.assertEqual(instance.to_internal_value(test_file), test_file) + + def test_to_internal_value_with_failed_validation(self): + + class FailImageValidation(object): + def to_python(self, value): + raise ValidationError(self.error_messages['invalid_image']) + + instance = fields.ImageField(_DjangoImageField=FailImageValidation) + test_file = MockFile(name='file.txt', size=10) + self.assertRaises( + ValidationError, + instance.to_internal_value, test_file + ) diff --git a/tests/db/orm/sqlalchemy/test_fields.py b/tests/db/orm/sqlalchemy/test_fields.py index cdb8150..7548a91 100644 --- a/tests/db/orm/sqlalchemy/test_fields.py +++ b/tests/db/orm/sqlalchemy/test_fields.py @@ -1185,6 +1185,13 @@ def test_to_representation_for_custom_field(self): ) +class TestReadOnlyField(unittest.TestCase): + + def test_to_representation(self): + instance = fields.ReadOnlyField() + self.assertEqual(instance.to_representation('value'), 'value') + + class TestSerializerMethodField(unittest.TestCase): class FakeModelSerializer(object): diff --git a/tests/db/orm/sqlalchemy/test_serializers.py b/tests/db/orm/sqlalchemy/test_serializers.py index ca52fd2..accecee 100644 --- a/tests/db/orm/sqlalchemy/test_serializers.py +++ b/tests/db/orm/sqlalchemy/test_serializers.py @@ -459,6 +459,7 @@ class Meta: self.assertIsNone(instance.run_validation(None)) + @override_settings(SQLALCHEMY_SESSION=SESSION) def test_run_validation_raises_error_for_assert(self): class UserSerializer(ModelSerializer): @@ -482,6 +483,7 @@ def validate(self, data): self.assertRaises(ValidationError, instance.run_validation, data) + @override_settings(SQLALCHEMY_SESSION=SESSION) def test_run_validation_raises_error_for_validation_error(self): class AdminNameValidator(BaseValidator): diff --git a/tests/db/orm/test_serializers.py b/tests/db/orm/test_serializers.py index 80b6962..4cdc48c 100644 --- a/tests/db/orm/test_serializers.py +++ b/tests/db/orm/test_serializers.py @@ -1841,10 +1841,12 @@ def get_default_field_names(self, declared_fields, model_info): def _get_unique_field(self, model, unique_field_name): return getattr(model, unique_field_name) - def _get_unique_constraint_names(self, model, model_fields): + def _get_unique_constraint_names(self, model, model_fields, + field_names): return {'pk', } - def _get_unique_together_constraints(self, model): + def _get_unique_together_constraints(self, model, model_fields, + field_names): return set() def _get_default_field_value(self, unique_constraint_field): @@ -1897,10 +1899,12 @@ def get_default_field_names(self, declared_fields, model_info): def _get_unique_field(self, model, unique_field_name): return getattr(model, unique_field_name) - def _get_unique_constraint_names(self, model, model_fields): + def _get_unique_constraint_names(self, model, model_fields, + field_names): return {'pk', } - def _get_unique_together_constraints(self, model): + def _get_unique_together_constraints(self, model, model_fields, + field_names): return {'key', 'value'} def _get_default_field_value(self, unique_constraint_field): @@ -1957,10 +1961,12 @@ def get_default_field_names(self, declared_fields, model_info): def _get_unique_field(self, model, unique_field_name): return getattr(model, unique_field_name) - def _get_unique_constraint_names(self, model, model_fields): + def _get_unique_constraint_names(self, model, model_fields, + field_names): return {'pk', 'user'} - def _get_unique_together_constraints(self, model): + def _get_unique_together_constraints(self, model, model_fields, + field_names): return set() def _get_default_field_value(self, unique_constraint_field): @@ -2051,7 +2057,7 @@ class FakeModel(object): with self.assertRaises(NotImplementedError): instance._get_unique_constraint_names( - FakeModel, {'pk': fields.IntegerField()} + FakeModel, {'pk': fields.IntegerField()}, ['pk', ] ) def test_get_unique_together_constraint_raises_not_implemented_error(self): @@ -2062,7 +2068,9 @@ class FakeModel(object): instance = ModelSerializer() with self.assertRaises(NotImplementedError): - instance._get_unique_together_constraints(FakeModel) + instance._get_unique_together_constraints( + FakeModel, {'pk': fields.IntegerField()}, ['pk', ] + ) def test_get_unique_field_raises_not_implemented_error(self):