Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue #563: Use weights in Permutation Feature Importance calculation #564

Merged
merged 7 commits into from
May 5, 2024
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import inspect
import multiprocessing as mp
import warnings
from numpy.random import SeedSequence, default_rng

import numpy as np
Expand All @@ -18,15 +20,15 @@ def calculate_variable_importance(explainer,
if processes == 1:
result = [None] * B
for i in range(B):
result[i] = loss_after_permutation(explainer.data, explainer.y, explainer.model, explainer.predict_function,
result[i] = loss_after_permutation(explainer.data, explainer.y, explainer.weights, explainer.model, explainer.predict_function,
loss_function, variables, N, np.random)
else:
# Create number generator for each iteration
ss = SeedSequence(random_state)
generators = [default_rng(s) for s in ss.spawn(B)]
pool = mp.get_context('spawn').Pool(processes)
result = pool.starmap_async(loss_after_permutation, [
(explainer.data, explainer.y, explainer.model, explainer.predict_function, loss_function, variables, N, generators[i]) for
(explainer.data, explainer.y, explainer.weights, explainer.model, explainer.predict_function, loss_function, variables, N, generators[i]) for
i in range(B)]).get()
pool.close()

Expand All @@ -49,21 +51,24 @@ def calculate_variable_importance(explainer,
return result, raw_permutations


def loss_after_permutation(data, y, model, predict, loss_function, variables, N, rng):
def loss_after_permutation(data, y, weights, model, predict, loss_function, variables, N, rng):
if isinstance(N, int):
N = min(N, data.shape[0])
sampled_rows = rng.choice(np.arange(data.shape[0]), N, replace=False)
sampled_data = data.iloc[sampled_rows, :]
observed = y[sampled_rows]
sample_weights = weights[sampled_rows] if weights is not None else None
else:
sampled_data = data
observed = y
sample_weights = weights

# loss on the full model or when outcomes are permuted
loss_full = loss_function(observed, predict(model, sampled_data))
loss_full = calculate_loss(loss_function, observed, predict(model, sampled_data), sample_weights)

sampled_rows2 = rng.choice(range(observed.shape[0]), observed.shape[0], replace=False)
loss_baseline = loss_function(observed[sampled_rows2], predict(model, sampled_data))
sample_weights_rows2 = sample_weights[sampled_rows2] if sample_weights is not None else None
loss_baseline = calculate_loss(loss_function, observed[sampled_rows2], predict(model, sampled_data), sample_weights_rows2)

loss_features = {}
for variables_set_key in variables:
Expand All @@ -74,9 +79,24 @@ def loss_after_permutation(data, y, model, predict, loss_function, variables, N,

predicted = predict(model, ndf)

loss_features[variables_set_key] = loss_function(observed, predicted)
loss_features[variables_set_key] = calculate_loss(loss_function, observed, predicted, sample_weights)

loss_features['_full_model_'] = loss_full
loss_features['_baseline_'] = loss_baseline

return pd.DataFrame(loss_features, index=[0])


def calculate_loss(loss_function, observed, predicted, sample_weights=None):
# Determine if loss function accepts 'sample_weight'
loss_args = inspect.signature(loss_function).parameters
supports_weight = "sample_weight" in loss_args

if supports_weight:
return loss_function(observed, predicted, sample_weight=sample_weights)
else:
if sample_weights is not None:
warnings.warn(
f"Loss function `{loss_function.__name__}` does not have `sample_weight` argument. Calculating unweighted loss."
)
return loss_function(observed, predicted)
Loading