## LIME

In [None]:
#| default_exp attr.lime

In [None]:
#| include: false
%load_ext autoreload
%autoreload 2
from ipynb_path import *

In [None]:
#| export
from explainax.imports import *
from explainax.attr.base import *
from explainax.linear_model import Lasso, Ridge, LinearModel

In [None]:
#| export
@partial(jit, static_argnums=(2))
def pairwise_distances(
    x: Array, # [n, k]
    y: Array, # [m, k]
    metric: str = "euclidean" # Supports "euclidean" and "cosine"
) -> Array: # [n, m]
    def euclidean_distances(x: Array, y: Array) -> float:
        XX = jnp.dot(x, x)
        YY = jnp.dot(y, y)
        XY = jnp.dot(x, y)
        dist = jnp.clip(XX - 2 * XY + YY, a_min=0.)
        return jnp.sqrt(dist)
        # return jnp.linalg.norm(x - y, ord=2)
    
    def cosine_distances(x: Array, y: Array) -> float:
        return 1.0 - jnp.dot(x, y) / (jnp.linalg.norm(x) * jnp.linalg.norm(y) + 1e-8)
    
    if metric == "euclidean":
        dists_fn = vmap(vmap(euclidean_distances, in_axes=(None, 0)), in_axes=(0, None))
    elif metric == "cosine":
        dists_fn = vmap(vmap(cosine_distances, in_axes=(None, 0)), in_axes=(0, None))
    else:
        raise ValueError(f"metric='{metric}' not supported")
    
    return dists_fn(x, y)

This function is similar to 
[sklearn.metrics.pairwise_distances](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.pairwise_distances.html).

In [None]:
from sklearn.metrics import pairwise_distances as sk_pairwise_distances

`pairwise_distances` is faster than sklearn's implementation.

In [None]:
X = np.random.normal(size=(1000, 28 * 28))
Y = np.random.normal(size=(1000, 28 * 28))

def benchmark_pairwise_distances(metric):
    print(f"[{metric}] Sklearn pairwise_distances:")
    %timeit -n 10 sk_pairwise_distances(X, Y, metric=metric)
    print(f"[{metric}] JAX pairwise_distances:")
    %timeit -n 10 pairwise_distances(X, Y, metric=metric).block_until_ready()
    assert jnp.allclose(
        sk_pairwise_distances(X, Y, metric=metric),
        pairwise_distances(X, Y, metric=metric)
    )

benchmark_pairwise_distances("euclidean")
benchmark_pairwise_distances("cosine")

[euclidean] Sklearn pairwise_distances:
28.6 ms ± 6.82 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
[euclidean] JAX pairwise_distances:
6.27 ms ± 3.02 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
[cosine] Sklearn pairwise_distances:
29.2 ms ± 6.07 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
[cosine] JAX pairwise_distances:
6.4 ms ± 2.82 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [None]:
#| export
def bernoulli_perturb_func(x: Array, prng_key: jrand.PRNGKey, **kwargs) -> Array:
    """Bernoulli perturbation function for LIME"""
    probs = jnp.ones(x.shape) * 0.5
    return jrand.bernoulli(prng_key, p=probs, shape=x.shape)

def gaussian_perturb_func(x: Array, prng_key: jrand.PRNGKey, **kwargs) -> Array:
    """Gaussian perturbation function for LIME"""
    return jrand.normal(prng_key, shape=x.shape) #+ x

In [None]:
#| export
def _perturb_data(
    x: Array, # [1, k]
    n_samples: int,
    perturb_func: Callable[[Array, jrand.PRNGKey], Array],
    prng_key: jrand.PRNGKey,
) -> Array:
    """Perturb data using perturb_func"""
    perturbed_data = vmap(jit(perturb_func))(
        jnp.repeat(x, n_samples, axis=0), 
        jrand.split(prng_key, n_samples)
    ) 
    return jnp.concatenate([x, perturbed_data], axis=0)

In [None]:
X = np.random.normal(size=(1, 28 * 28))
b_perturbed = _perturb_data(X, 100, bernoulli_perturb_func, jrand.PRNGKey(42))
g_perturbed = _perturb_data(X, 100, gaussian_perturb_func, jrand.PRNGKey(42))
assert b_perturbed.shape == (101, 28 * 28)
assert g_perturbed.shape == (101, 28 * 28)

In [None]:
#| export
@jit
def exp_kernel_func(dists: Array, kernel_width: float) -> Array:
    """Exponential kernel function for LIME"""
    return jnp.sqrt(jnp.exp(-(dists ** 2) / kernel_width ** 2) + 1e-8)

