# Utils

> Define utility funtions for `relax`.

In [5]:
#| default_exp utils

In [6]:
#| hide
%load_ext autoreload
%autoreload 2
from ipynb_path import *
import warnings

warnings.simplefilter(action='ignore', category=FutureWarning)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
#| export
from __future__ import annotations
from relax.import_essentials import *
import nbdev
from fastcore.basics import AttrDict
from nbdev.showdoc import BasicMarkdownRenderer
from inspect import isclass
from fastcore.test import *
from jax.core import InconclusiveDimensionOperation

ModuleNotFoundError: No module named 'matplotlib'

## Configurations

In [8]:
#| export
def validate_configs(
    configs: dict | BaseParser,  # A configuration of the model/dataset.
    config_cls: BaseParser,  # The desired configuration class.
) -> BaseParser:
    """return a valid configuration object."""

    assert isclass(config_cls), f"`config_cls` should be a class."
    assert issubclass(config_cls, BaseParser), \
        f"{config_cls} should be a subclass of `BaseParser`."
    
    if isinstance(configs, dict):
        configs = config_cls(**configs)
    if not isinstance(configs, config_cls):
        raise TypeError(
            f"configs should be either a `dict` or an instance of {config_cls.__name__}.")
    return configs

We define a configuration object (which inherent `BaseParser`) 
to manage training/model/data configurations.
`validate_configs` ensures to return the designated configuration object.

For example, we define a configuration object `LearningConfigs`:

In [9]:
class LearningConfigs(BaseParser):
    lr: float

NameError: name 'BaseParser' is not defined

A configuration can be `LearningConfigs`, or the raw data in dictionary.

In [10]:
configs_dict = dict(lr=0.01)

`validate_configs` will return a designated configuration object.

In [11]:
configs = validate_configs(configs_dict, LearningConfigs)
assert type(configs) == LearningConfigs
assert configs.lr == configs_dict['lr']

NameError: name 'LearningConfigs' is not defined

In [12]:
#| include: false
# TODO: add a test for this
# from relax.module import PredictiveTrainingModuleConfigs
# from relax.methods.counternet import CounterNetTrainingModuleConfigs

In [13]:
#| hide
# TODO: add a test for this
# configs = {
#     'lr': 0.1,
#     'sizes': [10, 5],
#     'lambda_1': 1.,
#     'lambda_2': 1.,
#     'lambda_3': 1.,
# }
# p_config = validate_configs(configs, PredictiveTrainingModuleConfigs)
# cf_config = validate_configs(configs, CounterNetTrainingModuleConfigs)

# assert isinstance(p_config, PredictiveTrainingModuleConfigs)
# assert isinstance(cf_config, CounterNetTrainingModuleConfigs)

# assert not isinstance(p_config, dict)
# assert not isinstance(cf_config, dict)

# p_config = validate_configs(p_config, PredictiveTrainingModuleConfigs)
# cf_config = validate_configs(cf_config, CounterNetTrainingModuleConfigs)

# assert isinstance(p_config, PredictiveTrainingModuleConfigs)
# assert isinstance(cf_config, CounterNetTrainingModuleConfigs)


## Serialization

In [14]:
#| export
def _is_array(x):
    return isinstance(x, np.ndarray) or isinstance(x, jnp.ndarray) or isinstance(x, list)

def save_pytree(pytree, saved_dir):
    """Save a pytree to a directory."""
    with open(os.path.join(saved_dir, "data.npy"), "wb") as f:
        for x in jax.tree_util.tree_leaves(pytree):
            np.save(f, x)

    tree_struct = jax.tree_util.tree_map(lambda t: _is_array(t), pytree)
    with open(os.path.join(saved_dir, "treedef.json"), "w") as f:
        json.dump(tree_struct, f)

The pytree will be stored under a directory with two files: 

* `{saved_dir}/data.npy`: This file stores the flattened leaves.
* `{saved_dir}/treedef.json`: This file stores the pytree structure and the information on whether the leave is an array or not. 

For example, a pytree

In [15]:
pytree = {
    'a': np.random.randn(5, 1),
    'b': 1,
    'c': {
        
        'd': True,
        'e': "Hello",
        'f': np.array(["a", "b", "c"])
    }
}

NameError: name 'np' is not defined

will be stored as

In [16]:
#| echo: false
data, pytreedef = jax.tree_util.tree_flatten(pytree)
pytreedef = jax.tree_util.tree_map(lambda x: _is_array(x), pytree)
print('data: ', data)
print('treedef: ', pytreedef)

NameError: name 'jax' is not defined

