Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions mongoengine/base/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,13 +272,6 @@ def __eq__(self, other):
def __ne__(self, other):
return not self.__eq__(other)

def __hash__(self):
if getattr(self, 'pk', None) is None:
# For new object
return super(BaseDocument, self).__hash__()
else:
return hash(self.pk)

def clean(self):
"""
Hook for doing document level data cleaning before validation is run.
Expand Down
15 changes: 15 additions & 0 deletions mongoengine/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ class EmbeddedDocument(BaseDocument):
my_metaclass = DocumentMetaclass
__metaclass__ = DocumentMetaclass

# A generic embedded document doesn't have any immutable properties
# that describe it uniquely, hence it shouldn't be hashable. You can
# define your own __hash__ method on a subclass if you need your
# embedded documents to be hashable.
__hash__ = None

def __init__(self, *args, **kwargs):
super(EmbeddedDocument, self).__init__(*args, **kwargs)
self._instance = None
Expand Down Expand Up @@ -160,6 +166,15 @@ def pk(self, value):
"""Set the primary key."""
return setattr(self, self._meta['id_field'], value)

def __hash__(self):
"""Return the hash based on the PK of this document. If it's new
and doesn't have a PK yet, return the default object hash instead.
"""
if self.pk is None:
return super(BaseDocument, self).__hash__()
else:
return hash(self.pk)

@classmethod
def _get_db(cls):
"""Some Model using other db_alias"""
Expand Down
28 changes: 20 additions & 8 deletions tests/document/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2164,7 +2164,7 @@ class User(Document):
class BlogPost(Document):
pass

# Clear old datas
# Clear old data
User.drop_collection()
BlogPost.drop_collection()

Expand All @@ -2176,17 +2176,18 @@ class BlogPost(Document):
b1 = BlogPost.objects.create()
b2 = BlogPost.objects.create()

# in List
# Make sure docs are properly identified in a list (__eq__ is used
# for the comparison).
all_user_list = list(User.objects.all())

self.assertTrue(u1 in all_user_list)
self.assertTrue(u2 in all_user_list)
self.assertTrue(u3 in all_user_list)
self.assertFalse(u4 in all_user_list) # New object
self.assertFalse(b1 in all_user_list) # Other object
self.assertFalse(b2 in all_user_list) # Other object
self.assertTrue(u4 not in all_user_list) # New object
self.assertTrue(b1 not in all_user_list) # Other object
self.assertTrue(b2 not in all_user_list) # Other object

# in Dict
# Make sure docs can be used as keys in a dict (__hash__ is used
# for hashing the docs).
all_user_dic = {}
for u in User.objects.all():
all_user_dic[u] = "OK"
Expand All @@ -2198,9 +2199,20 @@ class BlogPost(Document):
self.assertEqual(all_user_dic.get(b1, False), False) # Other object
self.assertEqual(all_user_dic.get(b2, False), False) # Other object

# in Set
# Make sure docs are properly identified in a set (__hash__ is used
# for hashing the docs).
all_user_set = set(User.objects.all())
self.assertTrue(u1 in all_user_set)
self.assertTrue(u4 not in all_user_set)
self.assertTrue(b1 not in all_user_list)
self.assertTrue(b2 not in all_user_list)

# Make sure duplicate docs aren't accepted in the set
self.assertEqual(len(all_user_set), 3)
all_user_set.add(u1)
all_user_set.add(u2)
all_user_set.add(u3)
self.assertEqual(len(all_user_set), 3)

def test_picklable(self):
pickle_doc = PickleTest(number=1, string="One", lists=['1', '2'])
Expand Down