Skip to content

Commit

Permalink
Refactor assertEqualXML into a testtools matcher
Browse files Browse the repository at this point in the history
Not all tests inheriting from TestCase will need to check XML equality.
Moving this functionality to a matcher leaves it up to the test class to
decided whether or not it needs it.

Change-Id: Ib28ec3b5dd96f662ce0cd90c650434b24c63ad6c
Related-Bug: #1226466
  • Loading branch information
dstanek committed Dec 4, 2013
1 parent 0dd5451 commit 96f1980
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 42 deletions.
25 changes: 0 additions & 25 deletions keystone/tests/core.py
Expand Up @@ -21,14 +21,12 @@
import re
import shutil
import socket
import StringIO
import sys
import time
import warnings

import fixtures
import logging
from lxml import etree
from paste import deploy
import testtools

Expand Down Expand Up @@ -469,29 +467,6 @@ def safe_repr(obj, short=False):

self.fail(self._formatMessage(msg, standardMsg))

def assertEqualXML(self, a, b):
"""Parses two XML documents from strings and compares the results.
This provides easy-to-read failures.
"""
parser = etree.XMLParser(remove_blank_text=True)

def canonical_xml(s):
s = s.strip()

fp = StringIO.StringIO()
dom = etree.fromstring(s, parser)
dom.getroottree().write_c14n(fp)
s = fp.getvalue()

dom = etree.fromstring(s, parser)
return etree.tostring(dom, pretty_print=True)

a = canonical_xml(a)
b = canonical_xml(b)
self.assertEqual(a.split('\n'), b.split('\n'))

def skip_if_no_ipv6(self):
try:
s = socket.socket(socket.AF_INET6)
Expand Down
62 changes: 62 additions & 0 deletions keystone/tests/matchers.py
@@ -0,0 +1,62 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4

# Copyright 2013 OpenStack Foundation
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

import six

from lxml import etree
from testtools import matchers


class XMLEquals(object):
"""Parses two XML documents from strings and compares the results.
"""

def __init__(self, expected):
self.expected = expected

def __str__(self):
return "%s(%r)" % (self.__class__.__name__, self.expected)

def match(self, other):
parser = etree.XMLParser(remove_blank_text=True)

def canonical_xml(s):
s = s.strip()

fp = six.StringIO()
dom = etree.fromstring(s, parser)
dom.getroottree().write_c14n(fp)
s = fp.getvalue()

dom = etree.fromstring(s, parser)
return etree.tostring(dom, pretty_print=True)

expected = canonical_xml(self.expected)
other = canonical_xml(other)
if expected == other:
return
return XMLMismatch(expected, other)


class XMLMismatch(matchers.Mismatch):

def __init__(self, expected, other):
self.expected = expected
self.other = other

def describe(self):
return 'expected = %s\nactual = %s' % (self.expected, self.other)
60 changes: 60 additions & 0 deletions keystone/tests/test_matchers.py
@@ -0,0 +1,60 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4

# Copyright 2013 OpenStack Foundation
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

import textwrap

import testtools
from testtools.tests.matchers import helpers

from keystone.tests import matchers


class TestXMLEquals(testtools.TestCase, helpers.TestMatchersInterface):
matches_xml = """
<?xml version="1.0" encoding="UTF-8"?>
<test xmlns="http://docs.openstack.org/identity/api/v2.0">
<success a="a" b="b"/>
</test>
"""
equivalent_xml = """
<?xml version="1.0" encoding="UTF-8"?>
<test xmlns="http://docs.openstack.org/identity/api/v2.0">
<success b="b" a="a"></success>
</test>
"""
mismatches_xml = """
<?xml version="1.0" encoding="UTF-8"?>
<test xmlns="http://docs.openstack.org/identity/api/v2.0">
<nope_it_fails/>
</test>
"""
mismatches_description = textwrap.dedent("""\
expected = <test xmlns="http://docs.openstack.org/identity/api/v2.0">
<success a="a" b="b"/>
</test>
actual = <test xmlns="http://docs.openstack.org/identity/api/v2.0">
<nope_it_fails/>
</test>
""").lstrip()

matches_matcher = matchers.XMLEquals(matches_xml)
matches_matches = [matches_xml, equivalent_xml]
matches_mismatches = [mismatches_xml]
describe_examples = [
(mismatches_description, mismatches_xml, matches_matcher),
]
str_examples = [('XMLEquals(%r)' % matches_xml, matches_matcher)]
15 changes: 8 additions & 7 deletions keystone/tests/test_serializer.py
Expand Up @@ -20,22 +20,23 @@

from keystone.common import serializer
from keystone import tests
from keystone.tests import matchers as ksmatchers


class XmlSerializerTestCase(tests.TestCase):
def assertSerializeDeserialize(self, d, xml, xmlns=None):
self.assertEqualXML(
self.assertThat(
serializer.to_xml(copy.deepcopy(d), xmlns),
xml)
ksmatchers.XMLEquals(xml))
self.assertEqual(serializer.from_xml(xml), d)

# operations should be invertible
self.assertEqual(
serializer.from_xml(serializer.to_xml(copy.deepcopy(d), xmlns)),
d)
self.assertEqualXML(
self.assertThat(
serializer.to_xml(serializer.from_xml(xml), xmlns),
xml)
ksmatchers.XMLEquals(xml))

