Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions airflow/providers/amazon/aws/example_dags/example_emr_change_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
from datetime import datetime

from airflow import DAG
from airflow.models.baseoperator import chain
from airflow.providers.amazon.aws.operators.emr import (
EmrAddStepsOperator,
EmrCreateJobFlowOperator,
EmrChangePolicyOperator
)
from airflow.providers.amazon.aws.sensors.emr import EmrStepSensor

JOB_FLOW_ROLE = os.getenv('EMR_JOB_FLOW_ROLE', 'EMR_EC2_DefaultRole')
SERVICE_ROLE = os.getenv('EMR_SERVICE_ROLE', 'EMR_DefaultRole')

SPARK_STEPS = [
{
'Name': 'calculate_pi',
'ActionOnFailure': 'CONTINUE',
'HadoopJarStep': {
'Jar': 'command-runner.jar',
'Args': ['/usr/lib/spark/bin/run-example', 'SparkPi', '10'],
},
}
]

JOB_FLOW_OVERRIDES = {
'Name': 'PiCalc',
'ReleaseLabel': 'emr-6.4.0',
'Applications': [{'Name': 'Spark'}],
'Instances': {
'InstanceGroups': [
{
'Name': 'Primary node',
'Market': 'ON_DEMAND',
'InstanceRole': 'MASTER',
'InstanceType': 'm5.xlarge',
'InstanceCount': 1,
},
],
'KeepJobFlowAliveWhenNoSteps': True,
'TerminationProtected': False,
},
'JobFlowRole': JOB_FLOW_ROLE,
'ServiceRole': SERVICE_ROLE,
}


with DAG(
dag_id='example_emr_put_auto_terminate_policy',
schedule_interval=None,
start_date=datetime(2021, 1, 1),
tags=['example'],
catchup=False,
) as dag:

cluster_creator = EmrCreateJobFlowOperator(
task_id='create_job_flow',
job_flow_overrides=JOB_FLOW_OVERRIDES,
)

# [START howto_operator_emr_auto_terminate_policy]
change_policy = EmrChangePolicyOperator(
task_id='change_policy',
job_flow_id=cluster_creator.output,
idle_timeout=300
)
# [END howto_operator_emr_terminate_job_flow]

# [START howto_operator_emr_add_steps]
step_adder = EmrAddStepsOperator(
task_id='add_steps',
job_flow_id=cluster_creator.output,
steps=SPARK_STEPS,
)
# [END howto_operator_emr_add_steps]

# [START howto_sensor_emr_step_sensor]
step_checker = EmrStepSensor(
task_id='watch_step',
job_flow_id=cluster_creator.output,
step_id=step_adder.output[0],
)
# [END howto_sensor_emr_step_sensor]


chain(
cluster_creator,
put_auto_terminate,
step_adder,
step_checker,
)
77 changes: 77 additions & 0 deletions airflow/providers/amazon/aws/operators/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,3 +404,80 @@ def execute(self, context: 'Context') -> None:
raise AirflowException(f'JobFlow termination failed: {response}')
else:
self.log.info('JobFlow with id %s terminated', self.job_flow_id)


class EmrChangePolicyOperator(BaseOperator):
"""
An operator to change policy on a given cluster/jobflow
Note: auto terminate policy is supported with Amazon EMR versions 5.30.0 and 6.1.0 and later.

.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:EmrChangePolicyOperator`

:param idle_timeout: Time in seconds to auto terminate the emr cluster if it is idle.
The timeout must be between 60 seconds and a max of 604800 seconds (7 days). (templated)
:param job_flow_id: id of the JobFlow to add the auto terminate policy (templated)
:param job_flow_name: name of the JobFlow to add the auto terminate policy. Use as an alternative to passing
job_flow_id. will search for id of JobFlow with matching name in one of the states in
param cluster_states. Exactly one cluster like this should exist or will fail. (templated)
:param cluster_states: Acceptable cluster states when searching for JobFlow id by job_flow_name.
(templated)
:param aws_conn_id: aws connection to uses
"""

