Skip to content

Commit

Permalink
DT/RF: Add function to report importance scores
Browse files Browse the repository at this point in the history
JIRA: MADLIB-925

This commit adds a new MADlib function (get_var_importance) to report the
importance scores in decision tree and random forest, by unnesting the
importance values along with corresponding features.

Closes #295

Co-authored-by: Rahul Iyer <riyer@apache.org>
Co-authored-by: Jingyi Mei <jmei@pivotal.io>
Co-authored-by: Orhan Kislal <okislal@pivotal.io>
  • Loading branch information
4 people committed Jul 25, 2018
1 parent 2aac418 commit 7f6e291
Show file tree
Hide file tree
Showing 13 changed files with 696 additions and 84 deletions.
11 changes: 2 additions & 9 deletions src/modules/recursive_partitioning/decision_tree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ print_decision_tree::run(AnyType &args){
}

AnyType
get_variable_importance::run(AnyType &args){
compute_variable_importance::run(AnyType &args){
Tree dt = args[0].getAs<ByteString>();
const int n_cat_features = args[1].getAs<int>();
const int n_con_features = args[2].getAs<int>();
Expand All @@ -496,19 +496,12 @@ get_variable_importance::run(AnyType &args){
ColumnVector con_var_importance = ColumnVector::Zero(n_con_features);
dt.computeVariableImportance(cat_var_importance, con_var_importance);

// Variable importance is scaled to represent a percentage. Even though
// the importance values are split between categorical and continuous, the
// percentages are relative to the combined set.
ColumnVector combined_var_imp(n_cat_features + n_con_features);
combined_var_imp << cat_var_importance, con_var_importance;

// Avoid divide by zero by adding a small number
double total_var_imp = combined_var_imp.sum();
double VAR_IMP_EPSILON = 1e-6;
combined_var_imp *= (100.0 / (total_var_imp + VAR_IMP_EPSILON));
return combined_var_imp;
}


AnyType
display_text_tree::run(AnyType &args){
Tree dt = args[0].getAs<ByteString>();
Expand Down
2 changes: 1 addition & 1 deletion src/modules/recursive_partitioning/decision_tree.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ DECLARE_UDF(recursive_partitioning, compute_surr_stats_transition)
DECLARE_UDF(recursive_partitioning, dt_surr_apply)

DECLARE_UDF(recursive_partitioning, print_decision_tree)
DECLARE_UDF(recursive_partitioning, get_variable_importance)
DECLARE_UDF(recursive_partitioning, compute_variable_importance)
DECLARE_UDF(recursive_partitioning, predict_dt_response)
DECLARE_UDF(recursive_partitioning, predict_dt_prob)

Expand Down
15 changes: 15 additions & 0 deletions src/modules/recursive_partitioning/random_forest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,21 @@ rf_con_imp_score::run(AnyType &args) {
// ------------------------------------------------------------


AnyType
normalize_sum_array::run(AnyType &args){
const MappedColumnVector input_vector = args[0].getAs<MappedColumnVector>();
const double sum_target = args[1].getAs<double>();

double sum_input_vector = input_vector.sum();
// Avoid divide by zero by dividing by a small number if sum is small
double VAR_IMP_EPSILON = 1e-6;
if (sum_input_vector < VAR_IMP_EPSILON)
sum_input_vector = VAR_IMP_EPSILON;
ColumnVector output_vector = input_vector * sum_target / sum_input_vector;
return output_vector;
}


} // namespace recursive_partitioning
} // namespace modules
} // namespace madlib
1 change: 1 addition & 0 deletions src/modules/recursive_partitioning/random_forest.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@

DECLARE_UDF(recursive_partitioning, rf_cat_imp_score)
DECLARE_UDF(recursive_partitioning, rf_con_imp_score)
DECLARE_UDF(recursive_partitioning, normalize_sum_array)
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ from utilities.control import OptimizerControl
from utilities.control import HashaggControl
from utilities.utilities import _assert
from utilities.utilities import _array_to_string
from utilities.utilities import _check_groups
from utilities.utilities import extract_keyvalue_params
from utilities.utilities import unique_string
from utilities.utilities import add_postfix
from utilities.utilities import extract_keyvalue_params
from utilities.utilities import is_psql_numeric_type, is_psql_boolean_type
Expand Down Expand Up @@ -2012,7 +2015,7 @@ def _compute_var_importance(schema_madlib, tree,
impurity_var_importance: Array of importance values
"""
var_imp_sql = """
SELECT {schema_madlib}._get_var_importance(
SELECT {schema_madlib}._compute_var_importance(
$1, -- trained decision tree
{n_cat_features},
{n_con_features}) AS impurity_var_importance
Expand Down Expand Up @@ -2412,7 +2415,6 @@ def _tree_error(schema_madlib, source_table, dependent_varname,
plpy.execute(sql)
# ------------------------------------------------------------


def tree_train_help_message(schema_madlib, message, **kwargs):
""" Help message for Decision Tree
"""
Expand Down Expand Up @@ -2567,6 +2569,10 @@ SELECT madlib.tree_train(
5);

SELECT madlib.tree_display('tree_out');
-- View the impurity importance value of each feature
DROP TABLE IF EXISTS var_imp_out;
SELECT madlib.get_var_importance('tree_out', 'var_imp_out');
SELECT * FROM var_imp_out;
"""
else:
help_string = "No such option. Use {schema_madlib}.tree_train('usage')"
Expand Down Expand Up @@ -2643,3 +2649,60 @@ SELECT * FROM tree_predict_out;
help_string = "No such option. Use {schema_madlib}.tree_predict('usage')"
return help_string.format(schema_madlib=schema_madlib)
# ------------------------------------------------------------

def tree_importance_help_message(schema_madlib, message, **kwargs):
""" Help message for Decision Tree get_var_importance
"""
if not message:
help_string = """
------------------------------------------------------------
SUMMARY
------------------------------------------------------------
Functionality: Decision Tree Importance Values Display

Create a table to record the importance values for a decision
tree (trained using {schema_madlib}.tree_train) or a random
forest (trained using {schema_madlib}.forest_train).

For more details on the function usage:
SELECT {schema_madlib}.get_var_importance('usage');
For an example on using this function:
SELECT {schema_madlib}.get_var_importance('example');
"""
elif message.lower().strip() in ['usage', 'help', '?']:
help_string = """
------------------------------------------------------------
USAGE
------------------------------------------------------------
SELECT {schema_madlib}.get_var_importance(
'model_table', -- Model table name (output of tree_train)
'output_table', -- Table name to store the predictions
);

------------------------------------------------------------
OUTPUT
------------------------------------------------------------
The output table ('output_table' above) has three columns.
'feature' : The name of the feature
'impurity_var_importance' : Impurity importance score for
the variable. This column will not be available in random
forest models unless the importance parameter is set to True
during training.
'oob_var_importance' : Out-of-bag variable importance score
for the variable. This column will not be available for
decision tree models.
"""
elif message.lower().strip() in ['example', 'examples']:
help_string = """
------------------------------------------------------------
EXAMPLE
------------------------------------------------------------
-- Assuming the example of tree_train() or forest_train()
has been run
SELECT {schema_madlib}.get_var_importance('train_output','imp_output');
SELECT * FROM imp_output;
"""
else:
help_string = "No such option. Use {schema_madlib}.tree_predict('usage')"
return help_string.format(schema_madlib=schema_madlib)
# ------------------------------------------------------------

0 comments on commit 7f6e291

Please sign in to comment.