# Utils

> Define utility funtions for `relax`.

In [None]:
#| default_exp legacy.utils

In [None]:
#| hide
%load_ext autoreload
%autoreload 2
from ipynb_path import *
import warnings

warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
#| export
from __future__ import annotations
from relax.legacy.import_essentials import *
import nbdev
from fastcore.basics import AttrDict
from nbdev.showdoc import BasicMarkdownRenderer
from inspect import isclass
from fastcore.test import *
from jax.core import InconclusiveDimensionOperation

Using JAX backend.


## Configurations

In [None]:
#| export
def validate_configs(
    configs: dict | BaseParser,  # A configuration of the model/dataset.
    config_cls: BaseParser,  # The desired configuration class.
) -> BaseParser:
    """return a valid configuration object."""

    assert isclass(config_cls), f"`config_cls` should be a class."
    assert issubclass(config_cls, BaseParser), \
        f"{config_cls} should be a subclass of `BaseParser`."
    
    if isinstance(configs, dict):
        configs = config_cls(**configs)
    if not isinstance(configs, config_cls):
        raise TypeError(
            f"configs should be either a `dict` or an instance of {config_cls.__name__}.")
    return configs

We define a configuration object (which inherent `BaseParser`) 
to manage training/model/data configurations.
`validate_configs` ensures to return the designated configuration object.

For example, we define a configuration object `LearningConfigs`:

In [None]:
class LearningConfigs(BaseParser):
    lr: float

A configuration can be `LearningConfigs`, or the raw data in dictionary.

In [None]:
configs_dict = dict(lr=0.01)

`validate_configs` will return a designated configuration object.

In [None]:
configs = validate_configs(configs_dict, LearningConfigs)
assert type(configs) == LearningConfigs
assert configs.lr == configs_dict['lr']

In [None]:
#| include: false
from relax.legacy.module import PredictiveTrainingModuleConfigs
# from relax.methods.counternet import CounterNetTrainingModuleConfigs

In [None]:
#| hide
configs = {
    'lr': 0.1,
    'sizes': [10, 5],
    'lambda_1': 1.,
    'lambda_2': 1.,
    'lambda_3': 1.,
}
p_config = validate_configs(configs, PredictiveTrainingModuleConfigs)
# cf_config = validate_configs(configs, CounterNetTrainingModuleConfigs)

assert isinstance(p_config, PredictiveTrainingModuleConfigs)
# assert isinstance(cf_config, CounterNetTrainingModuleConfigs)

assert not isinstance(p_config, dict)
# assert not isinstance(cf_config, dict)

p_config = validate_configs(p_config, PredictiveTrainingModuleConfigs)
# cf_config = validate_configs(cf_config, CounterNetTrainingModuleConfigs)

assert isinstance(p_config, PredictiveTrainingModuleConfigs)
# assert isinstance(cf_config, CounterNetTrainingModuleConfigs)


## Categorical normalization

In [None]:
#| export
def cat_normalize(
    cf: jnp.ndarray,  # Unnormalized counterfactual explanations `[n_samples, n_features]`
    cat_arrays: List[List[str]],  # A list of a list of each categorical feature name
    cat_idx: int,  # Index that starts categorical features
    hard: bool = False,  # If `True`, return one-hot vectors; If `False`, return probability normalized via softmax
) -> jnp.ndarray:
    """Ensure generated counterfactual explanations to respect one-hot encoding constraints."""
    cf_cont = cf[:, :cat_idx]
    normalized_cf = [cf_cont]

    for col in cat_arrays:
        cat_end_idx = cat_idx + len(col)
        _cf_cat = cf[:, cat_idx:cat_end_idx]

        cf_cat = lax.cond(
            hard,
            true_fun=lambda x: jax.nn.one_hot(jnp.argmax(x, axis=-1), len(col)),
            false_fun=lambda x: jax.nn.softmax(x, axis=-1),
            operand=_cf_cat,
        )

        cat_idx = cat_end_idx
        normalized_cf.append(cf_cat)
    return jnp.concatenate(normalized_cf, axis=-1)


A tabular data point is encoded as 
$$x = [\underbrace{x_{0}, x_{1}, ..., x_{m}}_{\text{cont features}}, 
\underbrace{x_{m+1}^{c=1},..., x_{m+p}^{c=1}}_{\text{cat feature (1)}}, ..., 
\underbrace{x_{k-q}^{c=i},..., x_{k}^{^{c=i}}}_{\text{cat feature (i)}}]$$

`cat_normalize` ensures the generated `cf` that satisfy the categorical constraints, 
i.e., $\sum_j x^{c=i}_j=1, x^{c=i}_j > 0, \forall c=[1, ..., i]$.

`cat_idx` is the index of the first categorical feature. 
In the above example, `cat_idx` is `m+1`.

For example, let's define a valid input data point:

In [None]:
x = np.array([
    [1., .9, 'dog', 'gray'],
    [.3, .3, 'cat', 'gray'],
    [.7, .1, 'fish', 'red'],
    [1., .6, 'dog', 'gray'],
    [.1, .2, 'fish', 'yellow']
])

We encode the categorical features via the `OneHotEncoder` in sklearn.

