Skip to content

Commit

Permalink
Support CloudDataTransferServiceJobStatusSensor without specifying a …
Browse files Browse the repository at this point in the history
…project_id (#30035)
  • Loading branch information
aibazhang committed Mar 14, 2023
1 parent 12b88cc commit 57fb80c
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 2 deletions.
Expand Up @@ -93,7 +93,7 @@ def poke(self, context: Context) -> bool:
impersonation_chain=self.impersonation_chain,
)
operations = hook.list_transfer_operations(
request_filter={"project_id": self.project_id, "job_names": [self.job_name]}
request_filter={"project_id": self.project_id or hook.project_id, "job_names": [self.job_name]}
)

for operation in operations:
Expand Down
Expand Up @@ -72,8 +72,41 @@ def test_wait_for_status_success(self, mock_tool):
@mock.patch(
"airflow.providers.google.cloud.sensors.cloud_storage_transfer_service.CloudDataTransferServiceHook"
)
def test_wait_for_status_success_default_expected_status(self, mock_tool):
def test_wait_for_status_success_without_project_id(self, mock_tool):
operations = [
{
"name": TEST_NAME,
"metadata": {
"status": GcpTransferOperationStatus.SUCCESS,
"counters": TEST_COUNTERS,
},
}
]
mock_tool.return_value.list_transfer_operations.return_value = operations
mock_tool.operations_contain_expected_statuses.return_value = True
mock_tool.return_value.project_id = "project-id"

op = CloudDataTransferServiceJobStatusSensor(
task_id="task-id",
job_name=JOB_NAME,
expected_statuses=GcpTransferOperationStatus.SUCCESS,
)

context = {"ti": (mock.Mock(**{"xcom_push.return_value": None}))}
result = op.poke(context)

mock_tool.return_value.list_transfer_operations.assert_called_once_with(
request_filter={"project_id": "project-id", "job_names": [JOB_NAME]}
)
mock_tool.operations_contain_expected_statuses.assert_called_once_with(
operations=operations, expected_statuses={GcpTransferOperationStatus.SUCCESS}
)
assert result

@mock.patch(
"airflow.providers.google.cloud.sensors.cloud_storage_transfer_service.CloudDataTransferServiceHook"
)
def test_wait_for_status_success_default_expected_status(self, mock_tool):
op = CloudDataTransferServiceJobStatusSensor(
task_id="task-id",
job_name=JOB_NAME,
Expand Down

0 comments on commit 57fb80c

Please sign in to comment.