Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

from collections.abc import Sequence
from typing import Literal

import structlog
from fastapi import HTTPException, Query, status
Expand All @@ -27,6 +28,8 @@
from sqlalchemy.orm import joinedload
from sqlalchemy.orm.session import Session

from airflow.api_fastapi.app import get_auth_manager
from airflow.api_fastapi.auth.managers.models.resource_details import DagAccessEntity, DagDetails
from airflow.api_fastapi.common.dagbag import DagBagDep, get_latest_version_of_dag
from airflow.api_fastapi.common.db.common import SessionDep
from airflow.api_fastapi.core_api.datamodels.common import (
Expand All @@ -45,6 +48,7 @@
from airflow.api_fastapi.core_api.security import GetUserDep
from airflow.api_fastapi.core_api.services.public.common import BulkService
from airflow.listeners.listener import get_listener_manager
from airflow.models.dag import DagModel
from airflow.models.taskinstance import TaskInstance as TI
from airflow.serialization.definitions.dag import SerializedDAG
from airflow.utils.state import TaskInstanceState
Expand Down Expand Up @@ -201,6 +205,8 @@ def _categorize_entities(
self,
entities: Sequence[str | BulkTaskInstanceBody],
results: BulkActionResponse,
method: Literal["PUT", "DELETE"],
action_name: str,
) -> tuple[set[tuple[str, str, str, int]], set[tuple[str, str, str]]]:
"""
Validate entities and categorize them into specific and all map index update sets.
Expand All @@ -211,6 +217,7 @@ def _categorize_entities(
"""
specific_map_index_task_keys = set()
all_map_index_task_keys = set()
dag_authorization_cache: dict[str, bool] = {}

for entity in entities:
dag_id, dag_run_id, task_id, map_index = self._extract_task_identifiers(entity)
Expand All @@ -229,6 +236,23 @@ def _categorize_entities(
)
continue

if dag_id not in dag_authorization_cache:
team_name = DagModel.get_team_name(dag_id, session=self.session)
dag_authorization_cache[dag_id] = get_auth_manager().is_authorized_dag(
method=method,
access_entity=DagAccessEntity.TASK_INSTANCE,
details=DagDetails(id=dag_id, team_name=team_name),
user=self.user,
)
if not dag_authorization_cache[dag_id]:
results.errors.append(
{
"error": f"User is not authorized to {action_name} task instances for DAG '{dag_id}'",
"status_code": status.HTTP_403_FORBIDDEN,
}
)
continue

# Separate logic for "update all" vs "update specific"
if map_index is not None:
specific_map_index_task_keys.add((dag_id, dag_run_id, task_id, map_index))
Expand Down Expand Up @@ -318,7 +342,7 @@ def handle_bulk_update(
"""Bulk Update Task Instances."""
# Validate and categorize entities into specific and all map index update sets
update_specific_map_index_task_keys, update_all_map_index_task_keys = self._categorize_entities(
action.entities, results
action.entities, results, method="PUT", action_name=action.action
)

try:
Expand Down Expand Up @@ -420,7 +444,7 @@ def handle_bulk_delete(
"""Bulk delete task instances."""
# Validate and categorize entities into specific and all map index delete sets
delete_specific_map_index_task_keys, delete_all_map_index_task_keys = self._categorize_entities(
action.entities, results
action.entities, results, method="DELETE", action_name=action.action
)

try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,22 @@

import pendulum
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import delete, func, select, update

from airflow._shared.timezones.timezone import datetime
from airflow.api_fastapi.auth.managers.simple.user import SimpleAuthManagerUser
from airflow.dag_processing.bundles.manager import DagBundlesManager
from airflow.dag_processing.dagbag import DagBag, sync_bag_to_db
from airflow.jobs.job import Job
from airflow.jobs.triggerer_job_runner import TriggererJobRunner
from airflow.models import DagRun, Log, TaskInstance
from airflow.models import DagModel, DagRun, Log, TaskInstance
from airflow.models.dag_version import DagVersion
from airflow.models.dagbundle import DagBundleModel
from airflow.models.renderedtifields import RenderedTaskInstanceFields as RTIF
from airflow.models.taskinstancehistory import TaskInstanceHistory
from airflow.models.taskmap import TaskMap
from airflow.models.team import Team
from airflow.models.trigger import Trigger
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.sdk import BaseOperator, TaskGroup
Expand All @@ -50,6 +54,7 @@
from tests_common.test_utils.config import conf_vars
from tests_common.test_utils.db import (
clear_db_runs,
clear_db_teams,
clear_rendered_ti_fields,
)
from tests_common.test_utils.logs import check_last_log
Expand Down Expand Up @@ -5502,6 +5507,14 @@ class TestBulkTaskInstances(TestTaskInstanceEndpoint):
BASH_TASK_ID = "also_run_this"
WILDCARD_ENDPOINT = "/dags/~/dagRuns/~/taskInstances"

@pytest.fixture(autouse=True)
def clean_db(self, session):
clear_db_runs()
clear_db_teams()
yield
clear_db_teams()
clear_db_runs()

@pytest.mark.parametrize(
("default_ti", "actions", "expected_results", "endpoint_url", "setup_dags"),
[
Expand Down Expand Up @@ -6069,10 +6082,24 @@ def test_bulk_task_instances(
):
# Setup task instances
if setup_dags:
for dag_id in setup_dags:
if setup_dags == [self.BASH_DAG_ID, self.DAG_ID]:
self.create_task_instances(
session,
task_instances=[{"task_id": self.BASH_TASK_ID, "state": default_ti[0]["state"]}],
dag_id=self.BASH_DAG_ID,
update_extras=True,
)
self.create_task_instances(
session, task_instances=default_ti, dag_id=dag_id, update_extras=True
session,
task_instances=[{"task_id": self.TASK_ID, "state": default_ti[1]["state"]}],
dag_id=self.DAG_ID,
update_extras=True,
)
else:
for dag_id in setup_dags:
self.create_task_instances(
session, task_instances=default_ti, dag_id=dag_id, update_extras=True
)
else:
self.create_task_instances(session, task_instances=default_ti)

Expand Down Expand Up @@ -6141,6 +6168,154 @@ def test_bulk_update_mapped_task_instance_state_is_persisted(
f"Expected map_index={mi} to remain running, got {ti.state!r}"
)

def test_bulk_task_instances_rejects_unauthorized_dag_ids_from_request_body(self, test_client, session):
restricted_bundle_name = "restricted-bundle-update"
restricted_team_name = "restricted-team-update"
self.create_task_instances(
session,
task_instances=[{"task_id": self.BASH_TASK_ID, "state": State.RUNNING}],
dag_id=self.BASH_DAG_ID,
update_extras=True,
)
self.create_task_instances(
session,
task_instances=[{"task_id": self.TASK_ID, "state": State.RUNNING}],
dag_id=self.DAG_ID,
update_extras=True,
)
restricted_bundle = DagBundleModel(name=restricted_bundle_name)
restricted_team = Team(name=restricted_team_name)
restricted_bundle.teams.append(restricted_team)
session.add_all([restricted_bundle, restricted_team])
session.flush()
session.execute(
update(DagModel)
.where(DagModel.dag_id == self.BASH_DAG_ID)
.values(bundle_name=restricted_bundle_name)
)
session.commit()

auth_manager = test_client.app.state.auth_manager
token = auth_manager._get_token_signer().generate(
auth_manager.serialize_user(
SimpleAuthManagerUser(username="limited-user", role="user", teams=[]),
)
)
with (
mock.patch("airflow.models.revoked_token.RevokedToken.is_revoked", return_value=False),
TestClient(
test_client.app,
headers={"Authorization": f"Bearer {token}"},
base_url=str(test_client.base_url),
) as limited_test_client,
):
response = limited_test_client.patch(
self.WILDCARD_ENDPOINT,
json={
"actions": [
{
"action": "update",
"entities": [
{
"dag_id": self.BASH_DAG_ID,
"dag_run_id": self.RUN_ID,
"task_id": self.BASH_TASK_ID,
"new_state": "success",
},
{
"dag_id": self.DAG_ID,
"dag_run_id": self.RUN_ID,
"task_id": self.TASK_ID,
"new_state": "success",
},
],
}
]
},
)

assert response.status_code == 200
assert response.json()["update"]["success"] == [f"{self.DAG_ID}.{self.RUN_ID}.{self.TASK_ID}[-1]"]
assert response.json()["update"]["errors"] == [
{
"error": f"User is not authorized to update task instances for DAG '{self.BASH_DAG_ID}'",
"status_code": 403,
}
]

def test_bulk_delete_rejects_unauthorized_dag_ids_from_request_body(self, test_client, session):
restricted_bundle_name = "restricted-bundle-delete"
restricted_team_name = "restricted-team-delete"
self.create_task_instances(
session,
task_instances=[{"task_id": self.BASH_TASK_ID, "state": State.SUCCESS}],
dag_id=self.BASH_DAG_ID,
update_extras=True,
)
self.create_task_instances(
session,
task_instances=[{"task_id": self.TASK_ID, "state": State.SUCCESS}],
dag_id=self.DAG_ID,
update_extras=True,
)
restricted_bundle = DagBundleModel(name=restricted_bundle_name)
restricted_team = Team(name=restricted_team_name)
restricted_bundle.teams.append(restricted_team)
session.add_all([restricted_bundle, restricted_team])
session.flush()
session.execute(
update(DagModel)
.where(DagModel.dag_id == self.BASH_DAG_ID)
.values(bundle_name=restricted_bundle_name)
)
session.commit()

auth_manager = test_client.app.state.auth_manager
token = auth_manager._get_token_signer().generate(
auth_manager.serialize_user(
SimpleAuthManagerUser(username="limited-user", role="user", teams=[]),
)
)
with (
mock.patch("airflow.models.revoked_token.RevokedToken.is_revoked", return_value=False),
TestClient(
test_client.app,
headers={"Authorization": f"Bearer {token}"},
base_url=str(test_client.base_url),
) as limited_test_client,
):
response = limited_test_client.patch(
self.WILDCARD_ENDPOINT,
json={
"actions": [
{
"action": "delete",
"entities": [
{
"dag_id": self.BASH_DAG_ID,
"dag_run_id": self.RUN_ID,
"task_id": self.BASH_TASK_ID,
},
{
"dag_id": self.DAG_ID,
"dag_run_id": self.RUN_ID,
"task_id": self.TASK_ID,
},
],
}
]
},
)

assert response.status_code == 200
assert response.json()["delete"]["success"] == [f"{self.DAG_ID}.{self.RUN_ID}.{self.TASK_ID}[-1]"]
assert response.json()["delete"]["errors"] == [
{
"error": f"User is not authorized to delete task instances for DAG '{self.BASH_DAG_ID}'",
"status_code": 403,
}
]

def test_should_respond_401(self, unauthenticated_test_client):
response = unauthenticated_test_client.patch(self.ENDPOINT_URL, json={})
assert response.status_code == 401
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@

from __future__ import annotations

from unittest import mock

import pytest

from airflow.api_fastapi.auth.managers.base_auth_manager import BaseAuthManager
from airflow.api_fastapi.core_api.datamodels.common import BulkActionResponse, BulkBody
from airflow.api_fastapi.core_api.datamodels.task_instances import BulkTaskInstanceBody
from airflow.api_fastapi.core_api.services.public.task_instances import BulkTaskInstanceService
from airflow.models import DagModel
from airflow.providers.standard.operators.bash import BashOperator

from tests_common.test_utils.db import (
Expand Down Expand Up @@ -53,6 +57,10 @@ def teardown_method(self):
self.clear_db()

class MockUser:
username = "test_user"
role = "admin"
teams = ["team1"]

def get_id(self) -> str:
return "test_user"

Expand Down Expand Up @@ -184,6 +192,10 @@ def teardown_method(self):
self.clear_db()

class MockUser:
username = "test_user"
role = "admin"
teams = ["team1"]

def get_id(self) -> str:
return "test_user"

Expand Down Expand Up @@ -260,6 +272,10 @@ def teardown_method(self):
self.clear_db()

class MockUser:
username = "test_user"
role = "admin"
teams = ["team1"]

def get_id(self) -> str:
return "test_user"

Expand Down Expand Up @@ -380,7 +396,6 @@ def test_categorize_entities(
expected_error_count,
):
"""Test _categorize_entities with different entity configurations and wildcard validation."""

user = self.MockUser()
bulk_request = BulkBody(actions=[])
service = BulkTaskInstanceService(
Expand All @@ -393,9 +408,18 @@ def test_categorize_entities(
)

results = BulkActionResponse()
specific_map_index_task_keys, all_map_index_task_keys = service._categorize_entities(
entities, results
)
with (
mock.patch.object(DagModel, "get_team_name", return_value="team1"),
mock.patch(
"airflow.api_fastapi.core_api.services.public.task_instances.get_auth_manager"
) as mock_get_auth_manager,
):
auth_manager = mock.create_autospec(BaseAuthManager, instance=True, spec_set=True)
auth_manager.is_authorized_dag.return_value = True
mock_get_auth_manager.return_value = auth_manager
specific_map_index_task_keys, all_map_index_task_keys = service._categorize_entities(
entities, results, method="PUT", action_name="update"
)

assert specific_map_index_task_keys == expected_specific_keys
assert all_map_index_task_keys == expected_all_keys
Expand Down
Loading