Skip to content

Commit

Permalink
Respect soft_fail argument when running SageMakerBaseSensor (#34565)
Browse files Browse the repository at this point in the history
  • Loading branch information
utkarsharma2 committed Sep 22, 2023
1 parent f56acda commit e76b505
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 5 deletions.
10 changes: 6 additions & 4 deletions airflow/providers/amazon/aws/sensors/sagemaker.py
Expand Up @@ -22,7 +22,7 @@

from deprecated import deprecated

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException
from airflow.providers.amazon.aws.hooks.sagemaker import LogState, SageMakerHook
from airflow.sensors.base import BaseSensorOperator

Expand Down Expand Up @@ -65,9 +65,11 @@ def poke(self, context: Context):
return False
if state in self.failed_states():
failed_reason = self.get_failed_reason_from_response(response)
raise AirflowException(
f"Sagemaker {self.resource_type} failed for the following reason: {failed_reason}"
)
# TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
message = f"Sagemaker {self.resource_type} failed for the following reason: {failed_reason}"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)
return True

def non_terminal_states(self) -> set[str]:
Expand Down
32 changes: 31 additions & 1 deletion tests/providers/amazon/aws/sensors/test_sagemaker_base.py
Expand Up @@ -19,7 +19,7 @@

import pytest

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.amazon.aws.sensors.sagemaker import SageMakerBaseSensor


Expand Down Expand Up @@ -109,3 +109,33 @@ def state_from_response(self, response):

with pytest.raises(AirflowException):
sensor.poke(None)

@pytest.mark.parametrize(
"soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
)
def test_fail_poke(self, soft_fail, expected_exception):
resource_type = "job"

class SageMakerBaseSensorSubclass(SageMakerBaseSensor):
def non_terminal_states(self):
return ["PENDING", "RUNNING", "CONTINUE"]

def failed_states(self):
return ["FAILED"]

def get_sagemaker_response(self):
return {"SomeKey": {"State": "FAILED"}, "ResponseMetadata": {"HTTPStatusCode": 200}}

def state_from_response(self, response):
return response["SomeKey"]["State"]

sensor = SageMakerBaseSensorSubclass(
task_id="test_task", poke_interval=2, aws_conn_id="aws_test", resource_type=resource_type
)
sensor.soft_fail = soft_fail
message = (
f"Sagemaker {resource_type} failed for the following reason:"
f" {sensor.get_failed_reason_from_response({})}"
)
with pytest.raises(expected_exception, match=message):
sensor.poke(context={})

0 comments on commit e76b505

Please sign in to comment.