In [17]:
#| export
def load_pytree(saved_dir):
    """Load a pytree from a saved directory."""
    with open(os.path.join(saved_dir, "treedef.json"), "r") as f:
        tree_struct = json.load(f)

    leaves, treedef = jax.tree_util.tree_flatten(tree_struct)
    with open(os.path.join(saved_dir, "data.npy"), "rb") as f:
        flat_state = [
            np.load(f, allow_pickle=True) if is_arr else np.load(f, allow_pickle=True).item()
            for is_arr in leaves
        ]
    return jax.tree_util.tree_unflatten(treedef, flat_state)

In [18]:
# Store a dictionary to disk
pytree = {
    'a': np.random.randn(100, 1),
    'b': 1,
    'c': {
        'd': True,
        'e': "Hello",
        'f': np.array(["a", "b", "c"])
    }
}
os.makedirs('tmp', exist_ok=True)
save_pytree(pytree, 'tmp')
pytree_loaded = load_pytree('tmp')
assert np.allclose(pytree['a'], pytree_loaded['a'])
assert pytree['a'].dtype == pytree_loaded['a'].dtype
assert pytree['b'] == pytree_loaded['b']
assert pytree['c']['d'] == pytree_loaded['c']['d']
assert pytree['c']['e'] == pytree_loaded['c']['e']
assert np.all(pytree['c']['f'] == pytree_loaded['c']['f'])

NameError: name 'np' is not defined

In [19]:
# Store a list to disk
pytree = [
    np.random.randn(100, 1),
    {'a': 1, 'b': np.array([1, 2, 3])},
    1,
    [1, 2, 3],
    "good"
]
save_pytree(pytree, 'tmp')
pytree_loaded = load_pytree('tmp')

assert np.allclose(pytree[0], pytree_loaded[0])
assert pytree[0].dtype == pytree_loaded[0].dtype
assert pytree[1]['a'] == pytree_loaded[1]['a']
assert np.all(pytree[1]['b'] == pytree_loaded[1]['b'])
assert pytree[2] == pytree_loaded[2]
assert pytree[3] == pytree_loaded[3]
assert isinstance(pytree_loaded[3], list)
assert pytree[4] == pytree_loaded[4]

NameError: name 'np' is not defined

In [20]:
#| hide
shutil.rmtree('tmp')

NameError: name 'shutil' is not defined

## Vectorization Utils

In [21]:
#| exporti
def _reshape_x(x: Array):
    x_size = x.shape
    if len(x_size) > 1 and x_size[0] != 1:
        raise ValueError(
            f"""Invalid Input Shape: Require `x.shape` = (1, k) or (k, ),
but got `x.shape` = {x.shape}. This method expects a single input instance."""
        )
    if len(x_size) == 1:
        x = x.reshape(1, -1)
    return x, x_size

In [22]:
#| export
def auto_reshaping(
    reshape_argname: str, # The name of the argument to be reshaped.
    reshape_output: bool = True, # Whether to reshape the output. Useful to set `False` when returning multiple cfs.
):
    """
    Decorator to automatically reshape function's input into (1, k), 
    and out to input's shape.
    """
    def decorator(func):
        def wrapper(*args, **kwargs):
            kwargs = inspect.getcallargs(func, *args, **kwargs)
            if reshape_argname in kwargs:
                reshaped_x, x_shape = _reshape_x(kwargs[reshape_argname])
                kwargs[reshape_argname] = reshaped_x
            else:
                raise ValueError(
                    f"Invalid argument name: `{reshape_argname}` is not a valid argument name.")
            # Call the function.
            cf = func(**kwargs)
            if not isinstance(cf, Array): 
                raise ValueError(
                    f"Invalid return type: must be a `jax.Array`, but got `{type(cf).__name__}`.")
            if reshape_output:
                try: 
                    cf = cf.reshape(x_shape)
                except (InconclusiveDimensionOperation, TypeError) as e:
                    raise ValueError(
                        f"Invalid return shape: Require `cf.shape` = {cf.shape} "
                        f"is not compatible with `x.shape` = {x_shape}.")
            return cf

        return wrapper
    return decorator

This decorator ensures that the specified input argument and output 
of a function are in the same shape. 
This is particularly useful when using `jax.vamp`.

In [23]:
@auto_reshaping('x')
def f_vmap(x): return x * jnp.ones((10,))
assert vmap(f_vmap)(jnp.ones((10, 10))).shape == (10, 10)

@auto_reshaping('x', reshape_output=False)
def f_vmap(x): return x * jnp.ones((10,))
assert vmap(f_vmap)(jnp.ones((10, 10))).shape == (10, 1, 10)

