Skip to content

Commit

Permalink
fix(providers/alibaba): respect soft_fail argument when exception is …
Browse files Browse the repository at this point in the history
…raised (#34157)
  • Loading branch information
Lee-W committed Sep 7, 2023
1 parent 44cb7c6 commit 7696e41
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 22 deletions.
14 changes: 11 additions & 3 deletions airflow/providers/alibaba/cloud/sensors/oss_key.py
Expand Up @@ -21,7 +21,7 @@
from typing import TYPE_CHECKING, Sequence
from urllib.parse import urlsplit

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.alibaba.cloud.hooks.oss import OSSHook
from airflow.sensors.base import BaseSensorOperator

Expand Down Expand Up @@ -72,16 +72,24 @@ def poke(self, context: Context):
parsed_url = urlsplit(self.bucket_key)
if self.bucket_name is None:
if parsed_url.netloc == "":
raise AirflowException("If key is a relative path from root, please provide a bucket_name")
# TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
message = "If key is a relative path from root, please provide a bucket_name"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)
self.bucket_name = parsed_url.netloc
self.bucket_key = parsed_url.path.lstrip("/")
else:
if parsed_url.scheme != "" or parsed_url.netloc != "":
raise AirflowException(
# TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
message = (
"If bucket_name is provided, bucket_key"
" should be relative path from root"
" level, rather than a full oss:// url"
)
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)

self.log.info("Poking for key : oss://%s/%s", self.bucket_name, self.bucket_key)
return self.get_hook.object_exists(key=self.bucket_key, bucket_name=self.bucket_name)
Expand Down
62 changes: 43 additions & 19 deletions tests/providers/alibaba/cloud/sensors/test_oss_key.py
Expand Up @@ -20,9 +20,13 @@
from unittest import mock
from unittest.mock import PropertyMock

import pytest

from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.alibaba.cloud.sensors.oss_key import OSSKeySensor

OSS_SENSOR_STRING = "airflow.providers.alibaba.cloud.sensors.oss_key.{}"
MODULE_NAME = "airflow.providers.alibaba.cloud.sensors.oss_key"

MOCK_TASK_ID = "test-oss-operator"
MOCK_REGION = "mock_region"
MOCK_BUCKET = "mock_bucket_name"
Expand All @@ -32,41 +36,61 @@
MOCK_CONTENT = "mock_content"


@pytest.fixture
def oss_key_sensor():
return OSSKeySensor(
bucket_key=MOCK_KEY,
oss_conn_id=MOCK_OSS_CONN_ID,
region=MOCK_REGION,
bucket_name=MOCK_BUCKET,
task_id=MOCK_TASK_ID,
)


class TestOSSKeySensor:
def setup_method(self):
self.sensor = OSSKeySensor(
bucket_key=MOCK_KEY,
oss_conn_id=MOCK_OSS_CONN_ID,
region=MOCK_REGION,
bucket_name=MOCK_BUCKET,
task_id=MOCK_TASK_ID,
)

@mock.patch(OSS_SENSOR_STRING.format("OSSHook"))
def test_get_hook(self, mock_service):
self.sensor.get_hook()
@mock.patch(f"{MODULE_NAME}.OSSHook")
def test_get_hook(self, mock_service, oss_key_sensor):
oss_key_sensor.get_hook()
mock_service.assert_called_once_with(oss_conn_id=MOCK_OSS_CONN_ID, region=MOCK_REGION)

@mock.patch(OSS_SENSOR_STRING.format("OSSKeySensor.get_hook"), new_callable=PropertyMock)
def test_poke_exsiting_key(self, mock_service):
@mock.patch(f"{MODULE_NAME}.OSSKeySensor.get_hook", new_callable=PropertyMock)
def test_poke_exsiting_key(self, mock_service, oss_key_sensor):
# Given
mock_service.return_value.object_exists.return_value = True

# When
res = self.sensor.poke(None)
res = oss_key_sensor.poke(None)

# Then
assert res is True
mock_service.return_value.object_exists.assert_called_once_with(key=MOCK_KEY, bucket_name=MOCK_BUCKET)

@mock.patch(OSS_SENSOR_STRING.format("OSSKeySensor.get_hook"), new_callable=PropertyMock)
def test_poke_non_exsiting_key(self, mock_service):
@mock.patch(f"{MODULE_NAME}.OSSKeySensor.get_hook", new_callable=PropertyMock)
def test_poke_non_exsiting_key(self, mock_service, oss_key_sensor):
# Given
mock_service.return_value.object_exists.return_value = False

# When
res = self.sensor.poke(None)
res = oss_key_sensor.poke(None)

# Then
assert res is False
mock_service.return_value.object_exists.assert_called_once_with(key=MOCK_KEY, bucket_name=MOCK_BUCKET)

@pytest.mark.parametrize(
"soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
)
@mock.patch(f"{MODULE_NAME}.OSSKeySensor.get_hook", new_callable=PropertyMock)
def test_poke_without_bucket_name(
self, mock_service, oss_key_sensor, soft_fail: bool, expected_exception: AirflowException
):
# Given
oss_key_sensor.soft_fail = soft_fail
oss_key_sensor.bucket_name = None
mock_service.return_value.object_exists.return_value = False

# When, Then
with pytest.raises(
expected_exception, match="If key is a relative path from root, please provide a bucket_name"
):
oss_key_sensor.poke(None)

0 comments on commit 7696e41

Please sign in to comment.