Skip to content

Commit

Permalink
[python] add weights to loss function: fix tests, update changelog
Browse files Browse the repository at this point in the history
  • Loading branch information
hbaniecki committed May 5, 2024
1 parent 9884571 commit fab61b8
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 20 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
**token
.vscode/settings.json
**.DS_Store

Expand Down
3 changes: 3 additions & 0 deletions python/dalex/NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
## Changelog

### development

* added a way to pass `sample_weight` to loss functions in `model_parts()` (variable importance) using `weights` from `dx.Explainer` ([#563](https://github.com/ModelOriented/DALEX/issues/563))

### v1.7.0 (2024-02-28)

Expand Down
2 changes: 1 addition & 1 deletion python/dalex/dalex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .aspect import Aspect


__version__ = '1.7.0'
__version__ = '1.7.0.9000'

__all__ = [
"Arena",
Expand Down
2 changes: 0 additions & 2 deletions python/dalex/dalex/_global_checks.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import pkg_resources
from importlib import import_module
from re import search
import numpy as np
import pandas as pd

# WARNING: below code is parsed by setup.py
# WARNING: each dependency should be in new line
Expand Down
22 changes: 10 additions & 12 deletions python/dalex/dalex/model_explanations/_variable_importance/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,18 @@ def loss_after_permutation(data, y, weights, model, predict, loss_function, vari
sampled_rows = rng.choice(np.arange(data.shape[0]), N, replace=False)
sampled_data = data.iloc[sampled_rows, :]
observed = y[sampled_rows]
sample_weights = weights[sampled_rows] if weights is not None else None
sample_weight = weights[sampled_rows] if weights is not None else None
else:
sampled_data = data
observed = y
sample_weights = weights
sample_weight = weights

# loss on the full model or when outcomes are permuted
loss_full = calculate_loss(loss_function, observed, predict(model, sampled_data), sample_weights)
loss_full = calculate_loss(loss_function, observed, predict(model, sampled_data), sample_weight)

sampled_rows2 = rng.choice(range(observed.shape[0]), observed.shape[0], replace=False)
sample_weights_rows2 = sample_weights[sampled_rows2] if sample_weights is not None else None
loss_baseline = calculate_loss(loss_function, observed[sampled_rows2], predict(model, sampled_data), sample_weights_rows2)
sample_weight_rows2 = sample_weight[sampled_rows2] if sample_weight is not None else None
loss_baseline = calculate_loss(loss_function, observed[sampled_rows2], predict(model, sampled_data), sample_weight_rows2)

loss_features = {}
for variables_set_key in variables:
Expand All @@ -79,24 +79,22 @@ def loss_after_permutation(data, y, weights, model, predict, loss_function, vari

predicted = predict(model, ndf)

loss_features[variables_set_key] = calculate_loss(loss_function, observed, predicted, sample_weights)
loss_features[variables_set_key] = calculate_loss(loss_function, observed, predicted, sample_weight)

loss_features['_full_model_'] = loss_full
loss_features['_baseline_'] = loss_baseline

return pd.DataFrame(loss_features, index=[0])


def calculate_loss(loss_function, observed, predicted, sample_weights=None):
def calculate_loss(loss_function, observed, predicted, sample_weight=None):
# Determine if loss function accepts 'sample_weight'
loss_args = inspect.signature(loss_function).parameters
supports_weight = "sample_weight" in loss_args

if supports_weight:
return loss_function(observed, predicted, sample_weight=sample_weights)
return loss_function(observed, predicted, sample_weight=sample_weight)
else:
if sample_weights is not None:
warnings.warn(
f"Loss function `{loss_function.__name__}` does not have `sample_weight` argument. Calculating unweighted loss."
)
if sample_weight is not None:
raise UserWarning(f"Loss function `{loss_function.__name__}` does not have `sample_weight` argument. Calculating unweighted loss.")
return loss_function(observed, predicted)
17 changes: 12 additions & 5 deletions python/dalex/test/test_variable_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,36 +55,43 @@ def test_loss_after_permutation(self):
variables = {}
for col in self.X.columns:
variables[col] = col
lap = utils.loss_after_permutation(self.X, self.y, self.exp.model, self.exp.predict_function, rmse,
lap = utils.loss_after_permutation(self.X, self.y, None, self.exp.model, self.exp.predict_function, rmse,
variables, 100, np.random)
self.assertIsInstance(lap, pd.DataFrame)
self.assertTrue(np.isin(np.array(['_full_model_', '_baseline_']),
lap.columns).all(), np.random)

with self.assertRaises(UserWarning):
lap = utils.loss_after_permutation(self.X, self.y, self.y, self.exp.model, self.exp.predict_function, rmse,
variables, 100, np.random)
self.assertIsInstance(lap, pd.DataFrame)
self.assertTrue(np.isin(np.array(['_full_model_', '_baseline_']),
lap.columns).all(), np.random)

variables = {'age': 'age', 'embarked': 'embarked'}
lap = utils.loss_after_permutation(self.X, self.y, self.exp.model, self.exp.predict_function, mad,
lap = utils.loss_after_permutation(self.X, self.y, None, self.exp.model, self.exp.predict_function, mad,
variables, 10, np.random)
self.assertIsInstance(lap, pd.DataFrame)
self.assertTrue(np.isin(np.array(['_full_model_', '_baseline_']),
lap.columns).all())

variables = {'embarked': 'embarked'}
lap = utils.loss_after_permutation(self.X, self.y, self.exp.model, self.exp.predict_function, mae,
lap = utils.loss_after_permutation(self.X, self.y, None, self.exp.model, self.exp.predict_function, mae,
variables, None, np.random)
self.assertIsInstance(lap, pd.DataFrame)
self.assertTrue(np.isin(np.array(['_full_model_', '_baseline_']),
lap.columns).all())

variables = {'age': 'age'}
lap = utils.loss_after_permutation(self.X, self.y, self.exp.model, self.exp.predict_function, rmse,
lap = utils.loss_after_permutation(self.X, self.y, None, self.exp.model, self.exp.predict_function, rmse,
variables, None, np.random)
self.assertIsInstance(lap, pd.DataFrame)
self.assertTrue(np.isin(np.array(['_full_model_', '_baseline_']),
lap.columns).all())

variables = {'personal': ['gender', 'age', 'sibsp', 'parch'],
'wealth': ['class', 'fare']}
lap = utils.loss_after_permutation(self.X, self.y, self.exp.model, self.exp.predict_function, mae,
lap = utils.loss_after_permutation(self.X, self.y, None, self.exp.model, self.exp.predict_function, mae,
variables, None, np.random)
self.assertIsInstance(lap, pd.DataFrame)
self.assertTrue(np.isin(np.array(['_full_model_', '_baseline_']),
Expand Down

0 comments on commit fab61b8

Please sign in to comment.