Skip to content

Commit

Permalink
Fix side effects in DBT Cloud tests (#39511)
Browse files Browse the repository at this point in the history
  • Loading branch information
Taragolis committed May 9, 2024
1 parent c7c680e commit f973502
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 19 deletions.
16 changes: 13 additions & 3 deletions tests/providers/dbt/cloud/hooks/test_dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import annotations

import json
from datetime import timedelta
from typing import Any
from unittest.mock import patch

Expand All @@ -32,7 +33,7 @@
TokenAuth,
fallback_to_default_account,
)
from airflow.utils import db
from airflow.utils import db, timezone

pytestmark = pytest.mark.db_test

Expand Down Expand Up @@ -551,11 +552,20 @@ def test_get_job_run_with_payload(self, mock_http_run, mock_paginate, conn_id, a
for argval in wait_for_job_run_status_test_args
],
)
def test_wait_for_job_run_status(hook, job_run_status, expected_status, expected_output):
def test_wait_for_job_run_status(self, job_run_status, expected_status, expected_output, time_machine):
config = {"run_id": RUN_ID, "timeout": 3, "check_interval": 1, "expected_statuses": expected_status}
hook = DbtCloudHook(ACCOUNT_ID_CONN)

with patch.object(DbtCloudHook, "get_job_run_status") as mock_job_run_status:
# Freeze time for avoid real clock side effects
time_machine.move_to(timezone.datetime(1970, 1, 1), tick=False)

def fake_sleep(seconds):
# Shift frozen time every time we call a ``time.sleep`` during this test case.
time_machine.shift(timedelta(seconds=seconds))

with patch.object(DbtCloudHook, "get_job_run_status") as mock_job_run_status, patch(
"airflow.providers.dbt.cloud.hooks.dbt.time.sleep", side_effect=fake_sleep
):
mock_job_run_status.return_value = job_run_status

if expected_output != "timeout":
Expand Down
62 changes: 46 additions & 16 deletions tests/providers/dbt/cloud/operators/test_dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import annotations

import os
from datetime import timedelta
from unittest.mock import MagicMock, patch

