Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
DagAccessEntity,
ReadableEventLogsFilterDep,
requires_access_dag,
requires_access_event_log,
)
from airflow.models import Log

Expand All @@ -58,7 +59,7 @@
@event_logs_router.get(
"/{event_log_id}",
responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]),
dependencies=[Depends(requires_access_dag("GET", DagAccessEntity.AUDIT_LOG))],
dependencies=[Depends(requires_access_event_log("GET"))],
)
def get_event_log(
event_log_id: int,
Expand Down
29 changes: 29 additions & 0 deletions airflow-core/src/airflow/api_fastapi/core_api/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,35 @@ async def inner(
return inner


def requires_access_event_log(
method: ResourceMethod,
) -> Callable[[Request, BaseUser, Session], Coroutine[Any, Any, None]]:
"""Wrap ``requires_access_dag`` and extract the dag_id from the event_log_id."""

async def inner(
request: Request,
user: GetUserDep,
session: SessionDep,
) -> None:
dag_id = None

event_log_id_raw = request.path_params.get("event_log_id")
try:
event_log_id = int(event_log_id_raw) if event_log_id_raw is not None else None
except ValueError:
event_log_id = None

if event_log_id is not None:
dag_id = session.scalar(select(Log.dag_id).where(Log.id == event_log_id))

requires_access_dag(method, DagAccessEntity.AUDIT_LOG, dag_id)(
request,
user,
)

return inner


class PermittedPoolFilter(OrmClause[set[str]]):
"""A parameter that filters the permitted pools for the user."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
from __future__ import annotations

from datetime import datetime, timezone
from unittest import mock

import pytest

from airflow.api_fastapi.auth.managers.models.resource_details import DagAccessEntity, DagDetails
from airflow.models.log import Log
from airflow.utils.session import provide_session

Expand Down Expand Up @@ -196,6 +198,40 @@ def test_should_raises_403_forbidden(self, unauthorized_test_client, setup):
response = unauthorized_test_client.get(f"/eventLogs/{event_log_id}")
assert response.status_code == 403

def test_should_respond_403_when_user_lacks_dag_audit_log_permission(self, test_client, setup):
"""The detail endpoint must enforce the per-DAG audit log permission of the event log's dag_id."""
event_log_id = setup[TASK_INSTANCE_EVENT].id
with mock.patch(
"airflow.api_fastapi.auth.managers.simple.simple_auth_manager.SimpleAuthManager.is_authorized_dag",
return_value=False,
) as mock_is_authorized_dag:
response = test_client.get(f"/eventLogs/{event_log_id}")

assert response.status_code == 403
mock_is_authorized_dag.assert_called_once_with(
method="GET",
access_entity=DagAccessEntity.AUDIT_LOG,
details=DagDetails(id=DAG_ID, team_name=None),
user=mock.ANY,
)

def test_should_authorize_with_event_log_dag_id(self, test_client, setup):
"""When the event log is bound to a DAG, authorization must scope to that DAG id."""
event_log_id = setup[TASK_INSTANCE_EVENT].id
with mock.patch(
"airflow.api_fastapi.auth.managers.simple.simple_auth_manager.SimpleAuthManager.is_authorized_dag",
return_value=True,
) as mock_is_authorized_dag:
response = test_client.get(f"/eventLogs/{event_log_id}")

assert response.status_code == 200
mock_is_authorized_dag.assert_called_once_with(
method="GET",
access_entity=DagAccessEntity.AUDIT_LOG,
details=DagDetails(id=DAG_ID, team_name=None),
user=mock.ANY,
)


class TestGetEventLogs(TestEventLogsEndpoint):
@pytest.mark.parametrize(
Expand Down
89 changes: 89 additions & 0 deletions airflow-core/tests/unit/api_fastapi/core_api/test_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
requires_access_connection,
requires_access_connection_bulk,
requires_access_dag,
requires_access_event_log,
requires_access_pool,
requires_access_pool_bulk,
requires_access_variable,
Expand Down Expand Up @@ -384,6 +385,94 @@ async def test_requires_access_backfill_backfill_not_found_falls_back_to_body(
user=user,
)

@pytest.mark.db_test
@pytest.mark.asyncio
@patch.object(DagModel, "get_team_name")
@patch("airflow.api_fastapi.core_api.security.get_auth_manager")
async def test_requires_access_event_log_authorized_from_path(
self, mock_get_auth_manager, mock_get_team_name
):
"""When event_log_id is in path and the Log exists, dag_id from the row is used."""
auth_manager = Mock()
auth_manager.is_authorized_dag.return_value = True
mock_get_auth_manager.return_value = auth_manager
mock_get_team_name.return_value = "team1"

session = Mock()
session.scalar.return_value = "event_log_dag_id"

request = Mock()
request.path_params = {"event_log_id": "42"}
user = Mock()

inner = requires_access_event_log("GET")
await inner(request, user, session)

auth_manager.is_authorized_dag.assert_called_once_with(
method="GET",
access_entity=DagAccessEntity.AUDIT_LOG,
details=DagDetails(id="event_log_dag_id", team_name="team1"),
user=user,
)

@pytest.mark.db_test
@pytest.mark.asyncio
@patch.object(DagModel, "get_team_name")
@patch("airflow.api_fastapi.core_api.security.get_auth_manager")
async def test_requires_access_event_log_unauthorized(self, mock_get_auth_manager, mock_get_team_name):
"""When is_authorized_dag returns False for the event log's dag_id, Forbidden is raised."""
auth_manager = Mock()
auth_manager.is_authorized_dag.return_value = False
mock_get_auth_manager.return_value = auth_manager
mock_get_team_name.return_value = None

session = Mock()
session.scalar.return_value = "unauthorized_dag"

request = Mock()
request.path_params = {"event_log_id": "1"}
user = Mock()

inner = requires_access_event_log("GET")
with pytest.raises(HTTPException, match="Forbidden"):
await inner(request, user, session)

auth_manager.is_authorized_dag.assert_called_once_with(
method="GET",
access_entity=DagAccessEntity.AUDIT_LOG,
details=DagDetails(id="unauthorized_dag", team_name=None),
user=user,
)

@pytest.mark.db_test
@pytest.mark.asyncio
@patch.object(DagModel, "get_team_name")
@patch("airflow.api_fastapi.core_api.security.get_auth_manager")
async def test_requires_access_event_log_row_not_found(self, mock_get_auth_manager, mock_get_team_name):
"""When the Log row does not exist, dag_id is None and the generic AUDIT_LOG check applies."""
auth_manager = Mock()
auth_manager.is_authorized_dag.return_value = True
mock_get_auth_manager.return_value = auth_manager

session = Mock()
session.scalar.return_value = None

request = Mock()
request.path_params = {"event_log_id": "999"}
request.query_params = {}
user = Mock()

inner = requires_access_event_log("GET")
await inner(request, user, session)

auth_manager.is_authorized_dag.assert_called_once_with(
method="GET",
access_entity=DagAccessEntity.AUDIT_LOG,
details=DagDetails(id=None, team_name=None),
user=user,
)
mock_get_team_name.assert_not_called()

@pytest.mark.parametrize(
("url", "expected_is_safe"),
[
Expand Down
Loading