From f13bf8eca6f9500e1be035e956c5b935ca3dfed5 Mon Sep 17 00:00:00 2001 From: Nikhil Kak Date: Thu, 15 Mar 2018 12:01:08 -0700 Subject: [PATCH 1/7] Add new function for regex match of ARRAY[*] The regex match for ARRAY[*] will be used by other modules like minibatch preprocessor, hence creating a method for it and refactored elastic_net to use this method. Co-authored-by: Jingyi Mei --- src/ports/postgres/modules/elastic_net/elastic_net.py_in | 3 ++- src/ports/postgres/modules/utilities/utilities.py_in | 7 +++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/ports/postgres/modules/elastic_net/elastic_net.py_in b/src/ports/postgres/modules/elastic_net/elastic_net.py_in index fe09aafc4..555840a35 100644 --- a/src/ports/postgres/modules/elastic_net/elastic_net.py_in +++ b/src/ports/postgres/modules/elastic_net/elastic_net.py_in @@ -9,6 +9,7 @@ from elastic_net_utils import _generate_warmup_lambda_sequence from elastic_net_utils import BINOMIAL_FAMILIES, GAUSSIAN_FAMILIES, OPTIMIZERS from utilities.validate_args import is_col_array +from utilities.utilities import is_string_formatted_as_array_expression from utilities.validate_args import table_exists from utilities.validate_args import table_is_empty from utilities.validate_args import columns_exist_in_table @@ -664,7 +665,7 @@ def analyze_input_str(schema_madlib, tbl_source, col_ind_var, excluded) else: # if input is an expression resulting in an array output - matched = re.match(r"(?i)^array\[(.*)\]", col_ind_var) + matched = is_string_formatted_as_array_expression(col_ind_var) if matched: # array expression starts with the word "ARRAY" outstr_array = _string_to_array(matched.group(1)) diff --git a/src/ports/postgres/modules/utilities/utilities.py_in b/src/ports/postgres/modules/utilities/utilities.py_in index 135404f92..01409bff2 100644 --- a/src/ports/postgres/modules/utilities/utilities.py_in +++ b/src/ports/postgres/modules/utilities/utilities.py_in @@ -186,6 +186,13 @@ def is_psql_numeric_type(arg, exclude=None): return (arg in to_check_types) # ------------------------------------------------------------------------- +def is_string_formatted_as_array_expression(string_to_match): + """ + Return true if the string is formatted as array[], else false + :param string_to_match: + """ + matched = re.match(r"(?i)^array\[(.*)\]", string_to_match) + return matched def _string_to_array(s): """ From 35798d8b0e3932737d88cad33d099c432d46e127 Mon Sep 17 00:00:00 2001 From: Nikhil Kak Date: Thu, 15 Mar 2018 12:01:22 -0700 Subject: [PATCH 2/7] Removed __ from 2 public methods in utils_regularization.py Renamed __utils_ind_var_scales and __utils_ind_var_scales_grouping so that we can access them from within a class, more specifically the minibatch_preprocessing module. The minibatch_preprocessing module will be added in a future commit. Co-authored-by: Jingyi Mei --- src/ports/postgres/modules/convex/mlp_igd.py_in | 9 ++++----- .../postgres/modules/convex/utils_regularization.py_in | 8 ++++---- .../postgres/modules/elastic_net/elastic_net_utils.py_in | 8 ++++---- src/ports/postgres/modules/pca/pca.py_in | 4 ++-- src/ports/postgres/modules/stats/cox_prop_hazards.py_in | 2 +- 5 files changed, 15 insertions(+), 16 deletions(-) diff --git a/src/ports/postgres/modules/convex/mlp_igd.py_in b/src/ports/postgres/modules/convex/mlp_igd.py_in index a1c63621a..cb835e527 100644 --- a/src/ports/postgres/modules/convex/mlp_igd.py_in +++ b/src/ports/postgres/modules/convex/mlp_igd.py_in @@ -26,8 +26,8 @@ import math import plpy -from convex.utils_regularization import __utils_ind_var_scales -from convex.utils_regularization import __utils_ind_var_scales_grouping +from convex.utils_regularization import utils_ind_var_scales +from convex.utils_regularization import utils_ind_var_scales_grouping from convex.utils_regularization import __utils_normalize_data from convex.utils_regularization import __utils_normalize_data_grouping @@ -331,7 +331,7 @@ def normalize_data(args): # specific to groups. Store these results in temp tables x_mean_table # __utils_normalize_data_grouping reads the various means and stds # from the tables. - __utils_ind_var_scales_grouping(args["source_table"], + utils_ind_var_scales_grouping(args["source_table"], args["independent_varname"], args["dimension"], args["schema_madlib"], args["grouping_col"], @@ -352,7 +352,7 @@ def normalize_data(args): # When no grouping_col is defined, the mean and std for 'x' # can be defined using strings, stored in x_mean_str, x_std_str. # We don't need a table like how we needed for grouping. - x_scaled_vals = __utils_ind_var_scales(args["source_table"], + x_scaled_vals = utils_ind_var_scales(args["source_table"], args["independent_varname"], args["dimension"], args["schema_madlib"], @@ -451,7 +451,6 @@ def _create_output_table(output_table, temp_output_table, """.format(**locals()) plpy.execute(build_output_query) - def _update_temp_model_table(args, iteration, temp_output_table, first_try): insert_or_create_str = "INSERT INTO {0}" if first_try: diff --git a/src/ports/postgres/modules/convex/utils_regularization.py_in b/src/ports/postgres/modules/convex/utils_regularization.py_in index 712364452..f56a0426a 100644 --- a/src/ports/postgres/modules/convex/utils_regularization.py_in +++ b/src/ports/postgres/modules/convex/utils_regularization.py_in @@ -14,7 +14,7 @@ mad_vec = version_wrapper.select_vecfunc() # ======================================================================== -def __utils_ind_var_scales(tbl_data, col_ind_var, dimension, schema_madlib, +def utils_ind_var_scales(tbl_data, col_ind_var, dimension, schema_madlib, x_mean_table=None, set_zero_std_to_one=False): """ The mean and standard deviation for each dimension of an array stored @@ -69,7 +69,7 @@ def __utils_ind_var_scales(tbl_data, col_ind_var, dimension, schema_madlib, # ======================================================================== -def __utils_ind_var_scales_grouping(tbl_data, col_ind_var, dimension, +def utils_ind_var_scales_grouping(tbl_data, col_ind_var, dimension, schema_madlib, grouping_col, x_mean_table, set_zero_std_to_one=False): """ @@ -188,7 +188,7 @@ def __utils_dep_var_scale_grouping(y_mean_table, tbl_data, grouping_col, def __utils_normalize_data_grouping(y_decenter=True, **kwargs): """ Normalize the independent and dependent variables using the calculated - mean's and std's in __utils_ind_var_scales and __utils_dep_var_scale. + mean's and std's in utils_ind_var_scales and __utils_dep_var_scale. Compute the scaled variables by: scaled_value = (origin_value - mean) / std, and special care is needed if std is zero. @@ -246,7 +246,7 @@ def __utils_normalize_data_grouping(y_decenter=True, **kwargs): def __utils_normalize_data(y_decenter=True, **kwargs): """ Normalize the independent and dependent variables using the calculated mean's and std's - in __utils_ind_var_scales and __utils_dep_var_scale. + in utils_ind_var_scales and __utils_dep_var_scale. Compute the scaled variables by: scaled_value = (origin_value - mean) / std, and special care is needed if std is zero. diff --git a/src/ports/postgres/modules/elastic_net/elastic_net_utils.py_in b/src/ports/postgres/modules/elastic_net/elastic_net_utils.py_in index 6785e5acf..7cab42d87 100644 --- a/src/ports/postgres/modules/elastic_net/elastic_net_utils.py_in +++ b/src/ports/postgres/modules/elastic_net/elastic_net_utils.py_in @@ -3,10 +3,10 @@ import math import re from utilities.utilities import _string_to_array from utilities.utilities import _array_to_string -from convex.utils_regularization import __utils_ind_var_scales +from convex.utils_regularization import utils_ind_var_scales from convex.utils_regularization import __utils_dep_var_scale from convex.utils_regularization import __utils_normalize_data -from convex.utils_regularization import __utils_ind_var_scales_grouping +from convex.utils_regularization import utils_ind_var_scales_grouping from convex.utils_regularization import __utils_dep_var_scale_grouping from convex.utils_regularization import __utils_normalize_data_grouping from utilities.validate_args import table_exists @@ -213,7 +213,7 @@ def _compute_data_scales_grouping(args): # mean of dependent variable (y) and the standard deviation for them # specific to groups. Store these results in temp tables x_mean_table # and y_mean_table. - __utils_ind_var_scales_grouping(args["tbl_source"], args["col_ind_var"], + utils_ind_var_scales_grouping(args["tbl_source"], args["col_ind_var"], args["dimension"], args["schema_madlib"], args["grouping_col"], args["x_mean_table"]) if args["family"] == "binomial": @@ -226,7 +226,7 @@ def _compute_data_scales_grouping(args): args["schema_madlib"], args["col_ind_var"], args["col_dep_var"]) def _compute_data_scales(args): - args["x_scales"] = __utils_ind_var_scales(args["tbl_source"], + args["x_scales"] = utils_ind_var_scales(args["tbl_source"], args["col_ind_var"], args["dimension"], args["schema_madlib"]) if args["family"] == "binomial": args["y_scale"] = dict(mean=0, std=1) diff --git a/src/ports/postgres/modules/pca/pca.py_in b/src/ports/postgres/modules/pca/pca.py_in index 680e9f6c4..2c881eb72 100644 --- a/src/ports/postgres/modules/pca/pca.py_in +++ b/src/ports/postgres/modules/pca/pca.py_in @@ -4,7 +4,7 @@ @namespace pca """ -from convex.utils_regularization import __utils_ind_var_scales +from convex.utils_regularization import utils_ind_var_scales from linalg.matrix_ops import get_dims from linalg.matrix_ops import create_temp_sparse_matrix_table_with_dims from linalg.matrix_ops import cast_dense_input_table_to_correct_columns @@ -677,7 +677,7 @@ def _recenter_data(schema_madlib, source_table, output_table, row_id, Column mean """ # Step 1: Compute column mean values - x_scales = __utils_ind_var_scales(tbl_data=source_table, + x_scales = utils_ind_var_scales(tbl_data=source_table, col_ind_var=col_name, dimension=dimension, schema_madlib=schema_madlib) diff --git a/src/ports/postgres/modules/stats/cox_prop_hazards.py_in b/src/ports/postgres/modules/stats/cox_prop_hazards.py_in index 74b1eade2..dc17b958f 100644 --- a/src/ports/postgres/modules/stats/cox_prop_hazards.py_in +++ b/src/ports/postgres/modules/stats/cox_prop_hazards.py_in @@ -28,7 +28,7 @@ from utilities.utilities import py_list_to_sql_string from utilities.validate_args import columns_exist_in_table from utilities.utilities import __mad_version from utilities.control import IterationController2S -from convex.utils_regularization import __utils_ind_var_scales +from convex.utils_regularization import utils_ind_var_scales import random # ---------------------------------------------------------------------- From 0125c6260fa8847ded7bfdf7d3f2fb957a23315f Mon Sep 17 00:00:00 2001 From: Nikhil Kak Date: Thu, 15 Mar 2018 12:01:27 -0700 Subject: [PATCH 3/7] Add an optional flag to _tbl_dimension_rownum Add an optional flag to _tbl_dimension_rownum to skip the row count calculation so as to avoid another table scan. Co-authored-by: Jingyi Mei --- src/ports/postgres/modules/utilities/validate_args.py_in | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/ports/postgres/modules/utilities/validate_args.py_in b/src/ports/postgres/modules/utilities/validate_args.py_in index ec1f410bd..2ad1536d3 100644 --- a/src/ports/postgres/modules/utilities/validate_args.py_in +++ b/src/ports/postgres/modules/utilities/validate_args.py_in @@ -544,9 +544,12 @@ def array_col_has_no_null(tbl, col): return True # ------------------------------------------------------------------------- -def _tbl_dimension_rownum(schema_madlib, tbl_source, col_ind_var): +def _tbl_dimension_rownum(schema_madlib, tbl_source, col_ind_var, skip_row_count=False): """ Measure the dimension and row number of source data table + Please note that calculating the row count will incur a pass over the + entire dataset. Hence the flag skip_row_count to optionally skip the row + count calculation. """ # independent variable array length dimension = plpy.execute(""" @@ -559,6 +562,9 @@ def _tbl_dimension_rownum(schema_madlib, tbl_source, col_ind_var): # NULLs in the independent variable (x). There is no NULL check made for # the dependent variable (y), since one of the hard assumptions of the # input data to elastic_net is that the dependent variable cannot be NULL. + if skip_row_count: + return dimension, None + row_num = plpy.execute(""" SELECT COUNT(*) FROM {tbl_source} WHERE NOT {schema_madlib}.array_contains_null({col_ind_var}) From 8b0e7d640ea124a6beffa910c2f7acbeff307920 Mon Sep 17 00:00:00 2001 From: Nikhil Kak Date: Thu, 15 Mar 2018 12:01:43 -0700 Subject: [PATCH 4/7] Add new class for mean and std dev calculation. Some of the modules need to calculate the mean and std dev of indep/dep variable and this code is duplicated in all these modules. This new class takes out the common code among all these modules and creates a function for it. A future commit should use this class for calculating the mean and std dev. --- .../utilities/mean_std_dev_calculator.py_in | 54 +++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 src/ports/postgres/modules/utilities/mean_std_dev_calculator.py_in diff --git a/src/ports/postgres/modules/utilities/mean_std_dev_calculator.py_in b/src/ports/postgres/modules/utilities/mean_std_dev_calculator.py_in new file mode 100644 index 000000000..e2a1c4fb9 --- /dev/null +++ b/src/ports/postgres/modules/utilities/mean_std_dev_calculator.py_in @@ -0,0 +1,54 @@ +# coding=utf-8 +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +@file mean_std_dev_calculator.py_in + +@brief + +@namespace utilities + +""" + +from convex.utils_regularization import utils_ind_var_scales +from utilities import _array_to_string + +m4_changequote(`') + +#TODO: use this for all the modules that calculate the std dev and mean for x +# mlp, pca, elastic_net +class MeanStdDevCalculator: + def __init__(self, schema_madlib, source_table, indep_var_array_str, dimension): + self.schema_madlib= schema_madlib + self.source_table= source_table + self.indep_var_array_str = indep_var_array_str + self.dimension = dimension + + def get_mean_and_std_dev_for_ind_var(self): + set_zero_std_to_one = True + + x_scaled_vals = utils_ind_var_scales(self.source_table, + self.indep_var_array_str, + self.dimension, + self.schema_madlib, + None, # do not dump the output to a temp table + set_zero_std_to_one) + x_mean_str = _array_to_string(x_scaled_vals["mean"]) + x_std_str = _array_to_string(x_scaled_vals["std"]) + + return x_mean_str, x_std_str From e17353c0754f5a0b0a47bf5ecb9054e0cfb92848 Mon Sep 17 00:00:00 2001 From: Nikhil Kak Date: Thu, 15 Mar 2018 12:36:35 -0700 Subject: [PATCH 5/7] Add plpy mock and separate out unit tests for utilities The plpy mock file will be used in unit tests to mock out all instances of plpy in the production code. Created a new unit test file for utilities and moved existing unit tests to this file. --- .../utilities/test/unit_tests/plpy_mock.py_in | 34 +++++ .../test/unit_tests/test_utilities.py_in | 122 ++++++++++++++++++ .../modules/utilities/utilities.py_in | 14 +- 3 files changed, 164 insertions(+), 6 deletions(-) create mode 100644 src/ports/postgres/modules/utilities/test/unit_tests/plpy_mock.py_in create mode 100644 src/ports/postgres/modules/utilities/test/unit_tests/test_utilities.py_in diff --git a/src/ports/postgres/modules/utilities/test/unit_tests/plpy_mock.py_in b/src/ports/postgres/modules/utilities/test/unit_tests/plpy_mock.py_in new file mode 100644 index 000000000..305883089 --- /dev/null +++ b/src/ports/postgres/modules/utilities/test/unit_tests/plpy_mock.py_in @@ -0,0 +1,34 @@ +# coding=utf-8 +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +m4_changequote(`') +def __init__(self): + pass + +def error(message): + raise Exception(message) + +def execute(query): + pass + +def warning(query): + pass + +def info(query): + print query diff --git a/src/ports/postgres/modules/utilities/test/unit_tests/test_utilities.py_in b/src/ports/postgres/modules/utilities/test/unit_tests/test_utilities.py_in new file mode 100644 index 000000000..e456fdf33 --- /dev/null +++ b/src/ports/postgres/modules/utilities/test/unit_tests/test_utilities.py_in @@ -0,0 +1,122 @@ +# coding=utf-8 +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import sys +from os import path +# Add utilites module to the pythonpath. +sys.path.append(path.dirname(path.dirname(path.dirname(path.abspath(__file__))))) + + +import unittest +from mock import * +import sys +import plpy_mock as plpy + +m4_changequote(`') +class UtilitiesTestCase(unittest.TestCase): + def setUp(self): + patches = { + 'plpy': plpy + } + self.plpy_mock_execute = MagicMock() + plpy.execute = self.plpy_mock_execute + + self.module_patcher = patch.dict('sys.modules', patches) + self.module_patcher.start() + + import utilities + self.subject = utilities + + self.default_source_table = "source" + self.default_output_table = "output" + self.default_ind_var = "indvar" + self.default_dep_var = "depvar" + self.default_module = "unittest_module" + self.optimizer_params1 = 'max_iter=10, optimizer::text="irls", precision=1e-4' + self.optimizer_params2 = 'max_iter=.01, optimizer=newton-irls, precision=1e-5' + self.optimizer_params3 = 'max_iter=10, 10, optimizer=, lambda={1,"2,2",3,4}' + self.optimizer_params4 = ('max_iter=10, optimizer="irls",' + 'precision=0.02.01, lambda={1,2,3,4}') + self.optimizer_params5 = ('max_iter=10, optimizer="irls",' + 'precision=0.02, PRECISION=2., lambda={1,2,3,4}') + self.optimizer_types = {'max_iter': int, 'optimizer': str, 'optimizer::text': str, + 'lambda': list, 'precision': float} + + def tearDown(self): + self.module_patcher.stop() + + def test_preprocess_optimizer(self): + self.assertEqual(self.subject.preprocess_keyvalue_params(self.optimizer_params1), + ['max_iter=10', 'optimizer::text="irls"', 'precision=1e-4']) + self.assertEqual(self.subject.preprocess_keyvalue_params(self.optimizer_params2), + ['max_iter=.01', 'optimizer=newton-irls', 'precision=1e-5']) + self.assertEqual(self.subject.preprocess_keyvalue_params(self.optimizer_params3), + ['max_iter=10', 'lambda={1,"2,2",3,4}']) + self.assertEqual(self.subject.preprocess_keyvalue_params(self.optimizer_params4), + ['max_iter=10', 'optimizer="irls"', 'precision=0.02.01', 'lambda={1,2,3,4}']) + + def test_extract_optimizers(self): + self.assertEqual({'max_iter': 10, 'optimizer::text': '"irls"', 'precision': 0.0001}, + self.subject.extract_keyvalue_params(self.optimizer_params1, self.optimizer_types)) + self.assertEqual({'max_iter': 10, 'lambda': ['1', '"2,2"', '3', '4']}, + self.subject.extract_keyvalue_params(self.optimizer_params3, self.optimizer_types)) + self.assertEqual({'max_iter': '10', 'optimizer': '"irls"', 'precision': '0.02.01', + 'lambda': '{1,2,3,4}'}, + self.subject.extract_keyvalue_params(self.optimizer_params4)) + self.assertEqual({'max_iter': '10', 'optimizer': '"irls"', + 'PRECISION': '2.', 'precision': '0.02', + 'lambda': '{1,2,3,4}'}, + self.subject.extract_keyvalue_params(self.optimizer_params5, + allow_duplicates=False, + lower_case_names=False + )) + self.assertRaises(ValueError, + self.subject.extract_keyvalue_params, self.optimizer_params2, self.optimizer_types) + self.assertRaises(ValueError, + self.subject.extract_keyvalue_params, self.optimizer_params5, allow_duplicates=False) + self.assertRaises(ValueError, + self.subject.extract_keyvalue_params, self.optimizer_params4, self.optimizer_types) + + def test_split_delimited_string(self): + self.assertEqual(['max_iter=10', 'optimizer::text="irls"', 'precision=1e-4'], + self.subject.split_quoted_delimited_str(self.optimizer_params1, quote='"')) + self.assertEqual(['a', 'b', 'c'], self.subject.split_quoted_delimited_str('a, b, c', quote='|')) + self.assertEqual(['a', '|b, c|'], self.subject.split_quoted_delimited_str('a, |b, c|', quote='|')) + self.assertEqual(['a', '"b, c"'], self.subject.split_quoted_delimited_str('a, "b, c"')) + self.assertEqual(['"a^5,6"', 'b', 'c'], self.subject.split_quoted_delimited_str('"a^5,6", b, c', quote='"')) + self.assertEqual(['"A""^5,6"', 'b', 'c'], self.subject.split_quoted_delimited_str('"A""^5,6", b, c', quote='"')) + + def test_collate_plpy_result(self): + plpy_result1 = [{'classes': '4', 'class_count': 3}, + {'classes': '1', 'class_count': 18}, + {'classes': '5', 'class_count': 7}, + {'classes': '3', 'class_count': 3}, + {'classes': '6', 'class_count': 7}, + {'classes': '2', 'class_count': 7}] + self.assertEqual(self.subject.collate_plpy_result(plpy_result1), + {'classes': ['4', '1', '5', '3', '6', '2'], + 'class_count': [3, 18, 7, 3, 7, 7]}) + self.assertEqual(self.subject.collate_plpy_result([]), {}) + self.assertEqual(self.subject.collate_plpy_result([{'class': 'a'}, + {'class': 'b'}, + {'class': 'c'}]), + {'class': ['a', 'b', 'c']}) + +if __name__ == '__main__': + unittest.main() diff --git a/src/ports/postgres/modules/utilities/utilities.py_in b/src/ports/postgres/modules/utilities/utilities.py_in index 01409bff2..39a29c599 100644 --- a/src/ports/postgres/modules/utilities/utilities.py_in +++ b/src/ports/postgres/modules/utilities/utilities.py_in @@ -5,12 +5,14 @@ import time import random from distutils.util import strtobool -if __name__ != "__main__": - from validate_args import _get_table_schema_names - from validate_args import get_first_schema - from validate_args import cols_in_tbl_valid - from validate_args import explicit_bool_to_text - import plpy +from validate_args import _get_table_schema_names +from validate_args import get_first_schema +from validate_args import cols_in_tbl_valid +from validate_args import explicit_bool_to_text +from validate_args import input_tbl_valid +from validate_args import is_var_valid +from validate_args import output_tbl_valid +import plpy m4_changequote(`') From 78b689d87f2b7c2e495088c20b0d62ac5c716fa7 Mon Sep 17 00:00:00 2001 From: Nikhil Kak Date: Thu, 15 Mar 2018 12:39:20 -0700 Subject: [PATCH 6/7] Add a new function to validate input args This function is supposed to be used for validating params for supervised learning like algos, e.g. linear regression, mlp, etc. Co-authored-by: Orhan Kislal --- .../test/unit_tests/test_utilities.py_in | 100 ++++++++++++++++++ .../modules/utilities/utilities.py_in | 41 +++++++ 2 files changed, 141 insertions(+) diff --git a/src/ports/postgres/modules/utilities/test/unit_tests/test_utilities.py_in b/src/ports/postgres/modules/utilities/test/unit_tests/test_utilities.py_in index e456fdf33..1109eeb71 100644 --- a/src/ports/postgres/modules/utilities/test/unit_tests/test_utilities.py_in +++ b/src/ports/postgres/modules/utilities/test/unit_tests/test_utilities.py_in @@ -61,6 +61,106 @@ class UtilitiesTestCase(unittest.TestCase): def tearDown(self): self.module_patcher.stop() + def test_validate_module_input_params_all_nulls(self): + with self.assertRaises(Exception) as context: + self.subject.validate_module_input_params(None, None, None, None, "unittest_module") + + expected_exception = Exception("unittest_module error: NULL/empty input table name!") + self.assertEqual(expected_exception.message, context.exception.message) + + def test_validate_module_input_params_source_table_null(self): + with self.assertRaises(Exception) as context: + self.subject.validate_module_input_params(None, self.default_output_table, + self.default_ind_var, + self.default_dep_var, + self.default_module) + + expected_exception = "unittest_module error: NULL/empty input table name!" + self.assertEqual(expected_exception, context.exception.message) + + def test_validate_module_input_params_output_table_null(self): + with self.assertRaises(Exception) as context: + self.subject.validate_module_input_params(self.default_source_table, None, + self.default_ind_var, + self.default_dep_var, + self.default_module) + + expected_exception = "unittest_module error: NULL/empty output table name!" + self.assertEqual(expected_exception, context.exception.message) + + @patch('validate_args.table_exists', return_value=Mock()) + def test_validate_module_input_params_output_table_exists(self, + table_exists_mock): + self.subject.input_tbl_valid = Mock() + table_exists_mock.side_effect = [True] + with self.assertRaises(Exception) as context: + self.subject.validate_module_input_params(self.default_source_table, + self.default_output_table, + self.default_ind_var, + self.default_dep_var, + self.default_module) + + expected_exception = "unittest_module error: Output table '{0}' already exists.".format(self.default_output_table) + self.assertTrue(expected_exception in context.exception.message) + + @patch('validate_args.table_exists', return_value=Mock()) + def test_validate_module_input_params_assert_other_tables_dont_exist(self, table_exists_mock): + self.subject.input_tbl_valid = Mock() + table_exists_mock.side_effect = [False, False, True] + with self.assertRaises(Exception) as context: + self.subject.validate_module_input_params(self.default_source_table, + self.default_output_table, + self.default_ind_var, + self.default_dep_var, + self.default_module, + ['foo','bar']) + + expected_exception = "unittest_module error: Output table 'bar' already exists." + self.assertTrue(expected_exception in context.exception.message) + + @patch('validate_args.table_is_empty', return_value=False) + @patch('validate_args.table_exists', return_value=Mock()) + def test_validate_module_input_params_ind_var_null(self, table_exists_mock, + table_is_empty_mock): + table_exists_mock.side_effect = [True, False] + with self.assertRaises(Exception) as context: + self.subject.validate_module_input_params(self.default_source_table, + self.default_output_table, + None, + self.default_dep_var, + self.default_module) + + expected_exception = "unittest_module error: invalid independent_varname ('None') for source_table (source)!" + self.assertEqual(expected_exception, context.exception.message) + # is_var_valid_mock.assert_called_once_with(self.default_source_table, self.default_ind_var) + + @patch('validate_args.table_exists', return_value=Mock()) + @patch('validate_args.table_is_empty', return_value=False) + def test_validate_module_input_params_dep_var_null(self, table_is_empty_mock, table_exists_mock): + table_exists_mock.side_effect = [True, False] + with self.assertRaises(Exception) as context: + self.subject.validate_module_input_params(self.default_source_table, + self.default_output_table, + self.default_ind_var, + None, + self.default_module) + + expected_exception = "unittest_module error: invalid dependent_varname ('None') for source_table (source)!" + self.assertEqual(expected_exception, context.exception.message) + + def test_is_var_valid_all_nulls(self): + self.assertEqual(False, self.subject.is_var_valid(None, None)) + + def test_is_var_valid_var_null(self): + self.assertEqual(False, self.subject.is_var_valid("some_table", None)) + + def test_is_var_valid_var_exists_in_table(self): + self.assertEqual(True, self.subject.is_var_valid("some_var", "some_var")) + + def test_is_var_valid_var_does_not_exist_in_table(self): + self.plpy_mock_execute.side_effect = Exception("var does not exist in tbl") + self.assertEqual(False, self.subject.is_var_valid("some_var", "some_var")) + def test_preprocess_optimizer(self): self.assertEqual(self.subject.preprocess_keyvalue_params(self.optimizer_params1), ['max_iter=10', 'optimizer::text="irls"', 'precision=1e-4']) diff --git a/src/ports/postgres/modules/utilities/utilities.py_in b/src/ports/postgres/modules/utilities/utilities.py_in index 39a29c599..133f4aca6 100644 --- a/src/ports/postgres/modules/utilities/utilities.py_in +++ b/src/ports/postgres/modules/utilities/utilities.py_in @@ -794,6 +794,47 @@ def collate_plpy_result(plpy_result_rows): # ------------------------------------------------------------------------------ +def validate_module_input_params(source_table, output_table, independent_varname, + dependent_varname, module_name, + other_output_tables=None): + """ + This function is supposed to be used for validating params for + supervised learning like algos, e.g. linear regression, mlp, etc. since all + of them need to validate the following 4 parameters. + :param source_table: This table should exist and not be empty + :param output_table: This table should not exist + :param dependent_varname: This should be a valid expression in the source + table + :param independent_varname: This should be a valid expression in the source + table + :param module_name: Name of the module to be printed with the error messages + :param other_output_tables: List of additional output tables to validate. + These tables should not exist + """ + + input_tbl_valid(source_table, module_name) + + output_tbl_valid(output_table, module_name) + + if other_output_tables: + for tbl in other_output_tables: + output_tbl_valid(tbl, module_name) + + _assert(is_var_valid(source_table, independent_varname), + "{module_name} error: invalid independent_varname " + "('{independent_varname}') for source_table " + "({source_table})!".format(module_name=module_name, + independent_varname=independent_varname, + source_table=source_table)) + + _assert(is_var_valid(source_table, dependent_varname), + "{module_name} error: invalid dependent_varname " + "('{dependent_varname}') for source_table " + "({source_table})!".format(module_name=module_name, + dependent_varname=dependent_varname, + source_table=source_table)) +# ------------------------------------------------------------------------ + import unittest From 73315937cf16a2173c22da70fd36c7da7b848724 Mon Sep 17 00:00:00 2001 From: Nikhil Kak Date: Thu, 15 Mar 2018 12:40:09 -0700 Subject: [PATCH 7/7] MiniBatch Pre-Processor: Add new module minibatch_preprocessing JIRA: MADLIB-1200 MiniBatch Preprocessor is a utility function to pre-process the input data for use with models that support mini-batching as an optimization. The main purpose of the function is to prepare the training data for minibatching algorithms. 1. If the dependent variable is boolean or text, perform one hot encoding. N/A for numeric. 2. Typecast independent variable to double precision[] 3. Based on the buffer size, group all the dependent and independent variables in a single tuple representative of the buffer. Notes 1. Ignore null values in independent and dependent variables 2. Standardize the input before packing it. Co-authored-by: Rahul Iyer Co-authored-by: Jingyi Mei Co-authored-by: Nandish Jayaram Co-authored-by: Orhan Kislal --- .../utilities/minibatch_preprocessing.py_in | 581 ++++++++++++++++++ .../utilities/minibatch_preprocessing.sql_in | 221 +++++++ .../test/minibatch_preprocessing.sql_in | 221 +++++++ .../test_minibatch_preprocessing.py_in | 304 +++++++++ 4 files changed, 1327 insertions(+) create mode 100644 src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in create mode 100644 src/ports/postgres/modules/utilities/minibatch_preprocessing.sql_in create mode 100644 src/ports/postgres/modules/utilities/test/minibatch_preprocessing.sql_in create mode 100644 src/ports/postgres/modules/utilities/test/unit_tests/test_minibatch_preprocessing.py_in diff --git a/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in b/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in new file mode 100644 index 000000000..f3766d7d9 --- /dev/null +++ b/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in @@ -0,0 +1,581 @@ +# coding=utf-8 +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +""" +@file minibatch_preprocessing.py_in + +""" +from math import ceil +import plpy + +from utilities import add_postfix +from utilities import _assert +from utilities import get_seg_number +from utilities import is_platform_pg +from utilities import is_psql_numeric_type +from utilities import is_string_formatted_as_array_expression +from utilities import py_list_to_sql_string +from utilities import split_quoted_delimited_str +from utilities import _string_to_array +from utilities import validate_module_input_params +from mean_std_dev_calculator import MeanStdDevCalculator +from validate_args import get_expr_type +from validate_args import output_tbl_valid +from validate_args import _tbl_dimension_rownum + +m4_changequote(`') + +# These are readonly variables, do not modify +MINIBATCH_OUTPUT_DEPENDENT_COLNAME = "dependent_varname" +MINIBATCH_OUTPUT_INDEPENDENT_COLNAME = "independent_varname" + +class MiniBatchPreProcessor: + """ + This class is responsible for executing the main logic of mini batch + preprocessing, which packs multiple rows of selected columns from the + source table into one row based on the buffer size + """ + def __init__(self, schema_madlib, source_table, output_table, + dependent_varname, independent_varname, buffer_size, **kwargs): + self.schema_madlib = schema_madlib + self.source_table = source_table + self.output_table = output_table + self.dependent_varname = dependent_varname + self.independent_varname = independent_varname + self.buffer_size = buffer_size + + self.module_name = "minibatch_preprocessor" + self.output_standardization_table = add_postfix(self.output_table, + "_standardization") + self.output_summary_table = add_postfix(self.output_table, "_summary") + self._validate_minibatch_preprocessor_params() + + def minibatch_preprocessor(self): + # Get array expressions for both dep and indep variables from the + # MiniBatchQueryFormatter class + dependent_var_dbtype = get_expr_type(self.dependent_varname, + self.source_table) + qry_formatter = MiniBatchQueryFormatter(self.source_table) + dep_var_array_str, dep_var_classes_str = qry_formatter.\ + get_dep_var_array_and_classes(self.dependent_varname, + dependent_var_dbtype) + indep_var_array_str = qry_formatter.get_indep_var_array_str( + self.independent_varname) + + standardizer = MiniBatchStandardizer(self.schema_madlib, + self.source_table, + dep_var_array_str, + indep_var_array_str, + self.output_standardization_table) + standardize_query = standardizer.get_query_for_standardizing() + + num_rows_processed, num_missing_rows_skipped = self.\ + _get_skipped_rows_processed_count( + dep_var_array_str, + indep_var_array_str) + calculated_buffer_size = MiniBatchBufferSizeCalculator.\ + calculate_default_buffer_size( + self.buffer_size, + num_rows_processed, + standardizer.independent_var_dimension) + """ + This query does the following: + 1. Standardize the independent variables in the input table + (see MiniBatchStandardizer for more details) + 2. Filter out rows with null values either in dependent/independent + variables + 3. Converts the input dependent/independent variables into arrays + (see MiniBatchQueryFormatter for more details) + 4. Based on the buffer size, pack the dependent/independent arrays into + matrices + + Notes + 1. we are ignoring null in x because + a. matrix_agg does not support null + b. __utils_normalize_data returns null if any element of the array + contains NULL + 2. Please keep the null checking where clause of this query in sync with + the query in _get_skipped_rows_processed_count. We are doing this null + check in two places to prevent another pass of the entire dataset. + """ + + # This ID is the unique row id that get assigned to each row after + # preprocessing + unique_row_id = "__id__" + sql = """ + CREATE TABLE {output_table} AS + SELECT {row_id}, + {schema_madlib}.matrix_agg({dep_colname}) as {dep_colname}, + {schema_madlib}.matrix_agg({ind_colname}) as {ind_colname} + FROM ( + SELECT (row_number() OVER (ORDER BY random()) - 1) / {buffer_size} + as {row_id}, * FROM + ( + {standardize_query} + ) sub_query_1 + WHERE NOT {schema_madlib}.array_contains_null({dep_colname}) + AND NOT {schema_madlib}.array_contains_null({ind_colname}) + ) sub_query_2 + GROUP BY {row_id} + {distributed_by_clause} + """.format( + schema_madlib=self.schema_madlib, + source_table=self.source_table, + output_table=self.output_table, + dependent_varname=self.dependent_varname, + independent_varname=self.independent_varname, + buffer_size = calculated_buffer_size, + dep_colname=MINIBATCH_OUTPUT_DEPENDENT_COLNAME, + ind_colname=MINIBATCH_OUTPUT_INDEPENDENT_COLNAME, + row_id = unique_row_id, + distributed_by_clause = '' if is_platform_pg() else + 'DISTRIBUTED RANDOMLY', + **locals()) + plpy.execute(sql) + + + standardizer.create_output_standardization_table() + MiniBatchSummarizer.create_output_summary_table( + self.source_table, + self.output_table, + self.dependent_varname, + self.independent_varname, + calculated_buffer_size, + dep_var_classes_str, + num_rows_processed, + num_missing_rows_skipped, + self.output_summary_table) + + def _validate_minibatch_preprocessor_params(self): + # Test if the independent variable can be typecasted to a double + # precision array and let postgres validate the expression + + # Note that this will not fail for 2d arrays but the standardizer will + # fail because utils_normalize_data will throw an error + typecasted_ind_varname = "{0}::double precision[]".format( + self.independent_varname) + validate_module_input_params(self.source_table, self.output_table, + typecasted_ind_varname, + self.dependent_varname, self.module_name, + [self.output_summary_table, + self.output_standardization_table]) + + num_of_dependent_cols = split_quoted_delimited_str( + self.dependent_varname) + + _assert(len(num_of_dependent_cols) == 1, + "Invalid dependent_varname: only one column name is allowed " + "as input.") + + if self.buffer_size is not None: + _assert(self.buffer_size > 0, + """minibatch_preprocessor: The buffer size has to be a positive + integer or NULL.""") + + def _get_skipped_rows_processed_count(self, dep_var_array, indep_var_array): + # Note: Keep the null checking where clause of this query in sync with + # the main create output table query. + query = """ + SELECT COUNT(*) AS source_table_row_count, + sum(CASE WHEN + NOT {schema_madlib}.array_contains_null({dep_var_array}) + AND NOT {schema_madlib}.array_contains_null({indep_var_array}) + THEN 1 ELSE 0 END) AS num_rows_processed + FROM {source_table} + """.format( + schema_madlib = self.schema_madlib, + source_table = self.source_table, + dep_var_array = dep_var_array, + indep_var_array = indep_var_array) + result = plpy.execute(query) + + source_table_row_count = result[0]['source_table_row_count'] + num_rows_processed = result[0]['num_rows_processed'] + if not source_table_row_count or not num_rows_processed: + plpy.error("Error while getting the row count of the source table" + "{0}".format(self.source_table)) + num_missing_rows_skipped = source_table_row_count - num_rows_processed + + return num_rows_processed, num_missing_rows_skipped + +class MiniBatchQueryFormatter: + """ + This class is responsible for formatting the independent and dependent + variables into arrays so that they can be matrix agged by the preprocessor + class. + """ + def __init__(self, source_table): + self.source_table = source_table + + def get_dep_var_array_and_classes(self, dependent_varname, + dependent_var_dbtype): + """ + :param dependent_varname: Name of the dependent variable + :param dependent_var_dbtype: Type of the dependent variable as stored in + postgres + :return: + This function returns a tuple of + 1. A string with transformed dependent varname depending on it's type + 2. All the distinct dependent class levels encoded as a string + + If dep_type == numeric , do not encode + 1. dependent_varname = rings + transformed_value = ARRAY[[rings1], [rings2], []] + class_level_str = ARRAY[rings = 'rings1', + rings = 'rings2']::integer[] + 2. dependent_varname = ARRAY[a, b, c] + transformed_value = ARRAY[[a1, b1, c1], [a2, b2, c2], []] + class_level_str = 'NULL::TEXT' + else if dep_type in ("text", "boolean"), encode: + 3. dependent_varname = rings (encoding) + transformed_value = ARRAY[[rings1=1, rings1=2], [rings2=1, + rings2=2], []] + class_level_str = 'NULL::TEXT' + """ + dep_var_class_value_str = 'NULL::TEXT' + if dependent_var_dbtype in ("text", "boolean"): + # for encoding, and since boolean can also be a logical expression, + # there is a () for {dependent_varname} to make the query work + dep_level_sql = """ + SELECT DISTINCT ({dependent_varname}) AS class + FROM {source_table} where ({dependent_varname}) is NOT NULL + """.format(dependent_varname=dependent_varname, + source_table=self.source_table) + dep_levels = plpy.execute(dep_level_sql) + + # this is string sorting + dep_var_classes = sorted( + ["{0}".format(l["class"]) for l in dep_levels]) + + dep_var_array_str = self._get_one_hot_encoded_str(dependent_varname, + dep_var_classes) + dep_var_class_value_str = py_list_to_sql_string(dep_var_classes, + array_type=dependent_var_dbtype) + + elif "[]" in dependent_var_dbtype: + dep_var_array_str = dependent_varname + + elif is_psql_numeric_type(dependent_var_dbtype): + dep_var_array_str = 'ARRAY[{0}]'.format(dependent_varname) + + else: + plpy.error("""Invalid dependent variable type. It should be text, + boolean, numeric, or an array.""") + + return dep_var_array_str, dep_var_class_value_str + + def _get_one_hot_encoded_str(self, var_name, var_classes): + one_hot_list = [] + for c in var_classes: + one_hot_list.append("({0}) = '{1}'".format(var_name, c)) + + return 'ARRAY[{0}]::integer[]'.format(','.join(one_hot_list)) + + def get_indep_var_array_str(self, independent_varname): + """ + we assume that all the independent features are either numeric or + already encoded by the user. + Supported formats + 1. ‘ARRAY[x1,x2,x3]’ , where x1,x2,x3 are columns in source table with + scalar values + 2. ‘x1’, where x1 is a single column in source table, with value as an + array, like ARRAY[1,2,3] or {1,2,3} + + we don't deal with a mixture of scalar and array independent variables + """ + typecasted_ind_varname = "{0}::double precision[]".format( + independent_varname) + return typecasted_ind_varname + +class MiniBatchStandardizer: + """ + This class is responsible for + 1. Calculating the mean and std dev for independent variables + 2. Format the query to standardize the input table based on the + calculated mean/std dev + 3. Creating the output standardization table + """ + def __init__(self, schema_madlib, source_table, dep_var_array_str, + indep_var_array_str, output_standardization_table): + self.schema_madlib = schema_madlib + self.source_table = source_table + self.dep_var_array_str = dep_var_array_str + self.indep_var_array_str = indep_var_array_str + self.output_standardization_table = output_standardization_table + + self.x_mean_str = None + self.x_std_dev_str = None + self.source_table_row_count = 0 + self.grouping_cols = "NULL" + self.independent_var_dimension = None + self._calculate_mean_and_std_dev_str() + + def _calculate_mean_and_std_dev_str(self): + self.independent_var_dimension, _ = _tbl_dimension_rownum( + self.schema_madlib, + self.source_table, + self.indep_var_array_str, + skip_row_count=True) + + calculator = MeanStdDevCalculator(self.schema_madlib, + self.source_table, + self.indep_var_array_str, + self.independent_var_dimension) + + self.x_mean_str, self.x_std_dev_str = calculator.\ + get_mean_and_std_dev_for_ind_var() + + if not self.x_mean_str or not self.x_std_dev_str: + plpy.error("mean/stddev for the independent variable" + "cannot be null") + + def get_query_for_standardizing(self): + query=""" + SELECT + {dep_var_array_str} as {dep_colname}, + {schema_madlib}.utils_normalize_data + ( + {indep_var_array_str},'{x_mean_str}'::double precision[], + '{x_std_dev_str}'::double precision[] + ) as {ind_colname} + FROM {source_table} + """.format( + source_table = self.source_table, + schema_madlib = self.schema_madlib, + dep_var_array_str = self.dep_var_array_str, + indep_var_array_str = self.indep_var_array_str, + dep_colname = MINIBATCH_OUTPUT_DEPENDENT_COLNAME, + ind_colname = MINIBATCH_OUTPUT_INDEPENDENT_COLNAME, + x_mean_str = self.x_mean_str, + x_std_dev_str = self.x_std_dev_str) + return query + + def create_output_standardization_table(self): + query = """ + CREATE TABLE {output_standardization_table} AS + select {grouping_cols}::TEXT AS grouping_cols, + '{x_mean_str}'::double precision[] AS mean, + '{x_std_dev_str}'::double precision[] AS std + """.format( + output_standardization_table = self.output_standardization_table, + grouping_cols = self.grouping_cols, + x_mean_str = self.x_mean_str, + x_std_dev_str = self.x_std_dev_str) + plpy.execute(query) + +class MiniBatchSummarizer: + @staticmethod + def create_output_summary_table(source_table, output_table, + dep_var_array_str, indep_var_array_str, + buffer_size, class_values, + num_rows_processed, + num_missing_rows_skipped, + output_summary_table): + query = """ + CREATE TABLE {output_summary_table} AS + SELECT '{source_table}'::TEXT AS source_table, + '{output_table}'::TEXT AS output_table, + '{dependent_varname}'::TEXT AS dependent_varname, + '{independent_varname}'::TEXT AS independent_varname, + {buffer_size} AS buffer_size, + {class_values} AS class_values, + {num_rows_processed} AS num_rows_processed, + {num_missing_rows_skipped} AS num_missing_rows_skipped, + {grouping_cols}::TEXT AS grouping_cols + """.format(output_summary_table = output_summary_table, + source_table = source_table, + output_table = output_table, + dependent_varname = dep_var_array_str, + independent_varname = indep_var_array_str, + buffer_size = buffer_size, + class_values = class_values, + num_rows_processed = num_rows_processed, + num_missing_rows_skipped = num_missing_rows_skipped, + grouping_cols = "NULL") + plpy.execute(query) + +class MiniBatchBufferSizeCalculator: + """ + This class is responsible for calculating the buffer size. + This is a work in progress, final formula might change. + """ + @staticmethod + def calculate_default_buffer_size(buffer_size, + num_rows_processed, + independent_var_dimension): + if buffer_size is not None: + return buffer_size + num_of_segments = get_seg_number() + + default_buffer_size = min(75000000.0/independent_var_dimension, + float(num_rows_processed)/num_of_segments) + """ + 1. For float number, we need at least one more buffer for the fraction part, e.g. + if default_buffer_size = 0.25, we need to round it to 1. + 2. Ceiling returns a float in python2. So after ceiling, we cast + default_buffer_size to int, because it will be used to calculate the + row id of the packed input. The query looks like this + + SELECT (row_number() OVER (ORDER BY random()) - 1) / {buffer_size} + + This calculation has to return an int for which buffer_size has + to be an int + """ + return int(ceil(default_buffer_size)) + +class MiniBatchDocumentation: + @staticmethod + def minibatch_preprocessor_help(schema_madlib, message): + method = "minibatch_preprocessor" + summary = """ + ---------------------------------------------------------------- + SUMMARY + ---------------------------------------------------------------- + MiniBatch Preprocessor is a utility function to pre process the input + data for use with models that support mini-batching as an optimization + + #TODO add more here + + For more details on function usage: + SELECT {schema_madlib}.{method}('usage') + + For a small example on using the function: + SELECT {schema_madlib}.{method}('example') + """.format(**locals()) + + usage = """ + --------------------------------------------------------------------------- + USAGE + --------------------------------------------------------------------------- + SELECT {schema_madlib}.{method}( + source_table, -- TEXT. Name of the table containing input + data. Can also be a view + output_table , -- TEXT. Name of the output table for + mini-batching + dependent_varname, -- TEXT. Name of the dependent variable column + independent_varname, -- TEXT. Name of the independent variable + column + buffer_size -- INTEGER. Number of source input rows to + pack into batch + ); + + + --------------------------------------------------------------------------- + OUTPUT + --------------------------------------------------------------------------- + The output table produced by MiniBatch Preprocessor contains the + following columns: + + id -- INTEGER. Unique id for packed table. + dependent_varname -- FLOAT8[]. Packed array of dependent variables. + independent_varname -- FLOAT8[]. Packed array of independent + variables. + + --------------------------------------------------------------------------- + The algorithm also creates a summary table named _summary + that has the following columns: + + source_table -- Source table name. + output_table -- Output table name from preprocessor. + dependent_varname -- Dependent variable from the original table. + independent_varname -- Independent variables from the original + table. + buffer_size -- Buffer size used in preprocessing step. + class_values -- Class values of the dependent variable + (‘NULL’(as TEXT type) for non + categorical vars). + num_rows_processed -- The total number of rows that were used in + the computation. + num_missing_rows_skipped -- The total number of rows that were skipped + because of NULL values in them. + grouping_cols -- NULL if no grouping_col was specified + during training, and a comma separated list + of grouping column names if not. + + --------------------------------------------------------------------------- + The algorithm also creates a standardization table that stores some + metadata used during the model training and prediction, and is named + _standardization. It has the following columns: + + grouping_cols -- If grouping_col is specified during training, + a column for each grouping column is created. + mean -- The mean for all input features (used for + normalization). + std -- The standard deviation for all input features (used + for normalization). + """.format(**locals()) + + example = """ + -- Create input table + CREATE TABLE iris_data( + id INTEGER, + attributes NUMERIC[], + class_text text, + class INTEGER, + state VARCHAR + ); + + COPY iris_data (attributes, class_text, class, state) FROM STDIN NULL '?' DELIMITER '|'; + {4.4,3.2,1.3,0.2}|Iris_setosa|1|Alaska + {5.0,3.5,1.6,0.6}|Iris_setosa|1|Alaska + {5.1,3.8,1.9,0.4}|Iris_setosa|1|Alaska + {4.8,3.0,1.4,0.3}|Iris_setosa|1|Alaska + {5.1,3.8,1.6,0.2}|Iris_setosa|1|Alaska + {5.7,2.8,4.5,1.3}|Iris_versicolor|2|Alaska + {6.3,3.3,4.7,1.6}|Iris_versicolor|2|Alaska + {4.9,2.4,3.3,1.0}|Iris_versicolor|2|Alaska + {6.6,2.9,4.6,1.3}|Iris_versicolor|2|Alaska + {5.2,2.7,3.9,1.4}|Iris_versicolor|2|Alaska + {5.0,2.0,3.5,1.0}|Iris_versicolor|2|Alaska + {4.8,3.0,1.4,0.1}|Iris_setosa|1|Tennessee + {4.3,3.0,1.1,0.1}|Iris_setosa|1|Tennessee + {5.8,4.0,1.2,0.2}|Iris_setosa|1|Tennessee + {5.7,4.4,1.5,0.4}|Iris_setosa|1|Tennessee + {5.4,3.9,1.3,0.4}|Iris_setosa|1|Tennessee + {6.0,2.9,4.5,1.5}|Iris_versicolor|2|Tennessee + {5.7,2.6,3.5,1.0}|Iris_versicolor|2|Tennessee + {5.5,2.4,3.8,1.1}|Iris_versicolor|2|Tennessee + {5.5,2.4,3.7,1.0}|Iris_versicolor|2|Tennessee + {5.8,2.7,3.9,1.2}|Iris_versicolor|2|Tennessee + {6.0,2.7,5.1,1.6}|Iris_versicolor|2|Tennessee + \. + + -- #TODO add description here + DROP TABLE IF EXISTS iris_data_batch, iris_data_batch_standardization, iris_data_batch_summary; + SELECT madlib.minibatch_preprocessor('iris_data', 'iris_data_batch', 'class_text', 'attributes', 3); + + + -- #TODO add description here NULL buffer size + DROP TABLE IF EXISTS iris_data_batch, iris_data_batch_standardization, iris_data_batch_summary; + SELECT madlib.minibatch_preprocessor('iris_data', 'iris_data_batch', 'class_text', 'attributes'); + + """ + + if not message: + return summary + elif message.lower() in ('usage', 'help', '?'): + return usage + elif message.lower() == 'example': + return example + return """ + No such option. Use "SELECT {schema_madlib}.minibatch_preprocessor()" + for help. + """.format(**locals()) +# --------------------------------------------------------------------- diff --git a/src/ports/postgres/modules/utilities/minibatch_preprocessing.sql_in b/src/ports/postgres/modules/utilities/minibatch_preprocessing.sql_in new file mode 100644 index 000000000..01d91e5ad --- /dev/null +++ b/src/ports/postgres/modules/utilities/minibatch_preprocessing.sql_in @@ -0,0 +1,221 @@ +/* ----------------------------------------------------------------------- */ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + * @file minibatch_preprocessing.sql_in + * @brief TODO + * @date Mar 2018 + * + */ +/* ----------------------------------------------------------------------- */ + +m4_include(`SQLCommon.m4') + +/** +@addtogroup grp_minibatch_preprocessing + +
Contents
+ +MiniBatch Preprocessor is a utility function to pre process the input +data for use with models that support mini-batching as an optimization + +@brief +@anchor minibatch_preprocessor +@par MiniBatch Preprocessor +
+minibatch_preprocessor(
+    source_table
+    output_table
+    dependent_varname
+    independent_varname
+    buffer_size
+    )
+
+ +\b Arguments +
+
source_table
+
TEXT. Name of the table containing input data. Can also be a view. +
+ +
output_table
+
TEXT. Name of the output table from the preprocessor which will be used + as input to algorithms that support mini-batching. +
+ +
dependent_varname
+
TEXT. Name of the dependent variable column. +
+ +
independent_varname
+
TEXT. Column name or expression list to evaluate for the independent + variable. Will be cast to double when packing. + @note + Supported expressions for independent variable + ‘ARRAY[x1,x2,x3]’ , where x1,x2,x3 are columns in source table with scalar values + ‘x1’, where x1 is a single column in source table, with value as an array, like ARRAY[1,2,3] or {1,2,3} + We might already support expressions that evaluate to array but haven't tested it. + + Not supported + ‘x1,x2,x3’, where x1,x2,x3 are columns in source table with scalar values + ARRAY[x1,x2] where x1 is scalar and x2 is array + ARRAY[x1,x2] where both x1 and x2 are arrays + ARRAY[x1] where x1 is array +
+ +
buffer_size
+
INTEGER. default: ???. Number of source input rows to pack into batch +
+ +
grouping_col (optional)
+
TEXT, default: NULL. + An expression list used to group the input dataset into discrete groups, + running one preprocessing step per group. Similar to the SQL GROUP BY clause. + When this value is NULL, no grouping is used and a single preprocessing step + is performed for the whole data set. +
+
+ +Output tables +
+ The output table produced by MLP contains the following columns: + + + + + + + + + + + + + + + + + +
idINTEGER. Unique id for packed table. +
dependent_varnameFLOAT8[]. Packed array of dependent variables. +
independent_varnameFLOAT8[]. Packed array of independent variables. +
grouping_colsTEXT. Name of grouping columns +
+ +A summary table named \_summary is also created, which has the following columns: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
source_tableThe source table.
output_tableOutput table name from preprocessor.
dependent_varnameDependent variable from the input table.
independent_varnameIndependent variable from the source table.
buffer_sizeBuffer size used in preprocessing step.
class_valuesClass values of the dependent variable (‘NULL’(as TEXT type) for non categorical vars, i,e., if dependent_vartype=”Categorical”)./td> +
num_rows_processedThe total number of rows that were used in the computation.
num_missing_rows_skippedThe total number of rows that were skipped because of NULL values in them.
grouping_colNULL if no grouping_col was specified , and a comma separated + list of grouping column names if not.
+ +A standardization table named \_standardization is also created, that has the +following columns: + + + + + + + + + + + + + +
grouping columnsIf grouping_col is specified during training, a column for each grouping column + is created.
meanMean of independent vars by group
stdStandard deviation of independent vars by group
+ +@anchor example +@par Examples + */ + +CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.minibatch_preprocessor( + source_table VARCHAR, + output_table VARCHAR, + dependent_varname VARCHAR, + independent_varname VARCHAR, + buffer_size INTEGER +) RETURNS VOID AS $$ + PythonFunctionBodyOnly(utilities, minibatch_preprocessing) + minibatch_preprocessor_obj = minibatch_preprocessing.MiniBatchPreProcessor(**globals()) + minibatch_preprocessor_obj.minibatch_preprocessor() +$$ LANGUAGE plpythonu VOLATILE +m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `'); + +CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.minibatch_preprocessor( + source_table VARCHAR, + output_table VARCHAR, + dependent_varname VARCHAR, + independent_varname VARCHAR +) RETURNS VOID AS $$ + SELECT MADLIB_SCHEMA.minibatch_preprocessor($1, $2, $3, $4, NULL); +$$ LANGUAGE sql VOLATILE +m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `'); + +CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.minibatch_preprocessor( + message VARCHAR +) RETURNS VARCHAR AS $$ + PythonFunctionBodyOnly(utilities, minibatch_preprocessing) + return minibatch_preprocessing.MiniBatchDocumentation.minibatch_preprocessor_help(schema_madlib, message) +$$ LANGUAGE plpythonu VOLATILE +m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `'); + +CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.minibatch_preprocessor() +RETURNS VARCHAR AS $$ + PythonFunctionBodyOnly(utilities, minibatch_preprocessing) + return minibatch_preprocessing.MiniBatchDocumentation.minibatch_preprocessor_help(schema_madlib, '') +$$ LANGUAGE plpythonu VOLATILE +m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `'); diff --git a/src/ports/postgres/modules/utilities/test/minibatch_preprocessing.sql_in b/src/ports/postgres/modules/utilities/test/minibatch_preprocessing.sql_in new file mode 100644 index 000000000..d49b66f22 --- /dev/null +++ b/src/ports/postgres/modules/utilities/test/minibatch_preprocessing.sql_in @@ -0,0 +1,221 @@ +/* ----------------------------------------------------------------------- *//** + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + *//* ----------------------------------------------------------------------- */ +DROP TABLE IF EXISTS minibatch_preprocessing_input; +CREATE TABLE minibatch_preprocessing_input( + sex TEXT, + id SERIAL NOT NULL, + length DOUBLE PRECISION, + diameter DOUBLE PRECISION, + height DOUBLE PRECISION, + whole DOUBLE PRECISION, + shucked DOUBLE PRECISION, + viscera DOUBLE PRECISION, + shell DOUBLE PRECISION, + rings INTEGER); + +INSERT INTO minibatch_preprocessing_input(id,sex,length,diameter,height,whole,shucked,viscera,shell,rings) VALUES +(1040,'F',0.66,0.475,0.18,1.3695,0.641,0.294,0.335,6), +(3160,'F',0.34,0.255,0.085,0.204,0.097,0.021,0.05,6), +(3984,'F',0.585,0.45,0.125,0.874,0.3545,0.2075,0.225,6), +(2551,'I',0.28,0.22,0.08,0.1315,0.066,0.024,0.03,5), +(1246,'I',0.385,0.28,0.09,0.228,0.1025,0.042,0.0655,5), +(519,'M',0.325,0.23,0.09,0.147,0.06,0.034,0.045,4), +(2382,'M',0.155,0.115,0.025,0.024,0.009,0.005,0.0075,5), +(698,'M',0.28,0.205,0.1,0.1165,0.0545,0.0285,0.03,5), +(2381,'M',0.175,0.135,0.04,0.0305,0.011,0.0075,0.01,5), +(516,'M',0.27,0.195,0.08,0.1,0.0385,0.0195,0.03,6); + +-- no of rows = 10, buffer_size = 4, so assert that count = 10/4 = 3 +\set expected_row_count 3 +DROP TABLE IF EXISTS minibatch_preprocessing_out, minibatch_preprocessing_out_standardization, minibatch_preprocessing_out_summary; +SELECT minibatch_preprocessor('minibatch_preprocessing_input', 'minibatch_preprocessing_out', 'length>0.2', 'ARRAY[diameter,height,whole,shucked,viscera,shell]', 4); +SELECT assert + ( + row_count = :expected_row_count, 'Row count validation failed for minibatch_preprocessing_out. + Expected:' || :expected_row_count || ' Actual: ' || row_count + ) from (select count(*) as row_count from minibatch_preprocessing_out) s; + +\set expected_dep_row_count '\'' 2,4,4 '\'' +\set expected_dep_col_count '\'' 2,2,2 '\'' +\set expected_indep_row_count '\'' 2,4,4 '\'' +\set expected_indep_col_count '\'' 6,6,6 '\'' + +-- assert dimensions for both dependent and independent variable +SELECT assert + ( + str_dep_row_count = :expected_dep_row_count, 'Dependent variable row count failed. Actual: ' || str_dep_row_count || ' Expected:' || :expected_dep_row_count + ) from + ( + select array_to_string(array_agg(row_count order by row_count asc), ',') as str_dep_row_count from (select array_upper(dependent_varname,1) as row_count from minibatch_preprocessing_out order by row_count asc) s + ) s; + +SELECT assert + ( + str_dep_col_count = :expected_dep_col_count, 'Dependent variable col count failed. Actual: ' || str_dep_col_count || ' Expected:' || :expected_dep_col_count + ) from + ( + select array_to_string(array_agg(col_count order by col_count asc), ',') as str_dep_col_count from (select array_upper(dependent_varname,2) as col_count from minibatch_preprocessing_out order by col_count asc) s + ) s; + +SELECT assert + ( + str_indep_row_count = :expected_indep_row_count, 'Independent variable row count failed. Actual: ' || str_indep_row_count || ' Expected:' || :expected_indep_row_count + ) from + ( + select array_to_string(array_agg(row_count order by row_count asc), ',') as str_indep_row_count from (select array_upper(independent_varname, 1) as row_count from minibatch_preprocessing_out order by row_count asc) s + ) s; + +SELECT assert + ( + str_indep_col_count = :expected_indep_col_count, 'Independent variable col count failed. Actual: ' || str_indep_col_count || ' Expected:' || :expected_indep_col_count + ) from + ( + select array_to_string(array_agg(col_count order by col_count asc), ',') as str_indep_col_count from (select array_upper(independent_varname,2) as col_count from minibatch_preprocessing_out order by col_count asc) s + ) s; + +SELECT assert + ( + source_table = 'minibatch_preprocessing_input' AND + output_table = 'minibatch_preprocessing_out' AND + dependent_varname = 'length>0.2' AND + independent_varname = 'ARRAY[diameter,height,whole,shucked,viscera,shell]' AND + buffer_size = 4 AND + class_values = '{f,t}' AND -- we sort the class values in python + num_rows_processed = 10 AND + num_missing_rows_skipped = 0 AND + grouping_cols is NULL, + 'Summary Validation failed. Expected:' || __to_char(summary) + ) from (select * from minibatch_preprocessing_out_summary) summary; + + +-- Test null values in x and y +\set expected_row_count 1 +DROP TABLE IF EXISTS minibatch_preprocessing_out, minibatch_preprocessing_out_standardization, minibatch_preprocessing_out_summary; + +TRUNCATE TABLE minibatch_preprocessing_input; +INSERT INTO minibatch_preprocessing_input(id,sex,length,diameter,height,whole,shucked,viscera,shell,rings) VALUES +(1040,'F',0.66,0.475,0.18,NULL,0.641,0.294,0.335,6), +(3160,'F',0.34,0.35,0.085,0.204,0.097,0.021,0.05,6), +(3984,NULL,0.585,0.45,0.25,0.874,0.3545,0.2075,0.225,5), +(861,'M',0.595,0.475,NULL,1.1405,0.547,0.231,0.271,6), +(932,NULL,0.445,0.335,0.11,0.4355,0.2025,0.1095,0.1195,6), +(698,'F',0.445,0.335,0.11,0.4355,0.2025,0.1095,0.1195,6), +(922,NULL,0.445,0.335,0.11,NULL,0.2025,0.1095,0.1195,6); +SELECT minibatch_preprocessor('minibatch_preprocessing_input', 'minibatch_preprocessing_out', 'sex', 'ARRAY[length,diameter,height,whole,shucked,viscera,shell]', 2); +SELECT assert + ( + row_count = :expected_row_count, 'Row count validation failed for minibatch_preprocessing_out. + Expected:' || :expected_row_count || ' Actual: ' || row_count + ) from (select count(*) as row_count from minibatch_preprocessing_out) s; +SELECT assert + (num_rows_processed = 2 AND num_missing_rows_skipped = 5, + 'Rows processed/skipped validation failed for minibatch_preprocessing_out_summary. + Actual num_rows_processed:' || num_rows_processed || ', Actual num_missing_rows_skipped: ' || num_missing_rows_skipped + ) from (select * from minibatch_preprocessing_out_summary) s; + +-- Test standardization +DROP TABLE IF EXISTS minibatch_preprocessing_input; +DROP TABLE IF EXISTS minibatch_preprocessing_out, minibatch_preprocessing_out_standardization, minibatch_preprocessing_out_summary; +CREATE TABLE minibatch_preprocessing_input(x1 INTEGER ,x2 INTEGER ,y TEXT); +INSERT INTO minibatch_preprocessing_input(x1,x2,y) VALUES +(2,10,'y1'), +(4,30,'y2'); +SELECT minibatch_preprocessor('minibatch_preprocessing_input', 'minibatch_preprocessing_out', 'y', 'ARRAY[x1,x2]', 2); + +-- since the order is not deterministic, we assert for all possible orders +\set expected_normalized_independent_var1 '\'' {{-1, -1},{1, 1}} '\'' +\set expected_normalized_independent_var2 '\'' {{1, 1},{-1, -1}} '\'' + +SELECT assert +( + independent_varname = :expected_normalized_independent_var1 OR + independent_varname = :expected_normalized_independent_var2, + 'Standardization check failed. Actual: ' || independent_varname +) from +( + select __to_char(independent_varname) as independent_varname from minibatch_preprocessing_out +) s; + + +-- Test that the standardization table gets created. +\set expected_row_count 1 +SELECT assert +( + row_count = :expected_row_count, 'Row count validation failed for minibatch_preprocessing_out_standardization. + Expected:' || :expected_row_count || ' Actual: ' || row_count +) from +( + select count(*) as row_count from minibatch_preprocessing_out_standardization +) s; + +-- Test that the summary table gets created. +\set expected_row_count 1 +SELECT assert +( + row_count = :expected_row_count, 'Row count validation failed for minibatch_preprocessing_out_summary. + Expected:' || :expected_row_count || ' Actual: ' || row_count +) from +( + select count(*) as row_count from minibatch_preprocessing_out_summary +) s; + +-- Test for array values in indep column +DROP TABLE IF EXISTS minibatch_preprocessing_input; +DROP TABLE IF EXISTS minibatch_preprocessing_out, minibatch_preprocessing_out_standardization, minibatch_preprocessing_out_summary; +CREATE TABLE minibatch_preprocessing_input( + id INTEGER, + sex TEXT, + attributes double precision[], + rings INTEGER); +TRUNCATE TABLE minibatch_preprocessing_input; +INSERT INTO minibatch_preprocessing_input(id,sex,attributes) VALUES +(1040,'F',ARRAY[0.66,0.475,0.18,NULL,0.641,0.294,0.335]), +(3160,'F',ARRAY[0.34,0.35,0.085,0.204,0.097,0.021,0.05]), +(3984,NULL,ARRAY[0.585,0.45,0.25,0.874,0.3545,0.2075,0.225]), +(861,'M',ARRAY[0.595,0.475,NULL,1.1405,0.547,0.231,0.271]), +(932,NULL,ARRAY[0.445,0.335,0.11,0.4355,0.2025,0.1095,0.1195]), +(NULL,'F',ARRAY[0.445,0.335,0.11,0.4355,0.2025,0.1095,0.1195]), +(922,NULL,ARRAY[0.445,0.335,0.11,NULL,0.2025,0.1095,0.1195]); +SELECT minibatch_preprocessor('minibatch_preprocessing_input', 'minibatch_preprocessing_out', 'sex', 'attributes', 1); +SELECT assert + ( + row_count = 2, 'Row count validation failed for minibatch_preprocessing_out. + Expected:' || 2 || ' Actual: ' || row_count + ) from (select count(*) as row_count from minibatch_preprocessing_out) s; + +-- Test for array values in dep column +DROP TABLE IF EXISTS minibatch_preprocessing_out, minibatch_preprocessing_out_standardization, minibatch_preprocessing_out_summary; +SELECT minibatch_preprocessor('minibatch_preprocessing_input', 'minibatch_preprocessing_out', 'attributes', 'ARRAY[id]', 1); +SELECT assert + ( + row_count = 3, 'Row count validation failed array values in dependent variable. + Expected:' || 3 || ' Actual: ' || row_count + ) from (select count(*) as row_count from minibatch_preprocessing_out) s; + +-- Test for null buffer size +DROP TABLE IF EXISTS minibatch_preprocessing_out, minibatch_preprocessing_out_standardization, minibatch_preprocessing_out_summary; +SELECT minibatch_preprocessor('minibatch_preprocessing_input', 'minibatch_preprocessing_out', 'attributes', 'ARRAY[id]'); +SELECT assert + ( + ind_var_rows = dep_var_rows AND ind_var_rows = buffer_size, 'Row count validation failed for null buffer size. + Buffer size from summary table: ' || buffer_size || ' does not match the output table:' + || ind_var_rows + ) from (select max(array_upper(o.dependent_varname, 1)) as dep_var_rows, max(array_upper(o.independent_varname, 1)) as ind_var_rows , s1.buffer_size from minibatch_preprocessing_out o, minibatch_preprocessing_out_summary s1 group by buffer_size) s; diff --git a/src/ports/postgres/modules/utilities/test/unit_tests/test_minibatch_preprocessing.py_in b/src/ports/postgres/modules/utilities/test/unit_tests/test_minibatch_preprocessing.py_in new file mode 100644 index 000000000..c8b6942db --- /dev/null +++ b/src/ports/postgres/modules/utilities/test/unit_tests/test_minibatch_preprocessing.py_in @@ -0,0 +1,304 @@ +# coding=utf-8 +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import sys +from os import path +# Add utilites module to the pythonpath. +sys.path.append(path.dirname(path.dirname(path.dirname(path.abspath(__file__))))) + +import unittest +from mock import * +import plpy_mock as plpy + +m4_changequote(`') + +class MiniBatchPreProcessingTestCase(unittest.TestCase): + def setUp(self): + self.plpy_mock = Mock(spec='error') + patches = { + 'plpy': plpy, + 'mean_std_dev_calculator': Mock() + } + + # we need to use MagicMock() instead of Mock() for the plpy.execute mock + # to be able to iterate on the return value + self.plpy_mock_execute = MagicMock() + plpy.execute = self.plpy_mock_execute + + self.module_patcher = patch.dict('sys.modules', patches) + self.module_patcher.start() + + + self.default_schema_madlib = "madlib" + self.default_source_table = "source" + self.default_output_table = "output" + self.default_dep_var = "depvar" + self.default_ind_var = "indvar" + self.default_buffer_size = 5 + + import minibatch_preprocessing + self.module = minibatch_preprocessing + self.module.validate_module_input_params = Mock() + self.output_tbl_valid_mock = Mock() + self.module.output_tbl_valid = self.output_tbl_valid_mock + + self.minibatch_query_formatter = self.module.MiniBatchQueryFormatter + self.minibatch_query_formatter.get_dep_var_array_and_classes = Mock( + return_value=("anything1", "anything2")) + self.minibatch_query_formatter.get_indep_var_array_str = Mock( + return_value="anything3") + + self.module.MiniBatchStandardizer = Mock() + self.module.MiniBatchSummarizer = Mock() + self.module.get_expr_type = MagicMock(return_value="anytype") + + def tearDown(self): + self.module_patcher.stop() + + def test_minibatch_preprocessor_executes_query(self): + preprocessor_obj = self.module.MiniBatchPreProcessor(self.default_schema_madlib, + "input", + "out", + self.default_dep_var, + self.default_ind_var, + self.default_buffer_size) + self.plpy_mock_execute.side_effect = [[{"source_table_row_count":5 , + "num_rows_processed":3}], ""] + preprocessor_obj.minibatch_preprocessor() + self.assertEqual(2, self.plpy_mock_execute.call_count) + self.assertEqual(self.default_buffer_size, preprocessor_obj.buffer_size) + + def test_minibatch_preprocessor_null_buffer_size_executes_query(self): + preprocessor_obj = self.module.MiniBatchPreProcessor(self.default_schema_madlib, + "input", + "out", + self.default_dep_var, + self.default_ind_var, + None) + self.plpy_mock_execute.side_effect = [[{"source_table_row_count":5 , + "num_rows_processed":3}], ""] + self.module.MiniBatchBufferSizeCalculator.calculate_default_buffer_size = Mock() + preprocessor_obj.minibatch_preprocessor() + self.assertEqual(2, self.plpy_mock_execute.call_count) + + def test_minibatch_preprocessor_multiple_dep_var_raises_exception(self): + with self.assertRaises(Exception): + self.module.MiniBatchPreProcessor(self.default_schema_madlib, + self.default_source_table, + self.default_output_table, + "y1,y2", + self.default_ind_var, + self.default_buffer_size) + + def test_minibatch_preprocessor_buffer_size_zero_fails(self): + with self.assertRaises(Exception): + self.module.MiniBatchPreProcessor(self.default_schema_madlib, + self.default_source_table, + self.default_output_table, + self.default_dep_var, + self.default_ind_var, + 0) + + def test_minibatch_preprocessor_buffer_size_one_passes(self): + #not sure how to assert that an exception has not been raised + preprocessor_obj = self.module.MiniBatchPreProcessor(self.default_schema_madlib, + self.default_source_table, + self.default_output_table, + self.default_dep_var, + self.default_ind_var, + 1) + preprocessor_obj.minibatch_preprocessor() + self.assert_(True) + +class MiniBatchQueryFormatterTestCase(unittest.TestCase): + def setUp(self): + self.default_source_table = "source" + self.default_dep_var = "depvar" + self.default_ind_var = "indvar" + patches = { + 'plpy': plpy, + 'mean_std_dev_calculator': Mock() + } + + # we need to use MagicMock() instead of Mock() for the plpy.execute mock + # to be able to iterate on the return value + self.plpy_mock_execute = MagicMock() + plpy.execute = self.plpy_mock_execute + + self.module_patcher = patch.dict('sys.modules', patches) + self.module_patcher.start() + + import minibatch_preprocessing + self.module = minibatch_preprocessing + self.subject = self.module.MiniBatchQueryFormatter(self.default_source_table) + + def tearDown(self): + self.module_patcher.stop() + + def test_get_dep_var_array_str_text_type(self): + self.plpy_mock_execute.return_value = [{"class":100},{"class":0},{"class":22}] + + dep_var_array_str, _ = self.subject.get_dep_var_array_and_classes\ + (self.default_dep_var, 'text') + + # get_dep_var_array_str does a string sorting on the class levels. Hence the order + # 0,100,22 and not 0,22,100 + self.assertEqual("ARRAY[({0}) = '0',({0}) = '100',({0}) = '22']::integer[]". + format(self.default_dep_var), dep_var_array_str) + + def test_get_dep_var_array_str_boolean_type(self): + self.plpy_mock_execute.return_value = [{"class":3}] + + dep_var_array_str, _ = self.subject.\ + get_dep_var_array_and_classes(self.default_dep_var, 'boolean') + self.assertEqual("ARRAY[({0}) = '3']::integer[]". + format(self.default_dep_var), dep_var_array_str) + + def test_get_dep_var_array_str_array_type(self): + dep_var_array_str, _ = self.subject.\ + get_dep_var_array_and_classes(self.default_dep_var, 'some_array[]') + + self.assertEqual(self.default_dep_var, dep_var_array_str) + + def test_get_dep_var_array_str_numeric_type(self): + dep_var_array_str, _ = self.subject. \ + get_dep_var_array_and_classes(self.default_dep_var, 'integer') + + self.assertEqual("ARRAY[{0}]".format(self.default_dep_var), dep_var_array_str) + + def test_get_dep_var_array_str_other_type(self): + with self.assertRaises(Exception): + self.subject.get_dep_var_array_and_classes(self.default_dep_var, 'other') + + def test_get_indep_var_array_str_passes(self): + ind_var_array_str = self.subject.get_indep_var_array_str('ARRAY[x1,x2,x3]') + self.assertEqual("ARRAY[x1,x2,x3]::double precision[]", ind_var_array_str) + +class MiniBatchQueryStandardizerTestCase(unittest.TestCase): + def setUp(self): + self.default_source_table = "source" + self.default_dep_var = "depvar" + self.default_ind_var = "indvar" + self.default_schema = "schema" + self.mean_std_calculator_mock = Mock() + patches = { + 'plpy': plpy, + 'mean_std_dev_calculator': self.mean_std_calculator_mock + } + self.x_mean = "5678" + self.x_std_dev = "4.789" + self.mean_std_calculator_mock.MeanStdDevCalculator.return_value.get_mean_and_std_dev_for_ind_var = Mock(return_value=(self.x_mean, self.x_std_dev)) + + # we need to use MagicMock() instead of Mock() for the plpy.execute mock + # to be able to iterate on the return value + self.plpy_mock_execute = MagicMock() + plpy.execute = self.plpy_mock_execute + + self.module_patcher = patch.dict('sys.modules', patches) + self.module_patcher.start() + + import minibatch_preprocessing + self.module = minibatch_preprocessing + self.subject = self.module.MiniBatchStandardizer(self.default_schema, + self.default_source_table, + self.default_dep_var, + self.default_ind_var, + "out_standardization") + + def tearDown(self): + self.module_patcher.stop() + + def test_get_query_for_standardizing_no_exception(self): + self.subject.get_query_for_standardizing() + + def test_get_query_for_standardizing_null_mean_raises_exception(self): + self.mean_std_calculator_mock.MeanStdDevCalculator.return_value.get_mean_and_std_dev_for_ind_var = Mock(return_value=(None, self.x_std_dev)) + with self.assertRaises(Exception): + self.module.MiniBatchStandardizer(self.default_schema, + self.default_source_table, + self.default_dep_var, + self.default_ind_var, + "does_not_matter") + + def test_get_query_for_standardizing_null_stddev_raises_exception(self): + self.mean_std_calculator_mock.MeanStdDevCalculator.return_value.get_mean_and_std_dev_for_ind_var = Mock(return_value=(self.x_mean, None)) + with self.assertRaises(Exception): + self.module.MiniBatchStandardizer(self.default_schema, + self.default_source_table, + self.default_dep_var, + self.default_ind_var, + "does_not_matter") + + def test_get_calculated_mean_and_std_dev_returns_values(self): + self.subject.get_query_for_standardizing() + mean, std_dev = self.subject.x_mean_str, self.subject.x_std_dev_str + self.assertEqual(self.x_mean, mean) + self.assertEqual(self.x_std_dev, std_dev) + + def test_create_standardization_output_table_executes_query(self): + self.subject.create_output_standardization_table() + expected_query_substr_create_table = "CREATE TABLE out_standardization AS" + self.plpy_mock_execute.assert_called_with(AnyStringWith(expected_query_substr_create_table)) + self.plpy_mock_execute.assert_called_with(AnyStringWith(self.x_mean)) + self.plpy_mock_execute.assert_called_with(AnyStringWith(self.x_std_dev)) + +class MiniBatchBufferSizeCalculatorTestCase(unittest.TestCase): + def setUp(self): + patches = { + 'plpy': plpy, + 'mean_std_dev_calculator': Mock() + } + self.a = 'a' + self.module_patcher = patch.dict('sys.modules', patches) + self.module_patcher.start() + import minibatch_preprocessing + self.module = minibatch_preprocessing + self.subject = self.module.MiniBatchBufferSizeCalculator + + def tearDown(self): + self.module_patcher.stop() + + def test_calculate_default_buffer_size_non_none_buffer_size(self): + buffer_size = self.subject.calculate_default_buffer_size(1, 3, 100) + self.assertTrue(isinstance(buffer_size, int)) + self.assertEqual(1, buffer_size) + + def test_calculate_default_buffer_size_none_buffer_size(self): + self.module.get_seg_number = Mock(return_value = 4) + buffer_size = self.subject.calculate_default_buffer_size(None, 100, 1000) + self.assertTrue(isinstance(buffer_size, int)) + self.assertEqual(25, buffer_size) + + def test_calculate_default_buffer_size_none_buffer_size_rounds_to_int(self): + self.module.get_seg_number = Mock(return_value = 5) + buffer_size = self.subject.calculate_default_buffer_size(None, 3, 1000) + self.assertTrue(isinstance(buffer_size, int)) + self.assertEqual(1, buffer_size) + + #TODO add more tests after finalizing the buffer size calculation + +class AnyStringWith(str): + def __eq__(self, other): + return self in other + + +if __name__ == '__main__': + unittest.main() + +# ---------------------------------------------------------------------