In [None]:
distances = pairwise_distances(g_perturbed, X)
%timeit -n 10 exp_kernel_func(distances, 0.75 * 28 * 28).block_until_ready()

The slowest run took 389.12 times longer than the fastest. This could mean that an intermediate result is being cached.
273 µs ± 657 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [None]:
#| export
def _lime_attribute_single_instance(
    inputs: Array, # [k]
    n_samples: int,
    rng_key: jrand.PRNGKey,
    bb_func: Callable[[Array], Array],
    additional_func_args: Dict,
    input_paramter_name: str,
    perturb_func: Callable[[Array, jrand.PRNGKey, Any], Array],
    kernel_func: Callable[[Array], Array],
    model_regressor: LinearModel,
    pairwise_distances_metric: str,
 ) -> Tuple[Array, Array]: # (Local explanation, intercept)
    # Perturb data
    inputs = inputs.reshape(1, -1)
    data = _perturb_data(inputs, n_samples, perturb_func, rng_key)
    distances = pairwise_distances(data, inputs, metric=pairwise_distances_metric).ravel()
    yss = bb_func(**{input_paramter_name: data}, **additional_func_args)

    if len(yss.shape) == 1:
        yss = yss.reshape(-1, 1)
    if yss.shape != (n_samples + 1, 1):
        raise ValueError("Black-box function must output a single value for each instance.")
    
    # fit a linear model to the perturbed data
    # TODO: implement feature selection
    # https://github.com/marcotcr/lime/blob/fd7eb2e6f760619c29fca0187c07b82157601b32/lime/lime_base.py#L70
    weights = kernel_func(distances)
    assert data.shape == (n_samples + 1, inputs.shape[1])
    assert yss.shape == (n_samples + 1, 1)
    assert weights.shape == (n_samples + 1,)
    model_regressor.fit(data, yss, weights=weights)
    return (model_regressor.coef_, model_regressor.intercept_)

In [None]:
#| export
class LimeBase(Attribution):
    def __init__(
        self,
        func: Callable, # A black-box function to be explained
        additional_func_args: Dict = None, # Additional arguments for the black-box function
        model_regressor = None, # Linear regressor to use in explanation
        kernal_func: Callable = None, # Kernel function for computing similarity
        kernel_width: float = None, # Kernel width for computing similarity. Defaults to (n_features * 0.75)
        perturb_func: Callable = None, # Perturbation function for generating perturbed instances
        input_paramter_name: str = "x", # Name of the input parameter for the black-box function
        pairwise_distances_metric: str = "euclidean", # Metric for computing pairwise distances
    ): 
        super().__init__(func, additional_func_args)
        self.bb_func = func
        self.model_regressor = model_regressor if model_regressor is not None else Ridge(alpha=1)
        self.kernal_func = kernal_func if kernal_func is not None else exp_kernel_func
        self.kernel_width = kernel_width
        self.perturb_func = perturb_func if perturb_func is not None else bernoulli_perturb_func
        self.x_name = input_paramter_name
        self.metric = pairwise_distances_metric

    def attribute(
        self, 
        inputs: Array, 
        n_samples: int = 100,
        rng_key: jrand.PRNGKey = None,
        **kwargs
    ) -> Array:
        """Compute attribution for a given input"""
        if len(inputs.shape) != 2:
            raise ValueError("Inputs shape must be (n_instances, n_features).")
        if rng_key is None:
            rng_key = jrand.PRNGKey(get_config().global_seed)

        kernel_width = self.kernel_width if self.kernel_width is not None else inputs.shape[-1] * 0.75
        kerenl_func = partial(self.kernal_func, kernel_width=kernel_width)
        perturb_func = partial(self.perturb_func, **kwargs)
        additional_func_args = self.additional_func_args if self.additional_func_args is not None else {}
        partialed_lime_func = partial(
            _lime_attribute_single_instance,
            n_samples=n_samples,
            rng_key=rng_key,
            bb_func=self.bb_func,
            additional_func_args=additional_func_args,
            input_paramter_name=self.x_name,
            perturb_func=perturb_func,
            kernel_func=kerenl_func,
            model_regressor=self.model_regressor,
            pairwise_distances_metric=self.metric,
        )

        exp, intercept = vmap(partialed_lime_func)(inputs)
        return (exp, intercept)

#### Tests

In [None]:
from sklearn.datasets import make_regression
import haiku as hk

In [None]:
xs, ys = make_regression(n_samples=500, n_features=20)

