In [None]:
from __future__ import annotations
import relax
from relax.data import load_data, TabularDataModule
from relax.evaluate import generate_cf_explanations, benchmark_cfs
from relax.methods import CounterNet
from relax.import_essentials import *

In [None]:
def gumbel_softmax(rng, logits, tau=1., axis=-1):
  """Sample from the Gumbel softmax / concrete distribution."""
  
  gumbel_noise = jax.random.gumbel(rng, logits.shape)
  return jax.nn.softmax((logits + gumbel_noise) / tau, axis=axis)

In [None]:
def softmax(
    x: jnp.DeviceArray,
    axis: int | Tuple[int, ...] = None,
    beta: float = 1.0
)  -> jnp.DeviceArray:

    x = x * beta
    x_max = jnp.max(x, axis, keepdims=True)
    unnormalized = jnp.exp(x - lax.stop_gradient(x_max))
    return unnormalized / jnp.sum(unnormalized, axis, keepdims=True)

In [None]:
@deprecated
def cat_normalize(
    cf: jnp.ndarray,  # Unnormalized counterfactual explanations `[n_samples, n_features]`
    cat_arrays: List[List[str]],  # A list of a list of each categorical feature name
    cat_idx: int,  # Index that starts categorical features
    beta: float = 1.0, 
    hard: bool = False,  # If `True`, return one-hot vectors; If `False`, return probability normalized via softmax
    soft_fun: Callable = None,
    hard_fun: Callable = None
) -> jnp.ndarray:
    """Ensure generated counterfactual explanations to respect one-hot encoding constraints."""
    cf_cont = cf[:, :cat_idx]
    normalized_cf = [cf_cont]

    for col in cat_arrays:
        cat_end_idx = cat_idx + len(col)
        _cf_cat = cf[:, cat_idx:cat_end_idx]
        if soft_fun is None: soft_fun = lambda x: jax.nn.softmax(x, axis=-1)
        if hard_fun is None: hard_fun = lambda x: jax.nn.one_hot(jnp.argmax(x, axis=-1), len(col))

        cf_cat = lax.cond(
            hard,
            true_fun=hard_fun,
            false_fun=soft_fun,
            operand=_cf_cat,
        )

        cat_idx = cat_end_idx
        normalized_cf.append(cf_cat)
    return jnp.concatenate(normalized_cf, axis=-1)


In [None]:
def cat_normalize_softmax(
    cf: jnp.ndarray,  # Unnormalized counterfactual explanations `[n_samples, n_features]`
    cat_arrays: List[List[str]],  # A list of a list of each categorical feature name
    cat_idx: int,  # Index that starts categorical features
    beta: float = 1.0, 
    hard: bool = False,  # If `True`, return one-hot vectors; If `False`, return probability normalized via softmax
) -> jnp.ndarray:
    """Ensure generated counterfactual explanations to respect one-hot encoding constraints."""
    cf_cont = cf[:, :cat_idx]
    normalized_cf = [cf_cont]

    for col in cat_arrays:
        cat_end_idx = cat_idx + len(col)
        _cf_cat = cf[:, cat_idx:cat_end_idx]
        soft_fun = lambda x: softmax(x, axis=-1, beta=beta)
        hard_fun = lambda x: jax.nn.one_hot(jnp.argmax(x, axis=-1), len(col))

        cf_cat = lax.cond(
            hard,
            true_fun=hard_fun,
            false_fun=soft_fun,
            operand=_cf_cat,
        )

        cat_idx = cat_end_idx
        normalized_cf.append(cf_cat)
    return jnp.concatenate(normalized_cf, axis=-1)


In [None]:
class TabularDataModuleSoftmax(TabularDataModule):
    def set_beta(self, beta: float = 1.0):
        self.beta = beta

    def apply_constraints(
        self, 
        x: jnp.DeviceArray, # input
        cf: jnp.DeviceArray, # Unnormalized counterfactuals
        hard: bool = False # Apply hard constraints or not
    ) -> jnp.DeviceArray:
        """Apply categorical normalization and immutability constraints"""
        cat_arrays = self.cat_encoder.categories_ \
            if self._configs.discret_cols else []
        soft_fun = lambda x: softmax(x, axis=-1, beta=self.beta)
        cf = cat_normalize_softmax(
            cf, cat_arrays=cat_arrays, 
            cat_idx=len(self._configs.continous_cols),
            beta=self.beta, hard=hard, 
        )
        return cf