template_fields: Sequence[str] = ('job_flow_id', 'job_flow_name', 'cluster_states', 'idle_timeout')
template_ext: Sequence[str] = ('.json',)
template_fields_renderers = {"steps": "json"}
ui_color = '#f9c915'

def __init__(
self,
idle_timeout: int,
job_flow_id: Optional[str] = None,
job_flow_name: Optional[str] = None,
cluster_states: Optional[List[str]] = None,
aws_conn_id: str = 'aws_default',
**kwargs
):
if kwargs.get('xcom_push') is not None:
raise AirflowException("'xcom_push' was deprecated, use 'do_xcom_push' instead")
if not (job_flow_id is None) ^ (job_flow_name is None):
raise AirflowException('Exactly one of job_flow_id or job_flow_name must be specified.')
super().__init__(**kwargs)
self.aws_conn_id = aws_conn_id
self.job_flow_id = job_flow_id
self.job_flow_name = job_flow_name
self.cluster_states = cluster_states
self.idle_timeout = idle_timeout

def execute(self, context: 'Context') -> None:
emr_hook = EmrHook(aws_conn_id=self.aws_conn_id)

emr = emr_hook.get_conn()

job_flow_id = self.job_flow_id or emr_hook.get_cluster_id_by_name(
self.job_flow_name, self.cluster_states
)

if not job_flow_id:
raise AirflowException(f'No cluster found for name: {self.job_flow_name}')

if self.do_xcom_push:
context['ti'].xcom_push(key='job_flow_id', value=job_flow_id)

self.log.info(f'Adding auto terminate policy to {job_flow_id}')

idle_timeout = self.idle_timeout

response = emr.put_auto_termination_policy(
ClusterId=job_flow_id,
AutoTerminationPolicy={
'IdleTimeout': idle_timeout
}
)

if not response['ResponseMetadata']['HTTPStatusCode'] == 200:
raise AirflowException(f'Adding auto terminate policy to the cluster failed: {response}')
else:
self.log.info(f'Cluster will auto terminate when idle for {idle_timeout}')
14 changes: 14 additions & 0 deletions docs/apache-airflow-providers-amazon/operators/emr.rst
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,20 @@ To add Steps to an existing EMR Job Flow you can use

.. _howto/operator:EmrTerminateJobFlowOperator:

Add auto terminate policy to EMR Job Flow
----------------------------

To add auto terminate policy to an existing EMR Job Flow you can use
:class:`~airflow.providers.amazon.aws.operators.emr.EmrChangePolicyOperator`.

.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_emr_put_auto_terminate_policy.py
:language: python
:dedent: 4
:start-after: [START howto_operator_emr_auto_terminate_policy]
:end-before: [END howto_operator_emr_auto_terminate_policy]

.. _howto/operator:EmrChangePolicyOperator:

Terminate an EMR Job Flow
-------------------------

Expand Down
157 changes: 157 additions & 0 deletions tests/providers/amazon/aws/operators/test_emr_change_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import os
import unittest
from unittest.mock import MagicMock, patch

import pytest

from airflow.exceptions import AirflowException
from airflow.models import DAG
from airflow.providers.amazon.aws.operators.emr import EmrChangePolicyOperator
from airflow.utils import timezone
from tests.test_utils import AIRFLOW_MAIN_FOLDER

DEFAULT_DATE = timezone.datetime(2017, 1, 1)

PUT_AUTO_TERMINATION_POLICY_SUCCESS_RETURN = {'ResponseMetadata': {'HTTPStatusCode': 200}}

PUT_AUTO_TERMINATION_POLICY_ERROR_RETURN = {'ResponseMetadata': {'HTTPStatusCode': 400}}

TEMPLATE_SEARCHPATH = os.path.join(
AIRFLOW_MAIN_FOLDER, 'tests', 'providers', 'amazon', 'aws', 'config_templates'
)