NameError: name 'vmap' is not defined

In [24]:
#| hide
@auto_reshaping('x')
def f_1(x):
    assert x.shape[0] == 1
    return x

assert f_1(jnp.ones(10)).shape == (10,)
assert f_1(jnp.ones((1, 10))).shape == (1, 10)

@auto_reshaping('x')
@jit
def f_2(y, x):
    assert x.shape[0] == 1
    return x

assert f_2(None, jnp.ones(10)).shape == (10,)
assert f_2(None, jnp.ones((1, 10))).shape == (1, 10)

@auto_reshaping('x')
def f_3(x, y): return x
test_fail(f_3, args=(jnp.ones((10, 10)), None), 
          contains='Invalid Input Shape: Require `x.shape` = (1, k)')

@auto_reshaping('x')
def f_4(x, y): return jnp.arange(3)
test_fail(f_4, args=(jnp.ones((10, )), None), 
          contains='Invalid return shape: Require `cf.shape`')

@auto_reshaping('x')
def f_5(x, y): return jnp.array([1, 2, 3]), jnp.array([1, 2, 3])
test_fail(f_5, args=(jnp.ones((10, )), None), 
          contains='Invalid return type: must be a `jax.Array`, but got `tuple`.')


NameError: name 'jnp' is not defined

## Gradient Utils

In [25]:
#| export
def grad_update(
    grads, # A pytree of gradients.
    params, # A pytree of parameters.
    opt_state: optax.OptState,
    opt: optax.GradientTransformation,
): # Return (upt_params, upt_opt_state)
    updates, opt_state = opt.update(grads, opt_state, params)
    upt_params = optax.apply_updates(params, updates)
    return upt_params, opt_state

## Helper functions

In [26]:
#| export
def load_json(f_name: str) -> Dict[str, Any]:  # file name
    with open(f_name) as f:
        return json.load(f)


## Config

In [27]:
#| exporti
@dataclass
class Config:
    rng_reserve_size: int
    global_seed: int

    @classmethod
    def default(cls) -> Config:
        return cls(rng_reserve_size=1, global_seed=42)

main_config = Config.default()

NameError: name 'dataclass' is not defined

In [28]:
#| export
def get_config() -> Config: 
    return main_config

In [29]:
# | export
def set_config(
        *,
        rng_reserve_size: int = None,
        global_seed: int = None,
        **kwargs
) -> None:
    """
    set_config() sets the global configurations.
    :param rng_reserve_size: set the number of random number generators to reserve.
    :param global_seed: set the global seed for random number generators.
    :param kwargs: A dictionary of keyword arguments, where the keys are the config keys to set and the values are the new values for those keys.
    """

    def arg_check(arg, arg_value, arg_min):
        """
         arg_check() checks the validity of the argument and returns the argument value.
        :param arg: The name of the argument.
        :param arg_value: The value of the argument.
        :param arg_min: The minimum value of the argument.
        :return: The argument value.
        """

        if arg_value is not None:
            if not isinstance(arg_value, int):
                raise TypeError(f"`{arg}` must be an integer, but got {type(arg_value).__name__}.")
            if arg_value < arg_min:
                raise ValueError(f"`{arg}` must be non-negative, but got {arg_value}.")
            return arg_value

    if arg_check('rng_reserve_size', rng_reserve_size, 0) is not None:
        main_config.rng_reserve_size = rng_reserve_size

    if arg_check('global_seed', global_seed, 0) is not None:
        main_config.global_seed = global_seed

In [30]:
# Generic Test cases
set_config()
assert get_config().rng_reserve_size == 1 and get_config().global_seed == 42
set_config(rng_reserve_size=100)
assert get_config().rng_reserve_size == 100
set_config(global_seed=1234)
assert get_config().global_seed == 1234
set_config(rng_reserve_size=2, global_seed=234)
assert get_config().rng_reserve_size == 2 and get_config().global_seed == 234
set_config()
assert get_config().rng_reserve_size == 2 and get_config().global_seed == 234
set_config(lol = 80)
assert get_config().rng_reserve_size == 2 and get_config().global_seed == 234
#Tests for invalid inputs
test_fail(set_config, kwargs={'rng_reserve_size': -1}, contains='must be non-negative')
test_fail(set_config, kwargs={'rng_reserve_size': 22.7}, contains='must be an integer')
test_fail(set_config, kwargs={'global_seed': -4}, contains='must be non-negative')
test_fail(set_config, kwargs={'global_seed': 3.14}, contains='must be an integer')

NameError: name 'main_config' is not defined