In [None]:
#| default_exp influence.if

In [None]:
#| export
from explainax.imports import *
from explainax.influence.base import Influence
from sklearn.datasets import make_classification
from sklearn import linear_model
import haiku as hk
import jax_dataloader as jdl

In [None]:
#| export
class BaseIF(Influence):
    def __init__(
        self, 
        func: Callable, # A black-box function to be explained
        params, # Parameters of the black-box function
        train_dataset: Tuple[Array, Array], # Training dataset
        additional_func_args: Dict = None, # Additional arguments for the black-box function
        input_paramter_name: str = "x", # Name of the input parameter for the black-box function
        **kwargs
    ):
        super().__init__(func, additional_func_args, train_dataset)

    def ihvp(self, vec: Array) -> Array:
        raise NotImplementedError
    

In [None]:
#| export
# https://github.com/pomonam/jax-influence/blob/main/jax_influence/utils.py
def flatten(params):
    return jax.flatten_util.ravel_pytree(params)[0]

def leaves_to_jndarray(pytree):
    """Converts leaves of pytree to jax.numpy arrays."""
    return jax.tree_map(jnp.array, pytree)

@partial(jit, static_argnums=(0,))
def hvp(f, primals, tangents):
    return jax.jvp(jax.grad(f), primals, tangents)[1]

def hessian(f):
     return jax.jit(jax.jacfwd(jax.jacrev(f)))

In [None]:
def train_loss_func(params, model, x, y):
    ...

In [None]:
class AutogradIF(Influence):
    def __init__(
        self, 
        func: Callable, # A black-box function to be explained
        params, # Parameters of the black-box function
        train_dataset: Tuple[Array, Array], # Training dataset
        train_loss: Callable, # Loss function for training
        additional_func_args: Dict = None, # Additional arguments for the black-box function
        input_paramter_name: str = "x", # Name of the input parameter for the black-box function
        **kwargs
    ):
        super().__init__(func, additional_func_args, train_dataset)
        batch_size = 100
        train_dl = jdl.DataLoader(train_dataset, batch_size=batch_size, drop_last=True)
        n_batches = 0
        for batch in train_dl:
            fn = partial(train_loss, func, *batch)
            hess += hessian(fn)(params)
            n_batches += 1
        hess /= n_batches
        
    # def ihvp(self, vec):
        

In [None]:
def load_dummy_data(num_samples=50000, num_features=10, seed=0):
    state = np.random.RandomState(seed)

    x, y = make_classification(num_samples, num_features, random_state=seed, n_informative=num_features, n_redundant=0, n_repeated=0, n_classes=2, n_clusters_per_class=1)
    y = np.expand_dims(y, -1)

    permutation = state.choice(np.arange(x.shape[0]), x.shape[0], replace=False)
    size_train = int(np.round(x.shape[0] * 0.8))
    index_train = permutation[0:size_train]
    index_val = permutation[size_train:]
    x_train, y_train = x[index_train, :], y[index_train, :]
    x_test, y_test = x[index_val, :], y[index_val, :]

    return x_train, y_train, x_test, y_test

In [None]:
class BinaryLogisticRegression(hk.Module):

    def __init__(self, out_features, bias=False, name=None):
        super().__init__(name=name)
        self.bias = bias
        self.out_features = out_features

    def __call__(self, x):
        out = hk.Linear(self.out_features, with_bias=self.bias)(x)
        return out

In [None]:
def safe_log(x):
  return jnp.log(jnp.maximum(x, jnp.ones_like(x) * 1e-10))

def xe_loss(model, params, inputs, targets):
    outputs = model.apply(params, inputs)
    loss = - (targets * safe_log(outputs) + (1 - targets) * safe_log(1 - outputs)).mean()
    return loss

In [None]:
x_train, y_train, x_test, y_test = load_dummy_data()

Train sklearn logistic regression.

In [None]:
train_sample_num = x_train.shape[0]
weight_decay = 0.01

c = 1.0 / (train_sample_num * weight_decay)
sk_model = linear_model.LogisticRegression(
    C=c, solver="lbfgs", tol=1e-10, max_iter=1000, fit_intercept=False)
sk_model.fit(x_train, y_train.ravel())

Load params to haiku model.

In [None]:
model = hk.without_apply_rng(
        hk.transform(lambda *args: BinaryLogisticRegression(1, bias=False)
                     (*args)))
params = model.init(jax.random.PRNGKey(42), x_train)

params = hk.data_structures.to_mutable_dict(params)
params["binary_logistic_regression/linear"]["w"] = jnp.array(sk_model.coef_).T

In [None]:
flatten(params)

Array([-0.33640587, -0.16284819, -0.6649091 , -0.20092598,  0.7574257 ,
       -0.1551791 , -0.41654733, -0.11855818,  0.38400975,  0.3220472 ],      dtype=float32)

Delete one data and train again.

In [None]:
delete_idx = 42 # delete one sample with index 42
x_train_minus_one = np.delete(x_train, delete_idx, axis=0)
y_train_minus_one = np.delete(y_train, delete_idx, axis=0)
# train again
c = 1.0 / ((train_sample_num - 1) * weight_decay)
sk_model_retrain = linear_model.LogisticRegression(
    C=c, solver="lbfgs", tol=1e-10, max_iter=1000, fit_intercept=False)
sk_model_retrain.fit(x_train, y_train.ravel())