Skip to content

Commit

Permalink
Allow inheritance in testing.inject_backend_tests
Browse files Browse the repository at this point in the history
  • Loading branch information
niboshi committed Dec 21, 2018
1 parent 28ca03e commit 108fa46
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 149 deletions.
88 changes: 88 additions & 0 deletions chainer/testing/_gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import collections
import inspect
import sys
import unittest


# A tuple that represents a test case.
# For bare (non-generated) test cases, [1] and [2] are None.
# [0] Test case class
# [1] Module name in whicn the class is defined
# [2] Class name
_TestCaseTuple = collections.namedtuple(
'_TestCaseTuple', ('klass', 'module_name', 'class_name'))


class _ParameterizedTestCaseBundle(object):
def __init__(self, cases):
# cases is a list of _TestCaseTuple's
assert isinstance(cases, list)
assert all(isinstance(tup, _TestCaseTuple) for tup in cases)

self.cases = cases


def make_decorator(test_case_generator):

def f(cases):
if isinstance(cases, _ParameterizedTestCaseBundle):
# The input is a parameterized test case.
cases = cases.cases
else:
# Input is a bare test case, i.e. not one generated from another
# parameterize.
assert issubclass(cases, unittest.TestCase)
cases = [_TestCaseTuple(cases, None, None)]

generated_cases = []
for klass, mod_name, cls_name in cases:
assert issubclass(klass, unittest.TestCase)
if mod_name is not None:
# The input is a parameterized test case.
# Remove it from its module.
delattr(sys.modules[mod_name], cls_name)
else:
# The input is a bare test case
mod_name = klass.__module__

# Generate parameterized test cases out of the input test case.
l = _generate_test_cases(mod_name, klass, test_case_generator)
generated_cases += l

# Return the bundle of generated cases to allow repeated application of
# parameterize decorators.
return _ParameterizedTestCaseBundle(generated_cases)
return f


def _generate_case(base, module, cls_name, mb, method_generator):
# Returns a _TestCaseTuple.
# Add parameters as members

cls = type(cls_name, (base,), mb)

# ismethod for Python 2 and isfunction for Python 3
members = inspect.getmembers(
cls, predicate=lambda _: inspect.ismethod(_) or inspect.isfunction(_))
for name, method in members:
if name.startswith('test_'):
setattr(cls, name, method_generator(method))

# Add new test class to module
setattr(module, cls_name, cls)

return _TestCaseTuple(cls, module.__name__, cls_name)


def _generate_test_cases(module_name, base_class, test_case_generator):
# Returns a list of _TestCaseTuple's holding generated test cases.
module = sys.modules[module_name]

generated_cases = []
for cls_name, members, method_generator in (
test_case_generator(base_class)):
c = _generate_case(
base_class, module, cls_name, members, method_generator)
generated_cases.append(c)

return generated_cases
55 changes: 23 additions & 32 deletions chainer/testing/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import chainer
from chainer import backend
from chainer.backends import cuda
from chainer.testing import _gen
from chainer.testing import attr
import chainerx

Expand Down Expand Up @@ -151,23 +152,30 @@ def get_array(self, np_array):
return self.device.send(np_array)


def _wrap_backend_test_method(impl, param, method_name):
backend_config = BackendConfig(param)
marks = backend_config.get_pytest_marks()
new_method_name = '{}__{}'.format(
method_name, backend_config.get_func_str())
def _test_case_generator(base, method_names, params):
# Defines the logic to generate test case classes parameterized with
# backends.

@functools.wraps(impl)
def func(self, *args, **kwargs):
impl(self, backend_config, *args, **kwargs)
for i_param, param in enumerate(params):
backend_config = BackendConfig(param)
marks = backend_config.get_pytest_marks()
cls_name = '{}_{}'.format(base.__name__, backend_config.get_func_str())
members = {}

func.__name__ = new_method_name
def method_generator(base_method):
# Generates a wrapped test method

# Apply test marks
for mark in marks:
func = mark(func)
@functools.wraps(base_method)
def new_method(self, *args, **kwargs):
return base_method(self, backend_config, *args, **kwargs)

return func, new_method_name
# Apply test marks
for mark in marks:
mark(new_method)

return new_method

yield (cls_name, members, method_generator)


def inject_backend_tests(method_names, params):
Expand All @@ -178,22 +186,5 @@ def inject_backend_tests(method_names, params):
if not all(isinstance(d, dict) for d in params):
raise TypeError('params must be a list of dicts.')

def wrap(case):
if method_names is None:
meth_names = [_ for _ in dir(case) if _.startswith('test_')]
else:
meth_names = method_names

for method_name in meth_names:
impl = getattr(case, method_name)
delattr(case, method_name)
for i_param, param in enumerate(params):
new_impl, new_method_name = _wrap_backend_test_method(
impl, param, method_name)
if hasattr(case, new_method_name):
raise RuntimeError(
'Test fixture already exists: {}'.format(
new_method_name))
setattr(case, new_method_name, new_impl)
return case
return wrap
return _gen.make_decorator(
lambda base: _test_case_generator(base, method_names, params))
167 changes: 50 additions & 117 deletions chainer/testing/parameterized.py
Original file line number Diff line number Diff line change
@@ -1,139 +1,72 @@
import collections
import functools
import inspect
import itertools
import sys
import types
import unittest

