Skip to content

Commit

Permalink
Add test for PCA preprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
andreArtelt committed Jul 16, 2019
1 parent 7c75574 commit e181fda
Showing 1 changed file with 30 additions and 3 deletions.
33 changes: 30 additions & 3 deletions tests/sklearn/test_sklearn_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import numpy as np
np.random.seed(42)
import sklearn
from sklearn.datasets import load_iris
from sklearn.datasets import load_iris, load_boston
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.linear_model import LogisticRegression, Lasso
from sklearn.preprocessing import StandardScaler, RobustScaler, PolynomialFeatures, Normalizer, MinMaxScaler, MaxAbsScaler
from sklearn.decomposition import PCA
from sklearn.pipeline import make_pipeline
Expand Down Expand Up @@ -305,4 +305,31 @@ def test_pipeline_scaler_poly_softmaxregression():
assert model.predict([x_orig]) == 2

# Compute counterfactual
compute_counterfactuals_poly(model, x_orig, 0)
compute_counterfactuals_poly(model, x_orig, 0)


def test_pipeline_pca_linearregression():
# Load data
X, y = load_boston(True)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=4242)

# Create and fit model
pca = PCA(n_components=4)
model = Lasso()

model = make_pipeline(pca, model)
model.fit(X_train, y_train)

# Select data point for explaining its prediction
x_orig = X_test[1:4][0,:]
y_orig_pred = model.predict([x_orig])
assert y_orig_pred >= 25 and y_orig_pred < 26

# Compute counterfactual
y_target = 20.
y_target_done = lambda z: np.abs(z - y_target) < 3.

x_cf, y_cf, _ = generate_counterfactual(model, x_orig, y_target=y_target, done=y_target_done, regularization="l1", C=0.1, features_whitelist=None, optimizer="bfgs", return_as_dict=False)
assert y_target_done(y_cf)
assert y_target_done(model.predict(np.array([x_cf])))

0 comments on commit e181fda

Please sign in to comment.