In [1]:
#default_exp test

In [66]:
#export
from collections.abc import Iterable,Iterator,Generator,Sequence
from numpy import array,ndarray
from copy import copy,deepcopy
import numpy as np
import hashlib,itertools,operator,random,io,re,warnings
from contextlib import redirect_stdout,contextmanager
from functools import partial,reduce
from collections import OrderedDict,defaultdict,Counter,namedtuple

def test_fail(f, msg='', contains=''):
    "Fails with `msg` unless `f()` raises an exception and (optionally) has `contains` in `e.args`"
    try: f()
    except Exception as e:
        assert not contains or contains in str(e)
        return
    assert False,f"Expected exception but none raised. {msg}"

In [3]:
def _fail(): raise Exception("foobar")
test_fail(_fail, contains="foo")

def _fail(): raise Exception()
test_fail(_fail)

In [4]:
#export
def test(a, b, cmp,cname=None):
    "`assert` that `cmp(a,b)`; display inputs and `cname or cmp.__name__` if it fails"
    if cname is None: cname=cmp.__name__
    assert cmp(a,b),f"{cname}:\n{a}\n{b}"

In [18]:
#export
def equals(a,b):
    "Compares `a` and `b` for equality; supports sublists, tensors and arrays too"
    if one_is_instance(a,b,type): return a==b
    if hasattr(a, '__array_eq__'): return a.__array_eq__(b)
    if hasattr(b, '__array_eq__'): return b.__array_eq__(a)
    cmp = (np.array_equal if one_is_instance(a, b, ndarray       ) else
           operator.eq    if one_is_instance(a, b, (str,dict,set)) else
           all_equal      if is_iter(a) or is_iter(b) else
           operator.eq)
    return cmp(a,b)

def is_iter(o):
    "Test whether `o` can be used in a `for` loop"
    #Rank 0 tensors in PyTorch are not really iterable
    return isinstance(o, (Iterable,Generator)) and getattr(o,'ndim',1)

def is_coll(o):
    "Test whether `o` is a collection (i.e. has a usable `len`)"
    #Rank 0 tensors in PyTorch do not have working `len`
    return hasattr(o, '__len__') and getattr(o,'ndim',1)

In [24]:
#export
def all_equal(a,b):
    "Compares whether `a` and `b` are the same length and have the same contents"
    if not is_iter(b): return False
    return all(equals(a_,b_) for a_,b_ in itertools.zip_longest(a,b))

def one_is_instance(a, b, t): return isinstance(a,t) or isinstance(b,t)

In [29]:
test(['abc'], ['abc'],  all_equal)

In [30]:
#export
def test_eq(a,b):
    "`test` that `a==b`"
    test(a,b,equals, '==')

In [33]:
test_eq([1,2],[1,2])
test_eq([1,2],map(int,[1,2]))
test_eq(array([1,2]),array([1,2]))
test_eq(array([1,2]),array([1,2]))
test_eq([array([1,2]),3],[array([1,2]),3])
test_eq(dict(a=1,b=2), dict(b=2,a=1))
test_fail(lambda: test_eq([1,2], 1), contains="==")
test_eq({'a', 'b', 'c'}, {'c', 'a', 'b'})

In [34]:
#export
def nequals(a,b):
    "Compares `a` and `b` for `not equals`"
    return not equals(a,b)

In [35]:
test(['abc'], ['ab' ], nequals)

In [36]:
#export
def test_eq(a,b):
    "`test` that `a==b`"
    test(a,b,equals, '==')

In [37]:
#export
def test_eq_type(a,b):
    "`test` that `a==b` and are same type"
    test_eq(a,b)
    test_eq(type(a),type(b))
    if isinstance(a,(list,tuple)): test_eq(map(type,a),map(type,b))

In [38]:
test_eq_type(1,1)
test_fail(lambda: test_eq_type(1,1.))
test_eq_type([1,1],[1,1])
test_fail(lambda: test_eq_type([1,1],(1,1)))
test_fail(lambda: test_eq_type([1,1],[1,1.]))

In [39]:
#export
def test_ne(a,b):
    "`test` that `a!=b`"
    test(a,b,nequals,'!=')

In [40]:
test_ne([1,2],[1])
test_ne([1,2],[1,3])
test_ne(array([1,2]),array([1,1]))
test_ne(array([1,2]),array([1,1]))
test_ne([array([1,2]),3],[array([1,2])])
test_ne([3,4],array([3]))
test_ne([3,4],array([3,5]))
test_ne(dict(a=1,b=2), ['a', 'b'])
test_ne(['a', 'b'], dict(a=1,b=2))

In [41]:
#export
def is_close(a,b,eps=1e-5):
    "Is `a` within `eps` of `b`"
    if hasattr(a, '__array__') or hasattr(b,'__array__'):
        return (abs(a-b)<eps).all()
    if isinstance(a, (Iterable,Generator)) or isinstance(b, (Iterable,Generator)):
        return is_close(np.array(a), np.array(b), eps=eps)
    return abs(a-b)<eps

In [42]:
#export
def test_close(a,b,eps=1e-5):
    "`test` that `a` is within `eps` of `b`"
    test(a,b,partial(is_close,eps=eps),'close')

In [45]:
test_close(1,1.001,eps=1e-2)
test_fail(lambda: test_close(1,1.001))
test_close([-0.001,1.001], [0.,1.], eps=1e-2)
test_close(np.array([-0.001,1.001]), np.array([0.,1.]), eps=1e-2)
test_close(array([-0.001,1.001]), array([0.,1.]), eps=1e-2)

In [46]:
#export
def test_is(a,b):
    "`test` that `a is b`"
    test(a,b,operator.is_, 'is')

In [47]:
test_fail(lambda: test_is([1], [1]))
a = [1]
test_is(a, a)

In [48]:
#export
def test_shuffled(a,b):
    "`test` that `a` and `b` are shuffled versions of the same sequence of items"
    test_ne(a, b)
    test_eq(Counter(a), Counter(b))

In [55]:
a = list(range(50))
b = copy(a)
random.shuffle(b)
test_shuffled(a,b)
test_fail(lambda:test_shuffled(a,a))

In [56]:
a = 'abc'
b = 'abcabc'
test_fail(lambda:test_shuffled(a,b))

In [57]:
a = ['a', 42, True] 
b = [42, True, 'a']
test_shuffled(a,b)

In [58]:
#export
def test_stdout(f, exp, regex=False):
    "Test that `f` prints `exp` to stdout, optionally checking as `regex`"
    s = io.StringIO()
    with redirect_stdout(s): f()
    if regex: assert re.search(exp, s.getvalue()) is not None
    else: test_eq(s.getvalue(), f'{exp}\n' if len(exp) > 0 else '')

In [63]:
test_stdout(lambda: print('hi'), 'hi')
test_fail(lambda: test_stdout(lambda: print('hi'), 'ho'))
test_stdout(lambda: 1+1, '')
test_stdout(lambda: print('hi there!'), r'^hi.*!$', regex=True)

In [64]:
#export
def test_warns(f, show=False):
    with warnings.catch_warnings(record=True) as w:
        f()
        test_ne(len(w), 0)
        if show: 
            for e in w: print(f"{e.category}: {e.message}")

In [67]:
test_warns(lambda: warnings.warn("Oh no!"), {})
test_fail(lambda: test_warns(lambda: 2+2))