Skip to content

Commit

Permalink
dbt, openlineage: set run_id after defer, do not log error if operato…
Browse files Browse the repository at this point in the history
…r has no run_id set (#34270)

* dbt, openlineage: set run_id after defer, do not log error if operator has no run_id set

Signed-off-by: Maciej Obuchowski <obuchowski.maciej@gmail.com>

* fix other tests

Signed-off-by: Maciej Obuchowski <obuchowski.maciej@gmail.com>

---------

Signed-off-by: Maciej Obuchowski <obuchowski.maciej@gmail.com>
  • Loading branch information
mobuchowski committed Sep 11, 2023
1 parent be665f2 commit 87fd884
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 8 deletions.
8 changes: 5 additions & 3 deletions airflow/providers/dbt/cloud/operators/dbt.py
Expand Up @@ -117,7 +117,7 @@ def __init__(
self.timeout = timeout
self.check_interval = check_interval
self.additional_run_config = additional_run_config or {}
self.run_id: int
self.run_id: int | None = None
self.deferrable = deferrable

def execute(self, context: Context):
Expand All @@ -135,12 +135,13 @@ def execute(self, context: Context):
additional_run_config=self.additional_run_config,
)
self.run_id = trigger_job_response.json()["data"]["id"]
print(self.run_id)
job_run_url = trigger_job_response.json()["data"]["href"]
# Push the ``job_run_url`` value to XCom regardless of what happens during execution so that the job
# run can be monitored via the operator link.
context["ti"].xcom_push(key="job_run_url", value=job_run_url)

if self.wait_for_termination:
if self.wait_for_termination and isinstance(self.run_id, int):
if self.deferrable is False:
self.log.info("Waiting for job run %s to terminate.", str(self.run_id))

Expand Down Expand Up @@ -197,6 +198,7 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> int:
if event["status"] == "error":
raise AirflowException(event["message"])
self.log.info(event["message"])
self.run_id = event["run_id"]
return int(event["run_id"])

def on_kill(self) -> None:
Expand Down Expand Up @@ -225,7 +227,7 @@ def get_openlineage_facets_on_complete(self, task_instance) -> OperatorLineage:
"""
from airflow.providers.openlineage.extractors import OperatorLineage

if self.wait_for_termination is True:
if isinstance(self.run_id, int) and self.wait_for_termination is True:
return generate_openlineage_events_from_dbt_cloud_run(operator=self, task_instance=task_instance)
return OperatorLineage()

Expand Down
20 changes: 17 additions & 3 deletions tests/providers/dbt/cloud/operators/test_dbt_cloud.py
Expand Up @@ -63,6 +63,12 @@
}


def mock_response_json(response: dict):
run_response = MagicMock(**response)
run_response.json.return_value = response
return run_response


def setup_module():
# Connection with ``account_id`` specified
conn_account_id = Connection(
Expand Down Expand Up @@ -125,7 +131,10 @@ def test_execute_succeeded_before_getting_deferred(
)
@patch("airflow.providers.dbt.cloud.operators.dbt.DbtCloudRunJobOperator.defer")
@patch("airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.get_connection")
@patch("airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.trigger_job_run")
@patch(
"airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.trigger_job_run",
return_value=mock_response_json(DEFAULT_ACCOUNT_JOB_RUN_RESPONSE),
)
def test_execute_failed_before_getting_deferred(
self, mock_trigger_job_run, mock_dbt_hook, mock_defer, mock_job_run_status
):
Expand Down Expand Up @@ -154,7 +163,10 @@ def test_execute_failed_before_getting_deferred(
"airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.get_job_run_status",
)
@patch("airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.get_connection")
@patch("airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.trigger_job_run")
@patch(
"airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.trigger_job_run",
return_value=mock_response_json(DEFAULT_ACCOUNT_JOB_RUN_RESPONSE),
)
def test_dbt_run_job_op_async(self, mock_trigger_job_run, mock_dbt_hook, mock_job_run_status, status):
"""
Asserts that a task is deferred and an DbtCloudRunJobTrigger will be fired
Expand All @@ -174,7 +186,9 @@ def test_dbt_run_job_op_async(self, mock_trigger_job_run, mock_dbt_hook, mock_jo
dbt_op.execute(MagicMock())
assert isinstance(exc.value.trigger, DbtCloudRunJobTrigger), "Trigger is not a DbtCloudRunJobTrigger"

@patch.object(DbtCloudHook, "trigger_job_run", return_value=MagicMock(**DEFAULT_ACCOUNT_JOB_RUN_RESPONSE))
@patch.object(
DbtCloudHook, "trigger_job_run", return_value=mock_response_json(DEFAULT_ACCOUNT_JOB_RUN_RESPONSE)
)
@pytest.mark.parametrize(
"job_run_status, expected_output",
[
Expand Down
8 changes: 6 additions & 2 deletions tests/providers/dbt/cloud/utils/test_openlineage.py
Expand Up @@ -22,6 +22,7 @@
from airflow.providers.dbt.cloud.hooks.dbt import DbtCloudHook
from airflow.providers.dbt.cloud.operators.dbt import DbtCloudRunJobOperator
from airflow.providers.dbt.cloud.utils.openlineage import generate_openlineage_events_from_dbt_cloud_run
from airflow.providers.openlineage.extractors import OperatorLineage

TASK_ID = "dbt_test"
DAG_ID = "dbt_dag"
Expand Down Expand Up @@ -130,7 +131,10 @@ def test_generate_events(
)

mock_build_task_instance_run_id.return_value = TASK_UUID

generate_openlineage_events_from_dbt_cloud_run(mock_operator, task_instance=mock_task_instance)

assert mock_client.emit.call_count == 4

def test_do_not_raise_error_if_runid_not_set_on_operator(self):
operator = DbtCloudRunJobOperator(task_id="dbt-job-runid-taskid", job_id=1500)
assert operator.run_id is None
assert operator.get_openlineage_facets_on_complete(MagicMock()) == OperatorLineage()

0 comments on commit 87fd884

Please sign in to comment.