From bf5fa81c264471729ef06ee4af8a27b41f22b45a Mon Sep 17 00:00:00 2001 From: Rahul Iyer Date: Tue, 17 Jul 2018 17:10:04 -0700 Subject: [PATCH 1/4] DT/RF: Ensure cat features are recorded per group JIRA: MADLIB-1254 If tree_train/forest_train is run with grouping enabled and if one of the groups has a categorical feature with just single level, then the categorical feature is eliminated for that group. If other groups retain that feature, then we end up with incorrect "bins" data structure built as part of DT. This commit fixes this issue by recording the categorical features present in each group separately. Closes #295 --- .../decision_tree.py_in | 190 +++++++++++++----- .../random_forest.py_in | 52 ++--- .../modules/utilities/utilities.py_in | 2 +- 3 files changed, 155 insertions(+), 89 deletions(-) diff --git a/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in b/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in index 57c30256e..c9665fe5d 100644 --- a/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in +++ b/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in @@ -14,27 +14,31 @@ from operator import itemgetter from itertools import groupby from collections import Iterable -from validation.cross_validation import cross_validation_grouping_w_params +from internal.db_utils import quote_literal + from utilities.control import MinWarning from utilities.control import OptimizerControl from utilities.control import HashaggControl -from utilities.validate_args import get_cols -from utilities.validate_args import get_cols_and_types -from utilities.validate_args import _get_table_schema_names -from utilities.validate_args import get_expr_type -from utilities.validate_args import table_exists -from utilities.validate_args import table_is_empty -from utilities.validate_args import columns_exist_in_table -from utilities.validate_args import is_var_valid -from utilities.validate_args import unquote_ident from utilities.utilities import _assert from utilities.utilities import _array_to_string -from utilities.utilities import extract_keyvalue_params -from utilities.utilities import unique_string from utilities.utilities import add_postfix +from utilities.utilities import extract_keyvalue_params from utilities.utilities import is_psql_numeric_type, is_psql_boolean_type -from utilities.utilities import split_quoted_delimited_str from utilities.utilities import py_list_to_sql_string +from utilities.utilities import split_quoted_delimited_str +from utilities.utilities import unique_string + +from utilities.validate_args import _get_table_schema_names +from utilities.validate_args import columns_exist_in_table +from utilities.validate_args import get_cols +from utilities.validate_args import get_cols_and_types +from utilities.validate_args import get_expr_type +from utilities.validate_args import is_var_valid +from utilities.validate_args import table_is_empty +from utilities.validate_args import table_exists +from utilities.validate_args import unquote_ident + +from validation.cross_validation import cross_validation_grouping_w_params # ------------------------------------------------------------ @@ -265,6 +269,7 @@ def _get_tree_states(schema_madlib, is_classification, split_criterion, dep_n_levels = len(dep_list) if dep_list else 1 + cat_features_info_table = unique_string() if not grouping_cols: # non-grouping case # 3) Find the splitting bins, one dict containing two arrays: # categorical bins and continuous bins @@ -276,12 +281,15 @@ def _get_tree_states(schema_madlib, is_classification, split_criterion, cat_features = bins['cat_features'] if not cat_features and not con_features: plpy.error("Decision tree: None of the input features are valid") + _create_cat_features_info_table(cat_features_info_table, bins) # 4) Run tree train till the training is finished # finished: 0 = running, 1 = finished training, 2 = terminated prematurely tree = _tree_train_using_bins(**locals()) tree['grp_key'] = '' tree['cp'] = grp_key_to_cp[tree['grp_key']] + tree['cat_features'] = cat_features + tree['con_features'] = con_features tree_states = [tree] else: grouping_array_str = get_grouping_array_str(training_table_name, grouping_cols) @@ -305,16 +313,20 @@ def _get_tree_states(schema_madlib, is_classification, split_criterion, if not cat_features and not con_features: plpy.error("Decision tree: None of the input features " "are valid for some groups") + _create_cat_features_info_table(cat_features_info_table, bins) # 3b) Load each group's tree state in memory and set to the initial tree tree_states = _tree_train_grps_using_bins(**locals()) for tree in tree_states: + grp_key = tree['grp_key'] if len(grp_key_to_cp.values()) == 1: # for train w/out CV, the cp value remains the same for # all groups. This is passed as a single-element list. tree['cp'] = grp_key_to_cp.values()[0] else: - tree['cp'] = grp_key_to_cp[tree['grp_key']] + tree['cp'] = grp_key_to_cp[grp_key] + tree['cat_features'] = bins['grp_to_cat_features'][grp_key] + tree['con_features'] = bins['con_features'] # 5) prune the tree using provided 'cp' value and produce a list of # cp values if cross-validation is required (cp_list = [] if not) @@ -333,10 +345,10 @@ def _get_tree_states(schema_madlib, is_classification, split_criterion, importance_vectors = _compute_var_importance( schema_madlib, tree, - len(cat_features), len(con_features)) + len(tree['cat_features']), len(tree['con_features'])) tree.update(**importance_vectors) - return tree_states, bins, dep_list, n_rows + return tree_states, bins, dep_list, n_rows, cat_features_info_table # ------------------------------------------------------------------------- @@ -387,7 +399,8 @@ def _build_tree(schema_madlib, is_classification, split_criterion, with MinWarning(msg_level): plpy.notice("Building tree for cross validation") - tree_states, bins, dep_list, n_rows = _get_tree_states(**locals()) + tree_states, bins, dep_list, n_rows, cat_features_info_table = \ + _get_tree_states(**locals()) all_cols_types = dict([(f, get_expr_type(f, training_table_name)) for f in cat_features + con_features]) @@ -509,7 +522,8 @@ def tree_train(schema_madlib, training_table_name, output_table_name, grp_key_to_cp = {'': cp} # main training function to get trained decision trees plpy.notice("Getting initial tree") - tree_states, bins, dep_list, n_rows = _get_tree_states(**locals()) + tree_states, bins, dep_list, n_rows, cat_features_info_table = \ + _get_tree_states(**locals()) # 5) Perform cross-validation to compute the lowest cp dep_n_levels = len(dep_list) if dep_list else 1 @@ -528,7 +542,8 @@ def _create_output_tables(schema_madlib, training_table_name, output_table_name, id_col_name, dependent_variable, list_of_features, list_of_features_to_exclude, is_classification, n_all_rows, n_rows, dep_list, cp, - all_cols_types, grouping_cols=None, + all_cols_types, cat_features_info_table, + grouping_cols=None, use_existing_tables=False, running_cv=False, n_folds=0, null_proxy=None, **kwargs): if not grouping_cols: @@ -539,8 +554,9 @@ def _create_output_tables(schema_madlib, training_table_name, output_table_name, else: _create_grp_result_table( schema_madlib, tree_states, bins, bins['cat_features'], - bins['con_features'], output_table_name, grouping_cols, - training_table_name, use_existing_tables, running_cv, n_folds) + bins['con_features'], output_table_name, cat_features_info_table, + grouping_cols, training_table_name, use_existing_tables, + running_cv, n_folds) failed_groups = sum(row['finished'] != 1 for row in tree_states) _create_summary_table( @@ -1005,7 +1021,8 @@ def _get_bins_grps( if len(use_cat_features) != len(cat_features): plpy.warning("Decision tree warning: Categorical columns with only " "one value are dropped from the tree model.") - cat_features = [feature for feature in cat_features if feature in use_cat_features] + cat_features = [feature for feature in cat_features + if feature in use_cat_features] # grp_col_to_levels is a list of tuples (pairs) with # first value = group value, @@ -1023,7 +1040,8 @@ def _get_bins_grps( grp_to_col_to_levels = [ (grp_key, dict((row['colname'], row['levels']) for row in items)) for grp_key, items in groupby(all_levels, key=itemgetter('grp_key'))] - if cat_features: + grp_to_cat_features = dict([(g, col_to_levels.keys()) + for (g, col_to_levels) in grp_to_col_to_levels]) # Below statements collect the grp_to_col_to_levels into multiple variables # From above eg. # cat_items_list = [[0,1], [4,6,8], [0,1], [4,6], [0,1], [4,6,8]] @@ -1039,7 +1057,11 @@ def _get_bins_grps( else: cat_n = [] cat_origin = [] - grp_key_cat=[con_splits['grp_key'] for con_splits in con_splits_all] + grp_key_cat = [con_splits['grp_key'] for con_splits in con_splits_all] + grp_to_col_to_levels = [(con_splits['grp_key'], dict()) + for con_splits in con_splits_all] + grp_to_cat_features = dict([(con_splits['grp_key'], list()) + for con_splits in con_splits_all]) if con_features: con = [con_splits['con_splits'] for con_splits in con_splits_all] @@ -1055,10 +1077,79 @@ def _get_bins_grps( cat_n=cat_n, cat_features=cat_features, grp_key_cat=grp_key_cat, - grouping_array_str=grouping_array_str) + grouping_array_str=grouping_array_str, + grp_to_col_to_levels=grp_to_col_to_levels, + grp_to_cat_features=grp_to_cat_features) # ------------------------------------------------------------ +def _create_cat_features_info_table(cat_features_info_table, bins): + # bins['grp_to_col_to_levels'] = + # [ + # ('3', {'vs': [0, 1], 'cyl': [4,6,8]}), + # ('4', {'vs': [0, 1], 'cyl': [4,6]}), + # ('5', {'vs': [0, 1]}) + # ] + # Convert this into a VALUES command and place in a table + # VALUES (('3', ARRAY[2, 3], ARRAY['0', '1', '4', '6', '8']), + # ('4', ARRAY[2, 2], ARRAY['0', '1', '4', '6']), + # ('5', ARRAY[2], ARRAY['0', '1']), + # ) + cat_features_info_values = [] + if 'grp_to_col_to_levels' in bins: + # Grouping enabled, implies the cat levels can be different for + # different groups + for i, (grp_key, col_to_levels) in enumerate(bins['grp_to_col_to_levels'], start=1): + grp_key_str = quote_literal(grp_key) + cat_names_levels = [(c, col_to_levels[c]) for c in bins['cat_features'] + if c in col_to_levels] + if cat_names_levels: + cat_names, cat_levels = zip(*cat_names_levels) + # categorical features in current group (expressed in an array) + cat_names_str = py_list_to_sql_string( + map(quote_literal, cat_names), 'text', True) + # number of levels in each cat feature + cat_n_levels_str = py_list_to_sql_string( + map(len, cat_levels), 'integer', True) + # flatten the levels across all cat features + cat_levels = [quote_literal(each_level) + for sublist in cat_levels + for each_level in sublist] + cat_levels_str = py_list_to_sql_string(cat_levels, 'text', True) + else: + # this is the case if no categorical features present + cat_names_str = cat_n_levels_str = cat_levels_str = "NULL" + + cat_features_info_values.append( + "({i}::INTEGER, {grp_key_str}, {cat_names_str}, {cat_n_levels_str}, {cat_levels_str})". + format(**locals())) + else: + # no grouping + if bins['cat_features']: + cat_names_str = py_list_to_sql_string( + map(quote_literal, bins['cat_features']), 'text', True) + cat_n_levels_str = py_list_to_sql_string(bins['cat_n'], 'integer', True) + cat_levels_str = py_list_to_sql_string( + map(quote_literal, bins['cat_origin']), 'text', True) + else: + cat_names_str = cat_n_levels_str = cat_levels_str = "NULL" + cat_features_info_values.append( + "(1::INTEGER, ''::TEXT, {0}, {1}, {2})".format( + cat_names_str, cat_n_levels_str, cat_levels_str)) + + sql_cat_features_info = """ + CREATE TEMP TABLE {0} AS + SELECT * + FROM ( + VALUES {1} + ) AS q(gid, grp_key, cat_names, cat_n_levels, cat_levels_in_text) + """.format(cat_features_info_table, + ',\n'.join(cat_features_info_values)) + plpy.notice("sql_cat_features_info:\n" + sql_cat_features_info) + plpy.execute(sql_cat_features_info.format(**locals())) +# ------------------------------------------------------------------------------ + + def get_feature_str(schema_madlib, boolean_cats, cat_features, con_features, levels_str, n_levels_str, @@ -1194,7 +1285,7 @@ def _one_step_for_grps( con_features, boolean_cats, bins, n_bins, tree_states, weights, grouping_cols, grouping_array_str, dep_var, min_split, min_bucket, max_depth, filter_null, dep_n_levels, subsample, n_random_features, - max_n_surr=0, null_proxy=None): + cat_features_info_table, max_n_surr=0, null_proxy=None): """ One step of trees training with grouping support """ # The function _map_catlevel_to_int maps a categorical variable value to its @@ -1249,12 +1340,12 @@ def _one_step_for_grps( FROM {training_table_name} as src, ( SELECT - grp_key AS {grp_key}, - finished AS {finished}, - tree_state AS {tree_state}, - con_splits AS {con_splits}, - cat_n_levels AS {cat_n_levels}, - cat_levels_in_text AS {cat_levels_in_text} + grp_key AS {grp_key}, + finished AS {finished}, + tree_state AS {tree_state}, + con_splits AS {con_splits}, + cat_n_levels::INTEGER[] AS {cat_n_levels}, + cat_levels_in_text::TEXT[] AS {cat_levels_in_text} FROM ( SELECT unnest($1) AS grp_key, @@ -1264,11 +1355,11 @@ def _one_step_for_grps( JOIN ( SELECT unnest($4) AS grp_key, - unnest($9) AS con_splits + unnest($5) AS con_splits ) AS con_splits USING (grp_key) JOIN - {schema_madlib}._gen_cat_levels_set($5, $6, $7, $8) AS cat_levels + {cat_features_info_table} USING (grp_key) ) AS needed_data WHERE {grouping_array_str} = {grp_key} @@ -1286,21 +1377,20 @@ def _one_step_for_grps( JOIN ( SELECT unnest($4) AS grp_key, - unnest($9) AS con_splits + unnest($5) AS con_splits ) AS con_splits USING (grp_key) ) s2 USING (grp_key) """ - train_sql = "SELECT grp_key, (result).* from (" + sql + ") sub" + train_sql = "SELECT grp_key, (result).* FROM (" + sql + ") sub" train_sql = train_sql.format(aggregate=train_aggregate, apply_func=train_apply_func, # check_finished="AND " + finished + " = 0", **locals()) train_sql_plan = plpy.prepare(train_sql, - ['text[]', 'integer[]', bytea8arr, 'text[]', - 'text[]', 'integer[]', 'integer', 'text[]', - bytea8arr]) + ['text[]', 'integer[]', bytea8arr, + 'text[]', bytea8arr]) unfinished_trees = [t for t in tree_states if t['finished'] == 0] finished_trees = [t for t in tree_states if t['finished'] != 0] @@ -1312,10 +1402,6 @@ def _one_step_for_grps( [t['finished'] for t in unfinished_trees], [t['tree_state'] for t in unfinished_trees], bins['grp_key_con'], - bins['grp_key_cat'], - bins['cat_n'], - len(cat_features), - bins['cat_origin'], bins['con']])) if max_n_surr > 0: @@ -1347,17 +1433,12 @@ def _one_step_for_grps( **locals()) surr_sql_plan = plpy.prepare(surr_sql, ['text[]', 'integer[]', bytea8arr, - 'text[]', 'text[]', 'integer[]', 'integer', 'text[]', bytea8arr]) surr_trees = list(plpy.execute(surr_sql_plan, [ [t['grp_key'] for t in updated_unfinished], [t['finished'] for t in updated_unfinished], [t['tree_state'] for t in updated_unfinished], bins['grp_key_con'], - bins['grp_key_cat'], - bins['cat_n'], - len(cat_features), - bins['cat_origin'], bins['con']])) surr_dict = dict() @@ -1376,7 +1457,8 @@ def _one_step_for_grps( def _create_grp_result_table( schema_madlib, tree_states, bins, cat_features, - con_features, output_table_name, grouping_cols, + con_features, output_table_name, cat_features_info_table, + grouping_cols, training_table_name, use_existing_tables=False, running_cv=False, k=0): """ Create the output table for grouping case. @@ -1435,7 +1517,7 @@ def _create_grp_result_table( cat_n_levels as {cat_n_levels}, cat_levels_in_text as {cat_levels_in_text} FROM - {schema_madlib}._gen_cat_levels_set($6, $7, $8, $9) + {cat_features_info_table} ) s3 USING ({grp_key}) """ @@ -1920,7 +2002,7 @@ def _compute_var_importance(schema_madlib, tree, Args: @param schema_madlib: str, MADlib schema name - @param tree: Tree data to prune + @param tree: dict. tree['tree_state'] is the trained tree (in byte form) @param n_cat_features: int, Number of categorical features @param n_con_features: int, Number of continuous features @@ -2098,7 +2180,7 @@ def _xvalidate(schema_madlib, tree_states, training_table_name, output_table_nam tree['pruned_depth'] = 0 importance_vectors = _compute_var_importance( schema_madlib, tree, - len(cat_features), len(con_features)) + len(tree['cat_features']), len(tree['con_features'])) tree.update(**importance_vectors) plpy.execute("DROP TABLE {group_to_param_list_table}".format(**locals())) @@ -2235,6 +2317,7 @@ def _tree_train_grps_using_bins( grouping_cols, grouping_array_str, dep_var_str, min_split, min_bucket, max_depth, filter_dep, dep_n_levels, is_classification, split_criterion, + cat_features_info_table, subsample=False, n_random_features=1, tree_terminated=None, max_n_surr=0, null_proxy=None, **kwargs): @@ -2279,7 +2362,8 @@ def _tree_train_grps_using_bins( tree_states, weights, grouping_cols, grouping_array_str, dep_var_str, min_split, min_bucket, max_depth, filter_dep, dep_n_levels, subsample, - n_random_features, max_n_surr, null_proxy) + n_random_features, cat_features_info_table, + max_n_surr, null_proxy) level += 1 plpy.notice("Finished training for level " + str(level)) diff --git a/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in b/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in index a048fa1b4..c06bed808 100644 --- a/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in +++ b/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in @@ -14,17 +14,18 @@ from utilities.control import MinWarning from utilities.control import OptimizerControl from utilities.control import HashaggControl from utilities.utilities import _assert -from utilities.utilities import unique_string from utilities.utilities import add_postfix -from utilities.utilities import split_quoted_delimited_str from utilities.utilities import extract_keyvalue_params from utilities.utilities import py_list_to_sql_string +from utilities.utilities import split_quoted_delimited_str +from utilities.utilities import unique_string + +from utilities.validate_args import cols_in_tbl_valid from utilities.validate_args import get_cols_and_types -from utilities.validate_args import is_var_valid +from utilities.validate_args import get_expr_type from utilities.validate_args import input_tbl_valid +from utilities.validate_args import is_var_valid from utilities.validate_args import output_tbl_valid -from utilities.validate_args import cols_in_tbl_valid -from utilities.validate_args import get_expr_type from decision_tree import _tree_train_using_bins from decision_tree import _tree_train_grps_using_bins @@ -39,6 +40,7 @@ from decision_tree import _get_filter_str from decision_tree import _get_display_header from decision_tree import get_feature_str from decision_tree import _compute_var_importance +from decision_tree import _create_cat_features_info_table # ------------------------------------------------------------ @@ -265,8 +267,7 @@ def forest_train( @param verbose: str, Verbosity of output messages @param sample_ratio: float, subsampling ratio for generating src_view """ - msg_level = "'notice'" if verbose else "'warning'" - + msg_level = "notice" if verbose else "warning" with MinWarning(msg_level): with OptimizerControl(False): # we disable optimizer (ORCA) for platforms that use it @@ -430,31 +431,9 @@ def forest_train( is_classification, dep_n_levels, filter_null, null_proxy) cat_features = bins['cat_features'] - # a table for converting cat_features to integers + # a table for getting information of cat features for each group cat_features_info_table = unique_string() - sql_cat_features_info = """ - CREATE TEMP TABLE {cat_features_info_table} AS - SELECT - gid, - cat_n_levels, - cat_levels_in_text - FROM - ( - SELECT * - FROM {schema_madlib}._gen_cat_levels_set($1, $2, $3, $4) - ) subq - JOIN - {grp_key_to_grp_cols} - USING (grp_key) - """.format(**locals()) - plpy.notice("sql_cat_features_info:\n" + sql_cat_features_info) - plan_cat_features_info = plpy.prepare( - sql_cat_features_info, ['text[]', 'integer[]', 'integer', 'text[]']) - plpy.execute(plan_cat_features_info, [ - bins['grp_key_cat'], - bins['cat_n'], - len(cat_features), - bins['cat_origin']]) + _create_cat_features_info_table(cat_features_info_table, bins) con_splits_table = unique_string() _create_con_splits_table(schema_madlib, con_splits_table, @@ -587,8 +566,11 @@ def forest_train( boolean_cats, num_bins, 'poisson_count', grouping_cols, grouping_array_str, dep, min_split, min_bucket, max_tree_depth, filter_null, dep_n_levels, - is_classification, split_criterion, True, - num_random_features, tree_terminated=tree_terminated, + is_classification, split_criterion, + cat_features_info_table, + subsample=True, + n_random_features=num_random_features, + tree_terminated=tree_terminated, max_n_surr=max_n_surr, null_proxy=null_proxy) # If a tree for a group is terminated (not finished properly), @@ -966,7 +948,7 @@ def _calculate_oob_prediction( 1 -- -1 shifted to 0 for null values ), {schema_madlib}.array_scalar_add( - cat_n_levels, + cat_n_levels::integer[], 1 -- -1 shifted to 0 for null values ) ) AS cat_feature_distributions, @@ -1024,7 +1006,7 @@ def _calculate_oob_prediction( tree, {cat_features_str}::integer[], {con_features_str}::double precision[], - cat_info.cat_n_levels, + cat_info.cat_n_levels::integer[], {num_permutations}, {dep}, {is_classification}, diff --git a/src/ports/postgres/modules/utilities/utilities.py_in b/src/ports/postgres/modules/utilities/utilities.py_in index 87445a788..572c50f44 100644 --- a/src/ports/postgres/modules/utilities/utilities.py_in +++ b/src/ports/postgres/modules/utilities/utilities.py_in @@ -353,7 +353,7 @@ def py_list_to_sql_string(array, array_type=None, long_format=None): array_type += "[]" if not array: - return "'{{ }}'::{0}".format(array_type) + return "ARRAY[]::{0}".format(array_type) else: quote_delimiter = "$__MADLIB_OUTER__$" # This is a quote delimiter that can be used in lieu of From 884e2c334cc59ec644fd9a4919ea8b419ade8ca9 Mon Sep 17 00:00:00 2001 From: Rahul Iyer Date: Thu, 19 Jul 2018 10:26:14 -0700 Subject: [PATCH 2/4] DT: Add explicit cast for columns --- .../modules/recursive_partitioning/decision_tree.py_in | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in b/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in index c9665fe5d..934a41ad4 100644 --- a/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in +++ b/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in @@ -1121,7 +1121,8 @@ def _create_cat_features_info_table(cat_features_info_table, bins): cat_names_str = cat_n_levels_str = cat_levels_str = "NULL" cat_features_info_values.append( - "({i}::INTEGER, {grp_key_str}, {cat_names_str}, {cat_n_levels_str}, {cat_levels_str})". + "({i}::INTEGER, {grp_key_str}::TEXT, {cat_names_str}::TEXT[], " + "{cat_n_levels_str}::INTEGER[], {cat_levels_str}::TEXT[])". format(**locals())) else: # no grouping @@ -1134,8 +1135,8 @@ def _create_cat_features_info_table(cat_features_info_table, bins): else: cat_names_str = cat_n_levels_str = cat_levels_str = "NULL" cat_features_info_values.append( - "(1::INTEGER, ''::TEXT, {0}, {1}, {2})".format( - cat_names_str, cat_n_levels_str, cat_levels_str)) + "(1::INTEGER, ''::TEXT, {0}::TEXT[], {1}::INTEGER[], {2}::TEXT[])". + format(cat_names_str, cat_n_levels_str, cat_levels_str)) sql_cat_features_info = """ CREATE TEMP TABLE {0} AS From fbffba6f3a5cb6e2294c51e75e2bb2f6cbe9b22a Mon Sep 17 00:00:00 2001 From: Rahul Iyer Date: Thu, 19 Jul 2018 10:48:13 -0700 Subject: [PATCH 3/4] DT: Ensure py_list_to_sql_string behavior is backwards compatible --- .../recursive_partitioning/decision_tree.py_in | 12 ++++++------ src/ports/postgres/modules/utilities/utilities.py_in | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in b/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in index 934a41ad4..89acd8aa1 100644 --- a/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in +++ b/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in @@ -1107,15 +1107,15 @@ def _create_cat_features_info_table(cat_features_info_table, bins): cat_names, cat_levels = zip(*cat_names_levels) # categorical features in current group (expressed in an array) cat_names_str = py_list_to_sql_string( - map(quote_literal, cat_names), 'text', True) + map(quote_literal, cat_names), 'text', long_format=True) # number of levels in each cat feature cat_n_levels_str = py_list_to_sql_string( - map(len, cat_levels), 'integer', True) + map(len, cat_levels), 'integer', long_format=True) # flatten the levels across all cat features cat_levels = [quote_literal(each_level) for sublist in cat_levels for each_level in sublist] - cat_levels_str = py_list_to_sql_string(cat_levels, 'text', True) + cat_levels_str = py_list_to_sql_string(cat_levels, 'text', long_format=True) else: # this is the case if no categorical features present cat_names_str = cat_n_levels_str = cat_levels_str = "NULL" @@ -1128,10 +1128,10 @@ def _create_cat_features_info_table(cat_features_info_table, bins): # no grouping if bins['cat_features']: cat_names_str = py_list_to_sql_string( - map(quote_literal, bins['cat_features']), 'text', True) - cat_n_levels_str = py_list_to_sql_string(bins['cat_n'], 'integer', True) + map(quote_literal, bins['cat_features']), 'text', long_format=True) + cat_n_levels_str = py_list_to_sql_string(bins['cat_n'], 'integer', long_format=True) cat_levels_str = py_list_to_sql_string( - map(quote_literal, bins['cat_origin']), 'text', True) + map(quote_literal, bins['cat_origin']), 'text', long_format=True) else: cat_names_str = cat_n_levels_str = cat_levels_str = "NULL" cat_features_info_values.append( diff --git a/src/ports/postgres/modules/utilities/utilities.py_in b/src/ports/postgres/modules/utilities/utilities.py_in index 572c50f44..c59ddbf31 100644 --- a/src/ports/postgres/modules/utilities/utilities.py_in +++ b/src/ports/postgres/modules/utilities/utilities.py_in @@ -353,7 +353,7 @@ def py_list_to_sql_string(array, array_type=None, long_format=None): array_type += "[]" if not array: - return "ARRAY[]::{0}".format(array_type) + return ("ARRAY[]::{0}" if long_format else "'{{ }}'::{0}").format(array_type) else: quote_delimiter = "$__MADLIB_OUTER__$" # This is a quote delimiter that can be used in lieu of From 8385014bb9477241b00aaa5007a54b2ed5ec5505 Mon Sep 17 00:00:00 2001 From: Rahul Iyer Date: Thu, 19 Jul 2018 13:12:26 -0700 Subject: [PATCH 4/4] Revert utilities change --- src/ports/postgres/modules/utilities/utilities.py_in | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ports/postgres/modules/utilities/utilities.py_in b/src/ports/postgres/modules/utilities/utilities.py_in index c59ddbf31..87445a788 100644 --- a/src/ports/postgres/modules/utilities/utilities.py_in +++ b/src/ports/postgres/modules/utilities/utilities.py_in @@ -353,7 +353,7 @@ def py_list_to_sql_string(array, array_type=None, long_format=None): array_type += "[]" if not array: - return ("ARRAY[]::{0}" if long_format else "'{{ }}'::{0}").format(array_type) + return "'{{ }}'::{0}".format(array_type) else: quote_delimiter = "$__MADLIB_OUTER__$" # This is a quote delimiter that can be used in lieu of