From 3e519dcce66d0dd4bfcc3c45f246f476c26e26d7 Mon Sep 17 00:00:00 2001 From: Jingyi Mei Date: Tue, 3 Apr 2018 17:50:57 -0700 Subject: [PATCH] Minibatch Preprocessor: change default buffer size formula to fit grouping - This commit changes the previous calculation formula for default buffer size. Previously, we used num_rows_processed/num_of_segments to indicate data distribution in each segment. To adjust this to grouping scenario, we use avg_num_rows_processed/num_of_segment to indicate data distribution when there are more than one groups of data. Other code changes are due to this change. - This commit also modifies get_seg_number() to only get the number of primary segments. Previously, this function was returning the total segment number including master segment. This commit changes it to only get the primary segment number. Closes #256 --- .../utilities/minibatch_preprocessing.py_in | 55 ++++++++++++------- .../test/minibatch_preprocessing.sql_in | 4 +- .../test_minibatch_preprocessing.py_in | 54 +++++++++--------- .../modules/utilities/utilities.py_in | 11 ++-- 4 files changed, 72 insertions(+), 52 deletions(-) diff --git a/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in b/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in index 4a1c8ae06..401323ed6 100644 --- a/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in +++ b/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in @@ -89,14 +89,14 @@ class MiniBatchPreProcessor: self.grouping_cols, self.output_standardization_table) - num_rows_processed, num_missing_rows_skipped = self.\ - _get_skipped_rows_processed_count( - dep_var_array_str, - indep_var_array_str) + total_num_rows_processed, avg_num_rows_processed, \ + num_missing_rows_skipped = self._get_skipped_rows_processed_count( + dep_var_array_str, + indep_var_array_str) calculated_buffer_size = MiniBatchBufferSizeCalculator.\ calculate_default_buffer_size( self.buffer_size, - num_rows_processed, + avg_num_rows_processed, standardizer.independent_var_dimension) """ This query does the following: @@ -175,7 +175,7 @@ class MiniBatchPreProcessor: dependent_var_dbtype, calculated_buffer_size, dep_var_classes_str, - num_rows_processed, + total_num_rows_processed, num_missing_rows_skipped, self.grouping_cols ) @@ -211,27 +211,42 @@ class MiniBatchPreProcessor: # Note: Keep the null checking where clause of this query in sync with # the main create output table query. query = """ - SELECT COUNT(*) AS source_table_row_count, - sum(CASE WHEN + SELECT SUM(source_table_row_count_by_group) AS source_table_row_count, + SUM(num_rows_processed_by_group) AS total_num_rows_processed, + AVG(num_rows_processed_by_group) AS avg_num_rows_processed + FROM ( + SELECT COUNT(*) AS source_table_row_count_by_group, + SUM(CASE WHEN NOT {schema_madlib}.array_contains_null({dep_var_array}) AND NOT {schema_madlib}.array_contains_null({indep_var_array}) - THEN 1 ELSE 0 END) AS num_rows_processed + THEN 1 ELSE 0 END) AS num_rows_processed_by_group FROM {source_table} + {group_by_clause}) s """.format( schema_madlib = self.schema_madlib, source_table = self.source_table, dep_var_array = dep_var_array, - indep_var_array = indep_var_array) + indep_var_array = indep_var_array, + group_by_clause = "GROUP BY {0}".format(self.grouping_cols) \ + if self.grouping_cols else '') result = plpy.execute(query) - source_table_row_count = result[0]['source_table_row_count'] - num_rows_processed = result[0]['num_rows_processed'] - if not source_table_row_count or not num_rows_processed: + ## SUM and AVG both return float, and we have to cast them into int fo + ## summary table. For avg_num_rows_processed we need to ceil first so + ## that the minimum won't be 0 + source_table_row_count = int(result[0]['source_table_row_count']) + total_num_rows_processed = int(result[0]['total_num_rows_processed']) + avg_num_rows_processed = int(ceil(result[0]['avg_num_rows_processed'])) + if not source_table_row_count or not total_num_rows_processed or \ + not avg_num_rows_processed: plpy.error("Error while getting the row count of the source table" "{0}".format(self.source_table)) - num_missing_rows_skipped = source_table_row_count - num_rows_processed - return num_rows_processed, num_missing_rows_skipped + num_missing_rows_skipped = source_table_row_count - total_num_rows_processed + + return total_num_rows_processed, avg_num_rows_processed, \ + num_missing_rows_skipped + class MiniBatchQueryFormatter: """ @@ -450,7 +465,7 @@ class MiniBatchSummarizer: dependent_var_dbtype, buffer_size, class_values, - num_rows_processed, + total_num_rows_processed, num_missing_rows_skipped, grouping_cols): # 1. All the string columns are surrounded by "$$" to take care of @@ -467,7 +482,7 @@ class MiniBatchSummarizer: $${dependent_var_dbtype}$$::TEXT AS dependent_vartype, {buffer_size} AS buffer_size, {class_values} AS class_values, - {num_rows_processed} AS num_rows_processed, + {total_num_rows_processed} AS num_rows_processed, {num_missing_rows_skipped} AS num_missing_rows_skipped, {grouping_cols}::TEXT AS grouping_cols """.format(output_summary_table = output_summary_table, @@ -478,7 +493,7 @@ class MiniBatchSummarizer: dependent_var_dbtype = dependent_var_dbtype, buffer_size = buffer_size, class_values = class_values, - num_rows_processed = num_rows_processed, + total_num_rows_processed = total_num_rows_processed, num_missing_rows_skipped = num_missing_rows_skipped, grouping_cols = "$$" + grouping_cols + "$$" if grouping_cols else "NULL") @@ -491,14 +506,14 @@ class MiniBatchBufferSizeCalculator: """ @staticmethod def calculate_default_buffer_size(buffer_size, - num_rows_processed, + avg_num_rows_processed, independent_var_dimension): if buffer_size is not None: return buffer_size num_of_segments = get_seg_number() default_buffer_size = min(75000000.0/independent_var_dimension, - float(num_rows_processed)/num_of_segments) + float(avg_num_rows_processed)/num_of_segments) """ 1. For float number, we need at least one more buffer for the fraction part, e.g. if default_buffer_size = 0.25, we need to round it to 1. diff --git a/src/ports/postgres/modules/utilities/test/minibatch_preprocessing.sql_in b/src/ports/postgres/modules/utilities/test/minibatch_preprocessing.sql_in index 04c7fb5f5..2f8d80253 100644 --- a/src/ports/postgres/modules/utilities/test/minibatch_preprocessing.sql_in +++ b/src/ports/postgres/modules/utilities/test/minibatch_preprocessing.sql_in @@ -144,7 +144,7 @@ SELECT assert num_rows_processed = 10 AND num_missing_rows_skipped = 0 AND grouping_cols = 'rings', - 'Summary Validation failed for grouping col. Expected:' || __to_char(summary) + 'Summary Validation failed for grouping col. Actual:' || __to_char(summary) ) from (select * from minibatch_preprocessing_out_summary) summary; -- Test that the standardization table gets created. @@ -283,5 +283,5 @@ SELECT assert num_rows_processed = 1 AND num_missing_rows_skipped = 0 AND grouping_cols = '"rin!#''gs"', - 'Summary Validation failed for special chars. Expected:' || __to_char(summary) + 'Summary Validation failed for special chars. Actual:' || __to_char(summary) ) from (select * from minibatch_preprocessing_out_summary) summary; diff --git a/src/ports/postgres/modules/utilities/test/unit_tests/test_minibatch_preprocessing.py_in b/src/ports/postgres/modules/utilities/test/unit_tests/test_minibatch_preprocessing.py_in index 548a6dcb3..879d77d48 100644 --- a/src/ports/postgres/modules/utilities/test/unit_tests/test_minibatch_preprocessing.py_in +++ b/src/ports/postgres/modules/utilities/test/unit_tests/test_minibatch_preprocessing.py_in @@ -74,28 +74,30 @@ class MiniBatchPreProcessingTestCase(unittest.TestCase): def test_minibatch_preprocessor_executes_query(self): preprocessor_obj = self.module.MiniBatchPreProcessor(self.default_schema_madlib, - "input", - "out", - self.default_dep_var, - self.default_ind_var, - self.grouping_cols, - self.default_buffer_size) + "input", + "out", + self.default_dep_var, + self.default_ind_var, + self.grouping_cols, + self.default_buffer_size) self.plpy_mock_execute.side_effect = [[{"source_table_row_count":5 , - "num_rows_processed":3}], ""] + "total_num_rows_processed":3, + "avg_num_rows_processed": 2}], ""] preprocessor_obj.minibatch_preprocessor() self.assertEqual(2, self.plpy_mock_execute.call_count) self.assertEqual(self.default_buffer_size, preprocessor_obj.buffer_size) def test_minibatch_preprocessor_null_buffer_size_executes_query(self): preprocessor_obj = self.module.MiniBatchPreProcessor(self.default_schema_madlib, - "input", - "out", - self.default_dep_var, - self.default_ind_var, - self.grouping_cols, - None) + "input", + "out", + self.default_dep_var, + self.default_ind_var, + self.grouping_cols, + None) self.plpy_mock_execute.side_effect = [[{"source_table_row_count":5 , - "num_rows_processed":3}], ""] + "total_num_rows_processed":3, + "avg_num_rows_processed": 2}], ""] self.module.MiniBatchBufferSizeCalculator.calculate_default_buffer_size = Mock() preprocessor_obj.minibatch_preprocessor() self.assertEqual(2, self.plpy_mock_execute.call_count) @@ -103,22 +105,22 @@ class MiniBatchPreProcessingTestCase(unittest.TestCase): def test_minibatch_preprocessor_multiple_dep_var_raises_exception(self): with self.assertRaises(plpy.PLPYException): self.module.MiniBatchPreProcessor(self.default_schema_madlib, - self.default_source_table, - self.default_output_table, - "y1,y2", - self.default_ind_var, - self.grouping_cols, - self.default_buffer_size) + self.default_source_table, + self.default_output_table, + "y1,y2", + self.default_ind_var, + self.grouping_cols, + self.default_buffer_size) def test_minibatch_preprocessor_buffer_size_zero_fails(self): with self.assertRaises(plpy.PLPYException): self.module.MiniBatchPreProcessor(self.default_schema_madlib, - self.default_source_table, - self.default_output_table, - self.default_dep_var, - self.default_ind_var, - self.grouping_cols, - 0) + self.default_source_table, + self.default_output_table, + self.default_dep_var, + self.default_ind_var, + self.grouping_cols, + 0) def test_minibatch_preprocessor_buffer_size_one_passes(self): #not sure how to assert that an exception has not been raised diff --git a/src/ports/postgres/modules/utilities/utilities.py_in b/src/ports/postgres/modules/utilities/utilities.py_in index 320082cf1..40ca40a59 100644 --- a/src/ports/postgres/modules/utilities/utilities.py_in +++ b/src/ports/postgres/modules/utilities/utilities.py_in @@ -36,16 +36,19 @@ def is_platform_hawq(): def get_seg_number(): - """ Find out how many primary segments exist in the distribution - Might be useful for partitioning data. + """ Find out how many primary segments(not include master segment) exist + in the distribution. Might be useful for partitioning data. """ if is_platform_pg(): return 1 else: - return plpy.execute(""" + count = plpy.execute(""" SELECT count(*) from gp_segment_configuration - WHERE role = 'p' + WHERE role = 'p' and content != -1 """)[0]['count'] + ## in case some weird gpdb configuration happens, always returns + ## primary segment number >= 1 + return max(1, count) # ------------------------------------------------------------------------------