In [None]:
def cat_normalize_gumbel(
    cf: jnp.ndarray,  # Unnormalized counterfactual explanations `[n_samples, n_features]`
    cat_arrays: List[List[str]],  # A list of a list of each categorical feature name
    cat_idx: int,  # Index that starts categorical features
    tau: float = 1.0, 
    hard: bool = False,  # If `True`, return one-hot vectors; If `False`, return probability normalized via softmax
) -> jnp.ndarray:
    """Ensure generated counterfactual explanations to respect one-hot encoding constraints."""
    cf_cont = cf[:, :cat_idx]
    normalized_cf = [cf_cont]
    keys = hk.PRNGSequence(42)

    for col in cat_arrays:
        cat_end_idx = cat_idx + len(col)
        _cf_cat = cf[:, cat_idx:cat_end_idx]
        rng = next(keys)
        soft_fun = lambda x: gumbel_softmax(rng ,x, axis=-1, tau=tau)
        hard_fun = lambda x: jax.nn.one_hot(jnp.argmax(x, axis=-1), len(col))

        cf_cat = lax.cond(
            hard,
            true_fun=hard_fun,
            false_fun=soft_fun,
            operand=_cf_cat,
        )

        cat_idx = cat_end_idx
        normalized_cf.append(cf_cat)
    return jnp.concatenate(normalized_cf, axis=-1)


In [None]:
class TabularDataModuleGumbel(TabularDataModule):
    def set_tau(self, tau: float = 1.0):
        self.tau = tau

    def apply_constraints(
        self, 
        x: jnp.DeviceArray, # input
        cf: jnp.DeviceArray, # Unnormalized counterfactuals
        hard: bool = False # Apply hard constraints or not
    ) -> jnp.DeviceArray:
        """Apply categorical normalization and immutability constraints"""
        cat_arrays = self.cat_encoder.categories_ \
            if self._configs.discret_cols else []
        
        key = random.PRNGKey(42)
        soft_fun = lambda x: gumbel_softmax(key, x, axis=-1, tau=self.tau)
        cf = cat_normalize_gumbel(
            cf, cat_arrays=cat_arrays, 
            cat_idx=len(self._configs.continous_cols),
            tau=self.tau, hard=hard
        )
        return cf

In [None]:
_, d_configs = load_data("adult", return_config=True)
dm = TabularDataModuleSoftmax(d_configs)
dm.set_beta(1.0)

cfnet = CounterNet()
cf_exp_softmax_1 = generate_cf_explanations(
    cfnet, dm, t_configs={'n_epochs': 100, 'batch_size': 128}
)


CounterNet contains parametric models. Starts training before generating explanations...


  "`monitor_metrics` is not specified in `CheckpointManager`. No checkpoints will be stored."
Epoch 99: 100%|██████████| 191/191 [00:01<00:00, 138.99batch/s, train/train_loss_1=0.0527, train/train_loss_2=0.000198, train/train_loss_3=0.112] 


In [None]:
_, d_configs = load_data("adult", return_config=True)
dm = TabularDataModuleSoftmax(d_configs)
dm.set_beta(10.)

cfnet = CounterNet()
cf_exp_softmax_10 = generate_cf_explanations(
    cfnet, dm, t_configs={'n_epochs': 100, 'batch_size': 128}
)


  "`monitor_metrics` is not specified in `CheckpointManager`. No checkpoints will be stored."


CounterNet contains parametric models. Starts training before generating explanations...


Epoch 99: 100%|██████████| 191/191 [00:01<00:00, 126.56batch/s, train/train_loss_1=0.0394, train/train_loss_2=0.000265, train/train_loss_3=0.114] 


In [None]:
_, d_configs = load_data("adult", return_config=True)
dm = TabularDataModuleGumbel(d_configs)
dm.set_tau(1.0)

cfnet = CounterNet()

cf_exp_gumbel_1 = generate_cf_explanations(
    cfnet, dm, t_configs={'n_epochs': 100, 'batch_size': 128}
)


  "`monitor_metrics` is not specified in `CheckpointManager`. No checkpoints will be stored."


CounterNet contains parametric models. Starts training before generating explanations...


Epoch 99: 100%|██████████| 191/191 [00:01<00:00, 136.88batch/s, train/train_loss_1=0.0789, train/train_loss_2=0.00317, train/train_loss_3=0.103]  


In [None]:
_, d_configs = load_data("adult", return_config=True)
dm = TabularDataModuleGumbel(d_configs)
dm.set_tau(.0001)

cfnet = CounterNet()

cf_exp_gumbel_001 = generate_cf_explanations(
    cfnet, dm, t_configs={'n_epochs': 100, 'batch_size': 128}
)


  "`monitor_metrics` is not specified in `CheckpointManager`. No checkpoints will be stored."


CounterNet contains parametric models. Starts training before generating explanations...


Epoch 99: 100%|██████████| 191/191 [00:01<00:00, 133.40batch/s, train/train_loss_1=0.0518, train/train_loss_2=0.00358, train/train_loss_3=0.113]  


In [None]:
benchmark_cfs([cf_exp_softmax_1, cf_exp_softmax_10, cf_exp_gumbel_1, cf_exp_gumbel_001])

Unnamed: 0,Unnamed: 1,acc,validity,proximity
adult,CounterNet,0.833067,0.998157,6.7232056
adult,CounterNet,0.830488,0.999631,6.6402793
adult,CounterNet,0.83147,0.998526,6.2795486
adult,CounterNet,0.830488,0.998894,6.853478
