In [1]:
# hide
# default_exp utils.utils
from nbdev.showdoc import *
from block_types.utils.nbdev_utils import nbdev_setup, TestRunner

nbdev_setup ()
tst = TestRunner (targets=['dummy'])

# Utils

In [5]:
# export
import sys
import os
import random as python_random
import logging
import shutil
from pathlib import Path
import re
import inspect 
import numpy as np

# block-types
import block_types.config.bt_defaults as dflt

In [6]:
#for tests
import pytest
import numpy as np

## make_reproducible

In [7]:
# export
def make_reproducible ():
    """
    Make results obtained from neural network model reproducible. 
    
    This function should be run at the very beginning. The result 
    of calling this is that the pipeline produces the exact same 
    results as previous runs.
    """
    os.environ['CUDA_VISIBLE_DEVICES'] = ''
    os.environ['PYTHONHASHSEED'] = '0'

    # The below is necessary for starting Numpy generated random numbers
    # in a well-defined initial state.
    np.random.seed(123)

    # The below is necessary for starting core Python generated random numbers
    # in a well-defined state.
    python_random.seed(123)

    # The below set_seed() will make random number generation
    # in the TensorFlow backend have a well-defined initial state.
    # For further details, see:
    # https://www.tensorflow.org/api_docs/python/tf/random/set_seed
    try:
        import tensorflow as tf
        tf.random.set_seed(1234)
    except:
        print ('tensorflow needs to be installed in order to run make_reproducible()')

### Usage example

In [8]:
# exports tests.utils.test_utils
#@pytest.mark.reference_fails
def test_make_reproducible ():
    make_reproducible ()
    x = np.random.rand(10)
    make_reproducible ()
    x2 = np.random.rand(10)
    assert (x==x2).all()

In [9]:
tst.run (test_make_reproducible, tag='dummy')

running test_make_reproducible


2022-02-02 11:37:42.793921: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/intel/compilers_and_libraries_2018.1.163/linux/tbb/lib/intel64_lin/gcc4.7:/opt/intel/compilers_and_libraries_2018.1.163/linux/compiler/lib/intel64_lin:/opt/intel/compilers_and_libraries_2018.1.163/linux/mkl/lib/intel64_lin:/opt/intel/compilers_and_libraries_2018.1.163/linux/tbb/lib/intel64_lin/gcc4.7:/opt/intel/compilers_and_libraries_2018.1.163/linux/compiler/lib/intel64_lin:/opt/intel/compilers_and_libraries_2018.1.163/linux/mkl/lib/intel64_lin::/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64/:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64/
2022-02-02 11:37:42.793961: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.


## set_logger

In [10]:
# TODO: use a custom class instead of separate functions for dealing with logging

### get_logging_level

In [11]:
# export 
def get_logging_level (verbose):
    return logging.DEBUG if verbose == 2 else logging.INFO if verbose == 1 else logging.WARNING

### delete_logger

In [12]:
#export
def delete_logger (name, path_results='log', filename='logs.txt'):
    if filename is not None and path_results is not None:
        path_to_log_file = f'{path_results}/{filename}'
        if os.path.exists (path_to_log_file):
            os.remove (path_to_log_file)

### set_logger

In [18]:
#export    
def set_logger (name, path_results='log', stdout=True,
                mode='a', just_message = False, filename='logs.txt',
                logging_level=logging.DEBUG, verbose=None, verbose_out=None, 
                print_path=False):
    """Set logger."""
    logger = logging.getLogger(name)
    if verbose is not None:
        logging_level = get_logging_level (verbose)
    if verbose_out is not None:
        logging_level_out = get_logging_level (verbose_out)
    else:
        logging_level_out = logging_level
    logger.setLevel(logging_level)

    for hdlr in logger.handlers[:]:  # remove all old handlers
        logger.removeHandler(hdlr)

    #if not logger.hasHandlers():

    # Create handlers
    if stdout:
        c_handler = logging.StreamHandler()
        c_handler.setLevel(logging_level_out)
        c_format = logging.Formatter('%(message)s')
        c_handler.setFormatter(c_format)
        logger.addHandler(c_handler)
    else:
        logger.removeHandler(sys.stderr)

    if filename is not None and path_results is not None:
        os.makedirs(path_results, exist_ok=True)
        path_to_log_file = f'{path_results}/{filename}'
        #pdb.set_trace()
        if print_path: print (f'log written in {os.path.abspath(path_to_log_file)}')
        f_handler = logging.FileHandler (path_to_log_file, mode = mode)
        f_handler.setLevel(logging_level)
        if just_message:
            f_format = logging.Formatter('%(asctime)s - %(message)s')
        else:
            f_format = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s {%(filename)s:%(funcName)s:%(lineno)d} - %(message)s')
        f_handler.setFormatter(f_format)
        logger.addHandler(f_handler)
    #logger.propagate = 0
    logger.propagate = False

    return logger

