Skip to content

Commit

Permalink
Use batch_is_authorized_dag to check if user has permission to read…
Browse files Browse the repository at this point in the history
… DAGs (#36279)
  • Loading branch information
vincbeck committed Dec 18, 2023
1 parent 3297806 commit a7ab64e
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions airflow/api_connexion/endpoints/dag_source_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,25 @@
from __future__ import annotations

from http import HTTPStatus
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Sequence

from flask import Response, current_app, request
from itsdangerous import BadSignature, URLSafeSerializer

from airflow.api_connexion import security
from airflow.api_connexion.exceptions import NotFound, PermissionDenied
from airflow.api_connexion.schemas.dag_source_schema import dag_source_schema
from airflow.api_connexion.security import get_readable_dags
from airflow.auth.managers.models.resource_details import DagAccessEntity
from airflow.auth.managers.models.resource_details import DagAccessEntity, DagDetails
from airflow.models.dag import DagModel
from airflow.models.dagcode import DagCode
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.www.extensions.init_auth_manager import get_auth_manager

if TYPE_CHECKING:
from sqlalchemy.orm import Session

from airflow.auth.managers.models.batch_apis import IsAuthorizedDagRequest


@security.requires_access_dag("GET", DagAccessEntity.CODE)
@provide_session
Expand All @@ -44,9 +46,16 @@ def get_dag_source(*, file_token: str, session: Session = NEW_SESSION) -> Respon
try:
path = auth_s.loads(file_token)
dag_ids = session.query(DagModel.dag_id).filter(DagModel.fileloc == path).all()
readable_dags = get_readable_dags()
requests: Sequence[IsAuthorizedDagRequest] = [
{
"method": "GET",
"details": DagDetails(id=dag_id[0]),
}
for dag_id in dag_ids
]

# Check if user has read access to all the DAGs defined in the file
if any(dag_id[0] not in readable_dags for dag_id in dag_ids):
if not get_auth_manager().batch_is_authorized_dag(requests):
raise PermissionDenied()
dag_source = DagCode.code(path, session=session)
except (BadSignature, FileNotFoundError):
Expand Down

0 comments on commit a7ab64e

Please sign in to comment.