import pytest
Expand Down Expand Up @@ -208,7 +209,7 @@ def test_dbt_run_job_op_async(self, mock_trigger_job_run, mock_dbt_hook, mock_jo
ids=["default_account", "explicit_account"],
)
def test_execute_wait_for_termination(
self, mock_run_job, conn_id, account_id, job_run_status, expected_output
self, mock_run_job, conn_id, account_id, job_run_status, expected_output, time_machine
):
operator = DbtCloudRunJobOperator(
task_id=TASK_ID, dbt_cloud_conn_id=conn_id, account_id=account_id, dag=self.dag, **self.config
Expand All @@ -224,7 +225,19 @@ def test_execute_wait_for_termination(
assert operator.schema_override == self.config["schema_override"]
assert operator.additional_run_config == self.config["additional_run_config"]

with patch.object(DbtCloudHook, "get_job_run") as mock_get_job_run:
# Freeze time for avoid real clock side effects
time_machine.move_to(timezone.datetime(1970, 1, 1), tick=False)

def fake_sleep(seconds):
# Shift frozen time every time we call a ``time.sleep`` during this test case.
# Because we freeze a time, we also need to add a small shift
# which is emulating time which we spent in a loop
overall_delta = timedelta(seconds=seconds) + timedelta(microseconds=42)
time_machine.shift(overall_delta)

with patch.object(DbtCloudHook, "get_job_run") as mock_get_job_run, patch(
"airflow.providers.dbt.cloud.hooks.dbt.time.sleep", side_effect=fake_sleep
):
mock_get_job_run.return_value.json.return_value = {
"data": {"status": job_run_status, "id": RUN_ID}
}
Expand Down Expand Up @@ -445,7 +458,7 @@ def setup_method(self):
[(ACCOUNT_ID_CONN, None), (NO_ACCOUNT_ID_CONN, ACCOUNT_ID)],
ids=["default_account", "explicit_account"],
)
def test_get_json_artifact(self, mock_get_artifact, conn_id, account_id):
def test_get_json_artifact(self, mock_get_artifact, conn_id, account_id, tmp_path, monkeypatch):
operator = DbtCloudGetJobRunArtifactOperator(
task_id=TASK_ID,
dbt_cloud_conn_id=conn_id,
Expand All @@ -456,7 +469,11 @@ def test_get_json_artifact(self, mock_get_artifact, conn_id, account_id):
)

mock_get_artifact.return_value.json.return_value = {"data": "file contents"}
return_value = operator.execute(context={})
with monkeypatch.context() as ctx:
# Let's change current working directory to temp,
# otherwise the output file will be created in the current working directory
ctx.chdir(tmp_path)
return_value = operator.execute(context={})

mock_get_artifact.assert_called_once_with(
run_id=RUN_ID,
Expand All @@ -466,7 +483,7 @@ def test_get_json_artifact(self, mock_get_artifact, conn_id, account_id):
)

assert operator.output_file_name == f"{RUN_ID}_path-to-my-manifest.json"
assert os.path.exists(operator.output_file_name)
assert os.path.exists(tmp_path / operator.output_file_name)
assert return_value == operator.output_file_name

@patch("airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.get_job_run_artifact")
Expand All @@ -475,7 +492,7 @@ def test_get_json_artifact(self, mock_get_artifact, conn_id, account_id):
[(ACCOUNT_ID_CONN, None), (NO_ACCOUNT_ID_CONN, ACCOUNT_ID)],
ids=["default_account", "explicit_account"],
)
def test_get_json_artifact_with_step(self, mock_get_artifact, conn_id, account_id):
def test_get_json_artifact_with_step(self, mock_get_artifact, conn_id, account_id, tmp_path, monkeypatch):
operator = DbtCloudGetJobRunArtifactOperator(
task_id=TASK_ID,
dbt_cloud_conn_id=conn_id,
Expand All @@ -487,7 +504,11 @@ def test_get_json_artifact_with_step(self, mock_get_artifact, conn_id, account_i
)

mock_get_artifact.return_value.json.return_value = {"data": "file contents"}
return_value = operator.execute(context={})
with monkeypatch.context() as ctx:
# Let's change current working directory to temp,
# otherwise the output file will be created in the current working directory
ctx.chdir(tmp_path)
return_value = operator.execute(context={})

mock_get_artifact.assert_called_once_with(
run_id=RUN_ID,
Expand All @@ -497,7 +518,7 @@ def test_get_json_artifact_with_step(self, mock_get_artifact, conn_id, account_i
)

assert operator.output_file_name == f"{RUN_ID}_path-to-my-manifest.json"
assert os.path.exists(operator.output_file_name)
assert os.path.exists(tmp_path / operator.output_file_name)
assert return_value == operator.output_file_name

@patch("airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.get_job_run_artifact")
Expand All @@ -506,7 +527,7 @@ def test_get_json_artifact_with_step(self, mock_get_artifact, conn_id, account_i
[(ACCOUNT_ID_CONN, None), (NO_ACCOUNT_ID_CONN, ACCOUNT_ID)],
ids=["default_account", "explicit_account"],
)
def test_get_text_artifact(self, mock_get_artifact, conn_id, account_id):
def test_get_text_artifact(self, mock_get_artifact, conn_id, account_id, tmp_path, monkeypatch):
operator = DbtCloudGetJobRunArtifactOperator(
task_id=TASK_ID,
dbt_cloud_conn_id=conn_id,
Expand All @@ -517,7 +538,11 @@ def test_get_text_artifact(self, mock_get_artifact, conn_id, account_id):
)

mock_get_artifact.return_value.text = "file contents"
return_value = operator.execute(context={})
with monkeypatch.context() as ctx:
# Let's change current working directory to temp,
# otherwise the output file will be created in the current working directory
ctx.chdir(tmp_path)
return_value = operator.execute(context={})

mock_get_artifact.assert_called_once_with(
run_id=RUN_ID,
Expand All @@ -527,7 +552,7 @@ def test_get_text_artifact(self, mock_get_artifact, conn_id, account_id):
)

assert operator.output_file_name == f"{RUN_ID}_path-to-my-model.sql"
assert os.path.exists(operator.output_file_name)
assert os.path.exists(tmp_path / operator.output_file_name)
assert return_value == operator.output_file_name

@patch("airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.get_job_run_artifact")
Expand All @@ -536,7 +561,7 @@ def test_get_text_artifact(self, mock_get_artifact, conn_id, account_id):
[(ACCOUNT_ID_CONN, None), (NO_ACCOUNT_ID_CONN, ACCOUNT_ID)],
ids=["default_account", "explicit_account"],
)
def test_get_text_artifact_with_step(self, mock_get_artifact, conn_id, account_id):
def test_get_text_artifact_with_step(self, mock_get_artifact, conn_id, account_id, tmp_path, monkeypatch):
operator = DbtCloudGetJobRunArtifactOperator(
task_id=TASK_ID,
dbt_cloud_conn_id=conn_id,
Expand All @@ -548,7 +573,11 @@ def test_get_text_artifact_with_step(self, mock_get_artifact, conn_id, account_i
)

mock_get_artifact.return_value.text = "file contents"
return_value = operator.execute(context={})
with monkeypatch.context() as ctx:
# Let's change current working directory to temp,
# otherwise the output file will be created in the current working directory
ctx.chdir(tmp_path)
return_value = operator.execute(context={})

mock_get_artifact.assert_called_once_with(
run_id=RUN_ID,
Expand All @@ -558,7 +587,7 @@ def test_get_text_artifact_with_step(self, mock_get_artifact, conn_id, account_i
)

assert operator.output_file_name == f"{RUN_ID}_path-to-my-model.sql"
assert os.path.exists(operator.output_file_name)
assert os.path.exists(tmp_path / operator.output_file_name)
assert return_value == operator.output_file_name

@patch("airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.get_job_run_artifact")
Expand All @@ -568,14 +597,15 @@ def test_get_text_artifact_with_step(self, mock_get_artifact, conn_id, account_i
ids=["default_account", "explicit_account"],
)
def test_get_artifact_with_specified_output_file(self, mock_get_artifact, conn_id, account_id, tmp_path):
specified_output_file = (tmp_path / "run_results.json").as_posix()
operator = DbtCloudGetJobRunArtifactOperator(
task_id=TASK_ID,
dbt_cloud_conn_id=conn_id,
run_id=RUN_ID,
account_id=account_id,
path="run_results.json",
dag=self.dag,
output_file_name=tmp_path / "run_results.json",
output_file_name=specified_output_file,
)

mock_get_artifact.return_value.json.return_value = {"data": "file contents"}
Expand All @@ -588,7 +618,7 @@ def test_get_artifact_with_specified_output_file(self, mock_get_artifact, conn_i
step=None,
)

assert operator.output_file_name == tmp_path / "run_results.json"
assert operator.output_file_name == specified_output_file
assert os.path.exists(operator.output_file_name)
assert return_value == operator.output_file_name

Expand Down

0 comments on commit f973502

Please sign in to comment.