### set_empty_logger

In [14]:
#export
def set_empty_logger ():
    return set_logger ('no_logging', stdout=False, filename=None, verbose=0)

### set_verbosity

In [15]:
# export
def set_verbosity (name=None, logger=None, logging_level=logging.DEBUG, verbose=None, verbose_out=None):
    """Set logger."""
    if logger is None:
        assert name is not None, 'either logger or name must be not None'
        logger = logging.getLogger(name)
    if verbose is not None:
        logging_level = get_logging_level (verbose)
    if verbose_out is not None:
        logging_level_out = get_logging_level (verbose_out)
    else:
        logging_level_out = logging_level
    logger.setLevel(logging_level)

    for hdlr in logger.handlers[:]:  # remove all old handlers
        hdlr.setLevel(logging_level)

### Usage example

In [16]:
# exports tests.utils.test_utils
#@pytest.mark.reference_fails
def test_set_logger ():
    path_results = 'test_logger'
    logger = set_logger ('test', path_results=path_results)
    assert os.listdir ('test_logger')==['logs.txt']
    assert logger.level==logging.DEBUG

    logger = set_logger ('test', path_results=path_results, verbose=1)
    assert logger.level==logging.INFO

    set_verbosity (logger=logger, verbose=0)
    assert logger.level==logging.WARNING

    for hdlr in logger.handlers[:]:  
        assert hdlr.level==logging.WARNING
        
    delete_logger ('test', path_results=path_results)
    assert os.listdir ('test_logger')==[]
    
    logger = set_empty_logger ()
    logger.critical ('this does not show up')
    assert os.listdir ('test_logger')==[]
        
    shutil.rmtree (path_results)

In [20]:
tst.run (test_set_logger, tag='dummy')

