# Growing Sphere

Note: This method only works with one-hot encoding.

In [None]:
#| default_exp methods.sphere

In [None]:
#| include: false
%load_ext autoreload
%autoreload 2
from ipynb_path import *
from nbdev import show_doc
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
#| export
from __future__ import annotations
from relax.import_essentials import *
from relax.methods.base import CFModule
from relax.utils import auto_reshaping, grad_update, validate_configs
from relax.data_utils import Feature, FeaturesList

In [None]:
#| hide
from relax.data_module import load_data
from relax.ml_model import load_ml_module

In [None]:
#| export
@partial(jit, static_argnums=(2, 5))
def hyper_sphere_coordindates(
    rng_key: jrand.PRNGKey, # Random number generator key
    x: Array, # Input instance with only continuous features. Shape: (1, n_features)
    n_samples: int, # Number of samples
    high: float, # Upper bound
    low: float, # Lower bound
    p_norm: int = 2 # Norm
):
    # Adapted from 
    # https://github.com/carla-recourse/CARLA/blob/24db00aa8616eb2faedea0d6edf6e307cee9d192/carla/recourse_methods/catalog/growing_spheres/library/gs_counterfactuals.py#L8
    key_1, key_2 = jrand.split(rng_key)
    delta = jrand.normal(key_1, shape=(n_samples, x.shape[-1]))
    dist = jrand.uniform(key_2, shape=(n_samples,)) * (high - low) + low
    norm_p = jnp.linalg.norm(delta, ord=p_norm, axis=1)
    d_norm = jnp.divide(dist, norm_p).reshape(-1, 1)  # rescale/normalize factor
    delta = jnp.multiply(delta, d_norm)
    candidates = x + delta

    return candidates

In [None]:
#| export
@partial(jit, static_argnums=(1, 2))
def sample_categorical(rng_key: jrand.PRNGKey, col_size: int, n_samples: int):
    rng_key, _ = jrand.split(rng_key)
    prob = jnp.ones(col_size) / col_size
    cat_sample = jrand.categorical(rng_key, prob, shape=(n_samples, 1))
    return cat_sample

In [None]:
#| export
def default_perturb_function(
    rng_key: jrand.PRNGKey,
    x: np.ndarray, # Shape: (1, k)
    n_samples: int,
    high: float,
    low: float,
    p_norm: int
):
    return hyper_sphere_coordindates(
        rng_key, x, n_samples, high, low, p_norm
    )

def perturb_function_with_features(
    rng_key: jrand.PRNGKey,
    x: np.ndarray, # Shape: (1, k)
    n_samples: int,
    high, 
    low,
    p_norm,
    feats: FeaturesList,
):
    def perturb_feature(rng_key, x, feat):
        if feat.is_categorical:
            return feat.transform(
                sample_categorical(
                    rng_key, feat.transformation.num_categories, n_samples
                ) #<== sampled labels
            ) #<== transformed labels
        else: 
            return hyper_sphere_coordindates(
                rng_key, x, n_samples, high, low, p_norm
            ) #<== transformed continuous features
        
    rng_keys = jrand.split(rng_key, len(feats))
    perturbed = jnp.repeat(x, n_samples, axis=0)
    for rng_key, (start, end), feat in zip(rng_keys, feats.feature_indices, feats):
        _perturbed_feat = perturb_feature(rng_keys[0], x[:, start: end], feat)
        perturbed = perturbed.at[:, start: end].set(_perturbed_feat)
    return perturbed


In [None]:
dm = load_data('adult')
x = dm.xs[:1]
assert x.ndim == 2
assert perturb_function_with_features(
    jrand.PRNGKey(0), x, 100, 1, 0, 2, feats=dm.features
).shape == (100, 29)
assert default_perturb_function(
    jrand.PRNGKey(0), x, 100, 1, 0, 2,
).shape == (100, 29)

