Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions airflow/gcp/hooks/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ class DataflowJobStatus:
"""
JOB_STATE_DONE = "JOB_STATE_DONE"
JOB_STATE_RUNNING = "JOB_STATE_RUNNING"
JOB_TYPE_STREAMING = "JOB_TYPE_STREAMING"
JOB_STATE_FAILED = "JOB_STATE_FAILED"
JOB_STATE_CANCELLED = "JOB_STATE_CANCELLED"
JOB_STATE_PENDING = "JOB_STATE_PENDING"
Expand All @@ -58,6 +57,15 @@ class DataflowJobStatus:
END_STATES = SUCCEEDED_END_STATES | FAILED_END_STATES


class DataflowJobType:
"""
Helper class with Dataflow job types.
"""
JOB_TYPE_UNKNOWN = "JOB_TYPE_UNKNOWN"
JOB_TYPE_BATCH = "JOB_TYPE_BATCH"
JOB_TYPE_STREAMING = "JOB_TYPE_STREAMING"


class _DataflowJob(LoggingMixin):
def __init__(
self,
Expand Down Expand Up @@ -178,7 +186,7 @@ def check_dataflow_job_state(self, job) -> bool:
raise Exception("Google Cloud Dataflow job {} was cancelled.".format(
job['name']))
elif DataflowJobStatus.JOB_STATE_RUNNING == job['currentState'] and \
DataflowJobStatus.JOB_TYPE_STREAMING == job['type']:
DataflowJobType.JOB_TYPE_STREAMING == job['type']:
return True
elif job['currentState'] in {DataflowJobStatus.JOB_STATE_RUNNING,
DataflowJobStatus.JOB_STATE_PENDING}:
Expand Down
293 changes: 201 additions & 92 deletions tests/gcp/hooks/test_dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@

from parameterized import parameterized

from airflow.gcp.hooks.dataflow import DataFlowHook, DataflowJobStatus, _Dataflow, _DataflowJob
from airflow.gcp.hooks.dataflow import (
DataFlowHook, DataflowJobStatus, DataflowJobType, _Dataflow, _DataflowJob,
)
from tests.compat import MagicMock, mock

TASK_ID = 'test-dataflow-operator'
Expand Down Expand Up @@ -98,6 +100,31 @@ def test_dataflow_client_creation(self, mock_build, mock_authorize):
)
self.assertEqual(mock_build.return_value, result)

@mock.patch(DATAFLOW_STRING.format('uuid.uuid4'))
@mock.patch(DATAFLOW_STRING.format('_DataflowJob'))
@mock.patch(DATAFLOW_STRING.format('_Dataflow'))
@mock.patch(DATAFLOW_STRING.format('DataFlowHook.get_conn'))
def test_start_python_dataflow(
self, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid
):
mock_uuid.return_value = MOCK_UUID
mock_conn.return_value = None
dataflow_instance = mock_dataflow.return_value
dataflow_instance.wait_for_done.return_value = None
dataflowjob_instance = mock_dataflowjob.return_value
dataflowjob_instance.wait_for_done.return_value = None
self.dataflow_hook.start_python_dataflow(
job_name=JOB_NAME, variables=DATAFLOW_OPTIONS_PY,
dataflow=PY_FILE, py_options=PY_OPTIONS)
expected_cmd = ["python2", '-m', PY_FILE,
'--region=us-central1',
'--runner=DataflowRunner', '--project=test',
'--labels=foo=bar',
'--staging_location=gs://test/staging',
'--job_name={}-{}'.format(JOB_NAME, MOCK_UUID)]
self.assertListEqual(sorted(mock_dataflow.call_args[0][0]),
sorted(expected_cmd))

@parameterized.expand([
('default_to_python2', "python2"),
('major_version_2', 'python2'),
Expand All @@ -108,8 +135,9 @@ def test_dataflow_client_creation(self, mock_build, mock_authorize):
@mock.patch(DATAFLOW_STRING.format('_DataflowJob'))
@mock.patch(DATAFLOW_STRING.format('_Dataflow'))
@mock.patch(DATAFLOW_STRING.format('DataFlowHook.get_conn'))
def test_start_python_dataflow(self, name, py, mock_conn,
mock_dataflow, mock_dataflowjob, mock_uuid):
def test_start_python_dataflow_with_custom_interpreter(
self, name, py_interpreter, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid
):
del name # unused variable
mock_uuid.return_value = MOCK_UUID
mock_conn.return_value = None
Expand All @@ -120,9 +148,8 @@ def test_start_python_dataflow(self, name, py, mock_conn,
self.dataflow_hook.start_python_dataflow(
job_name=JOB_NAME, variables=DATAFLOW_OPTIONS_PY,
dataflow=PY_FILE, py_options=PY_OPTIONS,
py_interpreter=py)
expected_interpreter = py if py else DEFAULT_PY_INTERPRETER
expected_cmd = [expected_interpreter, '-m', PY_FILE,
py_interpreter=py_interpreter)
expected_cmd = [py_interpreter, '-m', PY_FILE,
'--region=us-central1',
'--runner=DataflowRunner', '--project=test',
'--labels=foo=bar',
Expand Down Expand Up @@ -178,95 +205,34 @@ def test_start_java_dataflow_with_job_class(self, mock_conn, mock_dataflow, mock
self.assertListEqual(sorted(mock_dataflow.call_args[0][0]),
sorted(expected_cmd))

@mock.patch('airflow.gcp.hooks.dataflow._Dataflow.log')
@mock.patch('subprocess.Popen')
@mock.patch('select.select')
def test_dataflow_wait_for_done_logging(self, mock_select, mock_popen, mock_logging):
mock_logging.info = MagicMock()
mock_logging.warning = MagicMock()
mock_proc = MagicMock()
mock_proc.stderr = MagicMock()
mock_proc.stderr.readlines = MagicMock(return_value=['test\n', 'error\n'])
mock_stderr_fd = MagicMock()
mock_proc.stderr.fileno = MagicMock(return_value=mock_stderr_fd)
mock_proc_poll = MagicMock()
mock_select.return_value = [[mock_stderr_fd]]

def poll_resp_error():
mock_proc.return_code = 1
return True

mock_proc_poll.side_effect = [None, poll_resp_error]
mock_proc.poll = mock_proc_poll
mock_popen.return_value = mock_proc
dataflow = _Dataflow(['test', 'cmd'])
mock_logging.info.assert_called_once_with('Running command: %s', 'test cmd')
self.assertRaises(Exception, dataflow.wait_for_done)

def test_valid_dataflow_job_name(self):
job_name = self.dataflow_hook._build_dataflow_job_name(
job_name=JOB_NAME, append_job_name=False
)

self.assertEqual(job_name, JOB_NAME)

def test_fix_underscore_in_job_name(self):
job_name_with_underscore = 'test_example'
fixed_job_name = job_name_with_underscore.replace(
'_', '-'
)
@parameterized.expand([
(JOB_NAME, JOB_NAME, False),
('test-example', 'test_example', False),
('test-dataflow-pipeline-12345678', JOB_NAME, True),
('test-example-12345678', 'test_example', True),
('df-job-1', 'df-job-1', False),
('df-job', 'df-job', False),
('dfjob', 'dfjob', False),
('dfjob1', 'dfjob1', False),
])
@mock.patch(DATAFLOW_STRING.format('uuid.uuid4'), return_value=MOCK_UUID)
def test_valid_dataflow_job_name(self, expected_result, job_name, append_job_name, mock_uuid4):
job_name = self.dataflow_hook._build_dataflow_job_name(
job_name=job_name_with_underscore, append_job_name=False
)

self.assertEqual(job_name, fixed_job_name)

def test_invalid_dataflow_job_name(self):
invalid_job_name = '9test_invalid_name'
fixed_name = invalid_job_name.replace(
'_', '-')

with self.assertRaises(ValueError) as e:
self.dataflow_hook._build_dataflow_job_name(
job_name=invalid_job_name, append_job_name=False
)
# Test whether the job_name is present in the Error msg
self.assertIn('Invalid job_name ({})'.format(fixed_name),
str(e.exception))

def test_dataflow_job_regex_check(self):
self.assertEqual(self.dataflow_hook._build_dataflow_job_name(
job_name='df-job-1', append_job_name=False
), 'df-job-1')

self.assertEqual(self.dataflow_hook._build_dataflow_job_name(
job_name='df-job', append_job_name=False
), 'df-job')

self.assertEqual(self.dataflow_hook._build_dataflow_job_name(
job_name='dfjob', append_job_name=False
), 'dfjob')

self.assertEqual(self.dataflow_hook._build_dataflow_job_name(
job_name='dfjob1', append_job_name=False
), 'dfjob1')

self.assertRaises(
ValueError,
self.dataflow_hook._build_dataflow_job_name,
job_name='1dfjob', append_job_name=False
job_name=job_name, append_job_name=append_job_name
)

self.assertRaises(
ValueError,
self.dataflow_hook._build_dataflow_job_name,
job_name='dfjob@', append_job_name=False
)
self.assertEqual(expected_result, job_name)

@parameterized.expand([
("1dfjob@", ),
("dfjob@", ),
("df^jo", )
])
def test_build_dataflow_job_name_with_invalid_value(self, job_name):
self.assertRaises(
ValueError,
self.dataflow_hook._build_dataflow_job_name,
job_name='df^jo', append_job_name=False
job_name=job_name, append_job_name=False
)


Expand Down Expand Up @@ -346,10 +312,15 @@ def test_dataflow_job_init_without_job_id(self):
def test_dataflow_job_wait_for_multiple_jobs(self):
job = {"id": TEST_JOB_ID, "name": TEST_JOB_NAME, "currentState": DataflowJobStatus.JOB_STATE_DONE}

self.mock_dataflow.projects.return_value.locations.return_value. \
jobs.return_value.list.return_value.execute.return_value = {
"jobs": [job, job]
}
(
self.mock_dataflow.projects.return_value.
locations.return_value.
jobs.return_value.
list.return_value.
execute.return_value
) = {
"jobs": [job, job]
}

dataflow_job = _DataflowJob(
dataflow=self.mock_dataflow,
Expand All @@ -371,6 +342,119 @@ def test_dataflow_job_wait_for_multiple_jobs(self):

self.assertEqual(dataflow_job.get(), [job, job])

def test_dataflow_job_wait_for_multiple_jobs_and_one_failed(self):
(
self.mock_dataflow.projects.return_value.
locations.return_value.
jobs.return_value.
list.return_value.
execute.return_value
) = {
"jobs": [
{"id": "id-1", "name": "name-1", "currentState": DataflowJobStatus.JOB_STATE_DONE},
{"id": "id-2", "name": "name-2", "currentState": DataflowJobStatus.JOB_STATE_FAILED}
]
}

dataflow_job = _DataflowJob(
dataflow=self.mock_dataflow,
project_number=TEST_PROJECT,
name="name-",
location=TEST_LOCATION,
poll_sleep=0,
job_id=None,
num_retries=20,
multiple_jobs=True
)
with self.assertRaisesRegex(Exception, 'Google Cloud Dataflow job name-2 has failed\\.'):
dataflow_job.wait_for_done()

def test_dataflow_job_wait_for_multiple_jobs_and_one_cancelled(self):
(
self.mock_dataflow.projects.return_value.
locations.return_value.
jobs.return_value.
list.return_value.
execute.return_value
) = {
"jobs": [
{"id": "id-1", "name": "name-1", "currentState": DataflowJobStatus.JOB_STATE_DONE},
{"id": "id-2", "name": "name-2", "currentState": DataflowJobStatus.JOB_STATE_CANCELLED}
]
}

dataflow_job = _DataflowJob(
dataflow=self.mock_dataflow,
project_number=TEST_PROJECT,
name="name-",
location=TEST_LOCATION,
poll_sleep=0,
job_id=None,
num_retries=20,
multiple_jobs=True
)
with self.assertRaisesRegex(Exception, 'Google Cloud Dataflow job name-2 was cancelled\\.'):
dataflow_job.wait_for_done()

def test_dataflow_job_wait_for_multiple_jobs_and_one_unknown(self):
(
self.mock_dataflow.projects.return_value.
locations.return_value.
jobs.return_value.
list.return_value.
execute.return_value
) = {
"jobs": [
{"id": "id-1", "name": "name-1", "currentState": DataflowJobStatus.JOB_STATE_DONE},
{"id": "id-2", "name": "name-2", "currentState": "unknown"}
]
}

dataflow_job = _DataflowJob(
dataflow=self.mock_dataflow,
project_number=TEST_PROJECT,
name="name-",
location=TEST_LOCATION,
poll_sleep=0,
job_id=None,
num_retries=20,
multiple_jobs=True
)
with self.assertRaisesRegex(Exception, 'Google Cloud Dataflow job name-2 was unknown state: unknown'):
dataflow_job.wait_for_done()

def test_dataflow_job_wait_for_multiple_jobs_and_streaming_jobs(self):
mock_jobs_list = (
self.mock_dataflow.projects.return_value.
locations.return_value.
jobs.return_value.
list
)
mock_jobs_list.return_value.execute.return_value = {
"jobs": [
{
"id": "id-2",
"name": "name-2",
"currentState": DataflowJobStatus.JOB_STATE_RUNNING,
"type": DataflowJobType.JOB_TYPE_STREAMING
}
]
}

dataflow_job = _DataflowJob(
dataflow=self.mock_dataflow,
project_number=TEST_PROJECT,
name="name-",
location=TEST_LOCATION,
poll_sleep=0,
job_id=None,
num_retries=20,
multiple_jobs=True
)
dataflow_job.wait_for_done()

self.assertEqual(1, mock_jobs_list.call_count)

def test_dataflow_job_wait_for_single_jobs(self):
job = {"id": TEST_JOB_ID, "name": TEST_JOB_NAME, "currentState": DataflowJobStatus.JOB_STATE_DONE}

Expand Down Expand Up @@ -412,3 +496,28 @@ def test_data_flow_valid_job_id(self):
def test_data_flow_missing_job_id(self):
cmd = ['echo', 'unit testing']
self.assertEqual(_Dataflow(cmd).wait_for_done(), None)

@mock.patch('airflow.gcp.hooks.dataflow._Dataflow.log')
@mock.patch('subprocess.Popen')
@mock.patch('select.select')
def test_dataflow_wait_for_done_logging(self, mock_select, mock_popen, mock_logging):
mock_logging.info = MagicMock()
mock_logging.warning = MagicMock()
mock_proc = MagicMock()
mock_proc.stderr = MagicMock()
mock_proc.stderr.readlines = MagicMock(return_value=['test\n', 'error\n'])
mock_stderr_fd = MagicMock()
mock_proc.stderr.fileno = MagicMock(return_value=mock_stderr_fd)
mock_proc_poll = MagicMock()
mock_select.return_value = [[mock_stderr_fd]]

def poll_resp_error():
mock_proc.return_code = 1
return True

mock_proc_poll.side_effect = [None, poll_resp_error]
mock_proc.poll = mock_proc_poll
mock_popen.return_value = mock_proc
dataflow = _Dataflow(['test', 'cmd'])
mock_logging.info.assert_called_once_with('Running command: %s', 'test cmd')
self.assertRaises(Exception, dataflow.wait_for_done)