Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minibatch Preprocessing: change default buffer size formula for grouping #256

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 35 additions & 20 deletions src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,14 @@ class MiniBatchPreProcessor:
self.grouping_cols,
self.output_standardization_table)

num_rows_processed, num_missing_rows_skipped = self.\
_get_skipped_rows_processed_count(
dep_var_array_str,
indep_var_array_str)
total_num_rows_processed, avg_num_rows_processed, \
num_missing_rows_skipped = self._get_skipped_rows_processed_count(
dep_var_array_str,
indep_var_array_str)
calculated_buffer_size = MiniBatchBufferSizeCalculator.\
calculate_default_buffer_size(
self.buffer_size,
num_rows_processed,
avg_num_rows_processed,
standardizer.independent_var_dimension)
"""
This query does the following:
Expand Down Expand Up @@ -174,7 +174,7 @@ class MiniBatchPreProcessor:
self.independent_varname,
calculated_buffer_size,
dep_var_classes_str,
num_rows_processed,
total_num_rows_processed,
num_missing_rows_skipped,
self.grouping_cols
)
Expand Down Expand Up @@ -210,27 +210,42 @@ class MiniBatchPreProcessor:
# Note: Keep the null checking where clause of this query in sync with
# the main create output table query.
query = """
SELECT COUNT(*) AS source_table_row_count,
sum(CASE WHEN
SELECT SUM(source_table_row_count_by_group) AS source_table_row_count,
SUM(num_rows_processed_by_group) AS total_num_rows_processed,
AVG(num_rows_processed_by_group) AS avg_num_rows_processed
FROM (
SELECT COUNT(*) AS source_table_row_count_by_group,
SUM(CASE WHEN
NOT {schema_madlib}.array_contains_null({dep_var_array})
AND NOT {schema_madlib}.array_contains_null({indep_var_array})
THEN 1 ELSE 0 END) AS num_rows_processed
THEN 1 ELSE 0 END) AS num_rows_processed_by_group
FROM {source_table}
{group_by_clause}) s
""".format(
schema_madlib = self.schema_madlib,
source_table = self.source_table,
dep_var_array = dep_var_array,
indep_var_array = indep_var_array)
indep_var_array = indep_var_array,
group_by_clause = "GROUP BY {0}".format(self.grouping_cols) \
if self.grouping_cols else '')
result = plpy.execute(query)

source_table_row_count = result[0]['source_table_row_count']
num_rows_processed = result[0]['num_rows_processed']
if not source_table_row_count or not num_rows_processed:
## SUM and AVG both return float, and we have to cast them into int fo
## summary table. For avg_num_rows_processed we need to ceil first so
## that the minimum won't be 0
source_table_row_count = int(result[0]['source_table_row_count'])
total_num_rows_processed = int(result[0]['total_num_rows_processed'])
avg_num_rows_processed = int(ceil(result[0]['avg_num_rows_processed']))
if not source_table_row_count or not total_num_rows_processed or \
not avg_num_rows_processed:
plpy.error("Error while getting the row count of the source table"
"{0}".format(self.source_table))
num_missing_rows_skipped = source_table_row_count - num_rows_processed

return num_rows_processed, num_missing_rows_skipped
num_missing_rows_skipped = source_table_row_count - total_num_rows_processed

return total_num_rows_processed, avg_num_rows_processed, \
num_missing_rows_skipped


class MiniBatchQueryFormatter:
"""
Expand Down Expand Up @@ -444,7 +459,7 @@ class MiniBatchSummarizer:
def create_output_summary_table(output_summary_table, source_table,
output_table, dep_var_array_str,
indep_var_array_str, buffer_size,
class_values, num_rows_processed,
class_values, total_num_rows_processed,
num_missing_rows_skipped, grouping_cols):
# 1. All the string columns are surrounded by "$$" to take care of
# special characters in the column name.
Expand All @@ -459,7 +474,7 @@ class MiniBatchSummarizer:
$${independent_varname}$$::TEXT AS independent_varname,
{buffer_size} AS buffer_size,
{class_values} AS class_values,
{num_rows_processed} AS num_rows_processed,
{total_num_rows_processed} AS num_rows_processed,
{num_missing_rows_skipped} AS num_missing_rows_skipped,
{grouping_cols}::TEXT AS grouping_cols
""".format(output_summary_table = output_summary_table,
Expand All @@ -469,7 +484,7 @@ class MiniBatchSummarizer:
independent_varname = indep_var_array_str,
buffer_size = buffer_size,
class_values = class_values,
num_rows_processed = num_rows_processed,
total_num_rows_processed = total_num_rows_processed,
num_missing_rows_skipped = num_missing_rows_skipped,
grouping_cols = "$$" + grouping_cols + "$$"
if grouping_cols else "NULL")
Expand All @@ -482,14 +497,14 @@ class MiniBatchBufferSizeCalculator:
"""
@staticmethod
def calculate_default_buffer_size(buffer_size,
num_rows_processed,
avg_num_rows_processed,
independent_var_dimension):
if buffer_size is not None:
return buffer_size
num_of_segments = get_seg_number()

