From c6cd3a22ae18227d0ae02775739ace5d19ad041d Mon Sep 17 00:00:00 2001 From: bhavaniravi Date: Fri, 17 Dec 2021 15:43:13 +0530 Subject: [PATCH 1/2] move emr_container hook --- airflow/providers/amazon/aws/hooks/emr.py | 186 ++++++++++++++++ .../amazon/aws/hooks/emr_containers.py | 201 ++---------------- airflow/providers/amazon/aws/operators/emr.py | 12 +- airflow/providers/amazon/aws/sensors/emr.py | 9 +- .../prepare_provider_packages.py | 11 +- .../operators/emr.rst | 2 + tests/deprecated_classes.py | 4 + .../amazon/aws/hooks/test_emr_containers.py | 6 +- .../aws/operators/test_emr_containers.py | 18 +- .../amazon/aws/sensors/test_emr_containers.py | 16 +- 10 files changed, 247 insertions(+), 218 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/emr.py b/airflow/providers/amazon/aws/hooks/emr.py index 1268a426686f..48833f2340fd 100644 --- a/airflow/providers/amazon/aws/hooks/emr.py +++ b/airflow/providers/amazon/aws/hooks/emr.py @@ -15,8 +15,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from time import sleep from typing import Any, Dict, List, Optional +from botocore.exceptions import ClientError + from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook @@ -88,3 +91,186 @@ def create_job_flow(self, job_flow_overrides: Dict[str, Any]) -> Dict[str, Any]: response = self.get_conn().run_job_flow(**config) return response + + +class EmrContainerHook(AwsBaseHook): + """ + Interact with AWS EMR Virtual Cluster to run, poll jobs and return job status + Additional arguments (such as ``aws_conn_id``) may be specified and + are passed down to the underlying AwsBaseHook. + + .. seealso:: + :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` + + :param virtual_cluster_id: Cluster ID of the EMR on EKS virtual cluster + :type virtual_cluster_id: str + """ + + INTERMEDIATE_STATES = ( + "PENDING", + "SUBMITTED", + "RUNNING", + ) + FAILURE_STATES = ( + "FAILED", + "CANCELLED", + "CANCEL_PENDING", + ) + SUCCESS_STATES = ("COMPLETED",) + + def __init__(self, *args: Any, virtual_cluster_id: str = None, **kwargs: Any) -> None: + super().__init__(client_type="emr-containers", *args, **kwargs) # type: ignore + self.virtual_cluster_id = virtual_cluster_id + + def submit_job( + self, + name: str, + execution_role_arn: str, + release_label: str, + job_driver: dict, + configuration_overrides: Optional[dict] = None, + client_request_token: Optional[str] = None, + ) -> str: + """ + Submit a job to the EMR Containers API and and return the job ID. + A job run is a unit of work, such as a Spark jar, PySpark script, + or SparkSQL query, that you submit to Amazon EMR on EKS. + See: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/emr-containers.html#EMRContainers.Client.start_job_run # noqa: E501 + + :param name: The name of the job run. + :type name: str + :param execution_role_arn: The IAM role ARN associated with the job run. + :type execution_role_arn: str + :param release_label: The Amazon EMR release version to use for the job run. + :type release_label: str + :param job_driver: Job configuration details, e.g. the Spark job parameters. + :type job_driver: dict + :param configuration_overrides: The configuration overrides for the job run, + specifically either application configuration or monitoring configuration. + :type configuration_overrides: dict + :param client_request_token: The client idempotency token of the job run request. + Use this if you want to specify a unique ID to prevent two jobs from getting started. + :type client_request_token: str + :return: Job ID + """ + params = { + "name": name, + "virtualClusterId": self.virtual_cluster_id, + "executionRoleArn": execution_role_arn, + "releaseLabel": release_label, + "jobDriver": job_driver, + "configurationOverrides": configuration_overrides or {}, + } + if client_request_token: + params["clientToken"] = client_request_token + + response = self.conn.start_job_run(**params) + + if response['ResponseMetadata']['HTTPStatusCode'] != 200: + raise AirflowException(f'Start Job Run failed: {response}') + else: + self.log.info( + "Start Job Run success - Job Id %s and virtual cluster id %s", + response['id'], + response['virtualClusterId'], + ) + return response['id'] + + def get_job_failure_reason(self, job_id: str) -> Optional[str]: + """ + Fetch the reason for a job failure (e.g. error message). Returns None or reason string. + + :param job_id: Id of submitted job run + :type job_id: str + :return: str + """ + # We absorb any errors if we can't retrieve the job status + reason = None + + try: + response = self.conn.describe_job_run( + virtualClusterId=self.virtual_cluster_id, + id=job_id, + ) + failure_reason = response['jobRun']['failureReason'] + state_details = response["jobRun"]["stateDetails"] + reason = f"{failure_reason} - {state_details}" + except KeyError: + self.log.error('Could not get status of the EMR on EKS job') + except ClientError as ex: + self.log.error('AWS request failed, check logs for more info: %s', ex) + + return reason + + def check_query_status(self, job_id: str) -> Optional[str]: + """ + Fetch the status of submitted job run. Returns None or one of valid query states. + See: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/emr-containers.html#EMRContainers.Client.describe_job_run # noqa: E501 + :param job_id: Id of submitted job run + :type job_id: str + :return: str + """ + try: + response = self.conn.describe_job_run( + virtualClusterId=self.virtual_cluster_id, + id=job_id, + ) + return response["jobRun"]["state"] + except self.conn.exceptions.ResourceNotFoundException: + # If the job is not found, we raise an exception as something fatal has happened. + raise AirflowException(f'Job ID {job_id} not found on Virtual Cluster {self.virtual_cluster_id}') + except ClientError as ex: + # If we receive a generic ClientError, we swallow the exception so that the + self.log.error('AWS request failed, check logs for more info: %s', ex) + return None + + def poll_query_status( + self, job_id: str, max_tries: Optional[int] = None, poll_interval: int = 30 + ) -> Optional[str]: + """ + Poll the status of submitted job run until query state reaches final state. + Returns one of the final states. + + :param job_id: Id of submitted job run + :type job_id: str + :param max_tries: Number of times to poll for query state before function exits + :type max_tries: int + :param poll_interval: Time (in seconds) to wait between calls to check query status on EMR + :type poll_interval: int + :return: str + """ + try_number = 1 + final_query_state = None # Query state when query reaches final state or max_tries reached + + # TODO: Make this logic a little bit more robust. + # Currently this polls until the state is *not* one of the INTERMEDIATE_STATES + # While that should work in most cases...it might not. :) + while True: + query_state = self.check_query_status(job_id) + if query_state is None: + self.log.info("Try %s: Invalid query state. Retrying again", try_number) + elif query_state in self.INTERMEDIATE_STATES: + self.log.info("Try %s: Query is still in an intermediate state - %s", try_number, query_state) + else: + self.log.info("Try %s: Query execution completed. Final state is %s", try_number, query_state) + final_query_state = query_state + break + if max_tries and try_number >= max_tries: # Break loop if max_tries reached + final_query_state = query_state + break + try_number += 1 + sleep(poll_interval) + return final_query_state + + def stop_query(self, job_id: str) -> Dict: + """ + Cancel the submitted job_run + + :param job_id: Id of submitted job_run + :type job_id: str + :return: dict + """ + return self.conn.cancel_job_run( + virtualClusterId=self.virtual_cluster_id, + id=job_id, + ) diff --git a/airflow/providers/amazon/aws/hooks/emr_containers.py b/airflow/providers/amazon/aws/hooks/emr_containers.py index b340beac9c48..1e3b7a0ea4ae 100644 --- a/airflow/providers/amazon/aws/hooks/emr_containers.py +++ b/airflow/providers/amazon/aws/hooks/emr_containers.py @@ -15,193 +15,30 @@ # specific language governing permissions and limitations # under the License. -from time import sleep -from typing import Any, Dict, Optional +"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.hooks.emr`.""" -from botocore.exceptions import ClientError +import warnings -from airflow.exceptions import AirflowException -from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook +from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook +warnings.warn( + "This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.emr`.", + DeprecationWarning, + stacklevel=2, +) -class EMRContainerHook(AwsBaseHook): - """ - Interact with AWS EMR Virtual Cluster to run, poll jobs and return job status - Additional arguments (such as ``aws_conn_id``) may be specified and - are passed down to the underlying AwsBaseHook. - - .. seealso:: - :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` - :param virtual_cluster_id: Cluster ID of the EMR on EKS virtual cluster - :type virtual_cluster_id: str +class EMRContainerHook(EmrContainerHook): + """ + This class is deprecated. + Please use :class:`airflow.providers.amazon.aws.hooks.emr.EmrContainerHook`. """ - INTERMEDIATE_STATES = ( - "PENDING", - "SUBMITTED", - "RUNNING", - ) - FAILURE_STATES = ( - "FAILED", - "CANCELLED", - "CANCEL_PENDING", - ) - SUCCESS_STATES = ("COMPLETED",) - - def __init__(self, *args: Any, virtual_cluster_id: str = None, **kwargs: Any) -> None: - super().__init__(client_type="emr-containers", *args, **kwargs) # type: ignore - self.virtual_cluster_id = virtual_cluster_id - - def submit_job( - self, - name: str, - execution_role_arn: str, - release_label: str, - job_driver: dict, - configuration_overrides: Optional[dict] = None, - client_request_token: Optional[str] = None, - ) -> str: - """ - Submit a job to the EMR Containers API and and return the job ID. - A job run is a unit of work, such as a Spark jar, PySpark script, - or SparkSQL query, that you submit to Amazon EMR on EKS. - See: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/emr-containers.html#EMRContainers.Client.start_job_run # noqa: E501 - - :param name: The name of the job run. - :type name: str - :param execution_role_arn: The IAM role ARN associated with the job run. - :type execution_role_arn: str - :param release_label: The Amazon EMR release version to use for the job run. - :type release_label: str - :param job_driver: Job configuration details, e.g. the Spark job parameters. - :type job_driver: dict - :param configuration_overrides: The configuration overrides for the job run, - specifically either application configuration or monitoring configuration. - :type configuration_overrides: dict - :param client_request_token: The client idempotency token of the job run request. - Use this if you want to specify a unique ID to prevent two jobs from getting started. - :type client_request_token: str - :return: Job ID - """ - params = { - "name": name, - "virtualClusterId": self.virtual_cluster_id, - "executionRoleArn": execution_role_arn, - "releaseLabel": release_label, - "jobDriver": job_driver, - "configurationOverrides": configuration_overrides or {}, - } - if client_request_token: - params["clientToken"] = client_request_token - - response = self.conn.start_job_run(**params) - - if response['ResponseMetadata']['HTTPStatusCode'] != 200: - raise AirflowException(f'Start Job Run failed: {response}') - else: - self.log.info( - "Start Job Run success - Job Id %s and virtual cluster id %s", - response['id'], - response['virtualClusterId'], - ) - return response['id'] - - def get_job_failure_reason(self, job_id: str) -> Optional[str]: - """ - Fetch the reason for a job failure (e.g. error message). Returns None or reason string. - - :param job_id: Id of submitted job run - :type job_id: str - :return: str - """ - # We absorb any errors if we can't retrieve the job status - reason = None - - try: - response = self.conn.describe_job_run( - virtualClusterId=self.virtual_cluster_id, - id=job_id, - ) - failure_reason = response['jobRun']['failureReason'] - state_details = response["jobRun"]["stateDetails"] - reason = f"{failure_reason} - {state_details}" - except KeyError: - self.log.error('Could not get status of the EMR on EKS job') - except ClientError as ex: - self.log.error('AWS request failed, check logs for more info: %s', ex) - - return reason - - def check_query_status(self, job_id: str) -> Optional[str]: - """ - Fetch the status of submitted job run. Returns None or one of valid query states. - See: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/emr-containers.html#EMRContainers.Client.describe_job_run # noqa: E501 - :param job_id: Id of submitted job run - :type job_id: str - :return: str - """ - try: - response = self.conn.describe_job_run( - virtualClusterId=self.virtual_cluster_id, - id=job_id, - ) - return response["jobRun"]["state"] - except self.conn.exceptions.ResourceNotFoundException: - # If the job is not found, we raise an exception as something fatal has happened. - raise AirflowException(f'Job ID {job_id} not found on Virtual Cluster {self.virtual_cluster_id}') - except ClientError as ex: - # If we receive a generic ClientError, we swallow the exception so that the - self.log.error('AWS request failed, check logs for more info: %s', ex) - return None - - def poll_query_status( - self, job_id: str, max_tries: Optional[int] = None, poll_interval: int = 30 - ) -> Optional[str]: - """ - Poll the status of submitted job run until query state reaches final state. - Returns one of the final states. - - :param job_id: Id of submitted job run - :type job_id: str - :param max_tries: Number of times to poll for query state before function exits - :type max_tries: int - :param poll_interval: Time (in seconds) to wait between calls to check query status on EMR - :type poll_interval: int - :return: str - """ - try_number = 1 - final_query_state = None # Query state when query reaches final state or max_tries reached - - # TODO: Make this logic a little bit more robust. - # Currently this polls until the state is *not* one of the INTERMEDIATE_STATES - # While that should work in most cases...it might not. :) - while True: - query_state = self.check_query_status(job_id) - if query_state is None: - self.log.info("Try %s: Invalid query state. Retrying again", try_number) - elif query_state in self.INTERMEDIATE_STATES: - self.log.info("Try %s: Query is still in an intermediate state - %s", try_number, query_state) - else: - self.log.info("Try %s: Query execution completed. Final state is %s", try_number, query_state) - final_query_state = query_state - break - if max_tries and try_number >= max_tries: # Break loop if max_tries reached - final_query_state = query_state - break - try_number += 1 - sleep(poll_interval) - return final_query_state - - def stop_query(self, job_id: str) -> Dict: - """ - Cancel the submitted job_run - - :param job_id: Id of submitted job_run - :type job_id: str - :return: dict - """ - return self.conn.cancel_job_run( - virtualClusterId=self.virtual_cluster_id, - id=job_id, + def __init__(self, **kwargs): + warnings.warn( + """This class is deprecated. + Please use `airflow.providers.amazon.aws.hooks.emr.EmrContainerHook`.""", + DeprecationWarning, + stacklevel=2, ) + super().__init__(**kwargs) diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index 407dbd82d903..9d5842c3a53e 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -30,7 +30,7 @@ else: from cached_property import cached_property -from airflow.providers.amazon.aws.hooks.emr_containers import EMRContainerHook +from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook class EmrAddStepsOperator(BaseOperator): @@ -177,9 +177,9 @@ def __init__( self.job_id = None @cached_property - def hook(self) -> EMRContainerHook: - """Create and return an EMRContainerHook.""" - return EMRContainerHook( + def hook(self) -> EmrContainerHook: + """Create and return an EmrContainerHook.""" + return EmrContainerHook( self.aws_conn_id, virtual_cluster_id=self.virtual_cluster_id, ) @@ -196,13 +196,13 @@ def execute(self, context: dict) -> Optional[str]: ) query_status = self.hook.poll_query_status(self.job_id, self.max_tries, self.poll_interval) - if query_status in EMRContainerHook.FAILURE_STATES: + if query_status in EmrContainerHook.FAILURE_STATES: error_message = self.hook.get_job_failure_reason(self.job_id) raise AirflowException( f"EMR Containers job failed. Final state is {query_status}. " f"query_execution_id is {self.job_id}. Error: {error_message}" ) - elif not query_status or query_status in EMRContainerHook.INTERMEDIATE_STATES: + elif not query_status or query_status in EmrContainerHook.INTERMEDIATE_STATES: raise AirflowException( f"Final state of EMR Containers job is {query_status}. " f"Max tries of poll status exceeded, query_execution_id is {self.job_id}." diff --git a/airflow/providers/amazon/aws/sensors/emr.py b/airflow/providers/amazon/aws/sensors/emr.py index dce7733a026d..31ceba62e974 100644 --- a/airflow/providers/amazon/aws/sensors/emr.py +++ b/airflow/providers/amazon/aws/sensors/emr.py @@ -24,8 +24,7 @@ from cached_property import cached_property from airflow.exceptions import AirflowException -from airflow.providers.amazon.aws.hooks.emr import EmrHook -from airflow.providers.amazon.aws.hooks.emr_containers import EMRContainerHook +from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook from airflow.sensors.base import BaseSensorOperator @@ -178,9 +177,9 @@ def poke(self, context: dict) -> bool: return True @cached_property - def hook(self) -> EMRContainerHook: - """Create and return an EMRContainerHook""" - return EMRContainerHook(self.aws_conn_id, virtual_cluster_id=self.virtual_cluster_id) + def hook(self) -> EmrContainerHook: + """Create and return an EmrContainerHook""" + return EmrContainerHook(self.aws_conn_id, virtual_cluster_id=self.virtual_cluster_id) class EmrJobFlowSensor(EmrBaseSensor): diff --git a/dev/provider_packages/prepare_provider_packages.py b/dev/provider_packages/prepare_provider_packages.py index 1e1e2c6172b0..c1539278bddb 100755 --- a/dev/provider_packages/prepare_provider_packages.py +++ b/dev/provider_packages/prepare_provider_packages.py @@ -2156,11 +2156,12 @@ def summarise_total_vs_bad_and_warnings(total: int, bad: int, warns: List[warnin "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.dms`.", 'This module is deprecated. Please use `airflow.providers.amazon.aws.operators.emr`.', 'This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.emr`.', - "This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.redshift_cluster` " - "or `airflow.providers.amazon.aws.hooks.redshift_sql` as appropriate.", - "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.redshift_sql` " - "or `airflow.providers.amazon.aws.operators.redshift_cluster` as appropriate.", - "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.redshift_cluster`.", + 'This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.emr`.', + 'This module is deprecated. Please use `airflow.hooks.redshift_sql` ' + 'or `airflow.hooks.redshift_cluster` as appropriate.', + 'This module is deprecated. Please use `airflow.operators.redshift_sql` or ' + '`airflow.operators.redshift_cluster` as appropriate.', + 'This module is deprecated. Please use `airflow.sensors.redshift_cluster`.', } diff --git a/docs/apache-airflow-providers-amazon/operators/emr.rst b/docs/apache-airflow-providers-amazon/operators/emr.rst index 4bf4d3a0a25e..817bf1833921 100644 --- a/docs/apache-airflow-providers-amazon/operators/emr.rst +++ b/docs/apache-airflow-providers-amazon/operators/emr.rst @@ -33,10 +33,12 @@ Airflow to AWS EMR integration provides several operators to create and interact - :class:`~airflow.providers.amazon.aws.sensors.emr.EmrJobFlowSensor` - :class:`~airflow.providers.amazon.aws.sensors.emr.EmrStepSensor` + - :class:`~airflow.providers.amazon.aws.sensors.emr.EmrContainerSensor` - :class:`~airflow.providers.amazon.aws.operators.emr.EmrCreateJobFlowOperator` - :class:`~airflow.providers.amazon.aws.operators.emr.EmrAddStepsOperator` - :class:`~airflow.providers.amazon.aws.operators.emr.EmrModifyClusterOperator` - :class:`~airflow.providers.amazon.aws.operators.emr.EmrTerminateJobFlowOperator` + - :class:`~airflow.providers.amazon.aws.operators.emr.EmrContainerOperator` Two example_dags are provided which showcase these operators in action. diff --git a/tests/deprecated_classes.py b/tests/deprecated_classes.py index e258c4d5bbcb..11edfddd0fa7 100644 --- a/tests/deprecated_classes.py +++ b/tests/deprecated_classes.py @@ -251,6 +251,10 @@ 'airflow.providers.amazon.aws.hooks.logs.AwsLogsHook', 'airflow.contrib.hooks.aws_logs_hook.AwsLogsHook', ), + ( + 'airflow.providers.amazon.aws.hooks.emr.EmrContainerHook', + 'airflow.providers.amazon.aws.hooks.emr_containers.EMRContainerHook', + ), ( 'airflow.providers.amazon.aws.hooks.emr.EmrHook', 'airflow.contrib.hooks.emr_hook.EmrHook', diff --git a/tests/providers/amazon/aws/hooks/test_emr_containers.py b/tests/providers/amazon/aws/hooks/test_emr_containers.py index 8b8db5d330d5..a755a662b437 100644 --- a/tests/providers/amazon/aws/hooks/test_emr_containers.py +++ b/tests/providers/amazon/aws/hooks/test_emr_containers.py @@ -20,7 +20,7 @@ import unittest from unittest import mock -from airflow.providers.amazon.aws.hooks.emr_containers import EMRContainerHook +from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook SUBMIT_JOB_SUCCESS_RETURN = { 'ResponseMetadata': {'HTTPStatusCode': 200}, @@ -29,9 +29,9 @@ } -class TestEMRContainerHook(unittest.TestCase): +class TestEmrContainerHook(unittest.TestCase): def setUp(self): - self.emr_containers = EMRContainerHook(virtual_cluster_id='vc1234') + self.emr_containers = EmrContainerHook(virtual_cluster_id='vc1234') def test_init(self): assert self.emr_containers.aws_conn_id == 'aws_default' diff --git a/tests/providers/amazon/aws/operators/test_emr_containers.py b/tests/providers/amazon/aws/operators/test_emr_containers.py index 17b67cd40aea..24d7eb6f828e 100644 --- a/tests/providers/amazon/aws/operators/test_emr_containers.py +++ b/tests/providers/amazon/aws/operators/test_emr_containers.py @@ -23,7 +23,7 @@ from airflow import configuration from airflow.exceptions import AirflowException -from airflow.providers.amazon.aws.hooks.emr_containers import EMRContainerHook +from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook from airflow.providers.amazon.aws.operators.emr import EmrContainerOperator SUBMIT_JOB_SUCCESS_RETURN = { @@ -36,7 +36,7 @@ class TestEmrContainerOperator(unittest.TestCase): - @mock.patch('airflow.providers.amazon.aws.hooks.emr_containers.EMRContainerHook') + @mock.patch('airflow.providers.amazon.aws.hooks.emr.EmrContainerHook') def setUp(self, emr_hook_mock): configuration.load_test_config() @@ -53,8 +53,8 @@ def setUp(self, emr_hook_mock): client_request_token=GENERATED_UUID, ) - @mock.patch.object(EMRContainerHook, 'submit_job') - @mock.patch.object(EMRContainerHook, 'check_query_status') + @mock.patch.object(EmrContainerHook, 'submit_job') + @mock.patch.object(EmrContainerHook, 'check_query_status') def test_execute_without_failure( self, mock_check_query_status, @@ -72,7 +72,7 @@ def test_execute_without_failure( assert self.emr_container.release_label == '6.3.0-latest' @mock.patch.object( - EMRContainerHook, + EmrContainerHook, 'check_query_status', side_effect=['PENDING', 'PENDING', 'SUBMITTED', 'RUNNING', 'COMPLETED'], ) @@ -88,9 +88,9 @@ def test_execute_with_polling(self, mock_check_query_status): assert self.emr_container.execute(None) == 'job123456' assert mock_check_query_status.call_count == 5 - @mock.patch.object(EMRContainerHook, 'submit_job') - @mock.patch.object(EMRContainerHook, 'check_query_status') - @mock.patch.object(EMRContainerHook, 'get_job_failure_reason') + @mock.patch.object(EmrContainerHook, 'submit_job') + @mock.patch.object(EmrContainerHook, 'check_query_status') + @mock.patch.object(EmrContainerHook, 'get_job_failure_reason') def test_execute_with_failure( self, mock_get_job_failure_reason, mock_check_query_status, mock_submit_job ): @@ -105,7 +105,7 @@ def test_execute_with_failure( assert 'Error: CLUSTER_UNAVAILABLE - Cluster EKS eks123456 does not exist.' in str(ctx.value) @mock.patch.object( - EMRContainerHook, + EmrContainerHook, 'check_query_status', side_effect=['PENDING', 'PENDING', 'SUBMITTED', 'RUNNING', 'COMPLETED'], ) diff --git a/tests/providers/amazon/aws/sensors/test_emr_containers.py b/tests/providers/amazon/aws/sensors/test_emr_containers.py index 6ab45c50ec64..4a6dba835c5d 100644 --- a/tests/providers/amazon/aws/sensors/test_emr_containers.py +++ b/tests/providers/amazon/aws/sensors/test_emr_containers.py @@ -22,7 +22,7 @@ import pytest from airflow.exceptions import AirflowException -from airflow.providers.amazon.aws.hooks.emr_containers import EMRContainerHook +from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook from airflow.providers.amazon.aws.sensors.emr import EmrContainerSensor @@ -37,35 +37,35 @@ def setUp(self): aws_conn_id='aws_default', ) - @mock.patch.object(EMRContainerHook, 'check_query_status', side_effect=("PENDING",)) + @mock.patch.object(EmrContainerHook, 'check_query_status', side_effect=("PENDING",)) def test_poke_pending(self, mock_check_query_status): assert not self.sensor.poke(None) - @mock.patch.object(EMRContainerHook, 'check_query_status', side_effect=("SUBMITTED",)) + @mock.patch.object(EmrContainerHook, 'check_query_status', side_effect=("SUBMITTED",)) def test_poke_submitted(self, mock_check_query_status): assert not self.sensor.poke(None) - @mock.patch.object(EMRContainerHook, 'check_query_status', side_effect=("RUNNING",)) + @mock.patch.object(EmrContainerHook, 'check_query_status', side_effect=("RUNNING",)) def test_poke_running(self, mock_check_query_status): assert not self.sensor.poke(None) - @mock.patch.object(EMRContainerHook, 'check_query_status', side_effect=("COMPLETED",)) + @mock.patch.object(EmrContainerHook, 'check_query_status', side_effect=("COMPLETED",)) def test_poke_completed(self, mock_check_query_status): assert self.sensor.poke(None) - @mock.patch.object(EMRContainerHook, 'check_query_status', side_effect=("FAILED",)) + @mock.patch.object(EmrContainerHook, 'check_query_status', side_effect=("FAILED",)) def test_poke_failed(self, mock_check_query_status): with pytest.raises(AirflowException) as ctx: self.sensor.poke(None) assert 'EMR Containers sensor failed' in str(ctx.value) - @mock.patch.object(EMRContainerHook, 'check_query_status', side_effect=("CANCELLED",)) + @mock.patch.object(EmrContainerHook, 'check_query_status', side_effect=("CANCELLED",)) def test_poke_cancelled(self, mock_check_query_status): with pytest.raises(AirflowException) as ctx: self.sensor.poke(None) assert 'EMR Containers sensor failed' in str(ctx.value) - @mock.patch.object(EMRContainerHook, 'check_query_status', side_effect=("CANCEL_PENDING",)) + @mock.patch.object(EmrContainerHook, 'check_query_status', side_effect=("CANCEL_PENDING",)) def test_poke_cancel_pending(self, mock_check_query_status): with pytest.raises(AirflowException) as ctx: self.sensor.poke(None) From f5d76c30b44f8cfd5e64b9ddb9950a061b9f2c48 Mon Sep 17 00:00:00 2001 From: bhavaniravi Date: Tue, 21 Dec 2021 14:01:15 +0530 Subject: [PATCH 2/2] remove redshift from provider packages --- dev/provider_packages/prepare_provider_packages.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/dev/provider_packages/prepare_provider_packages.py b/dev/provider_packages/prepare_provider_packages.py index c1539278bddb..8c1204dd37c2 100755 --- a/dev/provider_packages/prepare_provider_packages.py +++ b/dev/provider_packages/prepare_provider_packages.py @@ -2156,12 +2156,12 @@ def summarise_total_vs_bad_and_warnings(total: int, bad: int, warns: List[warnin "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.dms`.", 'This module is deprecated. Please use `airflow.providers.amazon.aws.operators.emr`.', 'This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.emr`.', + "This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.redshift_cluster` " + "or `airflow.providers.amazon.aws.hooks.redshift_sql` as appropriate.", + "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.redshift_sql` " + "or `airflow.providers.amazon.aws.operators.redshift_cluster` as appropriate.", + "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.redshift_cluster`.", 'This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.emr`.', - 'This module is deprecated. Please use `airflow.hooks.redshift_sql` ' - 'or `airflow.hooks.redshift_cluster` as appropriate.', - 'This module is deprecated. Please use `airflow.operators.redshift_sql` or ' - '`airflow.operators.redshift_cluster` as appropriate.', - 'This module is deprecated. Please use `airflow.sensors.redshift_cluster`.', }