# Utils

In [1]:
# default_exp utils

In [2]:
# hide
%load_ext autoreload
%autoreload 2
from ipynb_path import *

In [3]:
# export
from cfnet.import_essentials import *


Validate configs

In [4]:
# export
def validate_configs(
    configs: Union[Dict[str, Any], BaseParser], # configs
    config_cls: BaseParser # config class
):
    if not isinstance(configs, config_cls):
        configs = config_cls(**configs)
    return configs

In [5]:
from cfnet.training_module import PredictiveTrainingModuleConfigs, CounterNetTrainingModuleConfigs

configs = {
    'lr': 0.1,
    'lambda_1': 1.,
    'lambda_2': 1.,
    'lambda_3': 1.,
}
p_config = validate_configs(configs, PredictiveTrainingModuleConfigs)
cf_config = validate_configs(configs, CounterNetTrainingModuleConfigs)

assert isinstance(p_config, PredictiveTrainingModuleConfigs)
assert isinstance(cf_config, CounterNetTrainingModuleConfigs)

assert not isinstance(p_config, dict)
assert not isinstance(cf_config, dict)

In [6]:
# export
def load_json(f_name: str) -> Dict[str, Any]:
    with open(f_name) as f: return json.load(f)

Categorical normalization

In [7]:
# export
def cat_normalize(cf, cat_arrays, cat_idx: int, hard: bool=False):
    cf_cont = cf[:, :cat_idx]
    normalized_cf = [cf_cont]

    for col in cat_arrays:
        cat_end_idx = cat_idx + len(col)
        _cf_cat = cf[:, cat_idx:cat_end_idx]

        cf_cat = lax.cond(
            hard,
            true_fun=lambda x: jax.nn.one_hot(
                jnp.argmax(x, axis=-1), len(col)
            ),
            false_fun=lambda x: jax.nn.softmax(x, axis=-1),
            operand=_cf_cat
        )

        cat_idx = cat_end_idx
        normalized_cf.append(cf_cat)
    return jnp.concatenate(normalized_cf, axis=-1)

Loss Functions

In [8]:
# export
def binary_cross_entropy(y_pred: chex.Array, y: chex.Array) -> chex.Array:
    return -(y * jnp.log(y_pred) + (1 - y) * jnp.log(1 - y_pred))

Metrics

In [24]:
#export
def accuracy(y_true: jnp.ndarray, y_pred: jnp.ndarray) -> jnp.DeviceArray:
    y_true, y_pred = map(jnp.round, (y_true, y_pred))
    return jnp.mean(jnp.equal(y_true, y_pred))

def dist(x: jnp.ndarray, cf: jnp.ndarray, ord: int = 2) -> jnp.DeviceArray:
    dist = jnp.linalg.norm(x - cf, ord=ord, axis=-1, keepdims=True)
    return jnp.mean(vmap(jnp.sum)(dist))

def proximity(x: jnp.ndarray, cf: jnp.ndarray) -> jnp.DeviceArray:
    return dist(x, cf, ord=1)

In [26]:
m = jnp.array([
    [1., 2., 3., 1.],
    [1., -1., 4., 1.],
])
n = jnp.array([
    [0., -1., 3., 1.],
    [1., 2., 4., 1.],
])
assert proximity(m, n).item() == 3.5