import six

from chainer.testing import _gen

# A tuple that represents a test case.
# For bare (non-generated) test cases, [1] and [2] are None.
# [0] Test case class
# [1] Module name in whicn the class is defined
# [2] Class name
_TestCaseTuple = collections.namedtuple(
'_TestCaseTuple', ('klass', 'module_name', 'class_name'))

def _parameterize_test_case_generator(base, params):
# Defines the logic to generate parameterized test case classes.

class _ParameterizedTestCaseBundle(object):
def __init__(self, cases):
# cases is a list of _TestCaseTuple's
assert isinstance(cases, list)
assert all(isinstance(tup, _TestCaseTuple) for tup in cases)

self.cases = cases


def _gen_case(base, module, i, param):
# Returns a _TestCaseTuple.

cls_name = '%s_param_%d' % (base.__name__, i)

# Add parameters as members

def __str__(self):
name = base.__str__(self)
return '%s parameter: %s' % (name, param)

mb = {'__str__': __str__}
for k, v in six.iteritems(param):
if isinstance(v, types.FunctionType):

def create_new_v():
f = v

def new_v(self, *args, **kwargs):
return f(*args, **kwargs)
return new_v

mb[k] = create_new_v()
else:
mb[k] = v

cls = type(cls_name, (base,), mb)

# Wrap test methods to generate useful error message

def wrap_test_method(method):
@functools.wraps(method)
def wrap(*args, **kwargs):
try:
return method(*args, **kwargs)
except unittest.SkipTest:
raise
except Exception as e:
s = six.StringIO()
s.write('Parameterized test failed.\n\n')
s.write('Base test method: {}.{}\n'.format(
base.__name__, method.__name__))
s.write('Test parameters:\n')
for k, v in six.iteritems(param):
s.write(' {}: {}\n'.format(k, v))
s.write('\n')
s.write('{}: {}\n'.format(type(e).__name__, e))
e_new = AssertionError(s.getvalue())
if sys.version_info < (3,):
six.reraise(AssertionError, e_new, sys.exc_info()[2])
else:
six.raise_from(e_new.with_traceback(e.__traceback__), None)
return wrap

# ismethod for Python 2 and isfunction for Python 3
members = inspect.getmembers(
cls, predicate=lambda _: inspect.ismethod(_) or inspect.isfunction(_))
for name, method in members:
if name.startswith('test_'):
setattr(cls, name, wrap_test_method(method))

# Add new test class to module
setattr(module, cls_name, cls)
for i, param in enumerate(params):
cls_name = '%s_param_%d' % (base.__name__, i)

return _TestCaseTuple(cls, module.__name__, cls_name)
def __str__(self):
name = base.__str__(self)
return '%s parameter: %s' % (name, param)

mb = {'__str__': __str__}
for k, v in six.iteritems(param):
if isinstance(v, types.FunctionType):

def _gen_cases(name, base, params):
# Returns a list of _TestCaseTuple's holding generated test cases.
module = sys.modules[name]
generated_cases = []
for i, param in enumerate(params):
c = _gen_case(base, module, i, param)
generated_cases.append(c)
return generated_cases
def create_new_v():
f = v

def new_v(self, *args, **kwargs):
return f(*args, **kwargs)
return new_v

def parameterize(*params):
def f(cases):
if isinstance(cases, _ParameterizedTestCaseBundle):
# The input is a parameterized test case.
cases = cases.cases
else:
# Input is a bare test case, i.e. not one generated from another
# parameterize.
assert issubclass(cases, unittest.TestCase)
cases = [_TestCaseTuple(cases, None, None)]

generated_cases = []
for klass, mod_name, cls_name in cases:
assert issubclass(klass, unittest.TestCase)
if mod_name is not None:
# The input is a parameterized test case.
# Remove it from its module.
delattr(sys.modules[mod_name], cls_name)
mb[k] = create_new_v()
else:
# The input is a bare test case
mod_name = klass.__module__
mb[k] = v

def method_generator(base_method):
# Generates a wrapped test method

@functools.wraps(base_method)
def new_method(self, *args, **kwargs):
try:
return base_method(self, *args, **kwargs)
except unittest.SkipTest:
raise
except Exception as e:
s = six.StringIO()
s.write('Parameterized test failed.\n\n')
s.write('Base test method: {}.{}\n'.format(
base.__name__, base_method.__name__))
s.write('Test parameters:\n')
for k, v in six.iteritems(param):
s.write(' {}: {}\n'.format(k, v))
s.write('\n')
s.write('{}: {}\n'.format(type(e).__name__, e))
e_new = AssertionError(s.getvalue())
if sys.version_info < (3,):
six.reraise(AssertionError, e_new, sys.exc_info()[2])
else:
six.raise_from(
e_new.with_traceback(e.__traceback__), None)
return new_method

yield (cls_name, mb, method_generator)

# Generate parameterized test cases out of the input test case.
l = _gen_cases(mod_name, klass, params)
generated_cases += l

# Return the bundle of generated cases to allow repeated application of
# parameterize decorators.
return _ParameterizedTestCaseBundle(generated_cases)
return f
def parameterize(*params):
return _gen.make_decorator(
lambda base: _parameterize_test_case_generator(base, params))


def product(parameter):
Expand Down

0 comments on commit 108fa46

Please sign in to comment.