/
test_example_dags.py
97 lines (69 loc) · 2.7 KB
/
test_example_dags.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from __future__ import annotations
from pathlib import Path
import airflow
import pytest
from airflow.models.dagbag import DagBag
from airflow.utils.db import create_default_connections
from airflow.utils.session import provide_session
from packaging.version import Version
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
from .sql.operators import utils as test_utils
RETRY_ON_EXCEPTIONS = []
try:
from google.api_core.exceptions import Forbidden, TooManyRequests
from pandas_gbq.exceptions import GenericGBQException
RETRY_ON_EXCEPTIONS.extend([Forbidden, TooManyRequests, GenericGBQException])
except ModuleNotFoundError:
pass
@retry(
stop=stop_after_attempt(3),
retry=retry_if_exception_type(tuple(RETRY_ON_EXCEPTIONS)),
wait=wait_exponential(multiplier=10, min=10, max=60), # values in seconds
)
def wrapper_run_dag(dag):
test_utils.run_dag(dag)
@provide_session
def get_session(session=None):
create_default_connections(session)
return session
@pytest.fixture()
def session():
return get_session()
MIN_VER_DAG_FILE: dict[str, list[str]] = {
"2.3": ["example_dynamic_task_template.py", "example_bigquery_dynamic_map_task.py"],
"2.4": ["example_datasets.py"],
}
# Sort descending based on Versions and convert string to an actual version
MIN_VER_DAG_FILE_VER: dict[Version, list[str]] = {
Version(version): MIN_VER_DAG_FILE[version]
for version in sorted(MIN_VER_DAG_FILE, key=Version, reverse=True)
}
def get_dag_bag() -> DagBag:
"""Create a DagBag by adding the files that are not supported to .airflowignore"""
example_dags_dir = Path(__file__).parent.parent / "example_dags"
airflow_ignore_file = example_dags_dir / ".airflowignore"
with open(airflow_ignore_file, "w+") as file:
for min_version, files in MIN_VER_DAG_FILE_VER.items():
if Version(airflow.__version__) < min_version:
print(f"Adding {files} to .airflowignore")
file.writelines([f"{file}\n" for file in files])
print(".airflowignore contents: ")
print(airflow_ignore_file.read_text())
dag_bag = DagBag(example_dags_dir, include_examples=False)
return dag_bag
PRE_DEFINED_ORDER = [
"example_dataset_producer",
"example_dataset_consumer",
]
def order(dag_id: str) -> int:
if dag_id in PRE_DEFINED_ORDER:
return PRE_DEFINED_ORDER.index(dag_id)
return -1
dag_bag = get_dag_bag()
@pytest.mark.parametrize("dag_id", sorted(dag_bag.dag_ids, key=order))
def test_example_dag(session, dag_id: str):
dag = dag_bag.get_dag(dag_id)
wrapper_run_dag(dag)
def test_example_dags_loaded_with_no_errors():
assert dag_bag.dags
assert not dag_bag.import_errors