## LIME

In [None]:
#| default_exp attr.lime

In [None]:
#| include: false
%load_ext autoreload
%autoreload 2
from ipynb_path import *

In [None]:
#| export
from jax_interpret.imports import *
from jax_interpret.attr.base import *
from jax_interpret.linear_model import Lasso, Ridge

In [None]:
def pairwise_distances(
    x: jnp.ndarray, # [n, k]
    y: jnp.ndarray, # [m, k]
): pass

In [None]:
def default_perturb_func(x: jnp.ndarray, prng_key: jrand.PRNGKey, **kwargs) -> jnp.ndarray:
    """Default perturbation function for LIME"""
    probs = jnp.ones((1, x.shape[-1])) * 0.5
    return jrand.bernoulli(prng_key, probs=probs, shape=x.shape)


In [None]:
class LimeBase(Attribution):
    def __init__(
        self,
        func: Callable, # A black-box function to be explained
        additional_func_args: Dict = None, # Additional arguments for the black-box function
        model_regressor = None, # Linear regressor to use in explanation
        kernal_func: Callable = None, # Kernel function for computing similarity
        perturb_func: Callable = None, # Perturbation function for generating perturbed instances
    ): 
        super().__init__(func, additional_func_args)
        self.model_regressor = model_regressor if model_regressor is not None else Ridge(alpha=1)
        

    def attribute_single(self, inputs: jnp.ndarray, **kwargs) -> jnp.ndarray:
        """Compute attribution for a single input"""
        if len(inputs.shape) == 1:
            inputs = inputs.reshape(1, -1)
        if len(inputs.shape) != 2:
            raise ValueError("Input must be a 1D or 2D array")
        if inputs.shape[0] != 1:
            raise ValueError("Input must be a single instance")
        

    def attribute(self, inputs: jnp.ndarray, **kwargs) -> jnp.ndarray:
        """Compute attribution for a given input"""
        raise NotImplementedError