Skip to content

Commit

Permalink
Implemented Entity __eq__ method comparing all fields (#350)
Browse files Browse the repository at this point in the history
close #335

(cherry picked from commit ecf91be)
  • Loading branch information
renzon authored and renzon committed Jan 16, 2017
1 parent e23409d commit 32ad3c6
Show file tree
Hide file tree
Showing 2 changed files with 193 additions and 9 deletions.
117 changes: 108 additions & 9 deletions nailgun/entity_mixins.py
@@ -1,10 +1,14 @@
# -*- encoding: utf-8 -*-
"""Defines a set of mixins that provide tools for interacting with entities."""
import json as std_json
from collections import Iterable
from datetime import date, datetime

from fauxfactory import gen_choice
from inflection import pluralize
from nailgun import client, config
from nailgun.entity_fields import IntegerField, OneToManyField, OneToOneField
from nailgun import client, config, signals
from nailgun.entity_fields import (
IntegerField, OneToManyField, OneToOneField, ListField)
import threading
import time

Expand Down Expand Up @@ -76,9 +80,10 @@ def _poll_task(task_id, server_config, poll_rate=None, timeout=None):
timeout = TASK_TIMEOUT

# Implement the timeout.
def raise_task_timeout():
def raise_task_timeout(): # pragma: no cover
"""Raise a KeyboardInterrupt exception in the main thread."""
thread.interrupt_main()

timer = threading.Timer(timeout, raise_task_timeout)

# Poll until the task finishes. The timeout prevents an infinite loop.
Expand All @@ -92,7 +97,7 @@ def raise_task_timeout():
if task_info['state'] in ('paused', 'stopped'):
break
time.sleep(poll_rate)
except KeyboardInterrupt:
except KeyboardInterrupt: # pragma: no cover
# raise_task_timeout will raise a KeyboardInterrupt when the timeout
# expires. Catch the exception and raise TaskTimedOutError
raise TaskTimedOutError(
Expand Down Expand Up @@ -191,6 +196,15 @@ def _payload(fields, values):
values[field_name + '_ids'] = [
entity.id for entity in values.pop(field_name)
]
elif isinstance(field, ListField):
def parse(obj):
"""parse obj payload if it is an Entity"""
if isinstance(obj, Entity):
return _payload(obj.get_fields(), obj.get_values())
return obj

values[field_name] = [
parse(obj) for obj in values[field_name]]
return values


Expand Down Expand Up @@ -485,12 +499,11 @@ def get_fields(self):
def get_values(self):
"""Return a copy of field values on the current object.
This method is almost identical to ``vars(self).copy()``. However, only
instance attributes that correspond to a field are included in the
returned dict.
This method is almost identical to ``vars(self).copy()``. However,
only instance attributes that correspond to a field are included in
the returned dict.
:return: A dict mapping field names to user-provided values.
"""
attrs = vars(self).copy()
attrs.pop('_server_config')
Expand All @@ -510,6 +523,63 @@ def __repr__(self):
)
)

def to_json(self):
r"""Create a JSON encoded string with Entity properties. Ex:
>>> from nailgun import entities, config
>>> kwargs = {
... 'id': 1,
... 'name': 'Nailgun Org',
... }
>>> org = entities.Organization(config.ServerConfig('foo'), \*\*kwargs)
>>> org.to_json()
'{"id": 1, "name": "Nailgun Org"}'
:return: str
"""
return std_json.dumps(self.to_json_dict())

def to_json_dict(self):
"""Create a dct with Entity properties for json encoding.
It can be overridden by subclasses for each standard serialization
doesn't work. By default it call _to_json_dict on OneToOne fields
and build a list calling the same method on each object on OneToMany
fields.
:return: dct
"""
fields, values = self.get_fields(), self.get_values()
json_dct = {}
for field_name, field in fields.items():
if field_name in values:
value = values[field_name]
if value is None:
json_dct[field_name] = None
# This conditions is needed because some times you get
# None on an OneToOneField what lead to an error
# on bellow condition, e.g., calling value.to_json_dict()
# when value is None
elif isinstance(field, OneToOneField):
json_dct[field_name] = value.to_json_dict()
elif isinstance(field, OneToManyField):
json_dct[field_name] = [
entity.to_json_dict() for entity in value
]
else:
json_dct[field_name] = to_json_serializable(value)
return json_dct

def __eq__(self, other):
"""Compare two entities based on their properties. Even nested
objects are considered for equality
:param other: entity to compare self to
:return: boolean indicating if entities are equal or not
"""
if other is None:
return False
return self.to_json_dict() == other.to_json_dict()


class EntityDeleteMixin(object):
"""This mixin provides the ability to delete an entity.
Expand Down Expand Up @@ -541,6 +611,7 @@ def delete_raw(self):
**self._server_config.get_client_kwargs()
)

@signals.emit(sender=signals.SENDER_CLASS, post_result_name='result')
def delete(self, synchronous=True):
"""Delete the current entity.
Expand All @@ -562,12 +633,14 @@ def delete(self, synchronous=True):
out.
"""

response = self.delete_raw()
response.raise_for_status()

if (synchronous is True and
response.status_code == http_client.ACCEPTED):
return _poll_task(response.json()['id'], self._server_config)
if response.status_code == http_client.NO_CONTENT:
elif response.status_code == http_client.NO_CONTENT:
# "The server successfully processed the request, but is not
# returning any content. Usually used as a response to a successful
# delete request."
Expand Down Expand Up @@ -813,6 +886,7 @@ def create_json(self, create_missing=None):
response.raise_for_status()
return response.json()

@signals.emit(sender=signals.SENDER_CLASS, post_result_name='entity')
def create(self, create_missing=None):
"""Create an entity.
Expand Down Expand Up @@ -915,6 +989,7 @@ def update_json(self, fields=None):
response.raise_for_status()
return response.json()

