Skip to content

Commit

Permalink
Merge pull request #435 from EducationalTestingService/make-mlp-test-…
Browse files Browse the repository at this point in the history
…faster

Make MLP regression and classification tests run faster
  • Loading branch information
desilinguist committed Nov 30, 2018
2 parents 9f7e962 + 4da69cd commit 3f65d27
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions tests/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def test_mlp_classification():
learner = Learner('MLPClassifier')
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=ConvergenceWarning)
learner.train(train_fs, grid_search=True)
learner.train(train_fs, grid_search=False)

# now generate the predictions on the test set
predictions = learner.predict(test_fs)
Expand All @@ -220,7 +220,7 @@ def test_mlp_classification():
# using make_regression_data. To do this, we just
# make sure that they are correlated
accuracy = accuracy_score(predictions, test_fs.labels)
assert_almost_equal(accuracy, 0.825)
assert_almost_equal(accuracy, 0.858, places=3)


def check_sparse_predict_sampler(use_feature_hashing=False):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,7 @@ def test_ransac_regression():
'SGDRegressor',
'DecisionTreeRegressor',
'SVR'],
[0.95, 0.45, 0.75, 0.65]):
[0.95, 0.45, 0.75, 0.65]):
yield check_ransac_regression, base_estimator_name, pearson_value


Expand All @@ -627,7 +627,7 @@ def check_mlp_regression(use_rescaling=False):
# we don't want to see any convergence warnings during the grid search
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=ConvergenceWarning)
learner.train(train_fs, grid_search=True, grid_objective='pearson')
learner.train(train_fs, grid_search=False)

# now generate the predictions on the test set
predictions = learner.predict(test_fs)
Expand Down

0 comments on commit 3f65d27

Please sign in to comment.