Skip to content

Commit

Permalink
RF/DT: Update documentation + unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
iyerr3 committed Jul 24, 2018
1 parent 0437877 commit 2f6acc4
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,9 @@ tree_surr_display(tree_model)

This is a helper function that creates a table to more easily
view impurity variable importance values for a given model
table.
table. This function rescales the importance values to represent them as
percentages i.e. importance values are scaled to sum to 100.

<pre class="syntax">
get_var_importance(model_table, output_table)
</pre>
Expand Down Expand Up @@ -692,7 +694,7 @@ SELECT pruning_cp, cat_levels_in_text, cat_n_levels, impurity_var_importance, tr
pruning_cp | 0
cat_levels_in_text | {overcast,rain,sunny,False,True}
cat_n_levels | {3,2}
impurity_var_importance | {10.6171201061712,0,89.3828798938288}
impurity_var_importance | {0.102040816326531,0,0.85905612244898}
tree_depth | 5
</pre>
View the summary table:
Expand Down Expand Up @@ -974,7 +976,7 @@ SELECT pruning_cp, cat_levels_in_text, cat_n_levels, impurity_var_importance, tr
pruning_cp | 0
cat_levels_in_text | {medium,none,high,low,unhealthy,good,moderate}
cat_n_levels | {4,3}
impurity_var_importance | {0,40.2340084993653,5.6791213643137,54.086870136321}
impurity_var_importance | {0,0.330612244897959,0.0466666666666666,0.444444444444444}
tree_depth | 3
</pre>
The first 4 levels correspond to cloud ceiling and the next 3 levels
Expand Down Expand Up @@ -1309,7 +1311,7 @@ SELECT pruning_cp, cat_levels_in_text, cat_n_levels, impurity_var_importance, tr
pruning_cp | 0
cat_levels_in_text | {0,1,4,6,8}
cat_n_levels | {2,3}
impurity_var_importance | {0,51.8593201959496,10.976977929129,5.31897402755374,31.8447278473677}
impurity_var_importance | {0,22.6309172500677,4.79024943310653,2.32115,13.8967382920109}
tree_depth | 4
</pre>
The cp values tested and average error and standard deviation are:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1626,6 +1626,8 @@ def get_var_importance(schema_madlib, model_table, output_table, **kwargs):
impurity_var_importance_str = ''
else:
# Decision tree models don't have a OOB variable importance
_assert(is_impurity_imp_col_present,
"Decision tree: Impurity importance not present in output table")
oob_var_importance_str = ''
impurity_var_importance_str = (
"{0} AS impurity_var_importance".
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,9 @@ for more details on working with tree output formats.
@par Importance Display
This is a helper function that creates a table to more easily
view out-of-bag and impurity variable importance values for a given model
table.
table. This function rescales the importance values to represent them as
percentages i.e. importance values are scaled to sum to 100.

<pre class="syntax">
get_var_importance(model_table, output_table)
</pre>
Expand Down Expand Up @@ -745,8 +747,8 @@ success | t
cat_n_levels | {3,2}
cat_levels_in_text | {overcast,sunny,rain,False,True}
oob_error | 0.64285714285714285714
oob_var_importance | {0.0394736842105263,0.0197368421052632,0,0.00877192982456138}
impurity_var_importance | {26.8150656206627,14.3772272393831,31.1610056528161,22.6465339362903}
oob_var_importance | {0.0525595238095238,0,0.0138095238095238,0.0276190476190476}
impurity_var_importance | {0.254133481284938,0.0837130966399198,0.258520599370744,0.173196167388586}
</pre>
The 'cat_levels_in_text' array shows the
levels of the categorical variables "OUTLOOK" and windy,
Expand All @@ -764,12 +766,13 @@ SELECT madlib.get_var_importance('train_output','imp_output');
SELECT * FROM imp_output ORDER BY oob_var_importance DESC;
</pre>
<pre class="result">
feature | oob_var_importance | impurity_var_importance
-------------+---------------------+-------------------------
"OUTLOOK" | 0.0394736842105263 | 26.8150656206627
windy | 0.0197368421052632 | 14.3772272393831
humidity | 0.00877192982456138 | 22.6465339362903
temperature | 0 | 31.1610056528161
feature | oob_var_importance | impurity_var_importance
-------------+--------------------+-------------------------
"OUTLOOK" | 55.9214692843572 | 33.0230751036133
humidity | 29.3856871437619 | 22.5057714332356
temperature | 14.692843571881 | 33.5931539822541
windy | 0 | 10.877999480897
(4 rows)
</pre>
-# Predict output categories. For the purpose of this
example, we use the same data that was used for training:
Expand Down Expand Up @@ -1017,8 +1020,8 @@ success | t
cat_n_levels | {3,2}
cat_levels_in_text | {overcast,rain,sunny,False,True}
oob_error | 0.57142857142857142857
oob_var_importance | {0.0666666666666667,0.0333333333333333,0.05,0}
impurity_var_importance | {14.4447296351495,16.1522023965976,22.8299129333055,31.57293033484}
oob_var_importance | {0,0.0166666666666667,0.0166666666666667,0.0166666666666667}
impurity_var_importance | {0.143759266026582,0.0342777777777778,0.157507369614512,0.0554953231292517}
</pre>

<b>Random Forest Regression Example</b>
Expand Down Expand Up @@ -1138,24 +1141,24 @@ Review the group table to see variable importance by group:
SELECT * FROM mt_cars_output_group ORDER BY gid;
</pre>
<pre class="result">
-[ RECORD 1 ]-----------+---------------------------------------------------------------------------------------
-[ RECORD 1 ]-----------+----------------------------------------------------------------------------------------
gid | 1
am | 0
success | t
cat_n_levels | {2,3}
cat_levels_in_text | {0,1,4,6,8}
oob_error | 7.77759302106122
oob_var_importance | {2.93612545351474,5.06468471590909,8.01038092500973,0.530339832451499,0}
impurity_var_importance | {13.7025743802763,22.8588174057646,31.0136494271938,6.34452775143336,26.0804243950309}
-[ RECORD 2 ]-----------+---------------------------------------------------------------------------------------
oob_error | 8.64500988190963
oob_var_importance | {3.91269987042436,0,2.28278236607143,0.0994074074074073,3.42585277187264}
impurity_var_importance | {5.07135586863621,3.72145581490929,5.06700415274492,0.594942174008333,8.10909642389614}
-[ RECORD 2 ]-----------+----------------------------------------------------------------------------------------
gid | 2
am | 1
success | t
cat_n_levels | {2,3}
cat_levels_in_text | {0,1,4,6,8}
oob_error | 26.9864629759899
oob_var_importance | {0.689754050925925,0,9.33801817602041,0.703828124999998,8.66631289468697}
impurity_var_importance | {3.3333333044783,11.6666665358472,31.6666657616242,0,33.3333324531319}
oob_error | 16.5197718747446
oob_var_importance | {5.22711111111111,10.0872041666667,9.6875362244898,3.97782,2.99447839506173}
impurity_var_importance | {5.1269704861111,7.04765974920884,20.9817274159476,4.02800949238769,10.5539079705215}
</pre>
Use the helper function to display variable importance:
<pre class="example">
Expand All @@ -1167,16 +1170,16 @@ SELECT * FROM mt_imp_output ORDER BY am, oob_var_importance DESC;
<pre class="result">
am | feature | oob_var_importance | impurity_var_importance
----+---------+--------------------+-------------------------
0 | disp | 8.01038092500973 | 31.0136494271938
0 | cyl | 5.06468471590909 | 22.8588174057646
0 | vs | 2.93612545351474 | 13.7025743802763
0 | qsec | 0.530339832451499 | 6.34452775143336
0 | wt | 0 | 26.0804243950309
1 | disp | 9.33801817602041 | 31.6666657616242
1 | wt | 8.66631289468697 | 33.3333324531319
1 | qsec | 0.703828124999998 | 0
1 | vs | 0.689754050925925 | 3.3333333044783
1 | cyl | 0 | 11.6666665358472
0 | vs | 40.2510395098467 | 22.4755743014842
0 | wt | 35.2427070417256 | 35.9384361725319
0 | disp | 23.4836216045257 | 22.4562880757909
0 | qsec | 1.02263184390195 | 2.63670453886068
0 | cyl | 0 | 16.4929969113323
1 | cyl | 31.5479979891794 | 14.7631219023997
1 | disp | 30.2980259228064 | 43.9515825943964
1 | vs | 16.3479283355324 | 10.7397480823277
1 | qsec | 12.4407373230344 | 8.4376938269215
1 | wt | 9.3653104294474 | 22.1078535939547
</pre>

-# Predict regression output for the same data and compare with original:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ import unittest
from mock import *
import plpy_mock as plpy

m4_changequote(`<!', `!>')
m4_changequote(` <!', `!>')


class GetVarImportanceTestCase(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -59,114 +60,114 @@ class GetVarImportanceTestCase(unittest.TestCase):
def tearDown(self):
self.module_patcher.stop()


def test_validate_var_importance_input(self):
tbl_exists_mock = Mock()
tbl_exists_mock.side_effect = [False]
self.module.table_exists = tbl_exists_mock
# Test for model table does not exist
with self.assertRaises(plpy.PLPYException):
validate_result = self.module._validate_var_importance_input(
self.default_model_table,
self.default_summary_table,
self.default_output_table)
self.default_model_table,
self.default_summary_table,
self.default_output_table)

# Test for model summary table does not exist
tbl_exists_mock.side_effect = [True, False]
with self.assertRaises(plpy.PLPYException):
validate_result = self.module._validate_var_importance_input(
self.default_model_table,
self.default_summary_table,
self.default_output_table)
self.default_model_table,
self.default_summary_table,
self.default_output_table)

# Test for output table already exists
tbl_exists_mock.side_effect = [True, True, True]
with self.assertRaises(plpy.PLPYException):
validate_result = self.module._validate_var_importance_input(
self.default_model_table,
self.default_summary_table,
self.default_output_table)
self.default_model_table,
self.default_summary_table,
self.default_output_table)

# Test to make sure the summary table is postfixed with "_summary"
tbl_exists_mock = Mock()
self.module.table_exists = tbl_exists_mock
tbl_exists_mock.side_effect = [True, True, False]
validate_result = self.module._validate_var_importance_input(
self.default_model_table,
"wrong_table_name",
self.default_output_table)
self.default_model_table,
"wrong_table_name",
self.default_output_table)
with self.assertRaises(AssertionError):
tbl_exists_mock.assert_any_call("model_summary")