def test_auth_request(self):
d = {
Expand Down Expand Up @@ -162,7 +163,7 @@ def test_policy_list(self):
<policy id="ab12cd"/>
</policies>
"""
self.assertEqualXML(serializer.to_xml(d), xml)
self.assertThat(serializer.to_xml(d), ksmatchers.XMLEquals(xml))

def test_values_list(self):
d = {
Expand All @@ -183,7 +184,7 @@ def test_values_list(self):
</objects>
"""

self.assertEqualXML(serializer.to_xml(d), xml)
self.assertThat(serializer.to_xml(d), ksmatchers.XMLEquals(xml))

def test_collection_list(self):
d = {
Expand Down Expand Up @@ -296,7 +297,7 @@ def test_v2_links_special_case(self):
identity-service/2.0/identity-dev-guide-2.0.pdf" type="application/pdf"/>
</object>
"""
self.assertEqualXML(serializer.to_xml(d), xml)
self.assertThat(serializer.to_xml(d), ksmatchers.XMLEquals(xml))

def test_xml_with_namespaced_attribute_to_dict(self):
expected = {
Expand Down
21 changes: 11 additions & 10 deletions keystone/tests/test_versions.py
Expand Up @@ -22,6 +22,7 @@
from keystone.openstack.common.fixture import moxstubout
from keystone.openstack.common import jsonutils
from keystone import tests
from keystone.tests import matchers


CONF = config.CONF
Expand Down Expand Up @@ -341,47 +342,47 @@ def test_public_versions(self):
self.assertEqual(resp.status_int, 300)
data = resp.body
expected = self.VERSIONS_RESPONSE % dict(port=CONF.public_port)
self.assertEqualXML(data, expected)
self.assertThat(data, matchers.XMLEquals(expected))

def test_admin_versions(self):
client = self.client(self.admin_app)
resp = client.get('/', headers=self.REQUEST_HEADERS)
self.assertEqual(resp.status_int, 300)
data = resp.body
expected = self.VERSIONS_RESPONSE % dict(port=CONF.admin_port)
self.assertEqualXML(data, expected)
self.assertThat(data, matchers.XMLEquals(expected))

def test_public_version_v2(self):
client = self.client(self.public_app)
resp = client.get('/v2.0/', headers=self.REQUEST_HEADERS)
self.assertEqual(resp.status_int, 200)
data = resp.body
expected = self.v2_VERSION_RESPONSE % dict(port=CONF.public_port)
self.assertEqualXML(data, expected)
self.assertThat(data, matchers.XMLEquals(expected))

def test_admin_version_v2(self):
client = self.client(self.admin_app)
resp = client.get('/v2.0/', headers=self.REQUEST_HEADERS)
self.assertEqual(resp.status_int, 200)
data = resp.body
expected = self.v2_VERSION_RESPONSE % dict(port=CONF.admin_port)
self.assertEqualXML(data, expected)
self.assertThat(data, matchers.XMLEquals(expected))

def test_public_version_v3(self):
client = self.client(self.public_app)
resp = client.get('/v3/', headers=self.REQUEST_HEADERS)
self.assertEqual(resp.status_int, 200)
data = resp.body
expected = self.v3_VERSION_RESPONSE % dict(port=CONF.public_port)
self.assertEqualXML(data, expected)
self.assertThat(data, matchers.XMLEquals(expected))

def test_admin_version_v3(self):
client = self.client(self.public_app)
resp = client.get('/v3/', headers=self.REQUEST_HEADERS)
self.assertEqual(resp.status_int, 200)
data = resp.body
expected = self.v3_VERSION_RESPONSE % dict(port=CONF.admin_port)
self.assertEqualXML(data, expected)
self.assertThat(data, matchers.XMLEquals(expected))

def test_v2_disabled(self):
self.stubs.Set(controllers, '_VERSIONS', ['v3'])
Expand All @@ -392,7 +393,7 @@ def test_v2_disabled(self):
self.assertEqual(resp.status_int, 200)
data = resp.body
expected = self.v3_VERSION_RESPONSE % dict(port=CONF.public_port)
self.assertEqualXML(data, expected)
self.assertThat(data, matchers.XMLEquals(expected))

# only v3 information should be displayed by requests to /
v3_only_response = ((self.DOC_INTRO + '<versions %(namespace)s>' +
Expand All @@ -404,7 +405,7 @@ def test_v2_disabled(self):
resp = client.get('/', headers=self.REQUEST_HEADERS)
self.assertEqual(resp.status_int, 300)
data = resp.body
self.assertEqualXML(data, v3_only_response)
self.assertThat(data, matchers.XMLEquals(v3_only_response))

def test_v3_disabled(self):
self.stubs.Set(controllers, '_VERSIONS', ['v2.0'])
Expand All @@ -415,7 +416,7 @@ def test_v3_disabled(self):
self.assertEqual(resp.status_int, 200)
data = resp.body
expected = self.v2_VERSION_RESPONSE % dict(port=CONF.public_port)
self.assertEqualXML(data, expected)
self.assertThat(data, matchers.XMLEquals(expected))

# only v2 information should be displayed by requests to /
v2_only_response = ((self.DOC_INTRO + '<versions %(namespace)s>' +
Expand All @@ -427,4 +428,4 @@ def test_v3_disabled(self):
resp = client.get('/', headers=self.REQUEST_HEADERS)
self.assertEqual(resp.status_int, 300)
data = resp.body
self.assertEqualXML(data, v2_only_response)
self.assertThat(data, matchers.XMLEquals(v2_only_response))

0 comments on commit 96f1980

Please sign in to comment.