Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions mongoengine/base/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def __init__(self, *args, **values):
else:
self._data[key] = value

# Set any get_fieldname_display methods
# Set any get_<field>_display methods
self.__set_field_display()

if self._dynamic:
Expand Down Expand Up @@ -1005,19 +1005,18 @@ def _translate_field_name(cls, field, sep='.'):
return '.'.join(parts)

def __set_field_display(self):
"""Dynamically set the display value for a field with choices"""
for attr_name, field in self._fields.items():
if field.choices:
if self._dynamic:
obj = self
else:
obj = type(self)
setattr(obj,
'get_%s_display' % attr_name,
partial(self.__get_field_display, field=field))
"""For each field that specifies choices, create a
get_<field>_display method.
"""
fields_with_choices = [(n, f) for n, f in self._fields.items()
if f.choices]
for attr_name, field in fields_with_choices:
setattr(self,
'get_%s_display' % attr_name,
partial(self.__get_field_display, field=field))

def __get_field_display(self, field):
"""Returns the display value for a choice field"""
"""Return the display value for a choice field"""
value = getattr(self, field.name)
if field.choices and isinstance(field.choices[0], (list, tuple)):
return dict(field.choices).get(value, value)
Expand Down
50 changes: 27 additions & 23 deletions tests/fields/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,7 +1047,7 @@ class BlogPost(Document):
BlogPost.drop_collection()

def test_list_assignment(self):
"""Ensure that list field element assignment and slicing work
"""Ensure that list field element assignment and slicing work
"""
class BlogPost(Document):
info = ListField()
Expand All @@ -1057,12 +1057,12 @@ class BlogPost(Document):
post = BlogPost()
post.info = ['e1', 'e2', 3, '4', 5]
post.save()

post.info[0] = 1
post.save()
post.reload()
self.assertEqual(post.info[0], 1)

post.info[1:3] = ['n2', 'n3']
post.save()
post.reload()
Expand Down Expand Up @@ -1209,7 +1209,7 @@ class Simple(Document):
self.assertEqual(simple.widgets, [4])

def test_list_field_with_negative_indices(self):

class Simple(Document):
widgets = ListField()

Expand Down Expand Up @@ -1823,7 +1823,7 @@ class Person(Document):
'parent': "50a234ea469ac1eda42d347d"})
mongoed = p1.to_mongo()
self.assertTrue(isinstance(mongoed['parent'], ObjectId))

def test_cached_reference_field_get_and_save(self):
"""
Tests #1047: CachedReferenceField creates DBRefs on to_python, but can't save them on to_mongo
Expand All @@ -1835,11 +1835,11 @@ class Animal(Document):
class Ocorrence(Document):
person = StringField()
animal = CachedReferenceField(Animal)

Animal.drop_collection()
Ocorrence.drop_collection()
Ocorrence(person="testte",

Ocorrence(person="testte",
animal=Animal(name="Leopard", tag="heavy").save()).save()
p = Ocorrence.objects.get()
p.person = 'new_testte'
Expand Down Expand Up @@ -3001,28 +3001,32 @@ class Shirt(Document):
('S', 'Small'), ('M', 'Medium'), ('L', 'Large'),
('XL', 'Extra Large'), ('XXL', 'Extra Extra Large')))
style = StringField(max_length=3, choices=(
('S', 'Small'), ('B', 'Baggy'), ('W', 'wide')), default='S')
('S', 'Small'), ('B', 'Baggy'), ('W', 'Wide')), default='W')

Shirt.drop_collection()

shirt = Shirt()
shirt1 = Shirt()
shirt2 = Shirt()

self.assertEqual(shirt.get_size_display(), None)
self.assertEqual(shirt.get_style_display(), 'Small')
# Make sure get_<field>_display returns the default value (or None)
self.assertEqual(shirt1.get_size_display(), None)
self.assertEqual(shirt1.get_style_display(), 'Wide')

shirt.size = "XXL"
shirt.style = "B"
self.assertEqual(shirt.get_size_display(), 'Extra Extra Large')
self.assertEqual(shirt.get_style_display(), 'Baggy')
shirt1.size = 'XXL'
shirt1.style = 'B'
shirt2.size = 'M'
shirt2.style = 'S'
self.assertEqual(shirt1.get_size_display(), 'Extra Extra Large')
self.assertEqual(shirt1.get_style_display(), 'Baggy')
self.assertEqual(shirt2.get_size_display(), 'Medium')
self.assertEqual(shirt2.get_style_display(), 'Small')

# Set as Z - an invalid choice
shirt.size = "Z"
shirt.style = "Z"
self.assertEqual(shirt.get_size_display(), 'Z')
self.assertEqual(shirt.get_style_display(), 'Z')
self.assertRaises(ValidationError, shirt.validate)

Shirt.drop_collection()
shirt1.size = 'Z'
shirt1.style = 'Z'
self.assertEqual(shirt1.get_size_display(), 'Z')
self.assertEqual(shirt1.get_style_display(), 'Z')
self.assertRaises(ValidationError, shirt1.validate)

def test_simple_choices_validation(self):
"""Ensure that value is in a container of allowed values.
Expand Down