Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return stacktrace from the DAG file in test_should_not_do_database_queries #39331

Merged
merged 1 commit into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/always/test_example_dags.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def test_should_be_importable(example):
@pytest.mark.db_test
@pytest.mark.parametrize("example", example_dags_except_db_exception(), ids=relative_path)
def test_should_not_do_database_queries(example):
with assert_queries_count(0):
with assert_queries_count(0, stacklevel_from_module=example.rsplit(os.sep, 1)[-1]):
DagBag(
dag_folder=example,
include_examples=False,
Expand Down
2 changes: 2 additions & 0 deletions tests/deprecations_ignore.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@
- tests/always/test_connection.py::TestConnection::test_connection_get_uri_from_uri
- tests/always/test_connection.py::TestConnection::test_connection_test_success
- tests/always/test_connection.py::TestConnection::test_from_json_extra
# `test_should_be_importable` and `test_should_not_do_database_queries` should be resolved together
- tests/always/test_example_dags.py::test_should_be_importable
- tests/always/test_example_dags.py::test_should_not_do_database_queries


# API
Expand Down
78 changes: 65 additions & 13 deletions tests/test_utils/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
from __future__ import annotations

import logging
import os
import re
import traceback
from collections import Counter
from contextlib import contextmanager
from typing import NamedTuple

from sqlalchemy import event

Expand All @@ -40,6 +42,47 @@ def _trim(s):
assert first_trim == second_trim, msg


class QueriesTraceRecord(NamedTuple):
module: str
name: str
lineno: int | None

@classmethod
def from_frame(cls, frame_summary: traceback.FrameSummary):
return cls(
module=frame_summary.filename.rsplit(os.sep, 1)[-1],
name=frame_summary.name,
lineno=frame_summary.lineno,
)

def __str__(self):
return f"{self.module}:{self.name}:{self.lineno}"


class QueriesTraceInfo(NamedTuple):
traces: tuple[QueriesTraceRecord, ...]

@classmethod
def from_traceback(cls, trace: traceback.StackSummary) -> QueriesTraceInfo:
records = [
QueriesTraceRecord.from_frame(f)
for f in trace
if "sqlalchemy" not in f.filename
and __file__ != f.filename
and ("session.py" not in f.filename and f.name != "wrapper")
]
return cls(traces=tuple(records))

def module_level(self, module: str) -> int:
stacklevel = 0
for ix, record in enumerate(reversed(self.traces), start=1):
if record.module == module:
stacklevel = ix
if stacklevel == 0:
raise LookupError(f"Unable to find module {stacklevel} in traceback")
return stacklevel


class CountQueries:
"""
Counts the number of queries sent to Airflow Database in a given context.
Expand All @@ -48,8 +91,10 @@ class CountQueries:
not be included.
"""

def __init__(self):
self.result = Counter()
def __init__(self, *, stacklevel: int = 1, stacklevel_from_module: str | None = None):
self.result: Counter[str] = Counter()
self.stacklevel = stacklevel
self.stacklevel_from_module = stacklevel_from_module

def __enter__(self):
event.listen(airflow.settings.engine, "after_cursor_execute", self.after_cursor_execute)
Expand All @@ -60,31 +105,38 @@ def __exit__(self, type_, value, tb):
log.debug("Queries count: %d", sum(self.result.values()))

def after_cursor_execute(self, *args, **kwargs):
stack = [
f
for f in traceback.extract_stack()
if "sqlalchemy" not in f.filename
and __file__ != f.filename
and ("session.py" not in f.filename and f.name != "wrapper")
]
stack_info = ">".join([f"{f.filename.rpartition('/')[-1]}:{f.name}:{f.lineno}" for f in stack][-5:])
self.result[f"{stack_info}"] += 1
stack = QueriesTraceInfo.from_traceback(traceback.extract_stack())
if not self.stacklevel_from_module:
stacklevel = self.stacklevel
else:
stacklevel = stack.module_level(self.stacklevel_from_module)

stack_info = " > ".join(map(str, stack.traces[-stacklevel:]))
self.result[stack_info] += 1


count_queries = CountQueries


@contextmanager
def assert_queries_count(expected_count: int, message_fmt: str | None = None, margin: int = 0):
def assert_queries_count(
expected_count: int,
message_fmt: str | None = None,
margin: int = 0,
stacklevel: int = 5,
stacklevel_from_module: str | None = None,
):
"""
Asserts that the number of queries is as expected with the margin applied
The margin is helpful in case of complex cases where we do not want to change it every time we
changed queries, but we want to catch cases where we spin out of control
:param expected_count: expected number of queries
:param message_fmt: message printed optionally if the number is exceeded
:param margin: margin to add to expected number of calls
:param stacklevel: limits the output stack trace to that numbers of frame
:param stacklevel_from_module: Filter stack trace from specific module
"""
with count_queries() as result:
with count_queries(stacklevel=stacklevel, stacklevel_from_module=stacklevel_from_module) as result:
yield None

count = sum(result.values())
Expand Down