# L2C

https://arxiv.org/abs/2209.13446

In [None]:
#| default_exp methods.l2c

In [None]:
#| include: false
%load_ext autoreload
%autoreload 2
from ipynb_path import *
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
#| export
from __future__ import annotations
from relax.import_essentials import *
from relax.methods.base import ParametricCFModule
from relax.base import BaseConfig
from relax.utils import *
from relax.data_utils import Feature, FeaturesList
from relax.ml_model import MLP, MLPBlock
from relax.data_module import DataModule
from keras.random import SeedGenerator

In [None]:
#| hide
import torch
import relax

In [None]:
start_end = jnp.array([[0, 1], [1, 2], [2, 3], [3, 5]])
xs = jrand.normal(jrand.PRNGKey(0), (4, 5),)
cfs = jrand.normal(jrand.PRNGKey(1), (4, 5),)
prob = jrand.uniform(jrand.PRNGKey(2), (4, 4),)

# split xs into 4 parts according to start_end
xs_split = jnp.split(xs, start_end[:-1, 1], axis=1)
cfs_split = jnp.split(cfs, start_end[:-1, 1], axis=1)
prob_split = jnp.split(prob, start_end.shape[0], axis=1)

def perturb(x, cf, prob):
    return x * (1 - prob) + cf * prob

perturbed = jax.tree_util.tree_map(
    perturb, xs_split, cfs_split, prob_split
)

## L2C Model

In [None]:
#| export
def gumbel_softmax(
    key: jrand.PRNGKey, # Random key
    logits: Array, # Logits for each class. Shape (batch_size, num_classes)
    tau: float, # Temperature for the Gumbel softmax
):
    """The Gumbel softmax function."""

    key, _ = jrand.split(key)
    gumbel_noise = jrand.gumbel(key, shape=logits.shape)
    y = logits + gumbel_noise
    return jax.nn.softmax(y / tau, axis=-1)

In [None]:
#| export
def sample_categorical(
    key: jrand.PRNGKey, # Random key
    logits: Array, # Logits for each class. Shape (batch_size, num_classes)
    tau: float, # Temperature for the Gumbel softmax
    training: bool = True, # Apply gumbel softmax if training
):
    """Sample from a categorical distribution."""

    def sample_cat(key, logits):
        cat = logits.argmax(axis=-1)
        return jax.nn.one_hot(cat, logits.shape[-1])

    return lax.cond(
        training,
        lambda _: gumbel_softmax(key, logits, tau=tau),
        lambda _: sample_cat(key, logits),
        None,
    )

In [None]:
logits = jnp.array([[2.0, 1.0, 0.1], [1.0, 2.0, 3.0]])
key = jrand.PRNGKey(0)
output = sample_categorical(key, logits, tau=0.5, training=True)
assert output.shape == logits.shape
assert jnp.allclose(output.sum(axis=-1), 1.0)
# low temperature -> one-hot
output = sample_categorical(key, logits, tau=0.001, training=True)
assert jnp.array_equal(
    output.argmax(axis=-1), logits.argmax(axis=-1)
)
# high temperature -> uniform
output = sample_categorical(key, logits, tau=100, training=True)
assert jnp.max(output) - jnp.min(output) < 0.5

output = sample_categorical(key, logits, tau=0.5, training=False)
assert output.shape == logits.shape
assert jnp.array_equal(
    output.argmax(axis=-1), logits.argmax(axis=-1)
)

In [None]:
#| export
def sample_bernouli(
    key: jrand.PRNGKey, # Random key
    prob: Array, # Logits for each class. Shape (batch_size, 1)
    tau: float, # Temperature for the Gumbel softmax
    training: bool = True, # Apply gumbel softmax if training
) -> Array:
    """"Sample from a bernouli distribution."""

    def sample_ber(key, prob):
        return jrand.bernoulli(key, p=prob).astype(prob.dtype)
    
    def gumbel_ber(key, prob, tau):
        key_1, key_2 = jrand.split(key)
        gumbel_1 = jrand.gumbel(key_1, shape=prob.shape)
        gumbel_2 = jrand.gumbel(key_2, shape=prob.shape)
        no_logits = (prob * jnp.exp(gumbel_1)) / tau
        de_logits = no_logits + ((1. - prob) * jnp.exp(gumbel_2)) / tau
        return no_logits / de_logits
    
    return lax.cond(
        training,
        lambda _: gumbel_ber(key, prob, tau),
        lambda _: sample_ber(key, prob),
        None,
    )