class TestEmrChangePolicyOperator(unittest.TestCase):

def setUp(self):
self.args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}

# Mock out the emr_client (moto has incorrect response)
self.emr_client_mock = MagicMock()

# Mock out the emr_client creator
emr_session_mock = MagicMock()
emr_session_mock.client.return_value = self.emr_client_mock
self.boto3_session_mock = MagicMock(return_value=emr_session_mock)

self.mock_context = MagicMock()

self.operator = EmrChangePolicyOperator(
idle_timeout=600,
task_id='test_task',
job_flow_id='j-8989898989',
aws_conn_id='aws_default',
dag=DAG('test_dag_id', default_args=self.args),
)

def test_init(self):
assert self.operator.job_flow_id == 'j-8989898989'
assert self.operator.aws_conn_id == 'aws_default'
assert self.operator.idle_timeout == 600

def test_init_fails_with_no_job_flow_arguments(self):
with pytest.raises(AirflowException):
EmrChangePolicyOperator(
idle_timeout=600,
task_id='test_task',
dag=DAG('test_dag_id', default_args=self.args),
)

def test_render_template_from_file(self):
dag = DAG(
dag_id='test',
default_args=self.args
)

self.emr_client_mock.put_auto_termination_policy.return_value = \
PUT_AUTO_TERMINATION_POLICY_SUCCESS_RETURN

test_task = EmrChangePolicyOperator(
task_id='test_task',
job_flow_id='j-8989898989',
aws_conn_id='aws_default',
idle_timeout=600,
dag=dag,
do_xcom_push=False,
)

with patch('boto3.session.Session', self.boto3_session_mock):
test_task.execute(None)

self.emr_client_mock.put_auto_termination_policy.assert_called_once_with(
ClusterId='j-8989898989',
AutoTerminationPolicy={
'IdleTimeout': 600
}
)

def test_init_with_cluster_name(self):
expected_job_flow_id = 'j-1231231234'

self.emr_client_mock.add_job_flow_steps.return_value = PUT_AUTO_TERMINATION_POLICY_SUCCESS_RETURN

with patch('boto3.session.Session', self.boto3_session_mock):
with patch(
'airflow.providers.amazon.aws.hooks.emr.EmrHook.get_cluster_id_by_name'
) as mock_get_cluster_id_by_name:
mock_get_cluster_id_by_name.return_value = expected_job_flow_id

operator = EmrChangePolicyOperator(
idle_timeout=600,
task_id='test_task',
job_flow_name='test_cluster',
cluster_states=['RUNNING', 'WAITING'],
aws_conn_id='aws_default',
dag=DAG('test_dag_id', default_args=self.args),
)
with pytest.raises(AirflowException):
operator.execute(self.mock_context)
ti = self.mock_context['ti']
ti.xcom_push.assert_called_once_with(key='job_flow_id', value=expected_job_flow_id)

def test_init_with_non_existent_cluster_name(self):
cluster_name = 'test_cluster'

with patch(
'airflow.providers.amazon.aws.hooks.emr.EmrHook.get_cluster_id_by_name'
) as mock_get_cluster_id_by_name:
mock_get_cluster_id_by_name.return_value = None

operator = EmrChangePolicyOperator(
idle_timeout=600,
task_id='test_task',
job_flow_name=cluster_name,
cluster_states=['RUNNING', 'WAITING'],
aws_conn_id='aws_default',
dag=DAG('test_dag_id', default_args=self.args),
)

with pytest.raises(AirflowException) as ctx:
operator.execute(self.mock_context)
assert str(ctx.value) == f'No cluster found for name: {cluster_name}'

def test_execute_returns_error(self):
self.emr_client_mock.put_auto_termination_policy.return_value = \
PUT_AUTO_TERMINATION_POLICY_ERROR_RETURN

with patch('boto3.session.Session', self.boto3_session_mock):
with pytest.raises(AirflowException):
self.operator.execute(self.mock_context)