default_buffer_size = min(75000000.0/independent_var_dimension,
float(num_rows_processed)/num_of_segments)
float(avg_num_rows_processed)/num_of_segments)
"""
1. For float number, we need at least one more buffer for the fraction part, e.g.
if default_buffer_size = 0.25, we need to round it to 1.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ SELECT assert
num_rows_processed = 10 AND
num_missing_rows_skipped = 0 AND
grouping_cols = 'rings',
'Summary Validation failed for grouping col. Expected:' || __to_char(summary)
'Summary Validation failed for grouping col. Actual:' || __to_char(summary)
) from (select * from minibatch_preprocessing_out_summary) summary;

-- Test that the standardization table gets created.
Expand Down Expand Up @@ -282,5 +282,5 @@ SELECT assert
num_rows_processed = 1 AND
num_missing_rows_skipped = 0 AND
grouping_cols = '"rin!#''gs"',
'Summary Validation failed for special chars. Expected:' || __to_char(summary)
'Summary Validation failed for special chars. Actual:' || __to_char(summary)
) from (select * from minibatch_preprocessing_out_summary) summary;
Original file line number Diff line number Diff line change
Expand Up @@ -74,51 +74,53 @@ class MiniBatchPreProcessingTestCase(unittest.TestCase):

def test_minibatch_preprocessor_executes_query(self):
preprocessor_obj = self.module.MiniBatchPreProcessor(self.default_schema_madlib,
"input",
"out",
self.default_dep_var,
self.default_ind_var,
self.grouping_cols,
self.default_buffer_size)
"input",
"out",
self.default_dep_var,
self.default_ind_var,
self.grouping_cols,
self.default_buffer_size)
self.plpy_mock_execute.side_effect = [[{"source_table_row_count":5 ,
"num_rows_processed":3}], ""]
"total_num_rows_processed":3,
"avg_num_rows_processed": 2}], ""]
preprocessor_obj.minibatch_preprocessor()
self.assertEqual(2, self.plpy_mock_execute.call_count)
self.assertEqual(self.default_buffer_size, preprocessor_obj.buffer_size)

def test_minibatch_preprocessor_null_buffer_size_executes_query(self):
preprocessor_obj = self.module.MiniBatchPreProcessor(self.default_schema_madlib,
"input",
"out",
self.default_dep_var,
self.default_ind_var,
self.grouping_cols,
None)
"input",
"out",
self.default_dep_var,
self.default_ind_var,
self.grouping_cols,
None)
self.plpy_mock_execute.side_effect = [[{"source_table_row_count":5 ,
"num_rows_processed":3}], ""]
"total_num_rows_processed":3,
"avg_num_rows_processed": 2}], ""]
self.module.MiniBatchBufferSizeCalculator.calculate_default_buffer_size = Mock()
preprocessor_obj.minibatch_preprocessor()
self.assertEqual(2, self.plpy_mock_execute.call_count)

def test_minibatch_preprocessor_multiple_dep_var_raises_exception(self):
with self.assertRaises(plpy.PLPYException):
self.module.MiniBatchPreProcessor(self.default_schema_madlib,
self.default_source_table,
self.default_output_table,
"y1,y2",
self.default_ind_var,
self.grouping_cols,
self.default_buffer_size)
self.default_source_table,
self.default_output_table,
"y1,y2",
self.default_ind_var,
self.grouping_cols,
self.default_buffer_size)

def test_minibatch_preprocessor_buffer_size_zero_fails(self):
with self.assertRaises(plpy.PLPYException):
self.module.MiniBatchPreProcessor(self.default_schema_madlib,
self.default_source_table,
self.default_output_table,
self.default_dep_var,
self.default_ind_var,
self.grouping_cols,
0)
self.default_source_table,
self.default_output_table,
self.default_dep_var,
self.default_ind_var,
self.grouping_cols,
0)

def test_minibatch_preprocessor_buffer_size_one_passes(self):
#not sure how to assert that an exception has not been raised
Expand Down
11 changes: 7 additions & 4 deletions src/ports/postgres/modules/utilities/utilities.py_in
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,19 @@ def is_platform_hawq():


def get_seg_number():
""" Find out how many primary segments exist in the distribution
Might be useful for partitioning data.
""" Find out how many primary segments(not include master segment) exist
in the distribution. Might be useful for partitioning data.
"""
if is_platform_pg():
return 1
else:
return plpy.execute("""
count = plpy.execute("""
SELECT count(*) from gp_segment_configuration
WHERE role = 'p'
WHERE role = 'p' and content != -1
""")[0]['count']
## in case some weird gpdb configuration happens, always returns
## primary segment number >= 1
return max(1, count)
# ------------------------------------------------------------------------------


Expand Down