Skip to content

Commit

Permalink
Issue #563: Use weights in Permutation Feature Importance calculati…
Browse files Browse the repository at this point in the history
…on (#564)

* Pass explainer weights to loss_after_permutation

* Define weights for sampled data

* Function to handle loss functions with or without sample_weight arg

* Replace loss function calls with wrapper

* Add imports

* Avoid ambiguous truth values

* More explicit warning if weights passed but not used in loss calc
  • Loading branch information
danielarifmurphy committed May 5, 2024
1 parent db2ae5d commit 9884571
Showing 1 changed file with 26 additions and 6 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import inspect
import multiprocessing as mp
import warnings
from numpy.random import SeedSequence, default_rng

import numpy as np
Expand All @@ -18,15 +20,15 @@ def calculate_variable_importance(explainer,
if processes == 1:
result = [None] * B
for i in range(B):
result[i] = loss_after_permutation(explainer.data, explainer.y, explainer.model, explainer.predict_function,
result[i] = loss_after_permutation(explainer.data, explainer.y, explainer.weights, explainer.model, explainer.predict_function,
loss_function, variables, N, np.random)
else:
# Create number generator for each iteration
ss = SeedSequence(random_state)
generators = [default_rng(s) for s in ss.spawn(B)]
pool = mp.get_context('spawn').Pool(processes)
result = pool.starmap_async(loss_after_permutation, [
(explainer.data, explainer.y, explainer.model, explainer.predict_function, loss_function, variables, N, generators[i]) for
(explainer.data, explainer.y, explainer.weights, explainer.model, explainer.predict_function, loss_function, variables, N, generators[i]) for
i in range(B)]).get()
pool.close()

Expand All @@ -49,21 +51,24 @@ def calculate_variable_importance(explainer,
return result, raw_permutations


def loss_after_permutation(data, y, model, predict, loss_function, variables, N, rng):
def loss_after_permutation(data, y, weights, model, predict, loss_function, variables, N, rng):
if isinstance(N, int):
N = min(N, data.shape[0])
sampled_rows = rng.choice(np.arange(data.shape[0]), N, replace=False)
sampled_data = data.iloc[sampled_rows, :]
observed = y[sampled_rows]
sample_weights = weights[sampled_rows] if weights is not None else None
else:
sampled_data = data
observed = y
sample_weights = weights

# loss on the full model or when outcomes are permuted
loss_full = loss_function(observed, predict(model, sampled_data))
loss_full = calculate_loss(loss_function, observed, predict(model, sampled_data), sample_weights)

sampled_rows2 = rng.choice(range(observed.shape[0]), observed.shape[0], replace=False)
loss_baseline = loss_function(observed[sampled_rows2], predict(model, sampled_data))
sample_weights_rows2 = sample_weights[sampled_rows2] if sample_weights is not None else None
loss_baseline = calculate_loss(loss_function, observed[sampled_rows2], predict(model, sampled_data), sample_weights_rows2)

loss_features = {}
for variables_set_key in variables:
Expand All @@ -74,9 +79,24 @@ def loss_after_permutation(data, y, model, predict, loss_function, variables, N,

predicted = predict(model, ndf)

loss_features[variables_set_key] = loss_function(observed, predicted)
loss_features[variables_set_key] = calculate_loss(loss_function, observed, predicted, sample_weights)

loss_features['_full_model_'] = loss_full
loss_features['_baseline_'] = loss_baseline

return pd.DataFrame(loss_features, index=[0])


def calculate_loss(loss_function, observed, predicted, sample_weights=None):
# Determine if loss function accepts 'sample_weight'
loss_args = inspect.signature(loss_function).parameters
supports_weight = "sample_weight" in loss_args

if supports_weight:
return loss_function(observed, predicted, sample_weight=sample_weights)
else:
if sample_weights is not None:
warnings.warn(
f"Loss function `{loss_function.__name__}` does not have `sample_weight` argument. Calculating unweighted loss."
)
return loss_function(observed, predicted)

0 comments on commit 9884571

Please sign in to comment.