# L2C

https://arxiv.org/abs/2209.13446

In [None]:
#| default_exp methods.l2c

In [None]:
#| include: false
%load_ext autoreload
%autoreload 2
from ipynb_path import *
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
#| export
from __future__ import annotations
from relax.import_essentials import *
from relax.methods.base import ParametricCFModule
from relax.base import BaseConfig
from relax.utils import *
from relax.data_utils import Feature, FeaturesList
from relax.ml_model import MLP, MLPBlock
from relax.data_module import DataModule
from keras_core.random import SeedGenerator
import einops

Using JAX backend.


An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In [None]:
#| hide
import torch

In [None]:
#| export
def gumbel_softmax(
    key: jrand.PRNGKey, # Random key
    logits: Array, # Logits for each class. Shape (batch_size, num_classes)
    tau: float, # Temperature for the Gumbel softmax
):
    """The Gumbel softmax function."""

    gumbel_noise = jrand.gumbel(key, shape=logits.shape)
    y = logits + gumbel_noise
    return jax.nn.softmax(y / tau, axis=-1)

In [None]:
#| export
def sample_categorical(
    key: jrand.PRNGKey, # Random key
    logits: Array, # Logits for each class. Shape (batch_size, num_classes)
    tau: float, # Temperature for the Gumbel softmax
    training: bool = True, # Apply gumbel softmax if training
):
    """Sample from a categorical distribution."""

    def sample_cat(key, logits):
        cat = jrand.categorical(key, logits=logits, axis=-1)
        return jax.nn.one_hot(cat, logits.shape[-1])

    return lax.cond(
        training,
        lambda _: gumbel_softmax(key, logits, tau=tau),
        lambda _: sample_cat(key, logits),
        None,
    )

In [None]:
logits = jnp.array([[2.0, 1.0, 0.1], [1.0, 2.0, 3.0]])
key = jrand.PRNGKey(0)
output = sample_categorical(key, logits, tau=0.5, training=True)
assert output.shape == logits.shape
assert jnp.allclose(output.sum(axis=-1), 1.0)
# low temperature -> one-hot
output = sample_categorical(key, logits, tau=0.01, training=True)
assert jnp.array_equal(
    output.argmax(axis=-1), logits.argmax(axis=-1)
)
# high temperature -> uniform
output = sample_categorical(key, logits, tau=100, training=True)
assert jnp.max(output) - jnp.min(output) < 0.5

output = sample_categorical(key, logits, tau=0.5, training=False)
assert output.shape == logits.shape
assert jnp.array_equal(
    output.argmax(axis=-1), logits.argmax(axis=-1)
)

In [None]:
#| export
def sample_bernouli(
    key: jrand.PRNGKey, # Random key
    prob: Array, # Logits for each class. Shape (batch_size, 1)
    tau: float, # Temperature for the Gumbel softmax
    training: bool = True, # Apply gumbel softmax if training
) -> Array:
    """"Sample from a bernouli distribution."""

    def sample_ber(key, prob):
        return jrand.bernoulli(key, p=prob).astype(prob.dtype)
    
    def gumbel_ber(key, prob, tau):
        key_1, key_2 = jrand.split(key)
        gumbel_1 = jrand.gumbel(key_1, shape=prob.shape)
        gumbel_2 = jrand.gumbel(key_2, shape=prob.shape)
        no_logits = (prob * jnp.exp(gumbel_1)) / tau
        de_logits = no_logits + ((1. - prob) * jnp.exp(gumbel_2)) / tau
        return no_logits / de_logits
    
    return lax.cond(
        training,
        lambda _: gumbel_ber(key, prob, tau),
        lambda _: sample_ber(key, prob),
        None,
    )

