Skip to content

Commit

Permalink
Add tests for EmrServerlessJobSensor and `EmrServerlessApplicationS…
Browse files Browse the repository at this point in the history
…ensor` (#39099)
  • Loading branch information
mateuslatrova committed Apr 26, 2024
1 parent 2a913b6 commit 6d09adf
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 2 deletions.
10 changes: 8 additions & 2 deletions airflow/providers/amazon/aws/sensors/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@
from deprecated import deprecated

from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException
from airflow.exceptions import (
AirflowException,
AirflowProviderDeprecationWarning,
AirflowSkipException,
)
from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook
from airflow.providers.amazon.aws.links.emr import EmrClusterLink, EmrLogsLink, get_log_uri
from airflow.providers.amazon.aws.triggers.emr import (
Expand Down Expand Up @@ -231,7 +235,9 @@ def poke(self, context: Context) -> bool:

if state in EmrServerlessHook.APPLICATION_FAILURE_STATES:
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
failure_message = f"EMR Serverless job failed: {self.failure_message_from_response(response)}"
failure_message = (
f"EMR Serverless application failed: {self.failure_message_from_response(response)}"
)
if self.soft_fail:
raise AirflowSkipException(failure_message)
raise AirflowException(failure_message)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from unittest.mock import MagicMock

import pytest

from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.amazon.aws.sensors.emr import EmrServerlessApplicationSensor


class TestEmrServerlessApplicationSensor:
def setup_method(self):
self.app_id = "vzwemreks"
self.job_run_id = "job1234"
self.sensor = EmrServerlessApplicationSensor(
task_id="test_emrcontainer_sensor",
application_id=self.app_id,
aws_conn_id="aws_default",
)

def set_get_application_return_value(self, return_value: dict[str, str]):
self.mock_hook = MagicMock()
self.mock_hook.conn.get_application.return_value = return_value
self.sensor.hook = self.mock_hook

def assert_get_application_was_called_once_with_app_id(self):
self.mock_hook.conn.get_application.assert_called_once_with(applicationId=self.app_id)


class TestPokeReturnValue(TestEmrServerlessApplicationSensor):
@pytest.mark.parametrize(
"state, expected_result",
[
("CREATING", False),
("STARTING", False),
("STOPPING", False),
("CREATED", True),
("STARTED", True),
],
)
def test_poke_returns_expected_result_for_states(self, state, expected_result):
get_application_return_value = {"application": {"state": state}}
self.set_get_application_return_value(get_application_return_value)
assert self.sensor.poke(None) == expected_result
self.assert_get_application_was_called_once_with_app_id()


class TestPokeRaisesAirflowException(TestEmrServerlessApplicationSensor):
@pytest.mark.parametrize("state", ["STOPPED", "TERMINATED"])
def test_poke_raises_airflow_exception_with_failure_states(self, state):
state_details = f"mock {state}"
exception_msg = f"EMR Serverless application failed: {state_details}"
get_job_run_return_value = {"application": {"state": state, "stateDetails": state_details}}
self.set_get_application_return_value(get_job_run_return_value)

with pytest.raises(AirflowException) as ctx:
self.sensor.poke(None)

assert exception_msg == str(ctx.value)
self.assert_get_application_was_called_once_with_app_id()


class TestPokeRaisesAirflowSkipException(TestEmrServerlessApplicationSensor):
def test_when_state_is_failed_and_soft_fail_is_true_poke_should_raise_skip_exception(self):
self.sensor.soft_fail = True
self.set_get_application_return_value(
{"application": {"state": "STOPPED", "stateDetails": "mock stopped"}}
)
with pytest.raises(AirflowSkipException) as ctx:
self.sensor.poke(None)
assert "EMR Serverless application failed: mock stopped" == str(ctx.value)
self.assert_get_application_was_called_once_with_app_id()
self.sensor.soft_fail = False
91 changes: 91 additions & 0 deletions tests/providers/amazon/aws/sensors/test_emr_serverless_job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from unittest.mock import MagicMock

import pytest

from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.amazon.aws.sensors.emr import EmrServerlessJobSensor


class TestEmrServerlessJobSensor:
def setup_method(self):
self.app_id = "vzwemreks"
self.job_run_id = "job1234"
self.sensor = EmrServerlessJobSensor(
task_id="test_emrcontainer_sensor",
application_id=self.app_id,
job_run_id=self.job_run_id,
aws_conn_id="aws_default",
)

def set_get_job_run_return_value(self, return_value: dict[str, str]):
self.mock_hook = MagicMock()
self.mock_hook.conn.get_job_run.return_value = return_value
self.sensor.hook = self.mock_hook

def assert_get_job_run_was_called_once_with_app_and_run_id(self):
self.mock_hook.conn.get_job_run.assert_called_once_with(
applicationId=self.app_id, jobRunId=self.job_run_id
)


class TestPokeReturnValue(TestEmrServerlessJobSensor):
@pytest.mark.parametrize(
"state, expected_result",
[
("PENDING", False),
("RUNNING", False),
("SCHEDULED", False),
("SUBMITTED", False),
("SUCCESS", True),
],
)
def test_poke_returns_expected_result_for_states(self, state, expected_result):
get_job_run_return_value = {"jobRun": {"state": state}}
self.set_get_job_run_return_value(get_job_run_return_value)
assert self.sensor.poke(None) == expected_result
self.assert_get_job_run_was_called_once_with_app_and_run_id()


class TestPokeRaisesAirflowException(TestEmrServerlessJobSensor):
@pytest.mark.parametrize("state", ["FAILED", "CANCELLING", "CANCELLED"])
def test_poke_raises_airflow_exception_with_specified_states(self, state):
state_details = f"mock {state}"
exception_msg = f"EMR Serverless job failed: {state_details}"
get_job_run_return_value = {"jobRun": {"state": state, "stateDetails": state_details}}
self.set_get_job_run_return_value(get_job_run_return_value)

with pytest.raises(AirflowException) as ctx:
self.sensor.poke(None)

assert exception_msg == str(ctx.value)
self.assert_get_job_run_was_called_once_with_app_and_run_id()


class TestPokeRaisesAirflowSkipException(TestEmrServerlessJobSensor):
def test_when_state_is_failed_and_soft_fail_is_true_poke_should_raise_skip_exception(self):
self.sensor.soft_fail = True
self.set_get_job_run_return_value({"jobRun": {"state": "FAILED", "stateDetails": "mock failed"}})
with pytest.raises(AirflowSkipException) as ctx:
self.sensor.poke(None)
assert "EMR Serverless job failed: mock failed" == str(ctx.value)
self.assert_get_job_run_was_called_once_with_app_and_run_id()
self.sensor.soft_fail = False

0 comments on commit 6d09adf

Please sign in to comment.