Skip to content

Commit

Permalink
LDA & QDA: Extended tests
Browse files Browse the repository at this point in the history
  • Loading branch information
andreArtelt committed May 8, 2020
1 parent 5619300 commit 5794716
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 2 deletions.
17 changes: 16 additions & 1 deletion tests/sklearn/test_sklearn_lda.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ def test_lda():
assert y_cf == 0
assert model.predict(np.array([x_cf])) == 0

x_cf, y_cf, delta = generate_counterfactual(model, x_orig, 1, features_whitelist=features_whitelist, regularization="l1", optimizer="mp", return_as_dict=False)
assert y_cf == 1
assert model.predict(np.array([x_cf])) == 1

cf = generate_counterfactual(model, x_orig, 0, features_whitelist=features_whitelist, regularization="l1", optimizer="mp", return_as_dict=True)
assert cf["y_cf"] == 0
assert model.predict(np.array([cf["x_cf"]])) == 0

x_cf, y_cf, delta = generate_counterfactual(model, x_orig, 0, features_whitelist=features_whitelist, regularization="l2", optimizer="mp", return_as_dict=False)
assert y_cf == 0
assert model.predict(np.array([x_cf])) == 0
Expand Down Expand Up @@ -89,4 +97,11 @@ def test_lda():
# Other stuff
from ceml.sklearn import LdaCounterfactual
with pytest.raises(TypeError):
LdaCounterfactual(sklearn.linear_model.LogisticRegression())
LdaCounterfactual(sklearn.linear_model.LogisticRegression())

model = LinearDiscriminantAnalysis()
model.fit(X_train, y_train)
with pytest.raises(AttributeError):
LdaCounterfactual(model)
with pytest.raises(AttributeError):
generate_counterfactual(model, x_orig, 0)
17 changes: 16 additions & 1 deletion tests/sklearn/test_sklearn_qda.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ def test_qda():
assert y_cf == 0
assert model.predict(np.array([x_cf])) == 0

cf = generate_counterfactual(model, x_orig, 0, features_whitelist=features_whitelist, regularization="l1", optimizer="mp", return_as_dict=True)
assert cf["y_cf"] == 0
assert model.predict(np.array([cf["x_cf"]])) == 0

x_cf, y_cf, delta = generate_counterfactual(model, x_orig, 0, features_whitelist=features_whitelist, regularization="l2", optimizer="mp", return_as_dict=False)
assert y_cf == 0
assert model.predict(np.array([x_cf])) == 0
Expand Down Expand Up @@ -103,6 +107,10 @@ def test_qda():
assert y_cf == 1
assert model.predict(np.array([x_cf])) == 1

cf = generate_counterfactual(model, x_orig, y_target=1, features_whitelist=features_whitelist, optimizer="mp", return_as_dict=True)
assert cf["y_cf"] == 1
assert model.predict(np.array([cf["x_cf"]])) == 1

x_orig = X_test[0,:]
assert model.predict([x_orig]) == 1

Expand All @@ -113,4 +121,11 @@ def test_qda():
# Other stuff
from ceml.sklearn import QdaCounterfactual
with pytest.raises(TypeError):
QdaCounterfactual(sklearn.linear_model.LogisticRegression())
QdaCounterfactual(sklearn.linear_model.LogisticRegression())

model = QuadraticDiscriminantAnalysis()
model.fit(X_train, y_train)
with pytest.raises(AttributeError):
QdaCounterfactual(model)
with pytest.raises(AttributeError):
generate_counterfactual(model, x_orig, 0)

0 comments on commit 5794716

Please sign in to comment.