From 80a75e619cbdb7401c6f622be51b566b50fa7947 Mon Sep 17 00:00:00 2001 From: Kenneth Ng <59739226+ngken0995@users.noreply.github.com> Date: Thu, 14 Dec 2023 18:27:51 -0500 Subject: [PATCH] Unify DAG creation/database cleaning fixtures for testing (#3361) * implement global clean_db function * remove unused packages * format lint * hardcode TEST_POOL and DAG_PREFIX * add comment and TODO for TEST_POOL and DAG_PREFIX constant * rename fixture for format standards * rename get_test_dag_id to sample_dag_id_fixture --- catalog/tests/conftest.py | 35 +++++++++++++++++++ .../test_single_run_external_dags_sensor.py | 25 +++++-------- .../tests/dags/common/sensors/test_utils.py | 28 ++++++--------- .../dags/common/test_ingestion_server.py | 23 +++--------- .../providers/test_provider_dag_factory.py | 18 ---------- 5 files changed, 58 insertions(+), 71 deletions(-) diff --git a/catalog/tests/conftest.py b/catalog/tests/conftest.py index 72f866cd19..43bb285261 100644 --- a/catalog/tests/conftest.py +++ b/catalog/tests/conftest.py @@ -1,4 +1,6 @@ import pytest +from airflow.models import DagRun, Pool, TaskInstance +from airflow.utils.session import create_session def pytest_addoption(parser): @@ -24,3 +26,36 @@ def pytest_addoption(parser): # Use this decorator on tests which are expected to take a long time and would best be # run on CI only mark_extended = pytest.mark.skipif("not config.getoption('extended')") + + +def _normalize_test_module_name(request) -> str: + # Extract the test name + name = request.module.__name__ + # Replace periods with two underscores + return name.replace(".", "__") + + +@pytest.fixture +def sample_dag_id_fixture(request): + return f"{_normalize_test_module_name(request)}_dag" + + +@pytest.fixture +def sample_pool_fixture(request): + return f"{_normalize_test_module_name(request)}_pool" + + +@pytest.fixture +def clean_db(sample_dag_id_fixture, sample_pool_fixture): + with create_session() as session: + # synchronize_session='fetch' required here to refresh models + # https://stackoverflow.com/a/51222378 CC BY-SA 4.0 + session.query(DagRun).filter( + DagRun.dag_id.startswith(sample_dag_id_fixture) + ).delete(synchronize_session="fetch") + session.query(TaskInstance).filter( + TaskInstance.dag_id.startswith(sample_dag_id_fixture) + ).delete(synchronize_session="fetch") + session.query(Pool).filter(Pool.pool.startswith(sample_pool_fixture)).delete( + synchronize_session="fetch" + ) diff --git a/catalog/tests/dags/common/sensors/test_single_run_external_dags_sensor.py b/catalog/tests/dags/common/sensors/test_single_run_external_dags_sensor.py index 9d000697d5..7f8867b9d3 100644 --- a/catalog/tests/dags/common/sensors/test_single_run_external_dags_sensor.py +++ b/catalog/tests/dags/common/sensors/test_single_run_external_dags_sensor.py @@ -3,9 +3,8 @@ import pytest from airflow.exceptions import AirflowException -from airflow.models import DagBag, DagRun, Pool, TaskInstance +from airflow.models import DagBag, Pool from airflow.models.dag import DAG -from airflow.utils.session import create_session from airflow.utils.state import State from airflow.utils.timezone import datetime from airflow.utils.types import DagRunType @@ -15,23 +14,14 @@ DEFAULT_DATE = datetime(2022, 1, 1) TEST_TASK_ID = "wait_task" -TEST_POOL = "single_run_external_dags_sensor_test_pool" DEV_NULL = "/dev/null" -DAG_PREFIX = "sreds" # single_run_external_dags_sensor - -@pytest.fixture(autouse=True) -def clean_db(): - with create_session() as session: - # synchronize_session='fetch' required here to refresh models - # https://stackoverflow.com/a/51222378 CC BY-SA 4.0 - session.query(DagRun).filter(DagRun.dag_id.startswith(DAG_PREFIX)).delete( - synchronize_session="fetch" - ) - session.query(TaskInstance).filter( - TaskInstance.dag_id.startswith(DAG_PREFIX) - ).delete(synchronize_session="fetch") - session.query(Pool).filter(id == TEST_POOL).delete() +# unittest.TestCase only allow auto-use fixture which can't retrieve the declared fixtures on conftest.py +# TODO: TEST_POOL/DAG_PREFIX constants can be remove after unittest.TestCase are converted to pytest. +TEST_POOL = ( + "catalog__tests__dags__common__sensors__test_single_run_external_dags_sensor_pool" +) +DAG_PREFIX = "catalog__tests__dags__common__sensors__test_single_run_external_dags_sensor_dag" # single_run_external_dags_sensor def run_sensor(sensor): @@ -75,6 +65,7 @@ def create_dagrun(dag, dag_state): ) +@pytest.mark.usefixtures("clean_db") # This appears to be coming from Airflow internals during testing as a result of # loading the example DAGs: # /opt/airflow/.local/lib/python3.10/site-packages/airflow/example_dags/example_subdag_operator.py:43: RemovedInAirflow3Warning # noqa: E501 diff --git a/catalog/tests/dags/common/sensors/test_utils.py b/catalog/tests/dags/common/sensors/test_utils.py index 9cd7212392..67e599ddf2 100644 --- a/catalog/tests/dags/common/sensors/test_utils.py +++ b/catalog/tests/dags/common/sensors/test_utils.py @@ -1,9 +1,6 @@ from datetime import timedelta -import pytest -from airflow.models import DagRun from airflow.models.dag import DAG -from airflow.utils.session import create_session from airflow.utils.state import State from airflow.utils.timezone import datetime from airflow.utils.types import DagRunType @@ -15,14 +12,8 @@ TEST_DAG = DAG(TEST_DAG_ID, default_args={"owner": "airflow"}) -@pytest.fixture(autouse=True) -def clean_db(): - with create_session() as session: - session.query(DagRun).filter(DagRun.dag_id == TEST_DAG_ID).delete() - - -def _create_dagrun(start_date, conf={}): - return TEST_DAG.create_dagrun( +def _create_dagrun(start_date, sample_dag_id_fixture, conf={}): + return DAG(sample_dag_id_fixture, default_args={"owner": "airflow"}).create_dagrun( start_date=start_date, execution_date=start_date, data_interval=(start_date, start_date), @@ -32,14 +23,17 @@ def _create_dagrun(start_date, conf={}): ) -def test_get_most_recent_dag_run_returns_most_recent_execution_date(): +def test_get_most_recent_dag_run_returns_most_recent_execution_date( + sample_dag_id_fixture, clean_db +): most_recent = datetime(2023, 5, 10) for i in range(3): - _create_dagrun(most_recent - timedelta(days=i)) - - assert get_most_recent_dag_run(TEST_DAG_ID) == most_recent + _create_dagrun(most_recent - timedelta(days=i), sample_dag_id_fixture) + assert get_most_recent_dag_run(sample_dag_id_fixture) == most_recent -def test_get_most_recent_dag_run_returns_empty_list_when_no_runs(): +def test_get_most_recent_dag_run_returns_empty_list_when_no_runs( + sample_dag_id_fixture, clean_db +): # Relies on ``clean_db`` cleaning up DagRuns from other tests - assert get_most_recent_dag_run(TEST_DAG_ID) == [] + assert get_most_recent_dag_run(sample_dag_id_fixture) == [] diff --git a/catalog/tests/dags/common/test_ingestion_server.py b/catalog/tests/dags/common/test_ingestion_server.py index 6f78e341dd..1f28e9adf0 100644 --- a/catalog/tests/dags/common/test_ingestion_server.py +++ b/catalog/tests/dags/common/test_ingestion_server.py @@ -4,9 +4,7 @@ import pytest import requests from airflow.exceptions import AirflowSkipException -from airflow.models import DagRun, TaskInstance from airflow.models.dag import DAG -from airflow.utils.session import create_session from airflow.utils.state import DagRunState, TaskInstanceState from airflow.utils.timezone import datetime from airflow.utils.types import DagRunType @@ -15,30 +13,17 @@ TEST_START_DATE = datetime(2022, 2, 1, 0, 0, 0) -TEST_DAG_ID = "api_healthcheck_test_dag" - - -@pytest.fixture(autouse=True) -def clean_db(): - with create_session() as session: - # synchronize_session='fetch' required here to refresh models - # https://stackoverflow.com/a/51222378 CC BY-SA 4.0 - session.query(DagRun).filter(DagRun.dag_id.startswith(TEST_DAG_ID)).delete( - synchronize_session="fetch" - ) - session.query(TaskInstance).filter( - TaskInstance.dag_id.startswith(TEST_DAG_ID) - ).delete(synchronize_session="fetch") @pytest.fixture() -def index_readiness_dag(): +def index_readiness_dag(sample_dag_id_fixture, clean_db): # Create a DAG that just has an index_readiness_check task - with DAG(dag_id=TEST_DAG_ID, schedule=None, start_date=TEST_START_DATE) as dag: + with DAG( + dag_id=sample_dag_id_fixture, schedule=None, start_date=TEST_START_DATE + ) as dag: ingestion_server.index_readiness_check( media_type="image", index_suffix="my_test_suffix", timeout=timedelta(days=1) ) - return dag diff --git a/catalog/tests/dags/providers/test_provider_dag_factory.py b/catalog/tests/dags/providers/test_provider_dag_factory.py index 0210089997..57010728fe 100644 --- a/catalog/tests/dags/providers/test_provider_dag_factory.py +++ b/catalog/tests/dags/providers/test_provider_dag_factory.py @@ -5,9 +5,7 @@ from airflow import DAG from airflow.exceptions import AirflowSkipException, BackfillUnfinished from airflow.executors.debug_executor import DebugExecutor -from airflow.models import DagRun, TaskInstance from airflow.operators.empty import EmptyOperator -from airflow.utils.session import create_session from pendulum import now from catalog.tests.conftest import mark_extended @@ -21,22 +19,6 @@ from providers.provider_workflows import ProviderWorkflow -DAG_ID = "test_provider_dag_factory" - - -def _clean_dag_from_db(): - with create_session() as session: - session.query(DagRun).filter(DagRun.dag_id == DAG_ID).delete() - session.query(TaskInstance).filter(TaskInstance.dag_id == DAG_ID).delete() - - -@pytest.fixture() -def clean_db(): - _clean_dag_from_db() - yield - _clean_dag_from_db() - - @mark_extended @pytest.mark.parametrize( "side_effect",