Skip to content

Commit

Permalink
[AIRFLOW-3059] Log how many rows are read from Postgres (#3905)
Browse files Browse the repository at this point in the history
To know how many data is being read from Postgres, it is nice to log
this to the Airflow log.

Previously when there was no data, it would still create a single file.
This is not something that we want, and therefore we've changed this
behaviour.

Refactored the tests to make use of Postgres itself since we have it
running. This makes the tests more realistic, instead of mocking
everything.
  • Loading branch information
Fokko authored and kaxil committed Sep 16, 2018
1 parent e56c1df commit 1e79dae
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 60 deletions.
54 changes: 32 additions & 22 deletions airflow/contrib/operators/postgres_to_gcs_operator.py
Expand Up @@ -133,28 +133,38 @@ def _write_local_data_files(self, cursor):
contain the data for the GCS objects.
"""
schema = list(map(lambda schema_tuple: schema_tuple[0], cursor.description))
file_no = 0
tmp_file_handle = NamedTemporaryFile(delete=True)
tmp_file_handles = {self.filename.format(file_no): tmp_file_handle}

for row in cursor:
# Convert datetime objects to utc seconds, and decimals to floats
row = map(self.convert_types, row)
row_dict = dict(zip(schema, row))

s = json.dumps(row_dict, sort_keys=True)
if PY3:
s = s.encode('utf-8')
tmp_file_handle.write(s)

# Append newline to make dumps BigQuery compatible.
tmp_file_handle.write(b'\n')

# Stop if the file exceeds the file size limit.
if tmp_file_handle.tell() >= self.approx_max_file_size_bytes:
file_no += 1
tmp_file_handle = NamedTemporaryFile(delete=True)
tmp_file_handles[self.filename.format(file_no)] = tmp_file_handle
tmp_file_handles = {}
row_no = 0

def _create_new_file():
handle = NamedTemporaryFile(delete=True)
filename = self.filename.format(len(tmp_file_handles))
tmp_file_handles[filename] = handle
return handle

# Don't create a file if there is nothing to write
if cursor.rowcount > 0:
tmp_file_handle = _create_new_file()

for row in cursor:
# Convert datetime objects to utc seconds, and decimals to floats
row = map(self.convert_types, row)
row_dict = dict(zip(schema, row))

s = json.dumps(row_dict, sort_keys=True)
if PY3:
s = s.encode('utf-8')
tmp_file_handle.write(s)

# Append newline to make dumps BigQuery compatible.
tmp_file_handle.write(b'\n')

# Stop if the file exceeds the file size limit.
if tmp_file_handle.tell() >= self.approx_max_file_size_bytes:
tmp_file_handle = _create_new_file()
row_no += 1

self.log.info('Received %s rows over %s files', row_no, len(tmp_file_handles))

return tmp_file_handles

Expand Down
100 changes: 62 additions & 38 deletions tests/contrib/operators/test_postgres_to_gcs_operator.py
Expand Up @@ -7,9 +7,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand All @@ -25,40 +25,66 @@
import sys
import unittest

from airflow.contrib.operators.postgres_to_gcs_operator import PostgresToGoogleCloudStorageOperator
from airflow.hooks.postgres_hook import PostgresHook
from airflow.contrib.operators.postgres_to_gcs_operator import \
PostgresToGoogleCloudStorageOperator

try:
from unittest import mock
from unittest.mock import patch
except ImportError:
try:
import mock
from mock import patch
except ImportError:
mock = None

PY3 = sys.version_info[0] == 3
TABLES = {'postgres_to_gcs_operator', 'postgres_to_gcs_operator_empty'}

TASK_ID = 'test-postgres-to-gcs'
POSTGRES_CONN_ID = 'postgres_conn_test'
SQL = 'select 1'
POSTGRES_CONN_ID = 'postgres_default'
SQL = 'SELECT * FROM postgres_to_gcs_operator'
BUCKET = 'gs://test'
FILENAME = 'test_{}.ndjson'
# we expect the psycopg cursor to return encoded strs in py2 and decoded in py3
if PY3:
ROWS = [('mock_row_content_1', 42), ('mock_row_content_2', 43), ('mock_row_content_3', 44)]
CURSOR_DESCRIPTION = (('some_str', 0), ('some_num', 1005))
else:
ROWS = [(b'mock_row_content_1', 42), (b'mock_row_content_2', 43), (b'mock_row_content_3', 44)]
CURSOR_DESCRIPTION = ((b'some_str', 0), (b'some_num', 1005))

NDJSON_LINES = [
b'{"some_num": 42, "some_str": "mock_row_content_1"}\n',
b'{"some_num": 43, "some_str": "mock_row_content_2"}\n',
b'{"some_num": 44, "some_str": "mock_row_content_3"}\n'
]
SCHEMA_FILENAME = 'schema_test.json'
SCHEMA_JSON = b'[{"mode": "NULLABLE", "name": "some_str", "type": "STRING"}, {"mode": "REPEATED", "name": "some_num", "type": "INTEGER"}]'
SCHEMA_JSON = b'[{"mode": "NULLABLE", "name": "some_str", "type": "STRING"}, ' \
b'{"mode": "NULLABLE", "name": "some_num", "type": "INTEGER"}]'


class PostgresToGoogleCloudStorageOperatorTest(unittest.TestCase):
def setUp(self):
postgres = PostgresHook()
with postgres.get_conn() as conn:
with conn.cursor() as cur:
for table in TABLES:
cur.execute("DROP TABLE IF EXISTS {} CASCADE;".format(table))
cur.execute("CREATE TABLE {}(some_str varchar, some_num integer);"
.format(table))

cur.execute(
"INSERT INTO postgres_to_gcs_operator VALUES(%s, %s);",
('mock_row_content_1', 42)
)
cur.execute(
"INSERT INTO postgres_to_gcs_operator VALUES(%s, %s);",
('mock_row_content_2', 43)
)
cur.execute(
"INSERT INTO postgres_to_gcs_operator VALUES(%s, %s);",
('mock_row_content_3', 44)
)

def tearDown(self):
postgres = PostgresHook()
with postgres.get_conn() as conn:
with conn.cursor() as cur:
for table in TABLES:
cur.execute("DROP TABLE IF EXISTS {} CASCADE;".format(table))

def test_init(self):
"""Test PostgresToGoogleCloudStorageOperator instance is properly initialized."""
op = PostgresToGoogleCloudStorageOperator(
Expand All @@ -68,9 +94,8 @@ def test_init(self):
self.assertEqual(op.bucket, BUCKET)
self.assertEqual(op.filename, FILENAME)

@mock.patch('airflow.contrib.operators.postgres_to_gcs_operator.PostgresHook')
@mock.patch('airflow.contrib.operators.postgres_to_gcs_operator.GoogleCloudStorageHook')
def test_exec_success(self, gcs_hook_mock_class, pg_hook_mock_class):
@patch('airflow.contrib.operators.postgres_to_gcs_operator.GoogleCloudStorageHook')
def test_exec_success(self, gcs_hook_mock_class):
"""Test the execute function in case where the run is successful."""
op = PostgresToGoogleCloudStorageOperator(
task_id=TASK_ID,
Expand All @@ -79,10 +104,6 @@ def test_exec_success(self, gcs_hook_mock_class, pg_hook_mock_class):
bucket=BUCKET,
filename=FILENAME)

pg_hook_mock = pg_hook_mock_class.return_value
pg_hook_mock.get_conn().cursor().__iter__.return_value = iter(ROWS)
pg_hook_mock.get_conn().cursor().description = CURSOR_DESCRIPTION

gcs_hook_mock = gcs_hook_mock_class.return_value

def _assert_upload(bucket, obj, tmp_filename, content_type):
Expand All @@ -96,16 +117,9 @@ def _assert_upload(bucket, obj, tmp_filename, content_type):

op.execute(None)

pg_hook_mock_class.assert_called_once_with(postgres_conn_id=POSTGRES_CONN_ID)
pg_hook_mock.get_conn().cursor().execute.assert_called_once_with(SQL, None)

@mock.patch('airflow.contrib.operators.postgres_to_gcs_operator.PostgresHook')
@mock.patch('airflow.contrib.operators.postgres_to_gcs_operator.GoogleCloudStorageHook')
def test_file_splitting(self, gcs_hook_mock_class, pg_hook_mock_class):
@patch('airflow.contrib.operators.postgres_to_gcs_operator.GoogleCloudStorageHook')
def test_file_splitting(self, gcs_hook_mock_class):
"""Test that ndjson is split by approx_max_file_size_bytes param."""
pg_hook_mock = pg_hook_mock_class.return_value
pg_hook_mock.get_conn().cursor().__iter__.return_value = iter(ROWS)
pg_hook_mock.get_conn().cursor().description = CURSOR_DESCRIPTION

gcs_hook_mock = gcs_hook_mock_class.return_value
expected_upload = {
Expand All @@ -129,13 +143,23 @@ def _assert_upload(bucket, obj, tmp_filename, content_type):
approx_max_file_size_bytes=len(expected_upload[FILENAME.format(0)]))
op.execute(None)

@mock.patch('airflow.contrib.operators.postgres_to_gcs_operator.PostgresHook')
@mock.patch('airflow.contrib.operators.postgres_to_gcs_operator.GoogleCloudStorageHook')
def test_schema_file(self, gcs_hook_mock_class, pg_hook_mock_class):
@patch('airflow.contrib.operators.postgres_to_gcs_operator.GoogleCloudStorageHook')
def test_empty_query(self, gcs_hook_mock_class):
"""If the sql returns no rows, we should not upload any files"""
gcs_hook_mock = gcs_hook_mock_class.return_value

op = PostgresToGoogleCloudStorageOperator(
task_id=TASK_ID,
sql='SELECT * FROM postgres_to_gcs_operator_empty',
bucket=BUCKET,
filename=FILENAME)
op.execute(None)

assert not gcs_hook_mock.upload.called, 'No data means no files in the bucket'

@patch('airflow.contrib.operators.postgres_to_gcs_operator.GoogleCloudStorageHook')
def test_schema_file(self, gcs_hook_mock_class):
"""Test writing schema files."""
pg_hook_mock = pg_hook_mock_class.return_value
pg_hook_mock.get_conn().cursor().__iter__.return_value = iter(ROWS)
pg_hook_mock.get_conn().cursor().description = CURSOR_DESCRIPTION

gcs_hook_mock = gcs_hook_mock_class.return_value

Expand Down

0 comments on commit 1e79dae

Please sign in to comment.