From 78430bc8586ae0256a24de2472392564a15f7f8e Mon Sep 17 00:00:00 2001 From: Nandish Jayaram Date: Mon, 17 Dec 2018 17:54:42 -0800 Subject: [PATCH] Minibatch Preprocessor for Deep learning The minibatch preprocessor we currently have in MADlib is bloated for DL tasks. This feature adds a simplified version of creating buffers, and divides each element of the independent array by a normalizing constant for standardization (which is 255.0 by default). This is standard practice with image data. Co-authored-by: Arvind Sridhar Co-authored-by: Domino Valdano --- .../utilities/minibatch_preprocessing.py_in | 101 +++++++++++++++++- .../utilities/minibatch_preprocessing.sql_in | 48 +++++++++ .../test/minibatch_preprocessing.sql_in | 38 +++++++ .../modules/utilities/utilities.py_in | 4 +- 4 files changed, 188 insertions(+), 3 deletions(-) diff --git a/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in b/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in index 88433c937..517daeb6e 100644 --- a/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in +++ b/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in @@ -39,7 +39,7 @@ from utilities import py_list_to_sql_string from utilities import split_quoted_delimited_str from utilities import unique_string from utilities import validate_module_input_params -from utilities import NUMERIC, INTEGER, TEXT, BOOLEAN, INCLUDE_ARRAY +from utilities import NUMERIC, INTEGER, TEXT, BOOLEAN, INCLUDE_ARRAY, ONLY_ARRAY from mean_std_dev_calculator import MeanStdDevCalculator from validate_args import get_expr_type @@ -51,6 +51,105 @@ m4_changequote(`') MINIBATCH_OUTPUT_DEPENDENT_COLNAME = "dependent_varname" MINIBATCH_OUTPUT_INDEPENDENT_COLNAME = "independent_varname" +class MiniBatchPreProcessorDL: + def __init__(self, schema_madlib, source_table, output_table, + dependent_varname, independent_varname, num_of_buffers, + normalizing_const, **kwargs): + self.schema_madlib = schema_madlib + self.source_table = source_table + self.output_table = output_table + self.dependent_varname = dependent_varname + self.independent_varname = independent_varname + self.num_of_buffers = num_of_buffers + self.normalizing_const = normalizing_const + self.module_name = "minibatch_preprocessor_DL" + self.output_summary_table = add_postfix(self.output_table, "_summary") + self._validate_args() + + def minibatch_preprocessor_dl(self): + norm_tbl = unique_string(desp='normalized') + # Create a temp table that has independent var normalized. + scalar_mult_sql = """ + CREATE TEMP TABLE {norm_tbl} AS + SELECT {self.schema_madlib}.array_scalar_mult( + {self.independent_varname}::REAL[], (1/{self.normalizing_const})::REAL) AS x_norm, + {self.dependent_varname} AS y, + row_number() over() AS row_id + FROM {self.source_table} + """.format(**locals()) + plpy.execute(scalar_mult_sql) + # Create the mini-batched output table + sql = """ + CREATE TABLE {self.output_table} AS + SELECT * FROM + ( + SELECT {self.schema_madlib}.agg_array_concat( + ARRAY[{norm_tbl}.x_norm::REAL[]]) AS {x}, + array_agg({norm_tbl}.y) AS {y}, + ({norm_tbl}.row_id%{self.num_of_buffers})::smallint AS buffer_id + FROM {norm_tbl} + GROUP BY buffer_id + ) b + DISTRIBUTED BY (buffer_id) + """.format(x=MINIBATCH_OUTPUT_INDEPENDENT_COLNAME, + y=MINIBATCH_OUTPUT_DEPENDENT_COLNAME, **locals()) + plpy.execute(sql) + plpy.execute("DROP TABLE {}".format(norm_tbl)) + # Create summary table + self._create_output_summary_table() + + def _create_output_summary_table(self): + query = """ + CREATE TABLE {self.output_summary_table} AS + SELECT + $__madlib__${self.source_table}$__madlib__$::TEXT AS source_table, + $__madlib__${self.output_table}$__madlib__$::TEXT AS output_table, + $__madlib__${self.dependent_varname}$__madlib__$::TEXT AS dependent_varname, + $__madlib__${self.independent_varname}$__madlib__$::TEXT AS independent_varname, + {self.num_of_buffers} AS num_of_buffers + """.format(self=self) + plpy.execute(query) + + def _validate_args(self): + validate_module_input_params( + self.source_table, self.output_table, self.independent_varname, + self.dependent_varname, self.module_name, None, + [self.output_summary_table]) + self.independent_vartype = get_expr_type( + self.independent_varname, self.source_table) + _assert(is_valid_psql_type(self.independent_vartype, + NUMERIC | ONLY_ARRAY), + "Invalid independent variable type, should be an array of " \ + "one of {0}".format(','.join(NUMERIC))) + self.dependent_vartype = get_expr_type( + self.dependent_varname, self.source_table) + dep_valid_types = NUMERIC | TEXT | BOOLEAN + _assert(is_valid_psql_type(self.dependent_vartype, dep_valid_types), + "Invalid dependent variable type, should be one of {0}". + format(','.join(dep_valid_types))) + self._validate_num_buffers() + + def _validate_num_buffers(self): + _assert(self.num_of_buffers > 0, + "Number of buffers must be greater than 0.") + rows_in_tbl = plpy.execute(""" + SELECT count(*) AS cnt FROM {} + """.format(self.source_table))[0]['cnt'] + _assert(self.num_of_buffers <= rows_in_tbl, + "Number of buffers cannot exceed the number of rows " \ + "in the source table.") + buffer_size_calculator = MiniBatchBufferSizeCalculator() + indepdent_var_dim = _tbl_dimension_rownum( + self.schema_madlib, self.source_table, self.independent_varname, + skip_row_count=True) + safe_buffer_size = buffer_size_calculator.calculate_default_buffer_size( + None, rows_in_tbl, indepdent_var_dim[0]) + safe_num_buffers = rows_in_tbl/safe_buffer_size + if self.num_of_buffers < safe_num_buffers: + plpy.warning("It might be safer to increase the number of " \ + "buffers to greater than {} to avoid 1GB limit.".format( + safe_num_buffers)) + class MiniBatchPreProcessor: """ diff --git a/src/ports/postgres/modules/utilities/minibatch_preprocessing.sql_in b/src/ports/postgres/modules/utilities/minibatch_preprocessing.sql_in index 58668a171..99388d0f0 100644 --- a/src/ports/postgres/modules/utilities/minibatch_preprocessing.sql_in +++ b/src/ports/postgres/modules/utilities/minibatch_preprocessing.sql_in @@ -589,3 +589,51 @@ RETURNS VARCHAR AS $$ return minibatch_preprocessing.MiniBatchDocumentation.minibatch_preprocessor_help(schema_madlib, '') $$ LANGUAGE plpythonu VOLATILE m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `'); + +CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.minibatch_preprocessor_dl( + source_table VARCHAR, + output_table VARCHAR, + dependent_varname VARCHAR, + independent_varname VARCHAR, + num_of_buffers INTEGER, + normalizing_const DOUBLE PRECISION +) RETURNS VOID AS $$ + PythonFunctionBodyOnly(utilities, minibatch_preprocessing) + from utilities.control import MinWarning + with AOControl(False): + with MinWarning('error'): + minibatch_preprocessor_obj = minibatch_preprocessing.MiniBatchPreProcessorDL(**globals()) + minibatch_preprocessor_obj.minibatch_preprocessor_dl() +$$ LANGUAGE plpythonu VOLATILE +m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `'); + +CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.minibatch_preprocessor_dl( + source_table VARCHAR, + output_table VARCHAR, + dependent_varname VARCHAR, + independent_varname VARCHAR, + num_of_buffers INTEGER +) RETURNS VOID AS $$ + SELECT MADLIB_SCHEMA.minibatch_preprocessor_dl($1, $2, $3, $4, $5, 255.0); +$$ LANGUAGE sql VOLATILE +m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `'); + +CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.agg_array_concat_transition(anyarray, anyarray) + RETURNS anyarray + AS 'select $1 || $2' + LANGUAGE SQL + IMMUTABLE; + +CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.agg_array_concat_merge(anyarray, anyarray) + RETURNS anyarray + AS 'select $1 || $2' + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT; + +DROP AGGREGATE IF EXISTS MADLIB_SCHEMA.agg_array_concat(anyarray); +CREATE AGGREGATE MADLIB_SCHEMA.agg_array_concat(anyarray) ( + SFUNC = MADLIB_SCHEMA.agg_array_concat_transition, + STYPE = anyarray, + PREFUNC = MADLIB_SCHEMA.agg_array_concat_merge + ); diff --git a/src/ports/postgres/modules/utilities/test/minibatch_preprocessing.sql_in b/src/ports/postgres/modules/utilities/test/minibatch_preprocessing.sql_in index b6b2996f3..8de504afc 100644 --- a/src/ports/postgres/modules/utilities/test/minibatch_preprocessing.sql_in +++ b/src/ports/postgres/modules/utilities/test/minibatch_preprocessing.sql_in @@ -348,3 +348,41 @@ SELECT assert( grouping_cols = '"rin$$Ж!#''gs"', 'Summary Validation failed for special chars. Actual:' || __to_char(summary) ) from (select * from minibatch_preprocessing_out_summary order by class_values) summary; + +DROP TABLE IF EXISTS minibatch_preprocessor_dl_input; +CREATE TABLE minibatch_preprocessor_dl_input(id serial, x double precision[]); +INSERT INTO minibatch_preprocessor_dl_input(x) VALUES +(ARRAY[1,2,3,4,5,6]), +(ARRAY[11,2,3,4,5,6]), +(ARRAY[11,22,33,4,5,6]), +(ARRAY[11,22,33,44,5,6]), +(ARRAY[11,22,33,44,65,6]), +(ARRAY[11,22,33,44,65,56]), +(ARRAY[11,22,33,44,65,56]), +(ARRAY[11,22,33,44,65,56]), +(ARRAY[11,22,33,44,65,56]), +(ARRAY[11,22,33,44,65,56]), +(ARRAY[11,22,33,44,65,56]), +(ARRAY[11,22,33,44,65,56]), +(ARRAY[11,22,33,144,65,56]), +(ARRAY[11,22,233,44,65,56]), +(ARRAY[11,22,33,44,65,56]), +(ARRAY[11,22,33,44,65,56]), +(ARRAY[11,22,33,44,65,56]); + +DROP TABLE IF EXISTS minibatch_preprocessor_dl_batch, minibatch_preprocessor_dl_batch_summary; +SELECT minibatch_preprocessor_dl( + 'minibatch_preprocessor_dl_input', + 'minibatch_preprocessor_dl_batch', + 'id', + 'x', + 4); + +SELECT assert(count(*)=4, 'Incorrect number of buffers in minibatch_preprocessor_dl_batch.') +FROM minibatch_preprocessor_dl_batch; + +SELECT assert(array_upper(independent_varname, 1)=5, 'Incorrect buffer size.') +FROM minibatch_preprocessor_dl_batch WHERE buffer_id=1; + +SELECT assert(array_upper(independent_varname, 1)=2, 'Incorrect buffer size.') +FROM minibatch_preprocessor_dl_batch WHERE buffer_id=4; diff --git a/src/ports/postgres/modules/utilities/utilities.py_in b/src/ports/postgres/modules/utilities/utilities.py_in index 50c426b74..89907b391 100644 --- a/src/ports/postgres/modules/utilities/utilities.py_in +++ b/src/ports/postgres/modules/utilities/utilities.py_in @@ -875,8 +875,8 @@ def collate_plpy_result(plpy_result_rows): def validate_module_input_params(source_table, output_table, independent_varname, - dependent_varname, module_name, grouping_cols, - other_output_tables=None): + dependent_varname, module_name, + grouping_cols=None, other_output_tables=None): """ This function is supposed to be used for validating params for supervised learning like algos, e.g. linear regression, mlp, etc. since all