In [None]:
linear_model = LinearModel()
linear_model.fit(xs, ys)
lime = LimeBase(linear_model.predict, input_paramter_name="X")

In [None]:
# fit a simple haiku model
def model(x):
    mlp = hk.Sequential([
        hk.Linear(10),
        jax.nn.relu,
        hk.Linear(10),
        jax.nn.relu,
        hk.Linear(1),
    ])
    return mlp(x)


def init(x):
    net = hk.without_apply_rng(hk.transform(model))
    opt = optax.sgd(1e-1)
    params = net.init(jrand.PRNGKey(42), x)
    opt_state = opt.init(params)
    return net, opt, params, opt_state

def loss(params, net, x, y):
    pred = net.apply(params, x)
    return jnp.mean((pred - y) ** 2)

@partial(jax.jit, static_argnums=(2,3))
def update(
    params: hk.Params,
    opt_state: optax.OptState,
    net: hk.Transformed,
    opt: optax.GradientTransformation,
    x: jnp.ndarray,
    y: jnp.ndarray
):
    grads = jax.grad(loss)(params, net, x, y)
    updates, opt_state = opt.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state

def train(
    net: hk.Transformed,
    opt: optax.GradientTransformation,
    params: hk.Params,
    opt_state: optax.OptState,
    x: jnp.ndarray,
    y: jnp.ndarray,
    n_epochs: int = 100,
    batch_size: int = 32,
):
    n_samples = x.shape[0]
    for _ in range(n_epochs):
        for i in range(0, n_samples, batch_size):
            x_batch = x[i:i+batch_size]
            y_batch = y[i:i+batch_size]
            params, opt_state = update(params, opt_state, net, opt, x_batch, y_batch)
    return params

def fit_a_model(
    X: Array,
    y: Array,
):
    net, opt, params, opt_state = init(X)
    params = train(net, opt, params, opt_state, X, y)
    return net, params


In [None]:
net, params = fit_a_model(xs, ys)

  param = init(shape, dtype)


In [None]:
net.apply(params, xs[:10])

DeviceArray([[0.24004635],
             [0.24004635],
             [0.24004635],
             [0.24004635],
             [0.24004635],
             [0.24004635],
             [0.24004635],
             [0.24004635],
             [0.24004635],
             [0.24004635]], dtype=float32)

In [None]:
kernel_func = partial(exp_kernel_func, kernel_width=2 * 0.75)

_lime_attribute_single_instance(
    X[:1],
    1000,
    jrand.PRNGKey(42),
    net.apply,
    additional_func_args={"params": params},
    input_paramter_name="x",
    perturb_func=gaussian_perturb_func,
    kernel_func=kernel_func,
    model_regressor=Ridge(alpha=1),
    pairwise_distances_metric="euclidean",
)

(DeviceArray([-5.97378996e-04,  4.81305433e-05,  1.55211031e-03,
              -3.89112538e-04, -1.21736361e-04, -1.05082065e-04,
               1.35615395e-04, -4.70238097e-04,  5.06596552e-05,
               5.96614438e-04,  8.60058644e-04,  1.57593074e-03,
              -1.35515491e-03,  7.79002556e-04,  8.51082150e-04,
               1.22845857e-04,  6.43223408e-04,  6.08801842e-04,
              -3.02861667e-06,  1.18552020e-03], dtype=float32),
 DeviceArray([0.24033982], dtype=float32))

In [None]:
lime = LimeBase(
    func=net.apply,
    additional_func_args={"params": params},
)
lime.attribute(X)

(DeviceArray([[ 0.00221086,  0.00581249,  0.00237516, ...,  0.0043735 ,
                0.00298327,  0.00597063],
              [ 0.00241475,  0.00598198,  0.00254859, ...,  0.00455689,
                0.00296534,  0.00607178],
              [-0.00225598, -0.00591312, -0.00256158, ..., -0.00448116,
               -0.00301292, -0.00606422],
              ...,
              [ 0.00223549,  0.00583388,  0.00241553, ...,  0.00438501,
                0.00298516,  0.00600691],
              [ 0.00221306,  0.00573398,  0.00225197, ...,  0.00415604,
                0.00285654,  0.00581799],
              [ 0.00242912,  0.00596596,  0.00257836, ...,  0.00450296,
                0.0030849 ,  0.00620169]], dtype=float32),
 DeviceArray([[0.24730496],
              [0.24756414],
              [0.23204625],
              [0.24731565],
              [0.24734785],
              [0.24736024],
              [0.24726894],
              [0.24740964],
              [0.24748607],
              [0.2472707 ],
