diff --git a/tests/schema/test_core.py b/tests/schema/test_core.py index d230967..509589c 100644 --- a/tests/schema/test_core.py +++ b/tests/schema/test_core.py @@ -104,6 +104,14 @@ class OneToOneLeaf(core.Model): root = core.OneToOneAttribute(OneToOneRoot, related_name='leaf') +class ManyToOneRoot(core.Model): + id = core.SlugAttribute(verbose_name='ID') + + +class ManyToOneLeaf(core.Model): + root = core.ManyToOneAttribute(ManyToOneRoot, related_name='leaves') + + class ManyToManyRoot(core.Model): id = core.SlugAttribute(verbose_name='ID') @@ -679,6 +687,128 @@ def test_validate_manytomany_attribute(self): error = obj.validate() self.assertEqual(error, None) + def test_onetoone_set_related(self): + root = OneToOneRoot() + leaf = OneToOneLeaf() + + root.leaf = leaf + self.assertEqual(leaf.root, root) + + root.leaf = None + self.assertEqual(leaf.root, None) + + leaf.root = root + self.assertEqual(root.leaf, leaf) + + leaf.root = None + self.assertEqual(root.leaf, None) + + def test_manytoone_set_related(self): + roots = [ + ManyToOneRoot(), + ManyToOneRoot(), + ] + leaves = [ + ManyToOneLeaf(), + ManyToOneLeaf(), + ] + + leaves[0].root = roots[0] + self.assertEqual(roots[0].leaves, set(leaves[0:1])) + + leaves[1].root = roots[0] + self.assertEqual(roots[0].leaves, set(leaves[0:2])) + + leaves[0].root = None + self.assertEqual(roots[0].leaves, set(leaves[1:2])) + + roots[0].leaves = set() + self.assertEqual(roots[0].leaves, set()) + self.assertEqual(leaves[1].root, None) + + roots[0].leaves.add(leaves[0]) + self.assertEqual(roots[0].leaves, set(leaves[0:1])) + self.assertEqual(leaves[0].root, roots[0]) + + roots[0].leaves.update(leaves[1:2]) + self.assertEqual(roots[0].leaves, set(leaves[0:2])) + self.assertEqual(leaves[1].root, roots[0]) + + roots[0].leaves.remove(leaves[0]) + self.assertEqual(roots[0].leaves, set(leaves[1:2])) + self.assertEqual(leaves[0].root, None) + + roots[0].leaves = set() + leaves[0].root = roots[0] + leaves[0].root = roots[1] + self.assertEqual(roots[0].leaves, set()) + self.assertEqual(roots[1].leaves, set(leaves[0:1])) + + roots[0].leaves = leaves[0:1] + self.assertEqual(roots[0].leaves, set(leaves[0:1])) + self.assertEqual(roots[1].leaves, set()) + self.assertEqual(leaves[0].root, roots[0]) + + roots[1].leaves = leaves[0:2] + self.assertEqual(roots[0].leaves, set()) + self.assertEqual(roots[1].leaves, set(leaves[0:2])) + self.assertEqual(leaves[0].root, roots[1]) + self.assertEqual(leaves[1].root, roots[1]) + + def test_manytomany_set_related(self): + roots = [ + ManyToManyRoot(), + ManyToManyRoot(), + ] + leaves = [ + ManyToManyLeaf(), + ManyToManyLeaf(), + ] + + roots[0].leaves.add(leaves[0]) + self.assertEqual(leaves[0].roots, set(roots[0:1])) + + roots[0].leaves.remove(leaves[0]) + self.assertEqual(leaves[0].roots, set()) + + roots[0].leaves.add(leaves[0]) + roots[1].leaves.add(leaves[0]) + self.assertEqual(leaves[0].roots, set(roots[0:2])) + + roots[0].leaves.clear() + roots[1].leaves.clear() + self.assertEqual(leaves[0].roots, set()) + self.assertEqual(leaves[1].roots, set()) + + roots[0].leaves = leaves + roots[1].leaves = leaves + self.assertEqual(leaves[0].roots, set(roots[0:2])) + self.assertEqual(leaves[1].roots, set(roots[0:2])) + + # reverse + roots[0].leaves.clear() + roots[1].leaves.clear() + + leaves[0].roots.add(roots[0]) + self.assertEqual(roots[0].leaves, set(leaves[0:1])) + + leaves[0].roots.remove(roots[0]) + self.assertEqual(roots[0].leaves, set()) + + leaves[0].roots.add(roots[0]) + leaves[1].roots.add(roots[0]) + self.assertEqual(roots[0].leaves, set(leaves[0:2])) + + leaves[0].roots.clear() + leaves[1].roots.clear() + self.assertEqual(roots[0].leaves, set()) + self.assertEqual(roots[1].leaves, set()) + + leaves[0].roots = roots + leaves[1].roots = roots + self.assertEqual(roots[0].leaves, set(leaves[0:2])) + self.assertEqual(roots[1].leaves, set(leaves[0:2])) + def test_clean_and_validate_objects(self): grandparent = Grandparent(id='root') parents = [ diff --git a/tests/schema/test_io.py b/tests/schema/test_io.py index 6103cc6..fcb55da 100644 --- a/tests/schema/test_io.py +++ b/tests/schema/test_io.py @@ -23,7 +23,7 @@ class Meta(core.Model.Meta): class Node(core.Model): - id = core.StringAttribute(primary=True, unique=True) + id = core.SlugAttribute(primary=True) root = core.ManyToOneAttribute(Root, related_name='nodes') val1 = core.FloatAttribute() val2 = core.FloatAttribute() @@ -34,12 +34,12 @@ class Meta(core.Model.Meta): class Leaf(core.Model): id = core.StringAttribute(primary=True) - node = core.ManyToOneAttribute(Node, related_name='leaves') + nodes = core.ManyToManyAttribute(Node, related_name='leaves') val1 = core.FloatAttribute() val2 = core.FloatAttribute() class Meta(core.Model.Meta): - attribute_order = ('id', 'node', 'val1', 'val2', ) + attribute_order = ('id', 'nodes', 'val1', 'val2', ) class TestIo(unittest.TestCase): @@ -54,17 +54,17 @@ def tearDown(self): def test_write_read(self): root = Root(id='root', name='root') nodes = [ - Node(root=root, id='node-0', val1=1, val2=2), + Node(root=root, id=u'node-0-\u20ac', val1=1, val2=2), Node(root=root, id='node-1', val1=3, val2=4), Node(root=root, id='node-2', val1=5, val2=6), ] leaves = [ - Leaf(node=nodes[0], id='leaf-0-0', val1=7, val2=8), - Leaf(node=nodes[0], id='leaf-0-1', val1=9, val2=10), - Leaf(node=nodes[1], id='leaf-1-0', val1=11, val2=12), - Leaf(node=nodes[1], id='leaf-1-1', val1=13, val2=14), - Leaf(node=nodes[2], id='leaf-2-0', val1=15, val2=16), - Leaf(node=nodes[2], id='leaf-2-1', val1=17, val2=18), + Leaf(nodes=[nodes[0]], id='leaf-0-0', val1=7, val2=8), + Leaf(nodes=[nodes[0]], id='leaf-0-1', val1=9, val2=10), + Leaf(nodes=[nodes[1]], id='leaf-1-0', val1=11, val2=12), + Leaf(nodes=[nodes[1]], id='leaf-1-1', val1=13, val2=14), + Leaf(nodes=[nodes[2]], id='leaf-2-0', val1=15, val2=16), + Leaf(nodes=[nodes[2]], id='leaf-2-1', val1=17, val2=18), ] objects = set((root, )) | root.get_related() @@ -79,3 +79,6 @@ def test_write_read(self): root2 = objects2[Root].pop() self.assertEqual(root2, root) + + # unicode + self.assertEqual(next(obj for obj in objects2[Node] if obj.val1 == 1).id, u'node-0-\u20ac') diff --git a/tests/schema/test_utils.py b/tests/schema/test_utils.py index bed7e8a..e1d0d38 100644 --- a/tests/schema/test_utils.py +++ b/tests/schema/test_utils.py @@ -5,6 +5,7 @@ :Copyright: 2016, Karr Lab :License: MIT """ +from six import string_types from wc_utils.schema import core, utils import sys import unittest @@ -72,4 +73,4 @@ def test_get_related_errors(self): self.assertEqual(len(errors_by_model[Root]), 1) self.assertEqual(len(errors_by_model[Node]), 2) - self.assertIsInstance(utils.get_object_set_error_string(errors), str) + self.assertIsInstance(utils.get_object_set_error_string(errors), string_types) diff --git a/wc_utils/schema/core.py b/wc_utils/schema/core.py index bdd5410..c1a76d6 100644 --- a/wc_utils/schema/core.py +++ b/wc_utils/schema/core.py @@ -121,7 +121,7 @@ def init_related_attributes(cls): related_class.Meta.primary_attribute.name, related_class.__name__)) if isinstance(attr, ManyToManyAttribute) and not isinstance(related_class.Meta.primary_attribute, (SlugAttribute, IntegerAttribute)): - raise ValueError('Primary attribute {} of related class {} must be unique'.format( + raise ValueError('Primary attribute {} of related class {} must be a slug or integer attribute'.format( related_class.Meta.primary_attribute.name, related_class.__name__)) # check that name doesn't conflict with another attribute @@ -1762,7 +1762,7 @@ def deserialize(self, value, objects): :obj:`tuple` of `object`, `InvalidAttribute`: tuple of cleaned value and cleaning error """ if not value: - return None + return (None, None) related_objs = [] related_classes = chain([self.related_class], get_subclasses(self.related_class)) @@ -1932,7 +1932,7 @@ def deserialize(self, value, objects): :obj:`tuple` of `object`, `InvalidAttribute`: tuple of cleaned value and cleaning error """ if not value: - return None + return (None, None) related_objs = [] related_classes = chain([self.related_class], get_subclasses(self.related_class)) @@ -2098,7 +2098,7 @@ def deserialize(self, values, objects): :obj:`tuple` of `object`, `InvalidAttribute`: tuple of cleaned value and cleaning error """ if not values: - return set() + return (set(), None) deserialized_values = set() errors = [] @@ -2156,7 +2156,7 @@ def discard(self, value): def clear(self): """ Remove all elements from set """ - for value in self: + for value in list(self): self.remove(value) def pop(self): @@ -2179,7 +2179,7 @@ def intersection_update(self, values): Args: values (:obj:`set`): values to intersect with set """ - for value in self: + for value in list(self): if value not in values: self.remove(value) @@ -2199,12 +2199,12 @@ def symmetric_difference_update(self, values): Args: values (:obj:`set`): values to difference with set """ - for value in self: + for value in list(self): if value in values: self.remove(value) values.remove(value) - for value in values: + for value in list(values): if value in self: self.remove(value) values.remove(value) @@ -2224,7 +2224,7 @@ def add(self, value, propagate=True): """ super(ManyToOneRelatedManager, self).add(value) if propagate: - setattr(value, attr.name, obj) + value.__setattr__(self.attribute.name, self.object) def remove(self, value, update_set=True, propagate=True): """ Remove value from set @@ -2233,10 +2233,10 @@ def remove(self, value, update_set=True, propagate=True): value (:obj:`object`): value propagate (:obj:`bool`, optional): propagate change to related attribute """ - if update_set: + if update_set and value in self: super(ManyToOneRelatedManager, self).remove(value) if propagate: - setattr(value, attr.name, None) + value.__setattr__(self.attribute.name, None) class ManyToManyRelatedManager(RelatedManager): @@ -2269,9 +2269,9 @@ def add(self, value, propagate=True): super(ManyToManyRelatedManager, self).add(value) if propagate: if self.related: - getattr(value, self.attribute.name).add(self.object) + getattr(value, self.attribute.name).add(self.object, propagate=False) else: - getattr(value, self.attribute.related_name).add(self.object) + getattr(value, self.attribute.related_name).add(self.object, propagate=False) def remove(self, value, update_set=True, propagate=True): """ Remove value from set @@ -2285,9 +2285,9 @@ def remove(self, value, update_set=True, propagate=True): super(ManyToManyRelatedManager, self).remove(value) if propagate: if self.related: - getattr(value, self.attribute.name).remove(self.object) + getattr(value, self.attribute.name).remove(self.object, propagate=False) else: - getattr(value, self.attribute.related_name).remove(self.object) + getattr(value, self.attribute.related_name).remove(self.object, propagate=False) class InvalidObjectSet(object): diff --git a/wc_utils/schema/utils.py b/wc_utils/schema/utils.py index 31deadf..5fe2e04 100644 --- a/wc_utils/schema/utils.py +++ b/wc_utils/schema/utils.py @@ -7,6 +7,7 @@ """ # todo: add method to compare (difference) models +from __future__ import unicode_literals from wc_utils.schema.core import Model, Attribute, RelatedAttribute, InvalidObjectSet, InvalidObject, clean_and_validate_objects