> [0;32m/tmp/ipykernel_63467/4030641049.py[0m(4)[0;36mtest_set_logger[0;34m()[0m
[0;32m      2 [0;31m[0;31m#@pytest.mark.reference_fails[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      3 [0;31m[0;32mdef[0m [0mtest_set_logger[0m [0;34m([0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m----> 4 [0;31m    [0mpath_results[0m [0;34m=[0m [0;34m'test_logger'[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      5 [0;31m    [0mlogger[0m [0;34m=[0m [0mset_logger[0m [0;34m([0m[0;34m'test'[0m[0;34m,[0m [0mpath_results[0m[0;34m=[0m[0mpath_results[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      6 [0;31m    [0;32massert[0m [0mos[0m[0;34m.[0m[0mlistdir[0m [0;34m([0m[0;34m'test_logger'[0m[0;34m)[0m[0;34m==[0m[0;34m[[0m[0;34m'logs.txt'[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  b set_empty_logger


Breakpoint 3 at /tmp/ipykernel_63467/500908706.py:2


ipdb>  c


> [0;32m/tmp/ipykernel_63467/500908706.py[0m(3)[0;36mset_empty_logger[0;34m()[0m
[0;32m      1 [0;31m[0;31m#export[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m[1;31m3[0;32m     2 [0;31m[0;32mdef[0m [0mset_empty_logger[0m [0;34m([0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m----> 3 [0;31m    [0;32mreturn[0m [0mset_logger[0m [0;34m([0m[0;34m'no_logging'[0m[0;34m,[0m [0mstdout[0m[0;34m=[0m[0;32mFalse[0m[0;34m,[0m [0mfilename[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m [0mverbose[0m[0;34m=[0m[0;36m0[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  s


--Call--
> [0;32m/tmp/ipykernel_63467/3883365675.py[0m(2)[0;36mset_logger[0;34m()[0m
[0;32m      1 [0;31m[0;31m#export[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m----> 2 [0;31mdef set_logger (name, path_results='log', stdout=True,
[0m[0;32m      3 [0;31m                [0mmode[0m[0;34m=[0m[0;34m'a'[0m[0;34m,[0m [0mjust_message[0m [0;34m=[0m [0;32mFalse[0m[0;34m,[0m [0mfilename[0m[0;34m=[0m[0;34m'logs.txt'[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      4 [0;31m                [0mlogging_level[0m[0;34m=[0m[0mlogging[0m[0;34m.[0m[0mDEBUG[0m[0;34m,[0m [0mverbose[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m [0mverbose_out[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      5 [0;31m                print_path=False):
[0m


ipdb>  ll


[0;32m----> 2 [0;31mdef set_logger (name, path_results='log', stdout=True,
[0m[1;32m      3 [0m                [0mmode[0m[0;34m=[0m[0;34m'a'[0m[0;34m,[0m [0mjust_message[0m [0;34m=[0m [0;32mFalse[0m[0;34m,[0m [0mfilename[0m[0;34m=[0m[0;34m'logs.txt'[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[1;32m      4 [0m                [0mlogging_level[0m[0;34m=[0m[0mlogging[0m[0;34m.[0m[0mDEBUG[0m[0;34m,[0m [0mverbose[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m [0mverbose_out[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[1;32m      5 [0m                print_path=False):
[1;32m      6 [0m    [0;34m"""Set logger."""[0m[0;34m[0m[0;34m[0m[0m
[1;32m      7 [0m    [0mlogger[0m [0;34m=[0m [0mlogging[0m[0;34m.[0m[0mgetLogger[0m[0;34m([0m[0mname[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[1;32m      8 [0m    [0;32mif[0m [0mverbose[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m

ipdb>  b 22


Breakpoint 4 at /tmp/ipykernel_63467/3883365675.py:22


ipdb>  c


> [0;32m/tmp/ipykernel_63467/3883365675.py[0m(22)[0;36mset_logger[0;34m()[0m
[0;32m     20 [0;31m[0;34m[0m[0m
[0m[0;32m     21 [0;31m    [0;31m# Create handlers[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m[1;31m4[0;32m--> 22 [0;31m    [0;32mif[0m [0mstdout[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     23 [0;31m        [0mc_handler[0m [0;34m=[0m [0mlogging[0m[0;34m.[0m[0mStreamHandler[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     24 [0;31m        [0mc_handler[0m[0;34m.[0m[0msetLevel[0m[0;34m([0m[0mlogging_level_out[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  logger.handlers


[]


ipdb>  n


> [0;32m/tmp/ipykernel_63467/3883365675.py[0m(29)[0;36mset_logger[0;34m()[0m
[0;32m     27 [0;31m        [0mlogger[0m[0;34m.[0m[0maddHandler[0m[0;34m([0m[0mc_handler[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     28 [0;31m[0;34m[0m[0m
[0m[0;32m---> 29 [0;31m    [0;32mif[0m [0mfilename[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m [0;32mand[0m [0mpath_results[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     30 [0;31m        [0mos[0m[0;34m.[0m[0mmakedirs[0m[0;34m([0m[0mpath_results[0m[0;34m,[0m [0mexist_ok[0m[0;34m=[0m[0;32mTrue[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     31 [0;31m        [0mpath_to_log_file[0m [0;34m=[0m [0;34mf'{path_results}/{filename}'[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_63467/3883365675.py[0m(43)[0;36mset_logger[0;34m()[0m
[0;32m     41 [0;31m        [0mlogger[0m[0;34m.[0m[0maddHandler[0m[0;34m([0m[0mf_handler[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     42 [0;31m    [0;31m#logger.propagate = 0[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 43 [0;31m    [0mlogger[0m[0;34m.[0m[0mpropagate[0m [0;34m=[0m [0;32mFalse[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     44 [0;31m[0;34m[0m[0m
[0m[0;32m     45 [0;31m    [0;32mreturn[0m [0mlogger[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/tmp/ipykernel_63467/3883365675.py[0m(45)[0;36mset_logger[0;34m()[0m
[0;32m     41 [0;31m        [0mlogger[0m[0;34m.[0m[0maddHandler[0m[0;34m([0m[0mf_handler[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     42 [0;31m    [0;31m#logger.propagate = 0[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     43 [0;31m    [0mlogger[0m[0;34m.[0m[0mpropagate[0m [0;34m=[0m [0;32mFalse[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     44 [0;31m[0;34m[0m[0m
[0m[0;32m---> 45 [0;31m    [0;32mreturn[0m [0mlogger[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  logger.handlers


[]


ipdb>  logger.critical('hello')


hello


ipdb>  logger.removeHandler(sys.stderr)
ipdb>  logger.critical('hello')


hello


ipdb>  logger.propagate


False


ipdb>  logger.removeHandler(sys.stdout)
ipdb>  logger.critical('hello')


hello


ipdb>  logger.handlers


[]




hello


ipdb>  logger.info('hello')
ipdb>  quit


## remove_previous_results

In [11]:
# export
def remove_previous_results (path_results=dflt.path_results):
    """Remove folder containing previous results, if exists."""
    if Path(path_results).exists():
        shutil.rmtree(path_results)

## set_tf_loglevel

In [12]:
# export
def set_tf_loglevel(level):
    if level >= logging.FATAL:
        os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
    if level >= logging.ERROR:
        os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    if level >= logging.WARNING:
        os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
    else:
        os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
    logging.getLogger('tensorflow').setLevel(level)
    
    try:
        import tensorflow as tf
    except:
        print ('tensorflow needs to be installed in order to call set_tf_loglevel()')
        
    tf.get_logger().setLevel(level)

### Usage example

In [13]:
# exports tests.utils.test_utils
#@pytest.mark.reference_fails
def test_set_tf_loglevel ():
    set_tf_loglevel (logging.DEBUG)
    assert logging.getLogger('tensorflow').getEffectiveLevel()==logging.DEBUG

In [14]:
tst.run (test_set_tf_loglevel, tag='dummy', debug=False)

running test_set_tf_loglevel


## store_attrs

### _store_attr

In [15]:
#export
def argnames(f, frame=False):
    "Names of arguments to function or frame `f`"
    code = getattr(f, 'f_code' if frame else '__code__')
    return code.co_varnames[:code.co_argcount+code.co_kwonlyargcount]

In [16]:
#export
def _store_attr(self, overwrite=False, error_if_present=False, ignore=set(), **attrs):
    stored = getattr(self, '__stored_args__', None)
    for n,v in attrs.items():
        if hasattr(self, n) and not overwrite:
            if (error_if_present and getattr(self, n) is not v and n not in ignore 
                and not callable(getattr(self, n))):
                raise RuntimeError (f'field {n} already present in {self}')
            continue
        setattr(self, n, v)
        if stored is not None: stored[n] = v

### obtain_class_specific_attrs

In [17]:
#export
def get_specific_dict_param (self, **kwargs):
    if (hasattr(self, 'name') and 
        kwargs.get(self.name) is not None and
        isinstance(kwargs[self.name], dict)):
        k = self.name
    elif (hasattr(self, 'class_name') and 
        kwargs.get(self.class_name) is not None and
        isinstance(kwargs[self.class_name], dict)):
        k = self.class_name
    elif (hasattr(self, 'group') and 
        kwargs.get(self.group) is not None and
        isinstance(kwargs[self.group], dict)):
        k = self.group
    elif (hasattr(self, 'hierarchy_level') and 
        kwargs.get('levels') is not None and
        isinstance(kwargs['levels'], dict) and
        'until' in kwargs['levels'] and 
        self.hierarchy_level <= kwargs['levels']['until']):
        k = 'levels'
    else:
        k = None
    
    return k

def obtain_class_specific_attrs (self, **kwargs):
    """Overwrites parameters in kwargs with those found in a dictionary of the same name 
    given to this component.

    Checks if there is a parameter whose name is the name of the class or the name given 
    to this component. In that case, it overwrites the parameters in kwargs with those 
    found in that dictionary. The parameters in kwargs can be used as *global* parameters
    for multiple components, while parameters specific of one component can be set using 
    a dictionary with the name of that component. See example below.
    """
    k = get_specific_dict_param (self, **kwargs)
    
    if k is not None:
        config = kwargs[k]
    else:
        config = {}

    return config

### get_hierarchy_level

In [18]:
# export
def get_hierarchy_level (base_class=object):
    stack = inspect.stack()
    hierarchy_level=0
    last_type = None
    for frame_number in range(1, len(stack)):
        fr = sys._getframe(frame_number)
        fr_stack = stack[frame_number]
        if fr is not fr_stack[0]:
            raise RuntimeError ('fr is not fr_stack[0]')
            
        args = argnames(fr, True)
        if len(args) > 0:
            self = fr.f_locals[args[0]]
            if last_type is None:
                last_type = type(self)
            if ((fr_stack.function == '__init__') and 
                isinstance(self, base_class) and 
                (type(self) != last_type) ):
                hierarchy_level += 1
                last_type = type(self)
    return hierarchy_level

#### test get_hierarchy_level

In [19]:
# exports tests.utils.test_utils
#@pytest.mark.reference_fails
def test_get_hierarchy ():
    def f (**kwargs):
        return B(**kwargs)

    class A ():
        def __init__ (self, x=3, **kwargs):
            self.hierarchy_level = get_hierarchy_level(base_class=A)

    class B(A):
        def __init__ (self, y=10, **kwargs):
            super().__init__ (**kwargs)
            self.ab = A (**kwargs)

    class C(B):
        def __init__ (self, z=100, **kwargs):
            super().__init__ (**kwargs)
            self.a = A(**kwargs)
            self.b = f(**kwargs)
    class D(C):
        def __init__ (self, h=100, **kwargs):
            super().__init__ (**kwargs)
            self.c = C(**kwargs)
            self.b = f(**kwargs)
    a = A()
    b = B()
    c = C()
    d = D()

    assert (a.hierarchy_level==0 and b.hierarchy_level==0 and c.hierarchy_level==0
            and c.a.hierarchy_level==1 and c.b.hierarchy_level==1 and c.ab.hierarchy_level==1 
            and c.b.ab.hierarchy_level==2 
            and d.hierarchy_level == 0 and d.a.hierarchy_level == 1 and d.b.hierarchy_level == 1
            and d.ab.hierarchy_level == 1 
            and d.b.hierarchy_level==1 and d.b.ab.hierarchy_level==2 and d.c.b.ab.hierarchy_level==3)

In [20]:
tst.run (test_get_hierarchy, tag='dummy')

running test_get_hierarchy


### replace_attr_and_store

In [21]:
#export
def replace_attr_and_store (names=None, but='', store_args=None, 
                            recursive=True, base_class=object, 
                            replace_generic_attr=True, overwrite=False,
                            error_if_present=False, ignore=set(), overwrite_name=True, 
                            self=None, include_first=False, **attrs):
    """
    Replaces generic attributes and stores them into attrs in `self`.
        
    If kwargs contains an attribute called the same way as the class of
    self, all the keys in that dictionary are considered class-specific
    attributes whose value overwrites any attribute in kwargs of the same
    name.
    
    The function is called recursively in the hierarchy of parent classes, 
    from the leaf to the root class, until it reaches an ascendant that 
    is not an instance of `base_class`. 
    
    Most of the implementation is taken from fastcore library, `store_attrs`
    function.
    """
    frame_number=1
    stack = inspect.stack()
    original_type = None
    input_attrs = attrs
    while True:
        fr = sys._getframe(frame_number)
        fr_stack = stack[frame_number]
        if fr is not fr_stack[0]:
            raise RuntimeError ('fr is not fr_stack[0]')
        
        args = argnames(fr, True)
        if recursive:
            if len(args) > 0:
                self = fr.f_locals[args[0]]
                if not isinstance(self, base_class):
                    break
                if fr_stack.function != '__init__':
                    break
                if original_type is None:
                    original_type = type(self)
                    
                if type(self) != original_type:
                    break
            else:
                break
        else:
            if self is not None:
                if include_first:
                    args = [self] + list(args) 
            elif len(args) > 0:
                self = fr.f_locals[args[0]]
            else:
                raise RuntimeError ('self not found')
        
        if store_args is None: store_args = not hasattr(self,'__slots__')
        if store_args and not hasattr(self, '__stored_args__'): self.__stored_args__ = {}
        if names and isinstance(names,str): names = re.split(', *', names)
        #pdb.set_trace()
        ns = names if names is not None else getattr(self, '__slots__', args[1:])
        added = {n:fr.f_locals[n] for n in ns}
        attrs = {**input_attrs, **added}
        if replace_generic_attr and 'kwargs' in fr.f_locals:
            class_specific_attrs = obtain_class_specific_attrs (self, **fr.f_locals['kwargs'])
            attrs.update(class_specific_attrs)
        else:
            class_specific_attrs={}
        if isinstance(but,str): but = re.split(', *', but)
        attrs = {k:v for k,v in attrs.items() if k not in but}
        _store_attr(self, overwrite=overwrite, error_if_present=error_if_present, 
                    ignore=ignore, **attrs)
        if overwrite_name and ('name' in class_specific_attrs 
                               or 'class_name' in class_specific_attrs):
            new_attrs = {k:class_specific_attrs[k] for k in ['name', 'class_name'] 
                         if k in class_specific_attrs}
            _store_attr(self, overwrite=True, error_if_present=error_if_present, 
                        ignore=ignore, **new_attrs)
        
        if not recursive:
            break
        
        frame_number += 1
        

#### test replace_attr_and_store

In [22]:
ignore=set()
ignore.update ({3,4})
ignore

{3, 4}

In [28]:
# exports tests.utils.test_utils
#@pytest.mark.reference_fails
def test_replace_attr_and_store ():
    def f (**kwargs):
        return B(**kwargs)

    class A ():
        def __init__ (self, x=3, **kwargs):
            replace_attr_and_store (base_class=A)

    class B(A):
        def __init__ (self, y=10, **kwargs):
            super().__init__ (**kwargs)
            self.ab = A (**kwargs)

    class C(A):
        def __init__ (self, z=100, **kwargs):
            super().__init__ (**kwargs)
            self.a = A(**kwargs)
            self.b = f(**kwargs)
    a = A()
    b = B()
    c = C()

    assert a.x==3 and b.y==10 and b.x==3 and c.z==100 and c.x==3 and c.a.x==3 and c.b.y==10 and c.b.x==3

    with pytest.raises (AttributeError):
        print (c.y)

    with pytest.raises (AttributeError):
        print (c.a.y)

    with pytest.raises (AttributeError):
        print (c.a.z)

    with pytest.raises (AttributeError):
        print (b.z)

    with pytest.raises (AttributeError):
        print (b.ab.y)
        
    # **************************************************
    # test changing the argument
    # **************************************************
    def f2 (y=10, **kwargs):
        y = 2*y
        return B2 (y=y, **kwargs)

    class A2 ():
        def __init__ (self, x='hello', **kwargs):
            x=f'{x} world'
            replace_attr_and_store (base_class=A2)

    class B2 (A2):
        def __init__ (self, y=10, **kwargs):
            super().__init__ (**kwargs)
            self.ab = A2 (**kwargs)

    class C2 (A2):
        def __init__ (self, z=100, **kwargs):
            super().__init__ (**kwargs)
            self.a = A2(**kwargs)
            self.b = f2 (**kwargs)
    a = A2()
    b = B2()
    c = C2()
    assert (a.x, b.x, b.ab.x) == ('hello world', 'hello world', 'hello world')
    assert c.b.y == 20

    a = A2 ('hey')
    b = B2 ()
    assert (a.x, b.x, b.ab.x) == ('hey world', 'hello world', 'hello world')

    c = C2 (y=3, x='Hi')
    assert (c.b.y, c.b.x) == (6, 'Hi world')
    
    # **************************************************
    # test error_if_present
    # **************************************************
    
    def f3 (**kwargs):
        return B(**kwargs)

    class A3 ():
        def __init__ (self, x=3, ignore=set(), **kwargs):
            replace_attr_and_store (base_class=A3, error_if_present=True,
                                    ignore=ignore, but='ignore')

    class B3(A3):
        def __init__ (self, y=10, **kwargs):
            super().__init__ (**kwargs)
            self.ab = A3 (**kwargs)
            
    
    a = A3()
    
    b = B3()
    
    b2 = B3(x=5, y=20)
    
    assert a.x==3 and b.y==10 and b.x==3 and b.ab.x==3 and b2.x==5 and b2.y==20 and b2.ab.x==5
    
    # *******************
    # test using same field in B4 and in A3, but
    # B4 passes that value to A3 in super()
    # *****************
    class B4(A3):
        def __init__ (self, x=30, y=10, **kwargs):
            super().__init__ (x=x, **kwargs)
            self.ab = A3 (**kwargs)

    b3 = B4 ()
    assert b3.x==30 and b3.ab.x==3 and b3.y==10
    
    # *******************
    # test using same field in B4 and in A3, but
    # B4 passes that value to A3 in super(),
    # after modifying it
    # *****************
    class B5(A3):
        def __init__ (self, x=30, y=10, **kwargs):
            x = x*2
            super().__init__ (x=x, **kwargs)
            self.ab = A3 (**kwargs)

    b3 = B5 ()
    assert b3.x==60 and b3.ab.x==3 and b3.y==10
    
    b3 = B5 (x=6)
    assert b3.x==12 and b3.ab.x==3 and b3.y==10
    
    # *******************
    # test using same field in D and in A3, but
    # the field is modified in a parent B5
    # *****************
    class D(B5):
        def __init__ (self, x=40, z=100, **kwargs):
            super().__init__ (x=x, **kwargs)
            self.b = B5(**kwargs)
    
    with pytest.raises (RuntimeError):
        d = D ()
        
    d = D(ignore={'x'})
    assert d.x==80 and d.y==10 and d.z==100 and d.b.x==60 and d.b.y==10
    
    d = D (x=9, ignore={'x'})
    assert d.x==18 and d.y==10 and d.z==100 and d.b.x==60 and d.b.y==10
    
    assert not hasattr(d, 'ignore')
    
    # *******************
    # test having a field with same name
    # *******************
    class C3(A3):
        def __init__ (self, x=4, z=100, **kwargs):
            super().__init__ (**kwargs)
            self.a = A3(**kwargs)
            self.b = f3(**kwargs)
    
    with pytest.raises (RuntimeError):
        c = C3()
        
    # **************************************************
    # test overwrite
    # **************************************************
    class A4 ():
        def __init__ (self, x=3, **kwargs):
            replace_attr_and_store (base_class=A4, overwrite=True)
            
    class C5(A4):
        def __init__ (self, x=4, z=100, **kwargs):
            super().__init__ (x=x, **kwargs)
            self.a = A4(**kwargs)
            
    c = C5 ()
    assert c.x == 4 and c.a.x==3

In [29]:
tst.run (test_replace_attr_and_store, tag='dummy', debug=False)

running test_replace_attr_and_store


#### test replace_attr_and_store without recursiveness

In [26]:
# exports tests.utils.test_utils
#@pytest.mark.reference_fails
def test_replace_attr_and_store_no_rec ():
    # test without recursiveness
    from sklearn.utils import Bunch

    def f (x=3, y=4, z=5, **kwargs):
        estimator = Bunch ()
        replace_attr_and_store (recursive=False,
                                self=estimator, include_first=True)
        return estimator

    estimator = f (y=40, z=50)

    stored_args = estimator.pop('__stored_args__')
    assert estimator=={'x': 3, 'y': 40, 'z': 50}
    assert stored_args == estimator

    # test without kwargs
    def g (x=3, y=4, z=5):
        estimator = Bunch ()
        replace_attr_and_store (recursive=False,
                                self=estimator, include_first=True)
        return estimator

    estimator = g (y=40, z=50)

    stored_args = estimator.pop('__stored_args__')
    assert estimator=={'x': 3, 'y': 40, 'z': 50}
    assert stored_args == estimator

    # test without kwargs and with replace_generic_attr=False
    def h (x=3, y=4, z=5):
        estimator = Bunch ()
        replace_attr_and_store (recursive=False,
                                self=estimator, include_first=True, replace_generic_attr=False)
        return estimator

    estimator = h (y=40, z=50)

    stored_args = estimator.pop('__stored_args__')
    assert estimator=={'x': 3, 'y': 40, 'z': 50}
    assert stored_args == estimator

In [27]:
tst.run (test_replace_attr_and_store_no_rec, tag='dummy')

running test_replace_attr_and_store_no_rec