# Positive test case, there should be no error
tbl_exists_mock.side_effect = [True, True, False]
validate_result = self.module._validate_var_importance_input(
self.default_model_table,
self.default_summary_table,
self.default_output_table)
self.default_model_table,
self.default_summary_table,
self.default_output_table)
tbl_exists_mock.assert_any_call("model_summary")

def test_get_var_importance_DT(self):
columns_exist_in_table_mock = Mock()
_is_model_for_RF_mock = Mock()
_validate_var_importance_input_mock = Mock()
self.module.columns_exist_in_table = columns_exist_in_table_mock
self.module._is_model_for_RF = _is_model_for_RF_mock
_is_model_for_RF_mock = Mock()
self.module._is_random_forest_model = _is_model_for_RF_mock
self.module._validate_var_importance_input = _validate_var_importance_input_mock

# Test for impurity_var_importance column absent in model_table
_is_model_for_RF_mock.side_effect = [False]
columns_exist_in_table_mock.side_effect = [False]
with self.assertRaises(plpy.PLPYException):
validate_result = self.module.get_var_importance(
self.default_schema_madlib,
self.default_model_table,
self.default_output_table)
self.default_schema_madlib,
self.default_model_table,
self.default_output_table)

# Positive test case, there should be no error
_is_model_for_RF_mock.side_effect = [False]
columns_exist_in_table_mock.side_effect = [True]
validate_result = self.module.get_var_importance(
self.default_schema_madlib,
self.default_model_table,
self.default_output_table)
self.default_schema_madlib,
self.default_model_table,
self.default_output_table)