In [None]:
#| export
class L2CModel(keras.Model):
    def __init__(
        self,
        generator_layers: list[int],
        selector_layers: list[int],
        feature_indices: list[tuple[int, int]] = None,
        pred_fn: Callable = None,
        alpha: float = 1e-4, # Sparsity regularization
        tau: float = 0.7,
        seed: int = None,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.generator_layers = generator_layers
        self.selector_layers = selector_layers
        self.feature_indices = feature_indices
        self.pred_fn = pred_fn
        self.tau = tau
        self.alpha = alpha
        seed = seed or get_config().global_seed
        self.seed_generator = SeedGenerator(seed)

    def set_features_info(self, feature_indices: list[tuple[int, int]]):
        self.feature_indices = feature_indices
        # TODO: check if the feature indices are valid

    def set_pred_fn(self, pred_fn: Callable):
        self.pred_fn = pred_fn

    def build(self, input_shape):
        n_feats = len(self.feature_indices)
        self.generator = MLP(
            sizes=self.generator_layers,
            output_size=input_shape[-1],
            dropout_rate=0.0,
            last_activation="linear",
        )
        self.selector = MLP(
            sizes=n_feats,
            output_size=input_shape[-1],
            dropout_rate=0.0,
            last_activation="sigmoid",
        )

    def compute_loss(self, inputs, cfs, probs):
        y_target = self.pred_fn(inputs).argmin(axis=-1)
        y_pred = self.pred_fn(cfs)
        validity_loss = keras.losses.sparse_categorical_crossentropy(
            y_target, y_pred
        )
        sparsity = jnp.linalg.norm(probs, p=1) * self.alpha
        return validity_loss + sparsity

    
    def call(self, inputs, training=False):
        def perturb(cfs, probs, i, start, end):
            return (
                cfs[:, start:end] * probs[:, i : i + 1] +
                inputs[:, start:end] * (1 - probs[:, i : i + 1])
            )

        select_probs = self.selector(inputs, training=training)
        probs = sample_bernouli(
            self.seed_generator().next(), select_probs, 
            tau=self.tau, training=training
        )
        cfs_logits = self.generator(inputs, training=training)
        cfs = sample_categorical(
            self.seed_generator().next(), cfs_logits, 
            tau=self.tau, training=training
        )
        cfs = jnp.concatenate([
                perturb(cfs, probs, i, start, end)
                for i, (start, end) in enumerate(self.feature_indices)
            ], axis=-1,
        )
        loss = self.compute_loss(inputs, cfs, probs)
        self.add_loss(loss)
        return cfs   


In [None]:
#| export
def qcut(
    x: Array, # Input array
    q: int, # Number of quantiles
    axis: int = 0, # Axis to quantile
) -> tuple[Array, Array]: # (digitized array, quantiles)
    """Quantile binning."""
    
    # Handle edge cases: empty array or single element
    if x.size <= 1:
        return jnp.zeros_like(x), jnp.array([])
    quantiles = jnp.quantile(x, jnp.linspace(0, 1, q + 1)[1:-1], axis=axis)
    # unique_quantiles = jnp.unique(quantiles)
    return jnp.digitize(x, quantiles), quantiles

In [None]:
digitized, quantiles = qcut(jnp.arange(10), 4)
assert digitized.shape == (10,)
assert quantiles.shape == (3,)
assert digitized.min() == 0
assert digitized.max() == 3
assert jnp.allclose(
    quantiles, jnp.array([2.25, 4.5, 6.75])
)
x_empty = jnp.array([])
q = 2
digitized_empty, quantiles_empty = qcut(x_empty, q)
assert digitized_empty.size == 0 and quantiles_empty.size == 0
# Test with single element array
x_single = jnp.array([1])
digitized_single, quantiles_single = qcut(x_single, q)
assert digitized_single.size == 1 and quantiles_single.size == 0

# Test with large q value
xs = jnp.array([1, 2, 3, 4, 5, 6])
q_large = 10
_, quantiles_large = qcut(xs, q_large)
assert len(quantiles_large) == q_large - 1

In [None]:
@dataclass
class Discretizer:
    q: int = 4 # Number of quantiles

    def fit(self, xs: Array):
        _, self.quantiles = jax.vmap(qcut, in_axes=(1, None))(xs, self.q)
        self.quantiles = jnp.concatenate(
            [xs.min(axis=0).reshape(-1, 1), self.quantiles], axis=-1
        ) # (n_feats, q)
        return self

    def transform(self, xs: Array):
        if not hasattr(self, "quantiles"):
            raise ValueError("Call fit before transform")
        digitized = jax.vmap(jnp.digitize, in_axes=(1, 0))(xs, self.quantiles[:, 1:]).T
        return digitized

    def fit_transform(self, xs: Array):
        digitized, self.quantiles = jax.vmap(qcut, in_axes=(1, None))(xs, self.q)
        self.quantiles = jnp.concatenate(
            [xs.min(axis=0).reshape(-1, 1), self.quantiles], axis=-1
        )
        return digitized.T

    def inverse_transform(self, xs: Array):
        # TODO
        return self.quantiles[xs]

In [None]:
xs = jnp.arange(40).reshape(5, 8)
discretizer = Discretizer(q=4)
discretizer.fit(xs)
assert discretizer.quantiles.shape == (8, 4)
digitized = discretizer.transform(xs)
assert digitized.shape == (5, 8)

assert jnp.allclose(
    digitized,
    einops.repeat(jnp.array([0, 1, 2, 3, 3]), 'i -> j i', j=8).T
)

digitized_ft = discretizer.fit_transform(xs)
assert jnp.allclose(
    digitized, digitized_ft
)


In [None]:
class L2CConfig(BaseConfig):
    generator_layers: list[int]
    selector_layers: list[int]
    alpha: float = Field(1e-4, description="Sparsity regularization.")
    tau: float = Field(0.7, description="Temperature for the Gumbel softmax.")

In [None]:
def discretize_xs(
    xs: Array, # Input array
    datamodule: DataModule, # Features list
    q: int = 4, # Number of quantiles
):
    # TODO
    discretized_xs = []
    cont_quantiles = []
    feature_indices = []

    for feat, (start, end) in datamodule.features.features_and_indices:
        if feat.is_continuous:
            discretized, quantiles = qcut(xs[:, start:end], q=q)

            
        else:
            discretized_xs.append(xs[:, start:end])

In [None]:
class L2C(ParametricCFModule):
    def __init__(
        self,
        config: Dict | L2CConfig = None,
        l2c_model: L2CModel = None,
        name: str = "l2c",
    ):
        if config is None:
            config = L2CConfig()
        config = validate_configs(config, L2CConfig)
        name = name or "l2c"
        self.l2c_model = l2c_model
        super().__init__(config=config, name=name)

    def train(
        self, 
        data: DataModule, 
        pred_fn: Callable = None,
        batch_size: int = 128,
        epochs: int = 10,
        **fit_kwargs
    ):
        if not isinstance(data, DataModule):
            raise ValueError(f"Only support `data` to be `DataModule`, "
                             f"got type=`{type(data).__name__}` instead.")
        xs_train, ys_train = data['train']
        