In [12]:
# hide
# default_exp utils.nbdev_utils
from nbdev.showdoc import *

# nbdev utils

> Temporary home for nbdev utils. One of the utilities is functions for  running tests with nbdev.

In [13]:
#export
import os
import joblib
from pathlib import Path
import socket

In [14]:
import pytest
import os

In [15]:
#export 
def cd_root ():
    max_count=10
    while not os.path.exists('settings.ini'):
        os.chdir('..')
        max_count = max_count - 1
        if max_count <= 0:
            break

In [16]:
#export   
def nbdev_setup (no_warnings=True):
    if no_warnings:
        from warnings import filterwarnings
        filterwarnings("ignore")
    cd_root ()

In [17]:
cd_root ()

In [74]:
#export
class TestRunner ():
    def __init__ (self, do_all=False, do_test=None, all_tests=None, tags=None, targets=None, 
                  remote_targets=None, load=True, save=True, path_config='config_test/test_names.pk', 
                  localhostname=None, show=True):
        
        if save:
            Path(path_config).parent.mkdir(parents=True, exist_ok=True)
            
        if load and Path(path_config).exists():
            do_test_, all_tests_, tags_, targets_, remote_targets_, localhostname_ = joblib.load (path_config)
            do_test = do_test_ if do_test is None else do_test
            all_tests = all_tests_ if all_tests is None else all_tests
            tags = tags_ if tags is None else tags
            targets = targets_ if targets is None else targets
            remote_targets = remote_targets_ if remote_targets is None else remote_targets
            localhostname = localhostname_ if localhostname is None else localhostname
        else:
            do_test = [] if do_test is None else do_test
            all_tests = [] if all_tests is None else all_tests
            tags = {} if tags is None else tags
            targets = [] if targets is None else targets
            remote_targets = ['dummy'] if remote_targets is None else remote_targets
            localhostname = 'DataScience-VMs-03' if localhostname is None else localhostname
        
        if not isinstance(targets, list):
            targets = [targets]
        
        self.do_test = do_test
        self.all_tests = all_tests
        self.tags = tags
        self.do_all = do_all
        self.targets = targets
        self.save = save
        self.path_config = path_config
        self.hostname = socket.gethostname()
        self.localhostname = localhostname
        self.remote_targets = remote_targets
        self.is_remote = self.localhostname != self.hostname
        self.show = show
    
    def run (self, test_func, data_func=None, do=False, include=False, debug=False,
            exclude=False, tag=None, show=None):
        name = test_func.__name__ 
        show = self.show if show is None else show 
        if (name not in self.all_tests) and not exclude:
            self.all_tests.append (name)
        if include and name not in self.do_test:
            self.do_test.append (name)
        if tag is not None:
            if tag in self.tags and name not in self.tags[tag]:
                self.tags[tag].append(name)
            else:
                self.tags[tag] = [name]
        if self.save:
            joblib.dump ([self.do_test, self.all_tests, self.tags, self.targets,
                         self.remote_targets, self.localhostname], self.path_config)
        targets = self.remote_targets if self.is_remote else self.targets
        if ((name in self.do_test) or do or (self.do_all and not exclude) or
            (tag is not None) and (tag in targets)):
            if data_func is not None:
                data = data_func()
                args = [data]
            else:
                args = []
            if debug:
                import pdb
                pdb.runcall (test_func, *args)
            else:
                if show:
                    print (f'running {name}')
                test_func (*args)

In [75]:
tst = TestRunner ()

In [78]:
# export test.utils.test_nbdev_utils
def example_people_data():
    return 5

def myf (x):
    return x*2

def my_first_test (example_people_data):
    print ('first passes')
    assert myf (example_people_data) == 10

def second_fails ():
    print ('second fails')
    assert False
    
def third_fails ():
    print ('third fails')
    assert False
    
def test_test_runner ():
    # one test
    tst_ = TestRunner (do_test=None, all_tests=None, load=False)
    tst_.run (my_first_test, example_people_data, True)
    assert tst_.all_tests == ['my_first_test']
    assert os.listdir('config_test')==['test_names.pk']
    
    do_test_, all_tests_, tags_, targets_, remote_targets_, localhostname_ = joblib.load ('config_test/test_names.pk')
    assert all_tests_==['my_first_test']
    assert remote_targets_==['dummy']
    assert tags_=={}
    
def test_test_runner_two_tests ():
    tst_ = TestRunner (do_test=None, all_tests=None, targets='dummy', load=False)
    assert tst_.do_test==[]
    assert tst_.all_tests==[]
    tst_.run (my_first_test, example_people_data, tag='dummy')
    tst_.run (second_fails, tag='slow')
    with pytest.raises (AssertionError):
        tst_.run (third_fails, tag='dummy')

    assert tst_.all_tests == ['my_first_test', 'second_fails', 'third_fails']
    assert tst_.tags == {'dummy': ['my_first_test', 'third_fails'], 'slow': ['second_fails']}
    assert tst_.targets==['dummy']
    assert tst_.do_test==[]

    do_test_, all_tests_, tags_, targets_, remote_targets_, localhostname_ = joblib.load ('config_test/test_names.pk')

    assert all_tests_ == ['my_first_test', 'second_fails', 'third_fails']
    assert tags_ == {'dummy': ['my_first_test', 'third_fails'], 'slow': ['second_fails']}
    assert targets_==['dummy']
    assert do_test_==[]
    
    tst_ = TestRunner (do_test=None, all_tests=None, load=True)
    assert tst_.all_tests == ['my_first_test', 'second_fails', 'third_fails']
    
    tst_ = TestRunner (do_test=None, all_tests=None, load=False)
    assert tst_.all_tests == []
    
def test_test_runner_two_targets ():
    tst_ = TestRunner (targets=['dummy','slow'], load=False)
    tst_.run (my_first_test, example_people_data, tag='slow')
    tst_.run (second_fails, tag='other')
    with pytest.raises (AssertionError):
        tst_.run (third_fails, tag='dummy')

In [None]:
tst.run (test_test_runner, do=True)
tst.run (test_test_runner_two_tests, do=True)
tst.run (test_test_runner_two_targets, do=True)