def test_is_RF_model_with_imp_pre_1_15(self):
tbl_exists_mock = Mock()
tbl_exists_mock.side_effect = [False]
self.module.table_exists = tbl_exists_mock

# Test for group table not existing
with self.assertRaises(plpy.PLPYException):
validate_result = self.module._is_impurity_importance_in_group_table(
self.default_group_table,
self.default_summary_table)
validate_result = self.module._is_impurity_importance_in_model(
self.default_group_table, self.default_summary_table)

# Error check for isImportance being False
# Check if error is raised if random forest was run with importance = False
tbl_exists_mock.side_effect = [True]
self.plpy_mock_execute.return_value = [{'importance':False}]
self.plpy_mock_execute.return_value = [{'importance': False}]
with self.assertRaises(plpy.PLPYException):
validate_result = self.module._is_impurity_importance_in_group_table(
self.default_group_table,
self.default_summary_table)
validate_result = self.module._is_impurity_importance_in_model(
self.default_group_table,
self.default_summary_table,
is_RF=True)

# Assert RF model is < 1.15
columns_exist_in_table_mock = Mock()
self.module.columns_exist_in_table = columns_exist_in_table_mock
tbl_exists_mock.side_effect = [True]
self.plpy_mock_execute.return_value = [{'importance':True}]
self.plpy_mock_execute.return_value = [{'importance': True}]
columns_exist_in_table_mock.side_effect = [True]
validate_result = self.module._is_impurity_importance_in_group_table(
self.default_group_table,
self.default_summary_table)
validate_result = self.module._is_impurity_importance_in_model(
self.default_group_table, self.default_summary_table)
self.assertTrue(validate_result)

# Assert RF model is >= 1.15
tbl_exists_mock.side_effect = [True]
self.plpy_mock_execute.return_value = [{'importance':True}]
self.plpy_mock_execute.return_value = [{'importance': True}]
columns_exist_in_table_mock.side_effect = [False]
validate_result = self.module._is_impurity_importance_in_group_table(
self.default_group_table,
self.default_summary_table)
validate_result = self.module._is_impurity_importance_in_model(
self.default_group_table, self.default_summary_table)
self.assertFalse(validate_result)


if __name__ == '__main__':
unittest.main()

0 comments on commit 2f6acc4

Please sign in to comment.