In [None]:
def _growing_spheres(
    rng_key: jrand.PRNGKey, # Random number generator key
    y_target: Array, # Target label
    x: Array, # Input instance. Shape: (n_features)
    pred_fn: Callable, # Prediction function
    n_steps: int, # Number of steps
    n_samples: int,  # Number of samples to sample
    step_size: float, # Step size
    p_norm: int, # Norm
    perturb_fn: Callable, # Perturbation function
    apply_constraints_fn: Callable # Apply immutable constraints
): 
    @jit
    def dist_fn(x, cf):
        if p_norm == 1:
            return jnp.abs(cf - x).sum(axis=1)
        elif p_norm == 2:
            return jnp.linalg.norm(cf - x, ord=2, axis=1)
        else:
            raise ValueError("Only p_norm = 1 or 2 is supported")
    
    @loop_tqdm(n_steps)
    def step(i, state):
        candidate_cf, count, rng_key = state
        rng_key, subkey_1, subkey_2 = jrand.split(rng_key, num=3)
        low, high = step_size * count, step_size * (count + 1)
        # Sample around x
        candidates = perturb_fn(rng_key, x, n_samples, high=high, low=low, p_norm=p_norm)
        
        # Apply immutable constraints
        candidates = apply_constraints_fn(x, candidates, hard=True)
        assert candidates.shape[1] == x.shape[1], f"candidates.shape = {candidates.shape}, x.shape = {x.shape}"

        # Calculate distance
        dist = dist_fn(x, candidates)

        # Calculate counterfactual labels
        candidate_preds = pred_fn(candidates).argmax(axis=1)
        indices = jnp.where(candidate_preds == y_target, 1, 0).astype(bool)

        candidates = jnp.where(indices.reshape(-1, 1), 
                               candidates, jnp.ones_like(candidates) * jnp.inf)
        dist = jnp.where(indices.reshape(-1, 1), dist, jnp.ones_like(dist) * jnp.inf)

        closest_idx = dist.argmin()
        candidate_cf_update = candidates[closest_idx].reshape(1, -1)

        candidate_cf = jnp.where(
            dist[closest_idx].mean() < dist_fn(x, candidate_cf).mean(),
            candidate_cf_update, 
            candidate_cf
        )
        return candidate_cf, count + 1, rng_key
    
    y_target = y_target.reshape(1, -1).argmax(axis=1)
    candidate_cf = jnp.ones_like(x) * jnp.inf
    count = 0
    state = (candidate_cf, count, rng_key)
    candidate_cf, _, _ = lax.fori_loop(0, n_steps, step, state)
    # if `inf` is found, return the original input
    candidate_cf = jnp.where(jnp.isinf(candidate_cf), x, candidate_cf)
    return candidate_cf

In [None]:
#| export
class GSConfig(BaseParser):
    n_steps: int = 100
    n_samples: int = 1000
    step_size: float = 0.05
    p_norm: int = 2


In [None]:
#| export
class GrowingSphere(CFModule):
    def __init__(self, config: dict | GSConfig = None, *, name: str = None, perturb_fn = None):
        if config is None:
             config = GSConfig()
        config = validate_configs(config, GSConfig)
        name = "GrowingSphere" if name is None else name
        self.perturb_fn = perturb_fn
        super().__init__(config, name=name)

    def before_generate_cf(self, *args, **kwargs):
        if self.perturb_fn is None:
            if hasattr(self, 'data_module'):
                self.perturb_fn = ft.partial(
                    perturb_function_with_features, feats=self.data_module.features
                )
            else:
                self.perturb_fn = default_perturb_function
        
    @auto_reshaping('x')
    def generate_cf(
        self,
        x: Array,  # `x` shape: (k,), where `k` is the number of features
        pred_fn: Callable[[Array], Array],
        y_target: Array = None,
        rng_key: jnp.ndarray = None,
        **kwargs,
    ) -> jnp.DeviceArray:
        # TODO: Currently assumes binary classification.
        if y_target is None:
            y_target = 1 - pred_fn(x)
        else:
            y_target = jnp.array(y_target, copy=True)
        if rng_key is None:
            raise ValueError("`rng_key` must be provided, but got `None`.")
        
        return _growing_spheres(
            rng_key=rng_key,
            x=x,
            y_target=y_target,
            pred_fn=pred_fn,
            n_steps=self.config.n_steps,
            n_samples=self.config.n_samples,
            step_size=self.config.step_size,
            p_norm=self.config.p_norm,
            perturb_fn=self.perturb_fn,
            apply_constraints_fn=self.apply_constraints_fn
        )

In [None]:
dm = load_data('dummy')
model = load_ml_module('dummy')
xs_train, ys_train = dm['train']
xs_test, ys_test = dm['test']
x_shape = xs_test.shape

In [None]:
gs = GrowingSphere()
gs.set_data_module(dm)
gs.init_fns(apply_constraints_fn=dm.apply_constraints)
gs.before_generate_cf()

cf = gs.generate_cf(xs_test[0], pred_fn=model.pred_fn, rng_key=jax.random.PRNGKey(0))

In [None]:
#| eval: false
partial_gen = partial(gs.generate_cf, pred_fn=model.pred_fn)
cfs = jax.vmap(jit(partial_gen))(xs_test, rng_key=jrand.split(jrand.PRNGKey(0), len(xs_test)))

assert cfs.shape == (x_shape[0], x_shape[1])

print("Validity: ", keras.metrics.binary_accuracy(
    (1 - model.pred_fn(xs_test)).round(),
    model.pred_fn(cfs[:, :])
).mean())