@signals.emit(sender=signals.SENDER_CLASS, post_result_name='entity')
def update(self, fields=None):
"""Update the current entity.
Expand Down Expand Up @@ -1158,6 +1233,7 @@ def search_normalize(self, results):
normalized.append(attrs)
return normalized

@signals.emit(sender=signals.SENDER_CLASS, post_result_name='entities')
def search(self, fields=None, query=None, filters=None):
"""Search for entities.
Expand Down Expand Up @@ -1298,3 +1374,26 @@ def search_filter(entities, filters):
if getattr(entity, field_name) == field_value
]
return filtered


def to_json_serializable(obj):
""" Transforms obj into a json serializable object.
:param obj: entity or any json serializable object
:return: serializable object
"""
if isinstance(obj, Entity):
return obj.to_json_dict()

if isinstance(obj, dict):
return {k: to_json_serializable(v) for k, v in obj.items()}
elif isinstance(obj, (list, tuple)):
return [to_json_serializable(v) for v in obj]
elif isinstance(obj, datetime):
return obj.strftime('%Y-%m-%d %H:%M:%S')
elif isinstance(obj, date):
return obj.strftime('%Y-%m-%d')

return obj
85 changes: 85 additions & 0 deletions tests/test_entity_mixins.py
Expand Up @@ -7,6 +7,7 @@
from nailgun import client, config, entity_mixins
from nailgun.entity_fields import (
IntegerField,
ListField,
OneToManyField,
OneToOneField,
StringField,
Expand Down Expand Up @@ -60,6 +61,25 @@ def __init__(self, server_config=None, **kwargs):
super(SampleEntityTwo, self).__init__(server_config, **kwargs)


class SampleEntityThree(entity_mixins.Entity):
"""An entity with foreign key fields as One to One and ListField.
This class has a :class:`nailgun.entity_fields.OneToOneField` called
"one_to_one" pointing to :class:`tests.test_entity_mixins.SampleEntityTwo`.
This class has a :class:`nailgun.entity_fields.ListField` called "list"
containing instances of :class:`tests.test_entity_mixins.SampleEntity`.
"""

def __init__(self, server_config=None, **kwargs):
self._fields = {
'one_to_one': OneToOneField(SampleEntityTwo),
'list': ListField()
}
super(SampleEntityThree, self).__init__(server_config, **kwargs)


class EntityWithCreate(entity_mixins.Entity, entity_mixins.EntityCreateMixin):
"""Inherits from :class:`nailgun.entity_mixins.EntityCreateMixin`."""

Expand Down Expand Up @@ -343,6 +363,70 @@ def test_bad_value_error(self):
with self.assertRaises(entity_mixins.BadValueError):
SampleEntityTwo(self.cfg, one_to_many=1)

def test_eq_none(self):
"""Test method ``nailgun.entity_mixins.Entity.__eq__`` against None
Assert that ``__eq__`` returns False when compared to None.
"""
alice = SampleEntity(self.cfg, id=1, name='Alice')
self.assertFalse(alice.__eq__(None))

def test_eq(self):
"""Test method ``nailgun.entity_mixins.Entity.__eq__``.
Assert that ``__eq__`` works comparing all attributes, even from
nested structures.
"""
# Testing simple properties
alice = SampleEntity(self.cfg, id=1, name='Alice')
alice_clone = SampleEntity(self.cfg, id=1, name='Alice')
self.assertEqual(alice, alice_clone)

alice_id_2 = SampleEntity(self.cfg, id=2, name='Alice')
self.assertNotEqual(alice, alice_id_2)

# Testing OneToMany nested objects

john = SampleEntityTwo(self.cfg, one_to_many=[alice, alice_id_2])
john_clone = SampleEntityTwo(self.cfg, one_to_many=[alice, alice_id_2])
self.assertEqual(john, john_clone)

john_different_order = SampleEntityTwo(self.cfg, one_to_many=[
alice_id_2, alice,
])
self.assertNotEqual(john, john_different_order)

john_missing_alice = SampleEntityTwo(self.cfg, one_to_many=[alice])
self.assertNotEqual(john, john_missing_alice)

john_without_alice = SampleEntityTwo(self.cfg)
self.assertNotEqual(john, john_without_alice)

# Testing OneToOne nested objects

mary = SampleEntityThree(self.cfg, one_to_one=john)
mary_clone = SampleEntityThree(
self.cfg, one_to_one=john_clone)
self.assertEqual(mary, mary_clone)

mary_different = SampleEntityThree(
self.cfg, one_to_one=john_different_order)
self.assertNotEqual(mary, mary_different)

mary_none_john = SampleEntityThree(self.cfg, one_to_one=None)
mary_none_john.to_json_dict()
self.assertNotEqual(mary, mary_none_john)

# Testing List nested objects
# noqa pylint:disable=attribute-defined-outside-init
mary.list = [alice]
self.assertNotEqual(mary, mary_clone)
# noqa pylint:disable=attribute-defined-outside-init
mary_clone.list = [alice_clone]
self.assertEqual(mary, mary_clone)

def test_repr_v1(self):
"""Test method ``nailgun.entity_mixins.Entity.__repr__``.
Expand Down Expand Up @@ -545,6 +629,7 @@ def setUpClass(cls):
``test_entity`` is a class having one to one and one to many fields.
"""

class TestEntity(entity_mixins.Entity, entity_mixins.EntityReadMixin):
"""An entity with several different types of fields."""

Expand Down

0 comments on commit 32ad3c6

Please sign in to comment.