Skip to content

Commit

Permalink
Check that we can pass feature_names and coef_scale directly
Browse files Browse the repository at this point in the history
Also add coef_scale to explain_prediction kwargs
  • Loading branch information
lopuhin committed Sep 26, 2016
1 parent 55af1c8 commit 6da8bb3
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 4 deletions.
2 changes: 1 addition & 1 deletion eli5/sklearn/explain_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

@singledispatch
def explain_prediction(clf, doc, vec=None, top=_TOP, target_names=None,
feature_names=None, vectorized=False):
feature_names=None, vectorized=False, coef_scale=None):
""" Return an explanation of an estimator """
return {
"estimator": repr(clf),
Expand Down
2 changes: 1 addition & 1 deletion eli5/sklearn/explain_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def explain_linear_classifier_weights(clf, vec=None, top=_TOP, target_names=None
To print it use utilities from eli5.formatters.
"""
feature_names, coef_scale = handle_hashing_vec(vec, feature_names,
coef_scale)
coef_scale)
feature_names = get_feature_names(clf, vec, feature_names=feature_names)
_extra_caveats = "\n" + HASHING_CAVEATS if is_invhashing(vec) else ''

Expand Down
8 changes: 6 additions & 2 deletions tests/test_sklearn_explain_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ def test_explain_hashing_vectorizer(newsgroups_train_binary):
X = vec.fit_transform(docs)
clf.fit(X, y)

get_res = lambda: explain_prediction(
clf, docs[0], ivec, target_names=target_names, top=20)
get_res = lambda **kwargs: explain_prediction(
clf, docs[0], ivec, target_names=target_names, top=20, **kwargs)
res = get_res()
check_explain_linear_binary(res)
assert res == get_res()
Expand All @@ -126,6 +126,10 @@ def test_explain_hashing_vectorizer(newsgroups_train_binary):
pprint(res_vectorized)
assert res_vectorized == res

assert res == get_res(
feature_names=ivec.get_feature_names(always_signed=False),
coef_scale=ivec.column_signs_)


def test_explain_linear_dense():
clf = LogisticRegression()
Expand Down

0 comments on commit 6da8bb3

Please sign in to comment.