Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,23 @@ def get_import_errors(
.cte()
)

# Prepare the import errors query by joining with the CTE above.
# Each returned row will be a tuple: (ParseImportError, dag_id).
# ``dag_id`` is NULL for import errors whose file has no Dags at all
# in ``DagModel`` (parse failed before any Dag was defined).
import_errors_stmt = (
select(ParseImportError, file_dags_cte.c.dag_id)
# Visibility filter: include import errors for files that either have no
# Dags at all (parse failed before any Dag was defined) or have at least
# one Dag that the requesting user is authorised to read.
visibility_condition = or_(
files_with_any_dags.c.relative_fileloc.is_(None),
file_dags_cte.c.dag_id.isnot(None),
)

# Deduplicated base statement: one row per distinct ParseImportError.
#
# When a single file contains multiple Dags the join with file_dags_cte
# produces N rows for that import error (one per Dag). Selecting only
# ParseImportError with DISTINCT collapses those N rows back to one so
# that total_entries reflects the number of *import-error objects* and
# limit/offset paginate over import-error objects rather than joined rows.
dedup_stmt = (
select(ParseImportError)
.outerjoin(
files_with_any_dags,
ParseImportError.filename == files_with_any_dags.c.relative_fileloc,
Expand All @@ -199,26 +210,44 @@ def get_import_errors(
ParseImportError.bundle_name == file_dags_cte.c.bundle_name,
),
)
.where(
or_(
files_with_any_dags.c.relative_fileloc.is_(None),
file_dags_cte.c.dag_id.isnot(None),
)
)
.order_by(ParseImportError.id)
.where(visibility_condition)
.distinct()
)

# Paginate the import errors query
import_errors_select, total_entries = paginated_select(
statement=import_errors_stmt,
# Paginate distinct import errors. total_entries now counts import-error
# objects, and limit/offset operate on those objects rather than joined rows.
paginated_stmt, total_entries = paginated_select(
statement=dedup_stmt,
filters=[filename_pattern, filename_prefix_pattern],
order_by=order_by,
offset=offset,
limit=limit,
session=session,
)
paginated_errors = list(session.scalars(paginated_stmt))
paginated_ids = [err.id for err in paginated_errors]

if not paginated_ids:
return ImportErrorCollectionResponse(import_errors=[], total_entries=total_entries)

# Fetch all Dag associations for the paginated import errors. The full
# outer-join with file_dags_cte is still needed so per-file authorisation
# (detecting co-located Dags the caller cannot read) and stacktrace
# redaction work correctly for each import-error object on this page.
import_errors_stmt = (
select(ParseImportError, file_dags_cte.c.dag_id)
.outerjoin(
file_dags_cte,
and_(
ParseImportError.filename == file_dags_cte.c.relative_fileloc,
ParseImportError.bundle_name == file_dags_cte.c.bundle_name,
),
)
.where(ParseImportError.id.in_(paginated_ids))
.order_by(ParseImportError.id)
)
import_errors_result: Iterable[tuple[ParseImportError, Iterable]] = groupby(
session.execute(import_errors_select), itemgetter(0)
session.execute(import_errors_stmt), itemgetter(0)
)

import_errors = []
Expand Down Expand Up @@ -249,6 +278,12 @@ def get_import_errors(
import_error.stacktrace = REDACTED_STACKTRACE
import_errors.append(import_error)

# Restore the pagination order from the dedup query. The full-join above
# orders by id for an efficient groupby; re-sort here to match the
# caller-requested ordering that was applied to the dedup query.
id_order = {err_id: idx for idx, err_id in enumerate(paginated_ids)}
import_errors.sort(key=lambda err: id_order[err.id])

return ImportErrorCollectionResponse(
import_errors=import_errors,
total_entries=total_entries,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def test_get_import_errors(
set_mock_auth_manager__get_authorized_dag_ids(mock_get_auth_manager, permitted_dag_model_all)
set_mock_auth_manager__batch_is_authorized_dag(mock_get_auth_manager, True)

with assert_queries_count(5):
with assert_queries_count(6):
response = test_client.get("/importErrors", params=query_params)

assert response.status_code == expected_status_code
Expand Down Expand Up @@ -527,6 +527,88 @@ def test_get_import_errors__no_dag_in_dagmodel(self, mock_get_auth_manager, test
FILENAME3: STACKTRACE3,
}

@pytest.fixture
@provide_session
def import_error_with_multiple_dags(
self,
testing_dag_bundle,
*,
session: Session = NEW_SESSION,
) -> tuple[ParseImportError, set[str]]:
"""One ParseImportError file mapping to three DagModel rows.

Used to verify that total_entries and pagination operate on distinct
ParseImportError objects rather than on the inflated joined-row count.
"""
multi_dag_file = "multi_dag_file.py"
multi_stacktrace = "SyntaxError in multi_dag_file"
multi_dag_ids = {"dag_a", "dag_b", "dag_c"}

multi_import_error = ParseImportError(
bundle_name=BUNDLE_NAME,
filename=multi_dag_file,
stacktrace=multi_stacktrace,
timestamp=TIMESTAMP1,
)
session.add(multi_import_error)
for dag_id in multi_dag_ids:
session.add(
DagModel(
fileloc=multi_dag_file,
relative_fileloc=multi_dag_file,
dag_id=dag_id,
is_paused=False,
bundle_name=BUNDLE_NAME,
)
)
session.commit()
return multi_import_error, multi_dag_ids

@mock.patch("airflow.api_fastapi.core_api.routes.public.import_error.get_auth_manager")
def test_total_entries_counts_distinct_import_errors_when_file_has_multiple_dags(
self,
mock_get_auth_manager,
test_client,
import_error_with_multiple_dags,
):
"""total_entries and pagination must count ParseImportError objects, not
joined rows.

When one file contains three Dags, the join from ParseImportError to
DagModel produces three rows for a single import-error record. Before
the fix, total_entries reflected the raw joined-row count (3) instead
of the number of distinct import-error objects (1). Similarly,
``limit=1`` would have returned zero or partial results because the
LIMIT was applied to joined rows before grouping.
"""
multi_import_error, multi_dag_ids = import_error_with_multiple_dags
set_mock_auth_manager__get_authorized_dag_ids(mock_get_auth_manager, multi_dag_ids)
set_mock_auth_manager__batch_is_authorized_dag(mock_get_auth_manager, True)

# Fetch only this import error (exclude the three from the autouse fixture)
response = test_client.get(
"/importErrors",
params={"filename_pattern": multi_import_error.filename},
)

assert response.status_code == 200
body = response.json()
# One import-error object, not three joined rows
assert body["total_entries"] == 1
assert len(body["import_errors"]) == 1
assert body["import_errors"][0]["filename"] == multi_import_error.filename
assert body["import_errors"][0]["stack_trace"] == multi_import_error.stacktrace

# limit=1 must also return the single import-error object
response_limit = test_client.get(
"/importErrors",
params={"filename_pattern": multi_import_error.filename, "limit": 1},
)
assert response_limit.status_code == 200
body_limit = response_limit.json()
assert body_limit["total_entries"] == 1
assert len(body_limit["import_errors"]) == 1


class TestImportErrorFileAuthorization:
"""Tests that the import error endpoints apply per-file authorization
Expand Down