Skip to content

Commit

Permalink
Added test_random_forest_2
Browse files Browse the repository at this point in the history
  • Loading branch information
GJena committed Apr 16, 2016
1 parent d2e6209 commit 7b27a9a
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion tests.py
Expand Up @@ -81,7 +81,7 @@ def test_decision_tree_2():
assert np.array_equal(result['guess'].values, dtc.predict(testing_features))

def test_random_forest():
"""Ensure that the TPOT random forest method outputs the same as the sklearn random forest"""
"""Ensure that the TPOT random forest method outputs the same as the sklearn random forest when max_features<1"""

tpot_obj = TPOT()
result = tpot_obj._random_forest(training_testing_data, 0)
Expand All @@ -92,6 +92,18 @@ def test_random_forest():

assert np.array_equal(result['guess'].values, rfc.predict(testing_features))

def test_random_forest_2():
"""Ensure that the TPOT random forest method outputs the same as the sklearn random forest when max_features=1"""

tpot_obj = TPOT()
result = tpot_obj._random_forest(training_testing_data, 1)
result = result[result['group'] == 'testing']

rfc = RandomForestClassifier(n_estimators=500, max_features=None, random_state=42, n_jobs=-1)
rfc.fit(training_features, training_classes)

assert np.array_equal(result['guess'].values, rfc.predict(testing_features))

def test_svc():
"""Ensure that the TPOT random forest method outputs the same as the sklearn svc when C>0.0001"""

Expand Down

0 comments on commit 7b27a9a

Please sign in to comment.