diff --git a/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in b/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in index 4a1c8ae06..401323ed6 100644 --- a/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in +++ b/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in @@ -89,14 +89,14 @@ class MiniBatchPreProcessor: self.grouping_cols, self.output_standardization_table) - num_rows_processed, num_missing_rows_skipped = self.\ - _get_skipped_rows_processed_count( - dep_var_array_str, - indep_var_array_str) + total_num_rows_processed, avg_num_rows_processed, \ + num_missing_rows_skipped = self._get_skipped_rows_processed_count( + dep_var_array_str, + indep_var_array_str) calculated_buffer_size = MiniBatchBufferSizeCalculator.\ calculate_default_buffer_size( self.buffer_size, - num_rows_processed, + avg_num_rows_processed, standardizer.independent_var_dimension) """ This query does the following: @@ -175,7 +175,7 @@ class MiniBatchPreProcessor: dependent_var_dbtype, calculated_buffer_size, dep_var_classes_str, - num_rows_processed, + total_num_rows_processed, num_missing_rows_skipped, self.grouping_cols ) @@ -211,27 +211,42 @@ class MiniBatchPreProcessor: # Note: Keep the null checking where clause of this query in sync with # the main create output table query. query = """ - SELECT COUNT(*) AS source_table_row_count, - sum(CASE WHEN + SELECT SUM(source_table_row_count_by_group) AS source_table_row_count, + SUM(num_rows_processed_by_group) AS total_num_rows_processed, + AVG(num_rows_processed_by_group) AS avg_num_rows_processed + FROM ( + SELECT COUNT(*) AS source_table_row_count_by_group, + SUM(CASE WHEN NOT {schema_madlib}.array_contains_null({dep_var_array}) AND NOT {schema_madlib}.array_contains_null({indep_var_array}) - THEN 1 ELSE 0 END) AS num_rows_processed + THEN 1 ELSE 0 END) AS num_rows_processed_by_group FROM {source_table} + {group_by_clause}) s """.format( schema_madlib = self.schema_madlib, source_table = self.source_table, dep_var_array = dep_var_array, - indep_var_array = indep_var_array) + indep_var_array = indep_var_array, + group_by_clause = "GROUP BY {0}".format(self.grouping_cols) \ + if self.grouping_cols else '') result = plpy.execute(query) - source_table_row_count = result[0]['source_table_row_count'] - num_rows_processed = result[0]['num_rows_processed'] - if not source_table_row_count or not num_rows_processed: + ## SUM and AVG both return float, and we have to cast them into int fo + ## summary table. For avg_num_rows_processed we need to ceil first so + ## that the minimum won't be 0 + source_table_row_count = int(result[0]['source_table_row_count']) + total_num_rows_processed = int(result[0]['total_num_rows_processed']) + avg_num_rows_processed = int(ceil(result[0]['avg_num_rows_processed'])) + if not source_table_row_count or not total_num_rows_processed or \ + not avg_num_rows_processed: plpy.error("Error while getting the row count of the source table" "{0}".format(self.source_table)) - num_missing_rows_skipped = source_table_row_count - num_rows_processed - return num_rows_processed, num_missing_rows_skipped + num_missing_rows_skipped = source_table_row_count - total_num_rows_processed + + return total_num_rows_processed, avg_num_rows_processed, \ + num_missing_rows_skipped + class MiniBatchQueryFormatter: """ @@ -450,7 +465,7 @@ class MiniBatchSummarizer: dependent_var_dbtype, buffer_size, class_values, - num_rows_processed, + total_num_rows_processed, num_missing_rows_skipped, grouping_cols): # 1. All the string columns are surrounded by "$$" to take care of @@ -467,7 +482,7 @@ class MiniBatchSummarizer: $${dependent_var_dbtype}$$::TEXT AS dependent_vartype, {buffer_size} AS buffer_size, {class_values} AS class_values, - {num_rows_processed} AS num_rows_processed, + {total_num_rows_processed} AS num_rows_processed, {num_missing_rows_skipped} AS num_missing_rows_skipped, {grouping_cols}::TEXT AS grouping_cols """.format(output_summary_table = output_summary_table, @@ -478,7 +493,7 @@ class MiniBatchSummarizer: dependent_var_dbtype = dependent_var_dbtype, buffer_size = buffer_size, class_values = class_values, - num_rows_processed = num_rows_processed, + total_num_rows_processed = total_num_rows_processed, num_missing_rows_skipped = num_missing_rows_skipped, grouping_cols = "$$" + grouping_cols + "$$" if grouping_cols else "NULL") @@ -491,14 +506,14 @@ class MiniBatchBufferSizeCalculator: """ @staticmethod def calculate_default_buffer_size(buffer_size, - num_rows_processed, + avg_num_rows_processed, independent_var_dimension): if buffer_size is not None: return buffer_size num_of_segments = get_seg_number() default_buffer_size = min(75000000.0/independent_var_dimension, - float(num_rows_processed)/num_of_segments) + float(avg_num_rows_processed)/num_of_segments) """ 1. For float number, we need at least one more buffer for the fraction part, e.g. if default_buffer_size = 0.25, we need to round it to 1. diff --git a/src/ports/postgres/modules/utilities/test/minibatch_preprocessing.sql_in b/src/ports/postgres/modules/utilities/test/minibatch_preprocessing.sql_in index 04c7fb5f5..2f8d80253 100644 --- a/src/ports/postgres/modules/utilities/test/minibatch_preprocessing.sql_in +++ b/src/ports/postgres/modules/utilities/test/minibatch_preprocessing.sql_in @@ -144,7 +144,7 @@ SELECT assert num_rows_processed = 10 AND num_missing_rows_skipped = 0 AND grouping_cols = 'rings', - 'Summary Validation failed for grouping col. Expected:' || __to_char(summary) + 'Summary Validation failed for grouping col. Actual:' || __to_char(summary) ) from (select * from minibatch_preprocessing_out_summary) summary; -- Test that the standardization table gets created. @@ -283,5 +283,5 @@ SELECT assert num_rows_processed = 1 AND num_missing_rows_skipped = 0 AND grouping_cols = '"rin!#''gs"', - 'Summary Validation failed for special chars. Expected:' || __to_char(summary) + 'Summary Validation failed for special chars. Actual:' || __to_char(summary) ) from (select * from minibatch_preprocessing_out_summary) summary; diff --git a/src/ports/postgres/modules/utilities/test/unit_tests/test_minibatch_preprocessing.py_in b/src/ports/postgres/modules/utilities/test/unit_tests/test_minibatch_preprocessing.py_in index 548a6dcb3..879d77d48 100644 --- a/src/ports/postgres/modules/utilities/test/unit_tests/test_minibatch_preprocessing.py_in +++ b/src/ports/postgres/modules/utilities/test/unit_tests/test_minibatch_preprocessing.py_in @@ -74,28 +74,30 @@ class MiniBatchPreProcessingTestCase(unittest.TestCase): def test_minibatch_preprocessor_executes_query(self): preprocessor_obj = self.module.MiniBatchPreProcessor(self.default_schema_madlib, - "input", - "out", - self.default_dep_var, - self.default_ind_var, - self.grouping_cols, - self.default_buffer_size) + "input", + "out", + self.default_dep_var, + self.default_ind_var, + self.grouping_cols, + self.default_buffer_size) self.plpy_mock_execute.side_effect = [[{"source_table_row_count":5 , - "num_rows_processed":3}], ""] + "total_num_rows_processed":3, + "avg_num_rows_processed": 2}], ""] preprocessor_obj.minibatch_preprocessor() self.assertEqual(2, self.plpy_mock_execute.call_count) self.assertEqual(self.default_buffer_size, preprocessor_obj.buffer_size) def test_minibatch_preprocessor_null_buffer_size_executes_query(self): preprocessor_obj = self.module.MiniBatchPreProcessor(self.default_schema_madlib, - "input", - "out", - self.default_dep_var, - self.default_ind_var, - self.grouping_cols, - None) + "input", + "out", + self.default_dep_var, + self.default_ind_var, + self.grouping_cols, + None) self.plpy_mock_execute.side_effect = [[{"source_table_row_count":5 , - "num_rows_processed":3}], ""] + "total_num_rows_processed":3, + "avg_num_rows_processed": 2}], ""] self.module.MiniBatchBufferSizeCalculator.calculate_default_buffer_size = Mock() preprocessor_obj.minibatch_preprocessor() self.assertEqual(2, self.plpy_mock_execute.call_count) @@ -103,22 +105,22 @@ class MiniBatchPreProcessingTestCase(unittest.TestCase): def test_minibatch_preprocessor_multiple_dep_var_raises_exception(self): with self.assertRaises(plpy.PLPYException): self.module.MiniBatchPreProcessor(self.default_schema_madlib, - self.default_source_table, - self.default_output_table, - "y1,y2", - self.default_ind_var, - self.grouping_cols, - self.default_buffer_size) + self.default_source_table, + self.default_output_table, + "y1,y2", + self.default_ind_var, + self.grouping_cols, + self.default_buffer_size) def test_minibatch_preprocessor_buffer_size_zero_fails(self): with self.assertRaises(plpy.PLPYException): self.module.MiniBatchPreProcessor(self.default_schema_madlib, - self.default_source_table, - self.default_output_table, - self.default_dep_var, - self.default_ind_var, - self.grouping_cols, - 0) + self.default_source_table, + self.default_output_table, + self.default_dep_var, + self.default_ind_var, + self.grouping_cols, + 0) def test_minibatch_preprocessor_buffer_size_one_passes(self): #not sure how to assert that an exception has not been raised diff --git a/src/ports/postgres/modules/utilities/utilities.py_in b/src/ports/postgres/modules/utilities/utilities.py_in index 320082cf1..40ca40a59 100644 --- a/src/ports/postgres/modules/utilities/utilities.py_in +++ b/src/ports/postgres/modules/utilities/utilities.py_in @@ -36,16 +36,19 @@ def is_platform_hawq(): def get_seg_number(): - """ Find out how many primary segments exist in the distribution - Might be useful for partitioning data. + """ Find out how many primary segments(not include master segment) exist + in the distribution. Might be useful for partitioning data. """ if is_platform_pg(): return 1 else: - return plpy.execute(""" + count = plpy.execute(""" SELECT count(*) from gp_segment_configuration - WHERE role = 'p' + WHERE role = 'p' and content != -1 """)[0]['count'] + ## in case some weird gpdb configuration happens, always returns + ## primary segment number >= 1 + return max(1, count) # ------------------------------------------------------------------------------