In [None]:
from sklearn.preprocessing import OneHotEncoder

In [None]:
cat_idx = 2
ohe = OneHotEncoder(sparse_output=False)
x_cat = ohe.fit_transform(x[:, cat_idx:])
x_cont = x[:, :cat_idx].astype(float)
x_transformed = np.concatenate(
    (x_cont, x_cat), axis=1
)

If `hard=True`, the categorical features are in one-hot format.

In [None]:
cfs = np.random.randn(*x_transformed.shape)
cfs = cat_normalize(cfs, ohe.categories_, 
    cat_idx=cat_idx, hard=True)
cfs[:1]

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Array([[-0.47835127, -0.32345298,  1.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  1.        ]], dtype=float32)

If `hard=False`, the categorical features are normalized via softmax function.

In [None]:
cfs = np.random.randn(*x_transformed.shape)
cfs = cat_normalize(cfs, ohe.categories_, 
    cat_idx=cat_idx, hard=False)
n_cat_feats = len(ohe.categories_)

assert (cfs[:, cat_idx:].sum(axis=1) - n_cat_feats * jnp.ones(len(cfs))).sum() < 1e-6

## Training Utils

In [None]:
#| export
def make_model(
    m_configs: Dict[str, Any], model: hk.Module  # model configs
) -> hk.Transformed:
    # example:
    # net = make_model(PredictiveModel)
    # params = net.init(...)
    def model_fn(x, is_training: bool = True):
        return model(m_configs)(x, is_training)

    return hk.transform(model_fn)


In [None]:
#| export
def make_hk_module(
    module: hk.Module, # haiku module 
    *args, # haiku module arguments
    **kargs, # haiku module arguments
) -> hk.Transformed:

    def model_fn(x, is_training: bool = True):
        return module(*args, **kargs)(x, is_training)
    
    return hk.transform(model_fn)


In [None]:
#| export
def init_net_opt(
    net: hk.Transformed,
    opt: optax.GradientTransformation,
    X: jax.Array,
    key: random.PRNGKey,
) -> Tuple[hk.Params, optax.OptState]:
    X = device_put(X)
    params = net.init(key, X, is_training=True)
    opt_state = opt.init(params)
    return params, opt_state


In [None]:
#| export 
def grad_update(
    grads: Dict[str, jnp.ndarray],
    params: hk.Params,
    opt_state: optax.OptState,
    opt: optax.GradientTransformation,
) -> Tuple[hk.Params, optax.OptState]:
    updates, opt_state = opt.update(grads, opt_state, params)
    upt_params = optax.apply_updates(params, updates)
    return upt_params, opt_state


In [None]:
#| export
def check_cat_info(method):
    def inner(cf_module, *args, **kwargs):
        warning_msg = f"""This CFExplanationModule might not be updated with categorical information.
You should try `{cf_module.name}.update_cat_info(dm)` before generating cfs.
        """
        if cf_module.cat_idx == 0 and cf_module.cat_arrays == []:
            warnings.warn(warning_msg, RuntimeWarning)
        return method(cf_module, *args, **kwargs)

    return inner


## Helper functions

In [None]:
#| export
def load_json(f_name: str) -> Dict[str, Any]:  # file name
    with open(f_name) as f:
        return json.load(f)


## Loss Functions

In [None]:
#| export
def binary_cross_entropy(
    preds: jax.Array, # The predicted values
    labels: jax.Array # The ground-truth labels
) -> jax.Array: # Loss value
    """Per-sample binary cross-entropy loss function."""

    # Clip the predictions to avoid NaNs in the log
    preds = jnp.clip(preds, 1e-7, 1 - 1e-7)

    # Compute the binary cross-entropy
    loss = -labels * jnp.log(preds) - (1 - labels) * jnp.log(1 - preds)

    return loss

In [None]:
#| export
def sigmoid(x):
    # https://stackoverflow.com/a/68293931
    return 0.5 * (jnp.tanh(x / 2) + 1)

## Metrics

In [None]:
#| export
def accuracy(y_true: jnp.ndarray, y_pred: jnp.ndarray) -> jax.Array:
    y_true, y_pred = map(jnp.round, (y_true, y_pred))
    return jnp.mean(jnp.equal(y_true, y_pred))


def dist(x: jnp.ndarray, cf: jnp.ndarray, ord: int = 2) -> jax.Array:
    dist = jnp.linalg.norm(x - cf, ord=ord, axis=-1, keepdims=True)
    return jnp.mean(vmap(jnp.sum)(dist))


def proximity(x: jnp.ndarray, cf: jnp.ndarray) -> jax.Array:
    return dist(x, cf, ord=1)

In [None]:
#| include: false
m = jnp.array([
    [1., 2., 3., 1.],
    [1., -1., 4., 1.],
])
n = jnp.array([
    [0., -1., 3., 1.],
    [1., 2., 4., 1.],
])
assert proximity(m, n).item() == 3.5

## Config

In [None]:
#| exporti
@dataclass
class Config:
    rng_reserve_size: int
    global_seed: int

    @classmethod
    def default(cls) -> Config:
        return cls(rng_reserve_size=1, global_seed=42)

main_config = Config.default()

In [None]:
#| export
def get_config() -> Config: 
    return main_config