Skip to content

Commit

Permalink
REST API: Fix task instance access issue in the batch endpoint (#34315)
Browse files Browse the repository at this point in the history
Currently, there's no restriction on the task instances a user can access in
the REST API batch fetch task instances endpoint.
This PR fixes it

(cherry picked from commit 3df1af4)
  • Loading branch information
ephraimbuddy committed Oct 5, 2023
1 parent 685c54e commit b328c86
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 2 deletions.
18 changes: 16 additions & 2 deletions airflow/api_connexion/endpoints/task_instance_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from typing import Any, Iterable, TypeVar

from flask import g
from marshmallow import ValidationError
from sqlalchemy import and_, or_, select
from sqlalchemy.exc import MultipleResultsFound
Expand All @@ -26,7 +27,7 @@

from airflow.api_connexion import security
from airflow.api_connexion.endpoints.request_dict import get_json_request_dict
from airflow.api_connexion.exceptions import BadRequest, NotFound
from airflow.api_connexion.exceptions import BadRequest, NotFound, PermissionDenied
from airflow.api_connexion.parameters import format_datetime, format_parameters
from airflow.api_connexion.schemas.task_instance_schema import (
TaskInstanceCollection,
Expand Down Expand Up @@ -396,10 +397,23 @@ def get_task_instances_batch(session: Session = NEW_SESSION) -> APIResponse:
data = task_instance_batch_form.load(body)
except ValidationError as err:
raise BadRequest(detail=str(err.messages))
dag_ids = data["dag_ids"]
if dag_ids:
cannot_access_dag_ids = set()
for id in dag_ids:
if not get_airflow_app().appbuilder.sm.can_read_dag(id, g.user):
cannot_access_dag_ids.add(id)
if cannot_access_dag_ids:
raise PermissionDenied(
detail=f"User not allowed to access these DAGs: {list(cannot_access_dag_ids)}"
)
else:
dag_ids = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user)

states = _convert_ti_states(data["state"])
base_query = select(TI).join(TI.dag_run)

base_query = _apply_array_filter(base_query, key=TI.dag_id, values=data["dag_ids"])
base_query = _apply_array_filter(base_query, key=TI.dag_id, values=dag_ids)
base_query = _apply_array_filter(base_query, key=TI.run_id, values=data["dag_run_ids"])
base_query = _apply_array_filter(base_query, key=TI.task_id, values=data["task_ids"])
base_query = _apply_range_filter(
Expand Down
39 changes: 39 additions & 0 deletions tests/api_connexion/endpoints/test_task_instance_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import pytest
from sqlalchemy.orm import contains_eager

from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP
from airflow.jobs.job import Job
from airflow.jobs.triggerer_job_runner import TriggererJobRunner
from airflow.models import DagRun, SlaMiss, TaskInstance, Trigger
Expand Down Expand Up @@ -82,6 +83,25 @@ def configured_app(minimal_app_for_api):
(permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE),
],
)
create_user(
app, # type: ignore
username="test_read_only_one_dag",
role_name="TestReadOnlyOneDag",
permissions=[
(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN),
(permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE),
],
)
# For some reason, "DAG:example_python_operator" is not synced when in the above list of perms,
# so do it manually here:
app.appbuilder.sm.bulk_sync_roles(
[
{
"role": "TestReadOnlyOneDag",
"perms": [(permissions.ACTION_CAN_READ, "DAG:example_python_operator")],
}
]
)
create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore

yield app
Expand All @@ -90,6 +110,7 @@ def configured_app(minimal_app_for_api):
delete_user(app, username="test_dag_read_only") # type: ignore
delete_user(app, username="test_task_read_only") # type: ignore
delete_user(app, username="test_no_permissions") # type: ignore
delete_user(app, username="test_read_only_one_dag") # type: ignore
delete_roles(app)


Expand Down Expand Up @@ -905,6 +926,24 @@ def test_should_raise_403_forbidden(self):
)
assert response.status_code == 403

def test_returns_403_forbidden_when_user_has_access_to_only_some_dags(self, session):
self.create_task_instances(session=session)
self.create_task_instances(session=session, dag_id="example_skip_dag")
payload = {"dag_ids": ["example_python_operator", "example_skip_dag"]}

response = self.client.post(
"/api/v1/dags/~/dagRuns/~/taskInstances/list",
environ_overrides={"REMOTE_USER": "test_read_only_one_dag"},
json=payload,
)
assert response.status_code == 403
assert response.json == {
"detail": "User not allowed to access these DAGs: ['example_skip_dag']",
"status": 403,
"title": "Forbidden",
"type": EXCEPTIONS_LINK_MAP[403],
}

def test_should_raise_400_for_no_json(self):
response = self.client.post(
"/api/v1/dags/~/dagRuns/~/taskInstances/list",
Expand Down

0 comments on commit b328c86

Please sign in to comment.