Skip to content

Commit

Permalink
Minibatch Preprocessor: change default buffer size formula to fit gro…
Browse files Browse the repository at this point in the history
…uping

- This commit changes the previous calculation formula for default buffer
  size. Previously, we used num_rows_processed/num_of_segments to indicate
  data distribution in each segment. To adjust this to grouping
  scenario, we use avg_num_rows_processed/num_of_segment to indicate data
  distribution when there are more than one groups of data. Other code changes
  are due to this change.
- This commit also modifies get_seg_number() to only get the number of
  primary segments. Previously, this function was returning the total
  segment number including master segment. This commit changes it to
  only get the primary segment number.

Closes #256
  • Loading branch information
Jingyi Mei authored and njayaram2 committed Apr 10, 2018
1 parent 3c443e1 commit 3e519dc
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 52 deletions.
55 changes: 35 additions & 20 deletions src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,14 @@ class MiniBatchPreProcessor:
self.grouping_cols,
self.output_standardization_table)

num_rows_processed, num_missing_rows_skipped = self.\
_get_skipped_rows_processed_count(
dep_var_array_str,
indep_var_array_str)
total_num_rows_processed, avg_num_rows_processed, \
num_missing_rows_skipped = self._get_skipped_rows_processed_count(
dep_var_array_str,
indep_var_array_str)
calculated_buffer_size = MiniBatchBufferSizeCalculator.\
calculate_default_buffer_size(
self.buffer_size,
num_rows_processed,
avg_num_rows_processed,
standardizer.independent_var_dimension)
"""
This query does the following:
Expand Down Expand Up @@ -175,7 +175,7 @@ class MiniBatchPreProcessor:
dependent_var_dbtype,
calculated_buffer_size,
dep_var_classes_str,
num_rows_processed,
total_num_rows_processed,
num_missing_rows_skipped,
self.grouping_cols
)
Expand Down Expand Up @@ -211,27 +211,42 @@ class MiniBatchPreProcessor:
# Note: Keep the null checking where clause of this query in sync with
# the main create output table query.
query = """
SELECT COUNT(*) AS source_table_row_count,
sum(CASE WHEN
SELECT SUM(source_table_row_count_by_group) AS source_table_row_count,
SUM(num_rows_processed_by_group) AS total_num_rows_processed,
AVG(num_rows_processed_by_group) AS avg_num_rows_processed
FROM (
SELECT COUNT(*) AS source_table_row_count_by_group,
SUM(CASE WHEN
NOT {schema_madlib}.array_contains_null({dep_var_array})
AND NOT {schema_madlib}.array_contains_null({indep_var_array})
THEN 1 ELSE 0 END) AS num_rows_processed
THEN 1 ELSE 0 END) AS num_rows_processed_by_group
FROM {source_table}
{group_by_clause}) s
""".format(
schema_madlib = self.schema_madlib,
source_table = self.source_table,
dep_var_array = dep_var_array,
indep_var_array = indep_var_array)
indep_var_array = indep_var_array,
group_by_clause = "GROUP BY {0}".format(self.grouping_cols) \
if self.grouping_cols else '')
result = plpy.execute(query)

source_table_row_count = result[0]['source_table_row_count']
num_rows_processed = result[0]['num_rows_processed']
if not source_table_row_count or not num_rows_processed:
## SUM and AVG both return float, and we have to cast them into int fo
## summary table. For avg_num_rows_processed we need to ceil first so
## that the minimum won't be 0
source_table_row_count = int(result[0]['source_table_row_count'])
total_num_rows_processed = int(result[0]['total_num_rows_processed'])
avg_num_rows_processed = int(ceil(result[0]['avg_num_rows_processed']))
if not source_table_row_count or not total_num_rows_processed or \
not avg_num_rows_processed:
plpy.error("Error while getting the row count of the source table"
"{0}".format(self.source_table))
num_missing_rows_skipped = source_table_row_count - num_rows_processed

return num_rows_processed, num_missing_rows_skipped
num_missing_rows_skipped = source_table_row_count - total_num_rows_processed

return total_num_rows_processed, avg_num_rows_processed, \
num_missing_rows_skipped


class MiniBatchQueryFormatter:
"""
Expand Down Expand Up @@ -450,7 +465,7 @@ class MiniBatchSummarizer:
dependent_var_dbtype,
buffer_size,
class_values,
num_rows_processed,
total_num_rows_processed,
num_missing_rows_skipped,
grouping_cols):
# 1. All the string columns are surrounded by "$$" to take care of
Expand All @@ -467,7 +482,7 @@ class MiniBatchSummarizer:
$${dependent_var_dbtype}$$::TEXT AS dependent_vartype,
{buffer_size} AS buffer_size,
{class_values} AS class_values,
{num_rows_processed} AS num_rows_processed,
{total_num_rows_processed} AS num_rows_processed,
{num_missing_rows_skipped} AS num_missing_rows_skipped,
{grouping_cols}::TEXT AS grouping_cols
""".format(output_summary_table = output_summary_table,
Expand All @@ -478,7 +493,7 @@ class MiniBatchSummarizer:
dependent_var_dbtype = dependent_var_dbtype,
buffer_size = buffer_size,
class_values = class_values,
num_rows_processed = num_rows_processed,
total_num_rows_processed = total_num_rows_processed,
num_missing_rows_skipped = num_missing_rows_skipped,
grouping_cols = "$$" + grouping_cols + "$$"
if grouping_cols else "NULL")
Expand All @@ -491,14 +506,14 @@ class MiniBatchBufferSizeCalculator:
"""
@staticmethod
def calculate_default_buffer_size(buffer_size,
num_rows_processed,
avg_num_rows_processed,
independent_var_dimension):
if buffer_size is not None:
return buffer_size
num_of_segments = get_seg_number()