In [None]:
#| export
def split_fn(feature_indices: list[tuple[int, int]]):
    feature_indices = tuple([x[1] for x in feature_indices[:-1]])

    @ft.partial(jit, static_argnums=1)
    def split_xs(xs, feature_indices):
        return jnp.split(xs, list(feature_indices), axis=-1)
    
    @ft.partial(jit, static_argnums=1)
    def split_prob(prob, feature_indices):
        return jnp.split(prob, len(feature_indices) + 1, axis=-1)
    
    return ft.partial(split_xs, feature_indices=feature_indices), ft.partial(split_prob, feature_indices=feature_indices)

In [None]:
start_end = [(0, 1), (1, 2), (2, 3), (3, 5)]
split_xs, split_prob = split_fn(start_end)
assert len(split_xs(xs)) == len(start_end)
assert len(split_prob(prob)) == len(start_end) 

In [None]:
#| export
class L2CModel(keras.Model):

    def __init__(
        self,
        generator_layers: list[int],
        selector_layers: list[int],
        feature_indices: list[tuple[int, int]] = None,
        immutable_mask: Array = None,
        pred_fn: Callable = None,
        alpha: float = 1e-4, # Sparsity regularization
        tau: float = 0.7,
        seed: int = None,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.generator_layers = generator_layers
        self.selector_layers = selector_layers
        self.pred_fn = pred_fn
        self.tau = tau
        self.alpha = alpha
        seed = seed or get_config().global_seed
        self.seed_generator = SeedGenerator(seed)
        self.set_features_info(feature_indices)
        self.set_immutable_mask(immutable_mask)
        
        # split functions
        self.split_xs_fn, self.split_prob_fn = split_fn(self.feature_indices)

    @property
    def start_end_indices(self):
        feature_indices = jnp.array(list(map(lambda x: list(x), self.feature_indices)))
        return feature_indices[:-1, 1]

    def set_features_info(self, feature_indices: list[tuple[int, int]]):
        self.feature_indices = feature_indices
        # self.feature_indices = jnp.array(
        #     list(map(lambda x: list(x), feature_indices)))
        # assert self.feature_indices.shape == (len(feature_indices), 2)
        # TODO: check if the feature indices are valid

    def set_immutable_mask(self, immutable_mask: Array):
        self.immutable_mask = immutable_mask

    def set_pred_fn(self, pred_fn: Callable):
        self.pred_fn = pred_fn

    def build(self, input_shape):
        n_feats = len(self.feature_indices)
        self.generator = MLP(
            sizes=self.generator_layers,
            output_size=input_shape[-1],
            dropout_rate=0.0,
            last_activation="linear",
        )
        self.selector = MLP(
            sizes=self.selector_layers,
            output_size=n_feats,
            dropout_rate=0.0,
            last_activation="sigmoid",
        )

    def compute_l2c_loss(self, inputs, cfs, probs):
        # inputs = self.split_xs_fn(inputs)
        # cfs = self.split_xs_fn(cfs)
        y_target = self.pred_fn(inputs).argmin(axis=-1)
        y_pred = self.pred_fn(cfs)
        validity_loss = keras.losses.sparse_categorical_crossentropy(
            y_target, y_pred
        ).mean()
        sparsity = jnp.linalg.norm(probs, ord=1) * self.alpha
        return validity_loss, sparsity
        
    def forward(self, rng_key, inputs, training=False):
        
        def perturb(x, cf, prob, immutable):
            cf = sample_categorical(
                key_2, cf, tau=tau, training=training
            )
            prob = prob * (1 - immutable)
            return x * (1 - prob) + cf * prob

        key_1, key_2 = jrand.split(rng_key)
        select_probs = self.selector(inputs, training=training)
        probs = sample_bernouli(
            key_1, select_probs, 
            tau=self.tau, training=training
        )
        cfs_logits = self.generator(inputs, training=training)
                
        # xs = jnp.split(inputs, start_end, axis=-1)
        # cfs = jnp.split(cfs_logits, start_end, axis=-1)
        # probs = jnp.split(prob, len(self.feature_indices), axis=-1)
        xs = self.split_xs_fn(inputs)
        cfs = self.split_xs_fn(cfs_logits)
        probs = self.split_prob_fn(probs)
        immutables = self.split_prob_fn(self.immutable_mask)
        tau = self.tau
        
        cfs = jax.tree_util.tree_map(
            perturb, xs, cfs, probs, immutables #self.tau, training
        )
        cfs = jnp.concatenate(cfs, axis=-1)
        probs = jnp.concatenate(probs, axis=-1)
        return cfs, probs
    
    def call(self, inputs, training=False):
        rng_key = self.seed_generator.next()
        cfs, probs = self.forward(rng_key, inputs, training=training)
        # loss = self.compute_l2c_loss(inputs, cfs, probs)

        validity_loss, sparsity = self.compute_l2c_loss(inputs, cfs, probs)
        self.add_loss(validity_loss)
        self.add_loss(sparsity)
        return cfs   

## Discretizer

In [None]:
#| export
def qcut(
    x: Array, # Input array
    q: int, # Number of quantiles
    axis: int = 0, # Axis to quantile
) -> tuple[Array, Array]: # (digitized array, quantiles)
    """Quantile binning."""
    
    # Handle edge cases: empty array or single element
    if x.size <= 1:
        return jnp.zeros_like(x), jnp.array([])
    quantiles = jnp.quantile(x, jnp.linspace(0, 1, q + 1)[1:-1], axis=axis)
    digitized = jnp.digitize(x, quantiles)
    return digitized, quantiles

In [None]:
digitized, quantiles = qcut(jnp.arange(10), 4)
assert digitized.shape == (10,)
assert quantiles.shape == (3,)
assert jnp.allclose(
    digitized, jnp.array([0,0,0,1,1,2,2,3,3,3])
)

quantiles_true = jnp.array([0, 2.25, 4.5, 6.75, 9])
assert jnp.allclose(
    quantiles, quantiles_true[1:-1]
)
x_empty = jnp.array([])
q = 2
digitized_empty, quantiles_empty = qcut(x_empty, q)
assert digitized_empty.size == 0 and quantiles_empty.size == 0
# Test with single element array
x_single = jnp.array([1])
digitized_single, quantiles_single = qcut(x_single, q)
assert digitized_single.size == 1 and quantiles_single.size == 0

# Test with large q value
xs = jnp.array([1, 2, 3, 4, 5, 6])
q_large = 10
_, quantiles_large = qcut(xs, q_large)
assert len(quantiles_large) == q_large - 1

In [None]:
#| export
def qcut_inverse(
    digitized: Array, # Digitized One-Hot Encoding Array
    quantiles: Array, # Quantiles
) -> Array:
    """Inverse of qcut."""
    
    result = digitized @ quantiles
    if result.ndim == 1:
        result = result[..., None]
    return result

In [None]:
digitized, quantiles = qcut(jnp.arange(10), 4)
ohe_digitized = jax.nn.one_hot(digitized, 4)
# continuous feats
quantiles_inv = qcut_inverse(ohe_digitized, jnp.arange(4))
assert quantiles_inv.shape == (10, 1)
# discrete feats
quantiles_inv = qcut_inverse(ohe_digitized, jnp.identity(4))
assert jnp.array_equal(quantiles_inv, ohe_digitized)

In [None]:
#| export
def cut_quantiles(
    quantiles: Array, # Quantiles
    xs: Array, # Input array
):
    quantiles = jnp.concatenate([
        xs.min(axis=0, keepdims=True), 
        quantiles, 
        xs.max(axis=0, keepdims=True)
    ])
    quantiles = (quantiles[1:] + quantiles[:-1]) / 2
    return quantiles

In [None]:
#| export
def discretize_xs(
    xs: Array, # Input array
    is_categorical_and_indices: list[tuple[bool, tuple[int, int]]], # Features list
    q: int = 4, # Number of quantiles
) -> tuple[list[Array], list[Array], list[Array], list[list[int, int]]]: # (discretized array, indices_and_quantiles_and_mid)
    """Discretize continuous features."""
    
    discretized_xs = []
    mid_quantiles = []
    quantiles_feats = []
    feature_indices = []
    discretized_start, discretized_end = 0, 0

    for is_categorical, (start, end) in is_categorical_and_indices:
        if is_categorical:
            discretized, quantiles, mid = xs[:, start:end], None, jnp.identity(end - start)
            discretized_end += end - start
        else:
            discretized, quantiles = qcut(xs[:, start:end].reshape(-1), q=q)
            mid = cut_quantiles(quantiles, xs[:, start])
            discretized = jax.nn.one_hot(discretized, q)
            discretized_end += discretized.shape[-1]
        
        discretized_xs.append(discretized)
        quantiles_feats.append(quantiles)
        mid_quantiles.append(mid)
        feature_indices.append([discretized_start, discretized_end])
        discretized_start = discretized_end
    # discretized_xs = jnp.concatenate(discretized_xs, axis=-1)
    return discretized_xs, quantiles_feats, mid_quantiles, feature_indices

In [None]:
dm = relax.load_data("dummy")
xs, ys = dm['train']
is_categorical_and_indices = [
    (feat.is_categorical, indices) for feat, indices in zip(dm.features, dm.features.feature_indices)
]
discretized_xs, quantiles_feats, mid_quantiles, feature_indices = discretize_xs(xs, is_categorical_and_indices)
assert len(discretized_xs) == len(is_categorical_and_indices)
assert all(discretized_xs[i].shape[1] == 4 for i in range(len(discretized_xs)))

assert len(quantiles_feats) == len(is_categorical_and_indices)
assert all(len(quantiles_feats[i]) == 3 for i in range(len(quantiles_feats)))
assert len(mid_quantiles) == len(is_categorical_and_indices)
assert all(len(mid_quantiles[i]) == 4 for i in range(len(mid_quantiles)))

In [None]:
#| export
class Discretizer:
    """Discretize continuous features."""
    
    def __init__(
        self, 
        is_cat_and_indices: list[tuple[bool, tuple[int, int]]], # Features list
        q: int = 4 # Number of quantiles
    ):
        self.is_cat_and_indices = is_cat_and_indices
        self.q = q

    @property
    def transform_indices(self):
        return [x[1][1] for x in self.is_cat_and_indices[:-1]]
    
    @property
    def inverse_transform_indices(self):
        return [x[1] for x in self.indices[:-1]]
        
    def fit(self, xs: Array):
        _, self.quantiles, self.mid_quantiles, self.indices = discretize_xs(
            xs, self.is_cat_and_indices, self.q
        )
        self.transform_indices, self.inverse_transform_indices
        return self

    # @ft.partial(jit, static_argnums=0)
    def transform(self, xs: Array):
        def digitize_fn(x, quantile):
            if quantile is None: 
                return x
            else: 
                digitized = jnp.digitize(x.reshape(-1), quantile)
                return jax.nn.one_hot(digitized, self.q)

        # indices = [x[1][1] for x in self.is_cat_and_indices[:-1]]
        # print(indices)
        digitized_xs = jnp.split(xs, self.transform_indices, axis=-1) # [feat_1, ..., feat_n]
        digitized_xs = jax.tree_util.tree_map(
            digitize_fn, digitized_xs, self.quantiles
        )        
        return jnp.concatenate(digitized_xs, axis=-1)

    def fit_transform(self, xs: Array):
        self.fit(xs)
        return self.transform(xs)

    # @ft.partial(jit, static_argnums=0)
    def inverse_transform(self, xs: Array):
        xs = jnp.split(xs, self.inverse_transform_indices, axis=-1)
        xs = jax.tree_util.tree_map(
            lambda x, q: qcut_inverse(x, q), xs, self.mid_quantiles
        )
        return jnp.concatenate(xs, axis=-1)
    
    def inversed_transform_pytree(self, xs: list[Array]):
        xs = jax.tree_util.tree_map(
            lambda x, q: qcut(x, q), xs, self.mid_quantiles)
        return jnp.concatenate(xs, axis=-1)
    
    def get_pred_fn(self, pred_fn: Callable[[Array], Array]):
        def _pred_fn(xs: Array):
            return pred_fn(self.inverse_transform(xs))
            # return pred_fn(self.inversed_transform_pytree(xs))
        return _pred_fn


In [None]:
dm = relax.load_data("adult")
xs, ys = dm['train']
is_categorical_and_indices = [
    (feat.is_categorical, indices) for feat, indices in zip(dm.features, dm.features.feature_indices)
]

dis = Discretizer(is_categorical_and_indices)
dis.fit(xs)
digitized_xs_1 = dis.transform(xs)
assert digitized_xs_1.shape == (xs.shape[0], 35)
# assert jnp.array_equal(jnp.concatenate(discretized_xs, axis=-1), digitized_xs_1)
inversed_xs = dis.inverse_transform(digitized_xs_1)
assert xs.shape == inversed_xs.shape
# assert jnp.unique(inversed_xs).size == xs.shape[1] * 4

ml_module = relax.load_ml_module("adult")
pred_fn = dis.get_pred_fn(ml_module.pred_fn)
# digitized_xs_1 = split_xs(xs)
y = pred_fn(digitized_xs_1)
assert y.shape == (xs.shape[0], 2)

def f(x, y):
    y_pred = pred_fn(x)
    return jnp.mean((y_pred - y) ** 2)

grad = jax.grad(f)(digitized_xs_1, ys)
assert grad.shape == digitized_xs_1.shape

## L2C Module

In [None]:
#| export
class L2CConfig(BaseConfig):
    generator_layers: list[int] = Field(
        [64, 64, 64], description="Generator MLP layers."
    )
    selector_layers: list[int] = Field(
        [64], description="Selector MLP layers."
    )
    lr: float = Field(1e-3, description="Model learning rate.")
    opt_name: str = Field("adam", description="Optimizer name of training L2C.")
    alpha: float = Field(1e-4, description="Sparsity regularization.")
    tau: float = Field(0.7, description="Temperature for the Gumbel softmax.")
    q: int = Field(4, description="Number of quantiles.")

In [None]:
#| export
class L2C(ParametricCFModule):
    def __init__(
        self,
        config: Dict | L2CConfig = None,
        l2c_model: L2CModel = None,
        name: str = "l2c",
    ):
        if config is None:
            config = L2CConfig()
        config = validate_configs(config, L2CConfig)
        name = name or "l2c"
        self.l2c_model = l2c_model
        super().__init__(config=config, name=name)

    def train(
        self, 
        data: DataModule, 
        pred_fn: Callable,
        batch_size: int = 128,
        epochs: int = 10,
        **fit_kwargs
    ):
        if not isinstance(data, DataModule):
            raise ValueError(f"Only support `data` to be `DataModule`, "
                             f"got type=`{type(data).__name__}` instead.")
        
        xs_train, ys_train = data['train']
        self.discretizer = Discretizer(
            [(feat.is_categorical, indices) for feat, indices in zip(data.features, data.features.feature_indices)],
            q=self.config.q
        )
        discretized_xs_train = self.discretizer.fit_transform(xs_train)
        pred_fn = self.discretizer.get_pred_fn(pred_fn)
        features_indices = self.discretizer.indices

        self.l2c_model = L2CModel(
            generator_layers=self.config.generator_layers,
            selector_layers=self.config.selector_layers,
            feature_indices=features_indices,
            immutable_mask=jnp.array([feat.is_immutable for feat in data.features], dtype=jnp.int32),
            pred_fn=pred_fn,
            alpha=self.config.alpha,
            tau=self.config.tau,
        )
        self.l2c_model.compile(
            optimizer=keras.optimizers.get({
                'class_name': self.config.opt_name, 
                'config': {'learning_rate': self.config.lr}
            }),
            loss=None,
        )
        self.l2c_model.fit(
            discretized_xs_train, ys_train,
            epochs=epochs,
            batch_size=batch_size,
            **fit_kwargs
        )
        self._is_trained = True
        return self
    
    @auto_reshaping('x')
    def generate_cf(
        self,
        x: Array,
        pred_fn: Callable = None,
        y_target: Array = None,
        rng_key: jrand.PRNGKey = None,
        **kwargs
    ) -> Array:
        
        @jax.jit
        def generate_cf(x: Array):
            discretized_x = self.discretizer.transform(x)
            cfs, probs = self.l2c_model.forward(rng_key, discretized_x, training=False)
            return self.discretizer.inverse_transform(cfs)
        return generate_cf(x)

In [None]:
dm = relax.load_data('adult')
ml_module = relax.load_ml_module('adult')

In [None]:
l2c = L2C()
exp = relax.generate_cf_explanations(
    l2c, dm, ml_module.pred_fn,
)

Epoch 1/10
[1m191/191[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 7ms/step - loss: 0.8718 
Epoch 2/10
[1m191/191[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 697us/step - loss: 0.1734
Epoch 3/10
[1m191/191[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 703us/step - loss: 0.1531
Epoch 4/10
[1m191/191[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 693us/step - loss: 0.1465
Epoch 5/10
[1m191/191[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 684us/step - loss: 0.1387
Epoch 6/10
[1m191/191[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 687us/step - loss: 0.1392
Epoch 7/10
[1m191/191[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 680us/step - loss: 0.1365
Epoch 8/10
[1m191/191[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 689us/step - loss: 0.1390
Epoch 9/10
[1m191/191[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 723us/step - loss: 0.1389
Epoch 10/10
[1m191/191[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s

In [None]:
relax.benchmark_cfs([exp])

Unnamed: 0,Unnamed: 1,acc,validity,proximity
adult,l2c,0.827124,0.981235,7.798692


In [None]:
partial_gen = ft.partial(l2c.generate_cf, pred_fn=ml_module.pred_fn)
cfs = jax.vmap(partial_gen)(dm.xs, rng_key=jrand.split(jrand.PRNGKey(0), dm.xs.shape[0]))