diff --git a/tests/gcp/operators/test_mlengine.py b/tests/gcp/operators/test_mlengine.py index cc5d9127bb2a7..9485ea913e1b0 100644 --- a/tests/gcp/operators/test_mlengine.py +++ b/tests/gcp/operators/test_mlengine.py @@ -88,128 +88,119 @@ def setUp(self): }, schedule_interval='@daily') - def test_success_with_model(self): - with patch('airflow.gcp.operators.mlengine.MLEngineHook') \ - as mock_hook: - - input_with_model = self.INPUT_MISSING_ORIGIN.copy() - input_with_model['modelName'] = \ - 'projects/test-project/models/test_model' - success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy() - success_message['predictionInput'] = input_with_model - - hook_instance = mock_hook.return_value - hook_instance.get_job.side_effect = HttpError( - resp=httplib2.Response({ - 'status': 404 - }), content=b'some bytes') - hook_instance.create_job.return_value = success_message - - prediction_task = MLEngineBatchPredictionOperator( - job_id='test_prediction', - project_id='test-project', - region=input_with_model['region'], - data_format=input_with_model['dataFormat'], - input_paths=input_with_model['inputPaths'], - output_path=input_with_model['outputPath'], - model_name=input_with_model['modelName'].split('/')[-1], - dag=self.dag, - task_id='test-prediction') - prediction_output = prediction_task.execute(None) - - mock_hook.assert_called_once_with('google_cloud_default', None) - hook_instance.create_job.assert_called_once_with( - project_id='test-project', - job={ - 'jobId': 'test_prediction', - 'predictionInput': input_with_model - }, - use_existing_job_fn=ANY - ) - self.assertEqual(success_message['predictionOutput'], - prediction_output) - - def test_success_with_version(self): - with patch('airflow.gcp.operators.mlengine.MLEngineHook') \ - as mock_hook: - - input_with_version = self.INPUT_MISSING_ORIGIN.copy() - input_with_version['versionName'] = \ - 'projects/test-project/models/test_model/versions/test_version' - success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy() - success_message['predictionInput'] = input_with_version - - hook_instance = mock_hook.return_value - hook_instance.get_job.side_effect = HttpError( - resp=httplib2.Response({ - 'status': 404 - }), content=b'some bytes') - hook_instance.create_job.return_value = success_message - - prediction_task = MLEngineBatchPredictionOperator( - job_id='test_prediction', - project_id='test-project', - region=input_with_version['region'], - data_format=input_with_version['dataFormat'], - input_paths=input_with_version['inputPaths'], - output_path=input_with_version['outputPath'], - model_name=input_with_version['versionName'].split('/')[-3], - version_name=input_with_version['versionName'].split('/')[-1], - dag=self.dag, - task_id='test-prediction') - prediction_output = prediction_task.execute(None) - - mock_hook.assert_called_once_with('google_cloud_default', None) - hook_instance.create_job.assert_called_once_with( - project_id='test-project', - job={ - 'jobId': 'test_prediction', - 'predictionInput': input_with_version - }, - use_existing_job_fn=ANY - ) - self.assertEqual(success_message['predictionOutput'], - prediction_output) - - def test_success_with_uri(self): - with patch('airflow.gcp.operators.mlengine.MLEngineHook') \ - as mock_hook: - - input_with_uri = self.INPUT_MISSING_ORIGIN.copy() - input_with_uri['uri'] = 'gs://my_bucket/my_models/savedModel' - success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy() - success_message['predictionInput'] = input_with_uri - - hook_instance = mock_hook.return_value - hook_instance.get_job.side_effect = HttpError( - resp=httplib2.Response({ - 'status': 404 - }), content=b'some bytes') - hook_instance.create_job.return_value = success_message + @patch('airflow.gcp.operators.mlengine.MLEngineHook') + def test_success_with_model(self, mock_hook): + input_with_model = self.INPUT_MISSING_ORIGIN.copy() + input_with_model['modelName'] = \ + 'projects/test-project/models/test_model' + success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy() + success_message['predictionInput'] = input_with_model + + hook_instance = mock_hook.return_value + hook_instance.get_job.side_effect = HttpError( + resp=httplib2.Response({ + 'status': 404 + }), content=b'some bytes') + hook_instance.create_job.return_value = success_message + + prediction_task = MLEngineBatchPredictionOperator( + job_id='test_prediction', + project_id='test-project', + region=input_with_model['region'], + data_format=input_with_model['dataFormat'], + input_paths=input_with_model['inputPaths'], + output_path=input_with_model['outputPath'], + model_name=input_with_model['modelName'].split('/')[-1], + dag=self.dag, + task_id='test-prediction') + prediction_output = prediction_task.execute(None) + + mock_hook.assert_called_once_with('google_cloud_default', None) + hook_instance.create_job.assert_called_once_with( + project_id='test-project', + job={ + 'jobId': 'test_prediction', + 'predictionInput': input_with_model + }, + use_existing_job_fn=ANY + ) + self.assertEqual(success_message['predictionOutput'], prediction_output) - prediction_task = MLEngineBatchPredictionOperator( - job_id='test_prediction', - project_id='test-project', - region=input_with_uri['region'], - data_format=input_with_uri['dataFormat'], - input_paths=input_with_uri['inputPaths'], - output_path=input_with_uri['outputPath'], - uri=input_with_uri['uri'], - dag=self.dag, - task_id='test-prediction') - prediction_output = prediction_task.execute(None) + @patch('airflow.gcp.operators.mlengine.MLEngineHook') + def test_success_with_version(self, mock_hook): + input_with_version = self.INPUT_MISSING_ORIGIN.copy() + input_with_version['versionName'] = \ + 'projects/test-project/models/test_model/versions/test_version' + success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy() + success_message['predictionInput'] = input_with_version + + hook_instance = mock_hook.return_value + hook_instance.get_job.side_effect = HttpError( + resp=httplib2.Response({ + 'status': 404 + }), content=b'some bytes') + hook_instance.create_job.return_value = success_message + + prediction_task = MLEngineBatchPredictionOperator( + job_id='test_prediction', + project_id='test-project', + region=input_with_version['region'], + data_format=input_with_version['dataFormat'], + input_paths=input_with_version['inputPaths'], + output_path=input_with_version['outputPath'], + model_name=input_with_version['versionName'].split('/')[-3], + version_name=input_with_version['versionName'].split('/')[-1], + dag=self.dag, + task_id='test-prediction') + prediction_output = prediction_task.execute(None) + + mock_hook.assert_called_once_with('google_cloud_default', None) + hook_instance.create_job.assert_called_once_with( + project_id='test-project', + job={ + 'jobId': 'test_prediction', + 'predictionInput': input_with_version + }, + use_existing_job_fn=ANY + ) + self.assertEqual(success_message['predictionOutput'], prediction_output) - mock_hook.assert_called_once_with('google_cloud_default', None) - hook_instance.create_job.assert_called_once_with( - project_id='test-project', - job={ - 'jobId': 'test_prediction', - 'predictionInput': input_with_uri - }, - use_existing_job_fn=ANY - ) - self.assertEqual(success_message['predictionOutput'], - prediction_output) + @patch('airflow.gcp.operators.mlengine.MLEngineHook') + def test_success_with_uri(self, mock_hook): + input_with_uri = self.INPUT_MISSING_ORIGIN.copy() + input_with_uri['uri'] = 'gs://my_bucket/my_models/savedModel' + success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy() + success_message['predictionInput'] = input_with_uri + + hook_instance = mock_hook.return_value + hook_instance.get_job.side_effect = HttpError( + resp=httplib2.Response({ + 'status': 404 + }), content=b'some bytes') + hook_instance.create_job.return_value = success_message + + prediction_task = MLEngineBatchPredictionOperator( + job_id='test_prediction', + project_id='test-project', + region=input_with_uri['region'], + data_format=input_with_uri['dataFormat'], + input_paths=input_with_uri['inputPaths'], + output_path=input_with_uri['outputPath'], + uri=input_with_uri['uri'], + dag=self.dag, + task_id='test-prediction') + prediction_output = prediction_task.execute(None) + + mock_hook.assert_called_once_with('google_cloud_default', None) + hook_instance.create_job.assert_called_once_with( + project_id='test-project', + job={ + 'jobId': 'test_prediction', + 'predictionInput': input_with_uri + }, + use_existing_job_fn=ANY + ) + self.assertEqual(success_message['predictionOutput'], prediction_output) def test_invalid_model_origin(self): # Test that both uri and model is given @@ -251,59 +242,56 @@ def test_invalid_model_origin(self): 'model, a model & version combination, or a URI to a savedModel.', str(context.exception)) - def test_http_error(self): + @patch('airflow.gcp.operators.mlengine.MLEngineHook') + def test_http_error(self, mock_hook): http_error_code = 403 + input_with_model = self.INPUT_MISSING_ORIGIN.copy() + input_with_model['modelName'] = \ + 'projects/experimental/models/test_model' + + hook_instance = mock_hook.return_value + hook_instance.create_job.side_effect = HttpError( + resp=httplib2.Response({ + 'status': http_error_code + }), + content=b'Forbidden') + + with self.assertRaises(HttpError) as context: + prediction_task = MLEngineBatchPredictionOperator( + job_id='test_prediction', + project_id='test-project', + region=input_with_model['region'], + data_format=input_with_model['dataFormat'], + input_paths=input_with_model['inputPaths'], + output_path=input_with_model['outputPath'], + model_name=input_with_model['modelName'].split('/')[-1], + dag=self.dag, + task_id='test-prediction') + prediction_task.execute(None) - with patch('airflow.gcp.operators.mlengine.MLEngineHook') \ - as mock_hook: - input_with_model = self.INPUT_MISSING_ORIGIN.copy() - input_with_model['modelName'] = \ - 'projects/experimental/models/test_model' - - hook_instance = mock_hook.return_value - hook_instance.create_job.side_effect = HttpError( - resp=httplib2.Response({ - 'status': http_error_code - }), - content=b'Forbidden') - - with self.assertRaises(HttpError) as context: - prediction_task = MLEngineBatchPredictionOperator( - job_id='test_prediction', - project_id='test-project', - region=input_with_model['region'], - data_format=input_with_model['dataFormat'], - input_paths=input_with_model['inputPaths'], - output_path=input_with_model['outputPath'], - model_name=input_with_model['modelName'].split('/')[-1], - dag=self.dag, - task_id='test-prediction') - prediction_task.execute(None) - - mock_hook.assert_called_once_with('google_cloud_default', None) - hook_instance.create_job.assert_called_once_with( - 'test-project', { - 'jobId': 'test_prediction', - 'predictionInput': input_with_model - }, ANY) - - self.assertEqual(http_error_code, context.exception.resp.status) - - def test_failed_job_error(self): - with patch('airflow.gcp.operators.mlengine.MLEngineHook') \ - as mock_hook: - hook_instance = mock_hook.return_value - hook_instance.create_job.return_value = { - 'state': 'FAILED', - 'errorMessage': 'A failure message' - } - task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy() - task_args['uri'] = 'a uri' - - with self.assertRaises(RuntimeError) as context: - MLEngineBatchPredictionOperator(**task_args).execute(None) - - self.assertEqual('A failure message', str(context.exception)) + mock_hook.assert_called_once_with('google_cloud_default', None) + hook_instance.create_job.assert_called_once_with( + 'test-project', { + 'jobId': 'test_prediction', + 'predictionInput': input_with_model + }, ANY) + + self.assertEqual(http_error_code, context.exception.resp.status) + + @patch('airflow.gcp.operators.mlengine.MLEngineHook') + def test_failed_job_error(self, mock_hook): + hook_instance = mock_hook.return_value + hook_instance.create_job.return_value = { + 'state': 'FAILED', + 'errorMessage': 'A failure message' + } + task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy() + task_args['uri'] = 'a uri' + + with self.assertRaises(RuntimeError) as context: + MLEngineBatchPredictionOperator(**task_args).execute(None) + + self.assertEqual('A failure message', str(context.exception)) class TestMLEngineTrainingOperator(unittest.TestCase): @@ -328,96 +316,92 @@ class TestMLEngineTrainingOperator(unittest.TestCase): } } - def test_success_create_training_job(self): - with patch('airflow.gcp.operators.mlengine.MLEngineHook') \ - as mock_hook: - success_response = self.TRAINING_INPUT.copy() - success_response['state'] = 'SUCCEEDED' - hook_instance = mock_hook.return_value - hook_instance.create_job.return_value = success_response - - training_op = MLEngineTrainingOperator( - **self.TRAINING_DEFAULT_ARGS) - training_op.execute(None) - - mock_hook.assert_called_once_with( - gcp_conn_id='google_cloud_default', delegate_to=None) - # Make sure only 'create_job' is invoked on hook instance - self.assertEqual(len(hook_instance.mock_calls), 1) - hook_instance.create_job.assert_called_once_with( - project_id='test-project', job=self.TRAINING_INPUT, use_existing_job_fn=ANY) + @patch('airflow.gcp.operators.mlengine.MLEngineHook') + def test_success_create_training_job(self, mock_hook): + success_response = self.TRAINING_INPUT.copy() + success_response['state'] = 'SUCCEEDED' + hook_instance = mock_hook.return_value + hook_instance.create_job.return_value = success_response + + training_op = MLEngineTrainingOperator( + **self.TRAINING_DEFAULT_ARGS) + training_op.execute(None) + + mock_hook.assert_called_once_with( + gcp_conn_id='google_cloud_default', delegate_to=None) + # Make sure only 'create_job' is invoked on hook instance + self.assertEqual(len(hook_instance.mock_calls), 1) + hook_instance.create_job.assert_called_once_with( + project_id='test-project', job=self.TRAINING_INPUT, use_existing_job_fn=ANY) - def test_success_create_training_job_with_optional_args(self): + @patch('airflow.gcp.operators.mlengine.MLEngineHook') + def test_success_create_training_job_with_optional_args(self, mock_hook): training_input = copy.deepcopy(self.TRAINING_INPUT) training_input['trainingInput']['runtimeVersion'] = '1.6' training_input['trainingInput']['pythonVersion'] = '3.5' training_input['trainingInput']['jobDir'] = 'gs://some-bucket/jobs/test_training' - with patch('airflow.gcp.operators.mlengine.MLEngineHook') \ - as mock_hook: - success_response = self.TRAINING_INPUT.copy() - success_response['state'] = 'SUCCEEDED' - hook_instance = mock_hook.return_value - hook_instance.create_job.return_value = success_response + success_response = self.TRAINING_INPUT.copy() + success_response['state'] = 'SUCCEEDED' + hook_instance = mock_hook.return_value + hook_instance.create_job.return_value = success_response + training_op = MLEngineTrainingOperator( + runtime_version='1.6', + python_version='3.5', + job_dir='gs://some-bucket/jobs/test_training', + **self.TRAINING_DEFAULT_ARGS) + training_op.execute(None) + + mock_hook.assert_called_once_with(gcp_conn_id='google_cloud_default', delegate_to=None) + # Make sure only 'create_job' is invoked on hook instance + self.assertEqual(len(hook_instance.mock_calls), 1) + hook_instance.create_job.assert_called_once_with( + project_id='test-project', job=training_input, use_existing_job_fn=ANY) + + @patch('airflow.gcp.operators.mlengine.MLEngineHook') + def test_http_error(self, mock_hook): + http_error_code = 403 + hook_instance = mock_hook.return_value + hook_instance.create_job.side_effect = HttpError( + resp=httplib2.Response({ + 'status': http_error_code + }), + content=b'Forbidden') + + with self.assertRaises(HttpError) as context: training_op = MLEngineTrainingOperator( - runtime_version='1.6', - python_version='3.5', - job_dir='gs://some-bucket/jobs/test_training', **self.TRAINING_DEFAULT_ARGS) training_op.execute(None) - mock_hook.assert_called_once_with(gcp_conn_id='google_cloud_default', delegate_to=None) - # Make sure only 'create_job' is invoked on hook instance - self.assertEqual(len(hook_instance.mock_calls), 1) - hook_instance.create_job.assert_called_once_with( - project_id='test-project', job=training_input, use_existing_job_fn=ANY) + mock_hook.assert_called_once_with( + gcp_conn_id='google_cloud_default', delegate_to=None) + # Make sure only 'create_job' is invoked on hook instance + self.assertEqual(len(hook_instance.mock_calls), 1) + hook_instance.create_job.assert_called_once_with( + project_id='test-project', job=self.TRAINING_INPUT, use_existing_job_fn=ANY) + self.assertEqual(http_error_code, context.exception.resp.status) - def test_http_error(self): - http_error_code = 403 - with patch('airflow.gcp.operators.mlengine.MLEngineHook') \ - as mock_hook: - hook_instance = mock_hook.return_value - hook_instance.create_job.side_effect = HttpError( - resp=httplib2.Response({ - 'status': http_error_code - }), - content=b'Forbidden') - - with self.assertRaises(HttpError) as context: - training_op = MLEngineTrainingOperator( - **self.TRAINING_DEFAULT_ARGS) - training_op.execute(None) - - mock_hook.assert_called_once_with( - gcp_conn_id='google_cloud_default', delegate_to=None) - # Make sure only 'create_job' is invoked on hook instance - self.assertEqual(len(hook_instance.mock_calls), 1) - hook_instance.create_job.assert_called_once_with( - project_id='test-project', job=self.TRAINING_INPUT, use_existing_job_fn=ANY) - self.assertEqual(http_error_code, context.exception.resp.status) - - def test_failed_job_error(self): - with patch('airflow.gcp.operators.mlengine.MLEngineHook') \ - as mock_hook: - failure_response = self.TRAINING_INPUT.copy() - failure_response['state'] = 'FAILED' - failure_response['errorMessage'] = 'A failure message' - hook_instance = mock_hook.return_value - hook_instance.create_job.return_value = failure_response - - with self.assertRaises(RuntimeError) as context: - training_op = MLEngineTrainingOperator( - **self.TRAINING_DEFAULT_ARGS) - training_op.execute(None) - - mock_hook.assert_called_once_with( - gcp_conn_id='google_cloud_default', delegate_to=None) - # Make sure only 'create_job' is invoked on hook instance - self.assertEqual(len(hook_instance.mock_calls), 1) - hook_instance.create_job.assert_called_once_with( - project_id='test-project', job=self.TRAINING_INPUT, use_existing_job_fn=ANY) - self.assertEqual('A failure message', str(context.exception)) + @patch('airflow.gcp.operators.mlengine.MLEngineHook') + def test_failed_job_error(self, mock_hook): + failure_response = self.TRAINING_INPUT.copy() + failure_response['state'] = 'FAILED' + failure_response['errorMessage'] = 'A failure message' + hook_instance = mock_hook.return_value + hook_instance.create_job.return_value = failure_response + + with self.assertRaises(RuntimeError) as context: + training_op = MLEngineTrainingOperator( + **self.TRAINING_DEFAULT_ARGS) + training_op.execute(None) + + mock_hook.assert_called_once_with( + gcp_conn_id='google_cloud_default', delegate_to=None) + # Make sure only 'create_job' is invoked on hook instance + self.assertEqual(len(hook_instance.mock_calls), 1) + hook_instance.create_job.assert_called_once_with( + project_id='test-project', job=self.TRAINING_INPUT, use_existing_job_fn=ANY) + self.assertEqual('A failure message', str(context.exception)) class TestMLEngineModelOperator(unittest.TestCase): @@ -542,23 +526,22 @@ class TestMLEngineVersionOperator(unittest.TestCase): 'task_id': 'test-version' } - def test_success_create_version(self): - with patch('airflow.gcp.operators.mlengine.MLEngineHook') \ - as mock_hook: - success_response = {'name': 'some-name', 'done': True} - hook_instance = mock_hook.return_value - hook_instance.create_version.return_value = success_response - - training_op = MLEngineVersionOperator( - version=TEST_VERSION, - **self.VERSION_DEFAULT_ARGS) - training_op.execute(None) + @patch('airflow.gcp.operators.mlengine.MLEngineHook') + def test_success_create_version(self, mock_hook): + success_response = {'name': 'some-name', 'done': True} + hook_instance = mock_hook.return_value + hook_instance.create_version.return_value = success_response - mock_hook.assert_called_once_with(gcp_conn_id='google_cloud_default', delegate_to=None) - # Make sure only 'create_version' is invoked on hook instance - self.assertEqual(len(hook_instance.mock_calls), 1) - hook_instance.create_version.assert_called_once_with( - project_id='test-project', model_name='test-model', version_spec=TEST_VERSION) + training_op = MLEngineVersionOperator( + version=TEST_VERSION, + **self.VERSION_DEFAULT_ARGS) + training_op.execute(None) + + mock_hook.assert_called_once_with(gcp_conn_id='google_cloud_default', delegate_to=None) + # Make sure only 'create_version' is invoked on hook instance + self.assertEqual(len(hook_instance.mock_calls), 1) + hook_instance.create_version.assert_called_once_with( + project_id='test-project', model_name='test-model', version_spec=TEST_VERSION) class TestMLEngineCreateVersion(unittest.TestCase):