Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,11 @@
from airflow.utils.types import DagRunType
from tests.models import DEFAULT_DATE
from tests.test_utils import db
from tests.test_utils.asserts import assert_queries_count
from tests.test_utils.config import conf_vars
from tests.test_utils.db import (
clear_db_dags, clear_db_errors, clear_db_pools, clear_db_runs, clear_db_sla_miss,
)


class CallbackWrapper:
Expand Down Expand Up @@ -1699,3 +1703,53 @@ def test_refresh_from_task(pool_override):
assert ti.max_tries == task.retries
assert ti.executor_config == task.executor_config
assert ti.operator == DummyOperator.__name__


class TestRunRawTaskQueriesCount(unittest.TestCase):
"""
These tests are designed to detect changes in the number of queries executed
when calling _run_raw_task
"""

@staticmethod
def _clean():
clear_db_runs()
clear_db_pools()
clear_db_dags()
clear_db_sla_miss()
clear_db_errors()

def setUp(self) -> None:
self._clean()

def tearDown(self) -> None:
self._clean()

@parameterized.expand([
# Expected queries, mark_success
(7, False),
(5, True),
])
def test_execute_queries_count(self, expected_query_count, mark_success):
with create_session() as session:
dag = DAG('test_queries', start_date=DEFAULT_DATE)
task = DummyOperator(task_id='op', dag=dag)
ti = TI(task=task, execution_date=datetime.datetime.now())
ti.state = State.RUNNING
session.merge(ti)

with assert_queries_count(expected_query_count):
ti._run_raw_task(mark_success=mark_success)

def test_execute_queries_count_store_serialized(self):
with create_session() as session:
dag = DAG('test_queries', start_date=DEFAULT_DATE)
task = DummyOperator(task_id='op', dag=dag)
ti = TI(task=task, execution_date=datetime.datetime.now())
ti.state = State.RUNNING
session.merge(ti)

with assert_queries_count(10), patch(
"airflow.models.taskinstance.STORE_SERIALIZED_DAGS", True
):
ti._run_raw_task()