default_buffer_size = min(75000000.0/independent_var_dimension,
float(num_rows_processed)/num_of_segments)
float(avg_num_rows_processed)/num_of_segments)
"""
1. For float number, we need at least one more buffer for the fraction part, e.g.
if default_buffer_size = 0.25, we need to round it to 1.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ SELECT assert
num_rows_processed = 10 AND
num_missing_rows_skipped = 0 AND
grouping_cols = 'rings',
'Summary Validation failed for grouping col. Expected:' || __to_char(summary)
'Summary Validation failed for grouping col. Actual:' || __to_char(summary)
) from (select * from minibatch_preprocessing_out_summary) summary;

-- Test that the standardization table gets created.
Expand Down Expand Up @@ -283,5 +283,5 @@ SELECT assert
num_rows_processed = 1 AND
num_missing_rows_skipped = 0 AND
grouping_cols = '"rin!#''gs"',
'Summary Validation failed for special chars. Expected:' || __to_char(summary)
'Summary Validation failed for special chars. Actual:' || __to_char(summary)
) from (select * from minibatch_preprocessing_out_summary) summary;
Original file line number Diff line number Diff line change
Expand Up @@ -74,51 +74,53 @@ class MiniBatchPreProcessingTestCase(unittest.TestCase):

def test_minibatch_preprocessor_executes_query(self):
preprocessor_obj = self.module.MiniBatchPreProcessor(self.default_schema_madlib,
"input",
"out",
self.default_dep_var,
self.default_ind_var,
self.grouping_cols,
self.default_buffer_size)
"input",
"out",
self.default_dep_var,
self.default_ind_var,
self.grouping_cols,
self.default_buffer_size)
self.plpy_mock_execute.side_effect = [[{"source_table_row_count":5 ,
"num_rows_processed":3}], ""]
"total_num_rows_processed":3,
"avg_num_rows_processed": 2}], ""]
preprocessor_obj.minibatch_preprocessor()
self.assertEqual(2, self.plpy_mock_execute.call_count)
self.assertEqual(self.default_buffer_size, preprocessor_obj.buffer_size)

def test_minibatch_preprocessor_null_buffer_size_executes_query(self):
preprocessor_obj = self.module.MiniBatchPreProcessor(self.default_schema_madlib,
"input",
"out",
self.default_dep_var,
self.default_ind_var,
self.grouping_cols,
None)
"input",
"out",
self.default_dep_var,
self.default_ind_var,
self.grouping_cols,
None)
self.plpy_mock_execute.side_effect = [[{"source_table_row_count":5 ,
"num_rows_processed":3}], ""]
"total_num_rows_processed":3,
"avg_num_rows_processed": 2}], ""]
self.module.MiniBatchBufferSizeCalculator.calculate_default_buffer_size = Mock()
preprocessor_obj.minibatch_preprocessor()
self.assertEqual(2, self.plpy_mock_execute.call_count)

def test_minibatch_preprocessor_multiple_dep_var_raises_exception(self):
with self.assertRaises(plpy.PLPYException):
self.module.MiniBatchPreProcessor(self.default_schema_madlib,
self.default_source_table,
self.default_output_table,
"y1,y2",
self.default_ind_var,
self.grouping_cols,
self.default_buffer_size)
self.default_source_table,
self.default_output_table,
"y1,y2",
self.default_ind_var,
self.grouping_cols,
self.default_buffer_size)

def test_minibatch_preprocessor_buffer_size_zero_fails(self):
with self.assertRaises(plpy.PLPYException):
self.module.MiniBatchPreProcessor(self.default_schema_madlib,
self.default_source_table,
self.default_output_table,
self.default_dep_var,
self.default_ind_var,
self.grouping_cols,
0)
self.default_source_table,
self.default_output_table,
self.default_dep_var,
self.default_ind_var,
self.grouping_cols,
0)

def test_minibatch_preprocessor_buffer_size_one_passes(self):
#not sure how to assert that an exception has not been raised
Expand Down
11 changes: 7 additions & 4 deletions src/ports/postgres/modules/utilities/utilities.py_in
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,19 @@ def is_platform_hawq():


def get_seg_number():
""" Find out how many primary segments exist in the distribution
Might be useful for partitioning data.
""" Find out how many primary segments(not include master segment) exist
in the distribution. Might be useful for partitioning data.
"""
if is_platform_pg():
return 1
else:
return plpy.execute("""
count = plpy.execute("""
SELECT count(*) from gp_segment_configuration
WHERE role = 'p'
WHERE role = 'p' and content != -1
""")[0]['count']
## in case some weird gpdb configuration happens, always returns
## primary segment number >= 1
return max(1, count)
# ------------------------------------------------------------------------------


Expand Down

0 comments on commit 3e519dc

Please sign in to comment.