From d843ac1aa9a6b78fde045e2014d16c05f565e75a Mon Sep 17 00:00:00 2001 From: Vincent Beck Date: Mon, 25 Sep 2023 11:30:30 -0400 Subject: [PATCH 01/31] Use auth manager APIs `is_authorized_` to check user permissions in Rest API and other places --- .../endpoints/config_endpoint.py | 7 +- .../endpoints/connection_endpoint.py | 12 +- .../api_connexion/endpoints/dag_endpoint.py | 16 +- .../endpoints/dag_run_endpoint.py | 66 ++----- .../endpoints/dag_source_endpoint.py | 4 +- .../endpoints/dag_warning_endpoint.py | 4 +- .../endpoints/dataset_endpoint.py | 9 +- .../endpoints/event_log_endpoint.py | 6 +- .../endpoints/extra_link_endpoint.py | 10 +- .../endpoints/import_error_endpoint.py | 6 +- .../api_connexion/endpoints/log_endpoint.py | 10 +- .../endpoints/plugin_endpoint.py | 3 +- .../api_connexion/endpoints/pool_endpoint.py | 11 +- .../endpoints/provider_endpoint.py | 3 +- .../api_connexion/endpoints/task_endpoint.py | 16 +- .../endpoints/task_instance_endpoint.py | 85 ++------- .../endpoints/variable_endpoint.py | 10 +- .../api_connexion/endpoints/xcom_endpoint.py | 20 +- airflow/api_connexion/security.py | 178 +++++++++++++++++- airflow/auth/managers/base_auth_manager.py | 38 +++- airflow/auth/managers/fab/fab_auth_manager.py | 88 ++++++--- .../managers/fab/security_manager/override.py | 71 ++++++- .../auth/managers/models/resource_details.py | 34 +++- airflow/www/auth.py | 12 +- airflow/www/extensions/init_jinja_globals.py | 10 +- airflow/www/security_manager.py | 153 +++------------ airflow/www/templates/airflow/dag.html | 4 +- airflow/www/views.py | 40 ++-- .../endpoints/test_dag_endpoint.py | 10 +- .../endpoints/test_log_endpoint.py | 4 +- .../managers/fab/test_fab_auth_manager.py | 2 +- tests/auth/managers/test_base_auth_manager.py | 37 +++- tests/www/test_security.py | 80 +++++--- tests/www/views/conftest.py | 1 + 34 files changed, 606 insertions(+), 454 deletions(-) diff --git a/airflow/api_connexion/endpoints/config_endpoint.py b/airflow/api_connexion/endpoints/config_endpoint.py index 38f6f32c22b2d..9ffedc465f33e 100644 --- a/airflow/api_connexion/endpoints/config_endpoint.py +++ b/airflow/api_connexion/endpoints/config_endpoint.py @@ -24,7 +24,6 @@ from airflow.api_connexion.exceptions import NotFound, PermissionDenied from airflow.api_connexion.schemas.config_schema import Config, ConfigOption, ConfigSection, config_schema from airflow.configuration import conf -from airflow.security import permissions from airflow.settings import json LINE_SEP = "\n" # `\n` cannot appear in f-strings @@ -66,7 +65,7 @@ def _config_to_json(config: Config) -> str: return json.dumps(config_schema.dump(config), indent=4) -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_CONFIG)]) +@security.requires_access_configuration("GET") def get_config(*, section: str | None = None) -> Response: """Get current configuration.""" serializer = { @@ -103,8 +102,8 @@ def get_config(*, section: str | None = None) -> Response: ) -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_CONFIG)]) -def get_value(section: str, option: str) -> Response: +@security.requires_access_configuration("GET") +def get_value(*, section: str, option: str) -> Response: serializer = { "text/plain": _config_to_text, "application/json": _config_to_json, diff --git a/airflow/api_connexion/endpoints/connection_endpoint.py b/airflow/api_connexion/endpoints/connection_endpoint.py index 1444421a8443c..16d9afb5b9e33 100644 --- a/airflow/api_connexion/endpoints/connection_endpoint.py +++ b/airflow/api_connexion/endpoints/connection_endpoint.py @@ -53,7 +53,7 @@ RESOURCE_EVENT_PREFIX = "connection" -@security.requires_access([(permissions.ACTION_CAN_DELETE, permissions.RESOURCE_CONNECTION)]) +@security.requires_access_connection("DELETE") @provide_session @action_logging( event=action_event_from_permission( @@ -73,7 +73,7 @@ def delete_connection(*, connection_id: str, session: Session = NEW_SESSION) -> return NoContent, HTTPStatus.NO_CONTENT -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_CONNECTION)]) +@security.requires_access_connection("GET") @provide_session def get_connection(*, connection_id: str, session: Session = NEW_SESSION) -> APIResponse: """Get a connection entry.""" @@ -86,7 +86,7 @@ def get_connection(*, connection_id: str, session: Session = NEW_SESSION) -> API return connection_schema.dump(connection) -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_CONNECTION)]) +@security.requires_access_connection("GET") @format_parameters({"limit": check_limit}) @provide_session def get_connections( @@ -109,7 +109,7 @@ def get_connections( ) -@security.requires_access([(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_CONNECTION)]) +@security.requires_access_connection("PUT") @provide_session @action_logging( event=action_event_from_permission( @@ -147,7 +147,7 @@ def patch_connection( return connection_schema.dump(connection) -@security.requires_access([(permissions.ACTION_CAN_CREATE, permissions.RESOURCE_CONNECTION)]) +@security.requires_access_connection("POST") @provide_session @action_logging( event=action_event_from_permission( @@ -176,7 +176,7 @@ def post_connection(*, session: Session = NEW_SESSION) -> APIResponse: raise AlreadyExists(detail=f"Connection already exist. ID: {conn_id}") -@security.requires_access([(permissions.ACTION_CAN_CREATE, permissions.RESOURCE_CONNECTION)]) +@security.requires_access_connection("POST") def test_connection() -> APIResponse: """ Test an API connection. diff --git a/airflow/api_connexion/endpoints/dag_endpoint.py b/airflow/api_connexion/endpoints/dag_endpoint.py index 5aac030ecbae0..a67699feb2b79 100644 --- a/airflow/api_connexion/endpoints/dag_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_endpoint.py @@ -34,9 +34,9 @@ dag_schema, dags_collection_schema, ) +from airflow.api_connexion.security import requires_authentication from airflow.exceptions import AirflowException, DagNotFound from airflow.models.dag import DagModel, DagTag -from airflow.security import permissions from airflow.utils.airflow_flask_app import get_airflow_app from airflow.utils.db import get_query_count from airflow.utils.session import NEW_SESSION, provide_session @@ -48,7 +48,7 @@ from airflow.api_connexion.types import APIResponse, UpdateMask -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG)]) +@security.requires_access_dag("GET") @provide_session def get_dag(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse: """Get basic information about a DAG.""" @@ -60,7 +60,7 @@ def get_dag(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse: return dag_schema.dump(dag) -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG)]) +@security.requires_access_dag("GET") def get_dag_details(*, dag_id: str) -> APIResponse: """Get details of DAG.""" dag: DAG = get_airflow_app().dag_bag.get_dag(dag_id) @@ -69,7 +69,7 @@ def get_dag_details(*, dag_id: str) -> APIResponse: return dag_detail_schema.dump(dag) -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG)]) +@requires_authentication @format_parameters({"limit": check_limit}) @provide_session def get_dags( @@ -96,7 +96,7 @@ def get_dags( if dag_id_pattern: dags_query = dags_query.where(DagModel.dag_id.ilike(f"%{dag_id_pattern}%")) - readable_dags = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user) + readable_dags = get_airflow_app().appbuilder.sm.get_permitted_dag_ids(g.user) dags_query = dags_query.where(DagModel.dag_id.in_(readable_dags)) if tags: @@ -110,7 +110,7 @@ def get_dags( return dags_collection_schema.dump(DAGCollection(dags=dags, total_entries=total_entries)) -@security.requires_access([(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG)]) +@security.requires_access_dag("PUT") @provide_session def patch_dag(*, dag_id: str, update_mask: UpdateMask = None, session: Session = NEW_SESSION) -> APIResponse: """Update the specific DAG.""" @@ -132,7 +132,7 @@ def patch_dag(*, dag_id: str, update_mask: UpdateMask = None, session: Session = return dag_schema.dump(dag) -@security.requires_access([(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG)]) +@security.requires_authentication @format_parameters({"limit": check_limit}) @provide_session def patch_dags(limit, session, offset=0, only_active=True, tags=None, dag_id_pattern=None, update_mask=None): @@ -180,7 +180,7 @@ def patch_dags(limit, session, offset=0, only_active=True, tags=None, dag_id_pat return dags_collection_schema.dump(DAGCollection(dags=dags, total_entries=total_entries)) -@security.requires_access([(permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG)]) +@security.requires_access_dag("DELETE") @provide_session def delete_dag(dag_id: str, session: Session = NEW_SESSION) -> APIResponse: """Delete the specific DAG.""" diff --git a/airflow/api_connexion/endpoints/dag_run_endpoint.py b/airflow/api_connexion/endpoints/dag_run_endpoint.py index d56fd469638ca..06b288754fc12 100644 --- a/airflow/api_connexion/endpoints/dag_run_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_run_endpoint.py @@ -56,6 +56,7 @@ TaskInstanceReferenceCollection, task_instance_reference_collection_schema, ) +from airflow.auth.managers.models.resource_details import DagAccessEntity from airflow.models import DagModel, DagRun from airflow.security import permissions from airflow.utils.airflow_flask_app import get_airflow_app @@ -76,12 +77,7 @@ RESOURCE_EVENT_PREFIX = "dag_run" -@security.requires_access( - [ - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG_RUN), - ], -) +@security.requires_access_dag("DELETE", DagAccessEntity.RUN) @provide_session def delete_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION) -> APIResponse: """Delete a DAG Run.""" @@ -93,12 +89,7 @@ def delete_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSI return NoContent, HTTPStatus.NO_CONTENT -@security.requires_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - ], -) +@security.requires_access_dag("GET", DagAccessEntity.RUN) @provide_session def get_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION) -> APIResponse: """Get a DAG Run.""" @@ -111,13 +102,8 @@ def get_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION) return dagrun_schema.dump(dag_run) -@security.requires_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DATASET), - ], -) +@security.requires_access_dag("GET", DagAccessEntity.RUN) +@security.requires_access_dag("GET", DagAccessEntity.DATASET) @provide_session def get_upstream_dataset_events( *, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION @@ -194,12 +180,7 @@ def _fetch_dag_runs( return session.scalars(query.offset(offset).limit(limit)).all(), total_entries -@security.requires_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - ], -) +@security.requires_access_dag("GET", DagAccessEntity.RUN) @format_parameters( { "start_date_gte": format_datetime, @@ -262,12 +243,7 @@ def get_dag_runs( return dagrun_collection_schema.dump(DAGRunCollection(dag_runs=dag_run, total_entries=total_entries)) -@security.requires_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - ], -) +@security.requires_access_dag("GET", DagAccessEntity.RUN) @provide_session def get_dag_runs_batch(*, session: Session = NEW_SESSION) -> APIResponse: """Get list of DAG Runs.""" @@ -307,12 +283,7 @@ def get_dag_runs_batch(*, session: Session = NEW_SESSION) -> APIResponse: return dagrun_collection_schema.dump(DAGRunCollection(dag_runs=dag_runs, total_entries=total_entries)) -@security.requires_access( - [ - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN), - ], -) +@security.requires_access_dag("POST", DagAccessEntity.RUN) @provide_session @action_logging( event=action_event_from_permission( @@ -378,12 +349,7 @@ def post_dag_run(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse: raise AlreadyExists(detail=f"DAGRun with DAG ID: '{dag_id}' and DAGRun ID: '{run_id}' already exists") -@security.requires_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), - ], -) +@security.requires_access_dag("PUT", DagAccessEntity.RUN) @provide_session def update_dag_run_state(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION) -> APIResponse: """Set a state of a dag run.""" @@ -410,12 +376,7 @@ def update_dag_run_state(*, dag_id: str, dag_run_id: str, session: Session = NEW return dagrun_schema.dump(dag_run) -@security.requires_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), - ], -) +@security.requires_access_dag("PUT", DagAccessEntity.RUN) @provide_session def clear_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION) -> APIResponse: """Clear a dag run.""" @@ -461,12 +422,7 @@ def clear_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSIO return dagrun_schema.dump(dag_run) -@security.requires_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), - ], -) +@security.requires_access_dag("PUT", DagAccessEntity.RUN) @provide_session def set_dag_run_note(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION) -> APIResponse: """Set the note for a dag run.""" diff --git a/airflow/api_connexion/endpoints/dag_source_endpoint.py b/airflow/api_connexion/endpoints/dag_source_endpoint.py index b191630815004..3ee80ee857a4c 100644 --- a/airflow/api_connexion/endpoints/dag_source_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_source_endpoint.py @@ -24,11 +24,11 @@ from airflow.api_connexion import security from airflow.api_connexion.exceptions import NotFound from airflow.api_connexion.schemas.dag_source_schema import dag_source_schema +from airflow.auth.managers.models.resource_details import DagAccessEntity from airflow.models.dagcode import DagCode -from airflow.security import permissions -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_CODE)]) +@security.requires_access_dag("GET", DagAccessEntity.CODE) def get_dag_source(*, file_token: str) -> Response: """Get source code using file token.""" secret_key = current_app.config["SECRET_KEY"] diff --git a/airflow/api_connexion/endpoints/dag_warning_endpoint.py b/airflow/api_connexion/endpoints/dag_warning_endpoint.py index 367b0ae104571..fa1e03ae64a9c 100644 --- a/airflow/api_connexion/endpoints/dag_warning_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_warning_endpoint.py @@ -28,8 +28,8 @@ DagWarningCollection, dag_warning_collection_schema, ) +from airflow.auth.managers.models.resource_details import DagAccessEntity from airflow.models.dagwarning import DagWarning as DagWarningModel -from airflow.security import permissions from airflow.utils.airflow_flask_app import get_airflow_app from airflow.utils.db import get_query_count from airflow.utils.session import NEW_SESSION, provide_session @@ -40,7 +40,7 @@ from airflow.api_connexion.types import APIResponse -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_WARNING)]) +@security.requires_access_dag("GET", DagAccessEntity.WARNING) @format_parameters({"limit": check_limit}) @provide_session def get_dag_warnings( diff --git a/airflow/api_connexion/endpoints/dataset_endpoint.py b/airflow/api_connexion/endpoints/dataset_endpoint.py index 81fe872fca72a..152ac6eecb2cf 100644 --- a/airflow/api_connexion/endpoints/dataset_endpoint.py +++ b/airflow/api_connexion/endpoints/dataset_endpoint.py @@ -32,7 +32,6 @@ dataset_schema, ) from airflow.models.dataset import DatasetEvent, DatasetModel -from airflow.security import permissions from airflow.utils.db import get_query_count from airflow.utils.session import NEW_SESSION, provide_session @@ -42,9 +41,9 @@ from airflow.api_connexion.types import APIResponse -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DATASET)]) +@security.requires_access_dataset("GET") @provide_session -def get_dataset(uri: str, session: Session = NEW_SESSION) -> APIResponse: +def get_dataset(*, uri: str, session: Session = NEW_SESSION) -> APIResponse: """Get a Dataset.""" dataset = session.scalar( select(DatasetModel) @@ -59,7 +58,7 @@ def get_dataset(uri: str, session: Session = NEW_SESSION) -> APIResponse: return dataset_schema.dump(dataset) -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DATASET)]) +@security.requires_access_dataset("GET") @format_parameters({"limit": check_limit}) @provide_session def get_datasets( @@ -86,7 +85,7 @@ def get_datasets( return dataset_collection_schema.dump(DatasetCollection(datasets=datasets, total_entries=total_entries)) -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DATASET)]) +@security.requires_access_dataset("GET") @provide_session @format_parameters({"limit": check_limit}) def get_dataset_events( diff --git a/airflow/api_connexion/endpoints/event_log_endpoint.py b/airflow/api_connexion/endpoints/event_log_endpoint.py index e195cfdcc23ef..21e5c3351bf65 100644 --- a/airflow/api_connexion/endpoints/event_log_endpoint.py +++ b/airflow/api_connexion/endpoints/event_log_endpoint.py @@ -28,8 +28,8 @@ event_log_collection_schema, event_log_schema, ) +from airflow.auth.managers.models.resource_details import DagAccessEntity from airflow.models import Log -from airflow.security import permissions from airflow.utils.session import NEW_SESSION, provide_session if TYPE_CHECKING: @@ -38,7 +38,7 @@ from airflow.api_connexion.types import APIResponse -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_AUDIT_LOG)]) +@security.requires_access_dag("GET", DagAccessEntity.AUDIT_LOG) @provide_session def get_event_log(*, event_log_id: int, session: Session = NEW_SESSION) -> APIResponse: """Get a log entry.""" @@ -48,7 +48,7 @@ def get_event_log(*, event_log_id: int, session: Session = NEW_SESSION) -> APIRe return event_log_schema.dump(event_log) -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_AUDIT_LOG)]) +@security.requires_access_dag("GET", DagAccessEntity.AUDIT_LOG) @format_parameters({"limit": check_limit}) @provide_session def get_event_logs( diff --git a/airflow/api_connexion/endpoints/extra_link_endpoint.py b/airflow/api_connexion/endpoints/extra_link_endpoint.py index ec92dd51ee7ad..2e9954587c071 100644 --- a/airflow/api_connexion/endpoints/extra_link_endpoint.py +++ b/airflow/api_connexion/endpoints/extra_link_endpoint.py @@ -22,8 +22,8 @@ from airflow.api_connexion import security from airflow.api_connexion.exceptions import NotFound +from airflow.auth.managers.models.resource_details import DagAccessEntity from airflow.exceptions import TaskNotFound -from airflow.security import permissions from airflow.utils.airflow_flask_app import get_airflow_app from airflow.utils.session import NEW_SESSION, provide_session @@ -35,13 +35,7 @@ from airflow.models.dagbag import DagBag -@security.requires_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - ], -) +@security.requires_access_dag("GET", DagAccessEntity.TASK_INSTANCE) @provide_session def get_extra_links( *, diff --git a/airflow/api_connexion/endpoints/import_error_endpoint.py b/airflow/api_connexion/endpoints/import_error_endpoint.py index f2b9a88311f37..81459b604e9ee 100644 --- a/airflow/api_connexion/endpoints/import_error_endpoint.py +++ b/airflow/api_connexion/endpoints/import_error_endpoint.py @@ -28,8 +28,8 @@ import_error_collection_schema, import_error_schema, ) +from airflow.auth.managers.models.resource_details import DagAccessEntity from airflow.models.errors import ImportError as ImportErrorModel -from airflow.security import permissions from airflow.utils.session import NEW_SESSION, provide_session if TYPE_CHECKING: @@ -38,7 +38,7 @@ from airflow.api_connexion.types import APIResponse -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_IMPORT_ERROR)]) +@security.requires_access_dag("GET", DagAccessEntity.IMPORT_ERRORS) @provide_session def get_import_error(*, import_error_id: int, session: Session = NEW_SESSION) -> APIResponse: """Get an import error.""" @@ -52,7 +52,7 @@ def get_import_error(*, import_error_id: int, session: Session = NEW_SESSION) -> return import_error_schema.dump(error) -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_IMPORT_ERROR)]) +@security.requires_access_dag("GET", DagAccessEntity.IMPORT_ERRORS) @format_parameters({"limit": check_limit}) @provide_session def get_import_errors( diff --git a/airflow/api_connexion/endpoints/log_endpoint.py b/airflow/api_connexion/endpoints/log_endpoint.py index 126b8634e3cfd..239f08ecdaf40 100644 --- a/airflow/api_connexion/endpoints/log_endpoint.py +++ b/airflow/api_connexion/endpoints/log_endpoint.py @@ -27,9 +27,9 @@ from airflow.api_connexion import security from airflow.api_connexion.exceptions import BadRequest, NotFound from airflow.api_connexion.schemas.log_schema import LogResponseObject, logs_schema +from airflow.auth.managers.models.resource_details import DagAccessEntity from airflow.exceptions import TaskNotFound from airflow.models import TaskInstance, Trigger -from airflow.security import permissions from airflow.utils.airflow_flask_app import get_airflow_app from airflow.utils.log.log_reader import TaskLogReader from airflow.utils.session import NEW_SESSION, provide_session @@ -40,13 +40,7 @@ from airflow.api_connexion.types import APIResponse -@security.requires_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - ], -) +@security.requires_access_dag("GET", DagAccessEntity.TASK_LOGS) @provide_session def get_log( *, diff --git a/airflow/api_connexion/endpoints/plugin_endpoint.py b/airflow/api_connexion/endpoints/plugin_endpoint.py index 02ba435d52a9b..500bd65749062 100644 --- a/airflow/api_connexion/endpoints/plugin_endpoint.py +++ b/airflow/api_connexion/endpoints/plugin_endpoint.py @@ -22,13 +22,12 @@ from airflow.api_connexion.parameters import check_limit, format_parameters from airflow.api_connexion.schemas.plugin_schema import PluginCollection, plugin_collection_schema from airflow.plugins_manager import get_plugin_info -from airflow.security import permissions if TYPE_CHECKING: from airflow.api_connexion.types import APIResponse -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_PLUGIN)]) +@security.requires_access_website() @format_parameters({"limit": check_limit}) def get_plugins(*, limit: int, offset: int = 0) -> APIResponse: """Get plugins endpoint.""" diff --git a/airflow/api_connexion/endpoints/pool_endpoint.py b/airflow/api_connexion/endpoints/pool_endpoint.py index 735d777e4ca75..0fbb2c8a23f4d 100644 --- a/airflow/api_connexion/endpoints/pool_endpoint.py +++ b/airflow/api_connexion/endpoints/pool_endpoint.py @@ -30,7 +30,6 @@ from airflow.api_connexion.parameters import apply_sorting, check_limit, format_parameters from airflow.api_connexion.schemas.pool_schema import PoolCollection, pool_collection_schema, pool_schema from airflow.models.pool import Pool -from airflow.security import permissions from airflow.utils.session import NEW_SESSION, provide_session if TYPE_CHECKING: @@ -39,7 +38,7 @@ from airflow.api_connexion.types import APIResponse, UpdateMask -@security.requires_access([(permissions.ACTION_CAN_DELETE, permissions.RESOURCE_POOL)]) +@security.requires_access_pool("DELETE") @provide_session def delete_pool(*, pool_name: str, session: Session = NEW_SESSION) -> APIResponse: """Delete a pool.""" @@ -52,7 +51,7 @@ def delete_pool(*, pool_name: str, session: Session = NEW_SESSION) -> APIRespons return Response(status=HTTPStatus.NO_CONTENT) -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_POOL)]) +@security.requires_access_pool("GET") @provide_session def get_pool(*, pool_name: str, session: Session = NEW_SESSION) -> APIResponse: """Get a pool.""" @@ -62,7 +61,7 @@ def get_pool(*, pool_name: str, session: Session = NEW_SESSION) -> APIResponse: return pool_schema.dump(obj) -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_POOL)]) +@security.requires_access_pool("GET") @format_parameters({"limit": check_limit}) @provide_session def get_pools( @@ -82,7 +81,7 @@ def get_pools( return pool_collection_schema.dump(PoolCollection(pools=pools, total_entries=total_entries)) -@security.requires_access([(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_POOL)]) +@security.requires_access_pool("PUT") @provide_session def patch_pool( *, @@ -138,7 +137,7 @@ def patch_pool( return pool_schema.dump(pool) -@security.requires_access([(permissions.ACTION_CAN_CREATE, permissions.RESOURCE_POOL)]) +@security.requires_access_pool("POST") @provide_session def post_pool(*, session: Session = NEW_SESSION) -> APIResponse: """Create a pool.""" diff --git a/airflow/api_connexion/endpoints/provider_endpoint.py b/airflow/api_connexion/endpoints/provider_endpoint.py index 75bba31218d05..a64368dce3587 100644 --- a/airflow/api_connexion/endpoints/provider_endpoint.py +++ b/airflow/api_connexion/endpoints/provider_endpoint.py @@ -27,7 +27,6 @@ provider_collection_schema, ) from airflow.providers_manager import ProvidersManager -from airflow.security import permissions if TYPE_CHECKING: from airflow.api_connexion.types import APIResponse @@ -46,7 +45,7 @@ def _provider_mapper(provider: ProviderInfo) -> Provider: ) -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_PROVIDER)]) +@security.requires_access_website() def get_providers() -> APIResponse: """Get providers.""" providers = [_provider_mapper(d) for d in ProvidersManager().providers.values()] diff --git a/airflow/api_connexion/endpoints/task_endpoint.py b/airflow/api_connexion/endpoints/task_endpoint.py index 70b6e4b8aba41..46cc816550ffc 100644 --- a/airflow/api_connexion/endpoints/task_endpoint.py +++ b/airflow/api_connexion/endpoints/task_endpoint.py @@ -22,8 +22,8 @@ from airflow.api_connexion import security from airflow.api_connexion.exceptions import BadRequest, NotFound from airflow.api_connexion.schemas.task_schema import TaskCollection, task_collection_schema, task_schema +from airflow.auth.managers.models.resource_details import DagAccessEntity from airflow.exceptions import TaskNotFound -from airflow.security import permissions from airflow.utils.airflow_flask_app import get_airflow_app if TYPE_CHECKING: @@ -31,12 +31,7 @@ from airflow.api_connexion.types import APIResponse -@security.requires_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - ], -) +@security.requires_access_dag("GET", DagAccessEntity.TASK_INSTANCE) def get_task(*, dag_id: str, task_id: str) -> APIResponse: """Get simplified representation of a task.""" dag: DAG = get_airflow_app().dag_bag.get_dag(dag_id) @@ -50,12 +45,7 @@ def get_task(*, dag_id: str, task_id: str) -> APIResponse: return task_schema.dump(task) -@security.requires_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - ], -) +@security.requires_access_dag("GET", DagAccessEntity.TASK_INSTANCE) def get_tasks(*, dag_id: str, order_by: str = "task_id") -> APIResponse: """Get tasks for DAG.""" dag: DAG = get_airflow_app().dag_bag.get_dag(dag_id) diff --git a/airflow/api_connexion/endpoints/task_instance_endpoint.py b/airflow/api_connexion/endpoints/task_instance_endpoint.py index f0d530958b810..6aaddf084f3f1 100644 --- a/airflow/api_connexion/endpoints/task_instance_endpoint.py +++ b/airflow/api_connexion/endpoints/task_instance_endpoint.py @@ -41,11 +41,11 @@ task_instance_reference_schema, task_instance_schema, ) +from airflow.auth.managers.models.resource_details import DagAccessEntity, DagDetails from airflow.models import SlaMiss from airflow.models.dagrun import DagRun as DR from airflow.models.operator import needs_expansion from airflow.models.taskinstance import TaskInstance as TI, clear_task_instances -from airflow.security import permissions from airflow.utils.airflow_flask_app import get_airflow_app from airflow.utils.db import get_query_count from airflow.utils.session import NEW_SESSION, provide_session @@ -61,13 +61,7 @@ T = TypeVar("T") -@security.requires_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - ], -) +@security.requires_access_dag("GET", DagAccessEntity.TASK_INSTANCE) @provide_session def get_task_instance( *, @@ -109,13 +103,7 @@ def get_task_instance( return task_instance_schema.dump(task_instance) -@security.requires_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - ], -) +@security.requires_access_dag("GET", DagAccessEntity.TASK_INSTANCE) @provide_session def get_mapped_task_instance( *, @@ -161,13 +149,7 @@ def get_mapped_task_instance( "updated_at_lte": format_datetime, }, ) -@security.requires_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - ], -) +@security.requires_access_dag("GET", DagAccessEntity.TASK_INSTANCE) @provide_session def get_mapped_task_instances( *, @@ -305,13 +287,7 @@ def _apply_range_filter(query: Select, key: ClauseElement, value_range: tuple[T, "updated_at_lte": format_datetime, }, ) -@security.requires_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - ], -) +@security.requires_access_dag("GET", DagAccessEntity.TASK_INSTANCE) @provide_session def get_task_instances( *, @@ -386,13 +362,7 @@ def get_task_instances( ) -@security.requires_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - ], -) +@security.requires_access_dag("GET", DagAccessEntity.TASK_INSTANCE) @provide_session def get_task_instances_batch(session: Session = NEW_SESSION) -> APIResponse: """Get list of task instances.""" @@ -405,7 +375,7 @@ def get_task_instances_batch(session: Session = NEW_SESSION) -> APIResponse: if dag_ids: cannot_access_dag_ids = set() for id in dag_ids: - if not get_airflow_app().appbuilder.sm.can_read_dag(id, g.user): + if not get_auth_manager().is_authorized_dag(method="GET", details=DagDetails(id=id), user=g.user): cannot_access_dag_ids.add(id) if cannot_access_dag_ids: raise PermissionDenied( @@ -461,13 +431,7 @@ def get_task_instances_batch(session: Session = NEW_SESSION) -> APIResponse: ) -@security.requires_access( - [ - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE), - ], -) +@security.requires_access_dag("PUT", DagAccessEntity.TASK_INSTANCE) @provide_session def post_clear_task_instances(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse: """Clear task instances.""" @@ -527,13 +491,7 @@ def post_clear_task_instances(*, dag_id: str, session: Session = NEW_SESSION) -> ) -@security.requires_access( - [ - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE), - ], -) +@security.requires_access_dag("PUT", DagAccessEntity.TASK_INSTANCE) @provide_session def post_set_task_instances_state(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse: """Set a state of task instances.""" @@ -600,13 +558,7 @@ def set_mapped_task_instance_note( return set_task_instance_note(dag_id=dag_id, dag_run_id=dag_run_id, task_id=task_id, map_index=map_index) -@security.requires_access( - [ - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE), - ], -) +@security.requires_access_dag("PUT", DagAccessEntity.TASK_INSTANCE) @provide_session def patch_task_instance( *, dag_id: str, dag_run_id: str, task_id: str, map_index: int = -1, session: Session = NEW_SESSION @@ -646,13 +598,7 @@ def patch_task_instance( return task_instance_reference_schema.dump(ti) -@security.requires_access( - [ - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE), - ], -) +@security.requires_access_dag("PUT", DagAccessEntity.TASK_INSTANCE) @provide_session def patch_mapped_task_instance( *, dag_id: str, dag_run_id: str, task_id: str, map_index: int, session: Session = NEW_SESSION @@ -663,14 +609,7 @@ def patch_mapped_task_instance( ) -@security.requires_access( - [ - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE), - ], -) +@security.requires_access_dag("PUT", DagAccessEntity.TASK_INSTANCE) @provide_session def set_task_instance_note( *, dag_id: str, dag_run_id: str, task_id: str, map_index: int = -1, session: Session = NEW_SESSION diff --git a/airflow/api_connexion/endpoints/variable_endpoint.py b/airflow/api_connexion/endpoints/variable_endpoint.py index 54d5ac744b6c3..05157298e7181 100644 --- a/airflow/api_connexion/endpoints/variable_endpoint.py +++ b/airflow/api_connexion/endpoints/variable_endpoint.py @@ -43,7 +43,7 @@ RESOURCE_EVENT_PREFIX = "variable" -@security.requires_access([(permissions.ACTION_CAN_DELETE, permissions.RESOURCE_VARIABLE)]) +@security.requires_access_variable("DELETE") @action_logging( event=action_event_from_permission( prefix=RESOURCE_EVENT_PREFIX, @@ -57,7 +57,7 @@ def delete_variable(*, variable_key: str) -> Response: return Response(status=HTTPStatus.NO_CONTENT) -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_VARIABLE)]) +@security.requires_access_variable("DELETE") @provide_session def get_variable(*, variable_key: str, session: Session = NEW_SESSION) -> Response: """Get a variable by key.""" @@ -67,7 +67,7 @@ def get_variable(*, variable_key: str, session: Session = NEW_SESSION) -> Respon return variable_schema.dump(var) -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_VARIABLE)]) +@security.requires_access_variable("GET") @format_parameters({"limit": check_limit}) @provide_session def get_variables( @@ -92,7 +92,7 @@ def get_variables( ) -@security.requires_access([(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_VARIABLE)]) +@security.requires_access_variable("PUT") @provide_session @action_logging( event=action_event_from_permission( @@ -126,7 +126,7 @@ def patch_variable( return variable_schema.dump(variable) -@security.requires_access([(permissions.ACTION_CAN_CREATE, permissions.RESOURCE_VARIABLE)]) +@security.requires_access_variable("POST") @action_logging( event=action_event_from_permission( prefix=RESOURCE_EVENT_PREFIX, diff --git a/airflow/api_connexion/endpoints/xcom_endpoint.py b/airflow/api_connexion/endpoints/xcom_endpoint.py index 73bdd8562e9a5..a26845a59cd83 100644 --- a/airflow/api_connexion/endpoints/xcom_endpoint.py +++ b/airflow/api_connexion/endpoints/xcom_endpoint.py @@ -26,8 +26,8 @@ from airflow.api_connexion.exceptions import BadRequest, NotFound from airflow.api_connexion.parameters import check_limit, format_parameters from airflow.api_connexion.schemas.xcom_schema import XComCollection, xcom_collection_schema, xcom_schema +from airflow.auth.managers.models.resource_details import DagAccessEntity from airflow.models import DagRun as DR, XCom -from airflow.security import permissions from airflow.settings import conf from airflow.utils.airflow_flask_app import get_airflow_app from airflow.utils.db import get_query_count @@ -39,14 +39,7 @@ from airflow.api_connexion.types import APIResponse -@security.requires_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_XCOM), - ], -) +@security.requires_access_dag("GET", DagAccessEntity.XCOM) @format_parameters({"limit": check_limit}) @provide_session def get_xcom_entries( @@ -85,14 +78,7 @@ def get_xcom_entries( return xcom_collection_schema.dump(XComCollection(xcom_entries=query, total_entries=total_entries)) -@security.requires_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_XCOM), - ], -) +@security.requires_access_dag("GET", DagAccessEntity.XCOM) @provide_session def get_xcom_entry( *, diff --git a/airflow/api_connexion/security.py b/airflow/api_connexion/security.py index b108adc2c36b5..4784769c079ce 100644 --- a/airflow/api_connexion/security.py +++ b/airflow/api_connexion/security.py @@ -17,12 +17,25 @@ from __future__ import annotations from functools import wraps -from typing import Callable, Sequence, TypeVar, cast +from typing import TYPE_CHECKING, Callable, Sequence, TypeVar, cast from flask import Response from airflow.api_connexion.exceptions import PermissionDenied, Unauthenticated +from airflow.auth.managers.models.resource_details import ( + ConfigurationDetails, + ConnectionDetails, + DagAccessEntity, + DagDetails, + DatasetDetails, + PoolDetails, + VariableDetails, +) from airflow.utils.airflow_flask_app import get_airflow_app +from airflow.www.extensions.init_auth_manager import get_auth_manager + +if TYPE_CHECKING: + from airflow.auth.managers.base_auth_manager import ResourceMethod T = TypeVar("T", bound=Callable) @@ -55,3 +68,166 @@ def decorated(*args, **kwargs): return cast(T, decorated) return requires_access_decorator + + +def _requires_access(*, is_authorized_callback: Callable[[], bool], func: Callable, args, kwargs): + """ + Define the behavior whether the user is authorized to access the resource. + + :param is_authorized_callback: callback to execute to figure whether the user is authorized to access + the resource? + :param func: the function to call if the user is authorized + :param args: the arguments of ``func`` + :param kwargs: the keyword arguments ``func`` + + :meta private: + """ + check_authentication() + if is_authorized_callback(): + return func(*args, **kwargs) + raise PermissionDenied() + + +def requires_authentication(func: T): + """Decorator for functions that require authentication.""" + + @wraps(func) + def decorated(*args, **kwargs): + check_authentication() + return func(*args, **kwargs) + + return cast(T, decorated) + + +def requires_access_configuration(method: ResourceMethod) -> Callable[[T], T]: + def requires_access_decorator(func: T): + @wraps(func) + def decorated(*args, **kwargs): + section: str | None = kwargs.get("section") + return _requires_access( + is_authorized_callback=lambda: get_auth_manager().is_authorized_configuration( + method=method, details=ConfigurationDetails(section=section) + ), + func=func, + args=args, + kwargs=kwargs, + ) + + return cast(T, decorated) + + return requires_access_decorator + + +def requires_access_connection(method: ResourceMethod) -> Callable[[T], T]: + def requires_access_decorator(func: T): + @wraps(func) + def decorated(*args, **kwargs): + connection_id: str | None = kwargs.get("connection_id") + return _requires_access( + is_authorized_callback=lambda: get_auth_manager().is_authorized_connection( + method=method, details=ConnectionDetails(conn_id=connection_id) + ), + func=func, + args=args, + kwargs=kwargs, + ) + + return cast(T, decorated) + + return requires_access_decorator + + +def requires_access_dag( + method: ResourceMethod, access_entity: DagAccessEntity | None = None +) -> Callable[[T], T]: + def requires_access_decorator(func: T): + @wraps(func) + def decorated(*args, **kwargs): + dag_id: str | None = kwargs.get("dag_id") + return _requires_access( + is_authorized_callback=lambda: get_auth_manager().is_authorized_dag( + method=method, + access_entity=access_entity, + details=DagDetails(id=dag_id), + ), + func=func, + args=args, + kwargs=kwargs, + ) + + return cast(T, decorated) + + return requires_access_decorator + + +def requires_access_dataset(method: ResourceMethod) -> Callable[[T], T]: + def requires_access_decorator(func: T): + @wraps(func) + def decorated(*args, **kwargs): + uri: str | None = kwargs.get("uri") + return _requires_access( + is_authorized_callback=lambda: get_auth_manager().is_authorized_dataset( + method=method, details=DatasetDetails(uri=uri) + ), + func=func, + args=args, + kwargs=kwargs, + ) + + return cast(T, decorated) + + return requires_access_decorator + + +def requires_access_pool(method: ResourceMethod) -> Callable[[T], T]: + def requires_access_decorator(func: T): + @wraps(func) + def decorated(*args, **kwargs): + pool_name: str | None = kwargs.get("pool_name") + return _requires_access( + is_authorized_callback=lambda: get_auth_manager().is_authorized_pool( + method=method, details=PoolDetails(name=pool_name) + ), + func=func, + args=args, + kwargs=kwargs, + ) + + return cast(T, decorated) + + return requires_access_decorator + + +def requires_access_variable(method: ResourceMethod) -> Callable[[T], T]: + def requires_access_decorator(func: T): + @wraps(func) + def decorated(*args, **kwargs): + variable_key: str | None = kwargs.get("variable_key") + return _requires_access( + is_authorized_callback=lambda: get_auth_manager().is_authorized_variable( + method=method, details=VariableDetails(key=variable_key) + ), + func=func, + args=args, + kwargs=kwargs, + ) + + return cast(T, decorated) + + return requires_access_decorator + + +def requires_access_website() -> Callable[[T], T]: + def requires_access_decorator(func: T): + @wraps(func) + def decorated(*args, **kwargs): + return _requires_access( + is_authorized_callback=lambda: get_auth_manager().is_authorized_website(), + func=func, + args=args, + kwargs=kwargs, + ) + + return cast(T, decorated) + + return requires_access_decorator diff --git a/airflow/auth/managers/base_auth_manager.py b/airflow/auth/managers/base_auth_manager.py index 29695dae12aef..71a6a94fdc235 100644 --- a/airflow/auth/managers/base_auth_manager.py +++ b/airflow/auth/managers/base_auth_manager.py @@ -28,9 +28,13 @@ from airflow.auth.managers.models.base_user import BaseUser from airflow.auth.managers.models.resource_details import ( + ConfigurationDetails, ConnectionDetails, DagAccessEntity, DagDetails, + DatasetDetails, + PoolDetails, + VariableDetails, ) from airflow.cli.cli_config import CLICommand from airflow.www.security_manager import AirflowSecurityManagerV2 @@ -82,12 +86,14 @@ def is_authorized_configuration( self, *, method: ResourceMethod, + details: ConfigurationDetails | None = None, user: BaseUser | None = None, ) -> bool: """ Return whether the user is authorized to perform a given action on configuration. :param method: the method to perform + :param details: optional details about the connection :param user: the user to perform the action on. If not provided (or None), it uses the current user """ @@ -110,14 +116,14 @@ def is_authorized_connection( self, *, method: ResourceMethod, - connection_details: ConnectionDetails | None = None, + details: ConnectionDetails | None = None, user: BaseUser | None = None, ) -> bool: """ Return whether the user is authorized to perform a given action on a connection. :param method: the method to perform - :param connection_details: optional details about the connection + :param details: optional details about the connection :param user: the user to perform the action on. If not provided (or None), it uses the current user """ @@ -126,17 +132,17 @@ def is_authorized_dag( self, *, method: ResourceMethod, - dag_access_entity: DagAccessEntity | None = None, - dag_details: DagDetails | None = None, + access_entity: DagAccessEntity | None = None, + details: DagDetails | None = None, user: BaseUser | None = None, ) -> bool: """ Return whether the user is authorized to perform a given action on a DAG. :param method: the method to perform - :param dag_access_entity: the kind of DAG information the authorization request is about. + :param access_entity: the kind of DAG information the authorization request is about. If not provided, the authorization request is about the DAG itself - :param dag_details: optional details about the DAG + :param details: optional details about the DAG :param user: the user to perform the action on. If not provided (or None), it uses the current user """ @@ -145,12 +151,30 @@ def is_authorized_dataset( self, *, method: ResourceMethod, + details: DatasetDetails | None = None, user: BaseUser | None = None, ) -> bool: """ Return whether the user is authorized to perform a given action on a dataset. :param method: the method to perform + :param details: optional details about the variable + :param user: the user to perform the action on. If not provided (or None), it uses the current user + """ + + @abstractmethod + def is_authorized_pool( + self, + *, + method: ResourceMethod, + details: PoolDetails | None = None, + user: BaseUser | None = None, + ) -> bool: + """ + Return whether the user is authorized to perform a given action on a pool. + + :param method: the method to perform + :param details: optional details about the variable :param user: the user to perform the action on. If not provided (or None), it uses the current user """ @@ -159,12 +183,14 @@ def is_authorized_variable( self, *, method: ResourceMethod, + details: VariableDetails | None = None, user: BaseUser | None = None, ) -> bool: """ Return whether the user is authorized to perform a given action on a variable. :param method: the method to perform + :param details: optional details about the variable :param user: the user to perform the action on. If not provided (or None), it uses the current user """ diff --git a/airflow/auth/managers/fab/fab_auth_manager.py b/airflow/auth/managers/fab/fab_auth_manager.py index 9c2c5643b60f5..fc2c8a7d4614a 100644 --- a/airflow/auth/managers/fab/fab_auth_manager.py +++ b/airflow/auth/managers/fab/fab_auth_manager.py @@ -29,7 +29,15 @@ SYNC_PERM_COMMAND, USERS_COMMANDS, ) -from airflow.auth.managers.models.resource_details import ConnectionDetails, DagAccessEntity, DagDetails +from airflow.auth.managers.models.resource_details import ( + ConfigurationDetails, + ConnectionDetails, + DagAccessEntity, + DagDetails, + DatasetDetails, + PoolDetails, + VariableDetails, +) from airflow.cli.cli_config import ( GroupCommand, ) @@ -50,9 +58,15 @@ RESOURCE_DAG_DEPENDENCIES, RESOURCE_DAG_PREFIX, RESOURCE_DAG_RUN, + RESOURCE_DAG_WARNING, RESOURCE_DATASET, + RESOURCE_IMPORT_ERROR, + RESOURCE_PLUGIN, + RESOURCE_POOL, + RESOURCE_PROVIDER, RESOURCE_TASK_INSTANCE, RESOURCE_TASK_LOG, + RESOURCE_TRIGGER, RESOURCE_VARIABLE, RESOURCE_WEBSITE, RESOURCE_XCOM, @@ -65,7 +79,7 @@ CLICommand, ) -_MAP_METHOD_NAME_TO_FAB_ACTION_NAME: dict[ResourceMethod, str] = { +MAP_METHOD_NAME_TO_FAB_ACTION_NAME: dict[ResourceMethod, str] = { "POST": ACTION_CAN_CREATE, "GET": ACTION_CAN_READ, "PUT": ACTION_CAN_EDIT, @@ -77,9 +91,11 @@ DagAccessEntity.CODE: RESOURCE_DAG_CODE, DagAccessEntity.DATASET: RESOURCE_DATASET, DagAccessEntity.DEPENDENCIES: RESOURCE_DAG_DEPENDENCIES, + DagAccessEntity.IMPORT_ERRORS: RESOURCE_IMPORT_ERROR, DagAccessEntity.RUN: RESOURCE_DAG_RUN, DagAccessEntity.TASK_INSTANCE: RESOURCE_TASK_INSTANCE, DagAccessEntity.TASK_LOGS: RESOURCE_TASK_LOG, + DagAccessEntity.WARNING: RESOURCE_DAG_WARNING, DagAccessEntity.XCOM: RESOURCE_XCOM, } @@ -139,7 +155,13 @@ def is_logged_in(self) -> bool: """Return whether the user is logged in.""" return not self.get_user().is_anonymous - def is_authorized_configuration(self, *, method: ResourceMethod, user: BaseUser | None = None) -> bool: + def is_authorized_configuration( + self, + *, + method: ResourceMethod, + details: ConfigurationDetails | None = None, + user: BaseUser | None = None, + ) -> bool: return self._is_authorized(method=method, resource_type=RESOURCE_CONFIG, user=user) def is_authorized_cluster_activity(self, *, method: ResourceMethod, user: BaseUser | None = None) -> bool: @@ -149,7 +171,7 @@ def is_authorized_connection( self, *, method: ResourceMethod, - connection_details: ConnectionDetails | None = None, + details: ConnectionDetails | None = None, user: BaseUser | None = None, ) -> bool: return self._is_authorized(method=method, resource_type=RESOURCE_CONNECTION, user=user) @@ -158,8 +180,8 @@ def is_authorized_dag( self, *, method: ResourceMethod, - dag_access_entity: DagAccessEntity | None = None, - dag_details: DagDetails | None = None, + access_entity: DagAccessEntity | None = None, + details: DagDetails | None = None, user: BaseUser | None = None, ) -> bool: """ @@ -171,34 +193,50 @@ def is_authorized_dag( entity (e.g. DAG runs). 2. ``dag_access`` is provided which means the user wants to access a sub entity of the DAG (e.g. DAG runs). - a. If ``method`` is GET, then check the user has READ permissions on the DAG and the sub entity + a. If ``method`` is GET, then check the user has READ permissions on the sub entity b. Else, check the user has EDIT permissions on the DAG and ``method`` on the sub entity :param method: The method to authorize. - :param dag_access_entity: The dag access entity. - :param dag_details: The dag details. + :param access_entity: The dag access entity. + :param details: The dag details. :param user: The user. """ - if not dag_access_entity: + if not access_entity: # Scenario 1 - return self._is_authorized_dag(method=method, dag_details=dag_details, user=user) + return self._is_authorized_dag(method=method, details=details, user=user) else: # Scenario 2 - resource_type = self._get_fab_resource_type(dag_access_entity) - dag_method: ResourceMethod = "GET" if method == "GET" else "PUT" + resource_type = self._get_fab_resource_type(access_entity) - return self._is_authorized_dag( - method=dag_method, dag_details=dag_details, user=user - ) and self._is_authorized(method=method, resource_type=resource_type, user=user) + if method == "GET": + return self._is_authorized(method=method, resource_type=resource_type, user=user) + else: + return self._is_authorized_dag( + method="PUT", details=details, user=user + ) and self._is_authorized(method=method, resource_type=resource_type, user=user) - def is_authorized_dataset(self, *, method: ResourceMethod, user: BaseUser | None = None) -> bool: + def is_authorized_dataset( + self, *, method: ResourceMethod, details: DatasetDetails | None = None, user: BaseUser | None = None + ) -> bool: return self._is_authorized(method=method, resource_type=RESOURCE_DATASET, user=user) - def is_authorized_variable(self, *, method: ResourceMethod, user: BaseUser | None = None) -> bool: + def is_authorized_pool( + self, *, method: ResourceMethod, details: PoolDetails | None = None, user: BaseUser | None = None + ) -> bool: + return self._is_authorized(method=method, resource_type=RESOURCE_POOL, user=user) + + def is_authorized_variable( + self, *, method: ResourceMethod, details: VariableDetails | None = None, user: BaseUser | None = None + ) -> bool: return self._is_authorized(method=method, resource_type=RESOURCE_VARIABLE, user=user) def is_authorized_website(self, *, user: BaseUser | None = None) -> bool: - return self._is_authorized(method="GET", resource_type=RESOURCE_WEBSITE, user=user) + return ( + self._is_authorized(method="GET", resource_type=RESOURCE_PLUGIN, user=user) + or self._is_authorized(method="GET", resource_type=RESOURCE_PROVIDER, user=user) + or self._is_authorized(method="GET", resource_type=RESOURCE_TRIGGER, user=user) + or self._is_authorized(method="GET", resource_type=RESOURCE_WEBSITE, user=user) + ) def get_security_manager_override_class(self) -> type: """Return the security manager override.""" @@ -270,14 +308,14 @@ def _is_authorized( def _is_authorized_dag( self, method: ResourceMethod, - dag_details: DagDetails | None = None, + details: DagDetails | None = None, user: BaseUser | None = None, ) -> bool: """ Return whether the user is authorized to perform a given action on a DAG. :param method: the method to perform - :param dag_details: optional details about the DAG + :param details: optional details about the DAG :param user: the user to perform the action on. If not provided (or None), it uses the current user :meta private: @@ -286,9 +324,9 @@ def _is_authorized_dag( if is_global_authorized: return True - if dag_details and dag_details.id: + if details and details.id: # Check whether the user has permissions to access a specific DAG - resource_dag_name = self._resource_name_for_dag(dag_details.id) + resource_dag_name = self._resource_name_for_dag(details.id) return self._is_authorized(method=method, resource_type=resource_dag_name, user=user) return False @@ -302,9 +340,9 @@ def _get_fab_action(method: ResourceMethod) -> str: :meta private: """ - if method not in _MAP_METHOD_NAME_TO_FAB_ACTION_NAME: + if method not in MAP_METHOD_NAME_TO_FAB_ACTION_NAME: raise AirflowException(f"Unknown method: {method}") - return _MAP_METHOD_NAME_TO_FAB_ACTION_NAME[method] + return MAP_METHOD_NAME_TO_FAB_ACTION_NAME[method] @staticmethod def _get_fab_resource_type(dag_access_entity: DagAccessEntity): diff --git a/airflow/auth/managers/fab/security_manager/override.py b/airflow/auth/managers/fab/security_manager/override.py index 4b884c303ab63..5ec6d58997b0a 100644 --- a/airflow/auth/managers/fab/security_manager/override.py +++ b/airflow/auth/managers/fab/security_manager/override.py @@ -23,7 +23,7 @@ import uuid import warnings from functools import cached_property -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Container, Iterable import re2 from flask import flash, g, session @@ -40,13 +40,20 @@ from sqlalchemy.exc import MultipleResultsFound from werkzeug.security import generate_password_hash +from airflow.auth.managers.fab.fab_auth_manager import MAP_METHOD_NAME_TO_FAB_ACTION_NAME from airflow.auth.managers.fab.models import Action, Permission, RegisterUser, Resource, Role from airflow.auth.managers.fab.models.anonymous_user import AnonymousUser -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, RemovedInAirflow3Warning +from airflow.models import DagModel +from airflow.security import permissions +from airflow.utils.session import NEW_SESSION, provide_session from airflow.www.security_manager import AirflowSecurityManagerV2 from airflow.www.session import AirflowDatabaseSessionInterface if TYPE_CHECKING: + from sqlalchemy.orm import Session + + from airflow.auth.managers.base_auth_manager import ResourceMethod from airflow.auth.managers.fab.models import User log = logging.getLogger(__name__) @@ -463,6 +470,66 @@ def create_db(self): log.error(const.LOGMSG_ERR_SEC_CREATE_DB, e) exit(1) + def get_readable_dags(self, user) -> Iterable[DagModel]: + """Gets the DAGs readable by authenticated user.""" + warnings.warn( + "`get_readable_dags` has been deprecated. Please use `get_readable_dag_ids` instead.", + RemovedInAirflow3Warning, + stacklevel=2, + ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", RemovedInAirflow3Warning) + return self.get_accessible_dags([permissions.ACTION_CAN_READ], user) + + def get_editable_dags(self, user) -> Iterable[DagModel]: + """Gets the DAGs editable by authenticated user.""" + warnings.warn( + "`get_editable_dags` has been deprecated. Please use `get_editable_dag_ids` instead.", + RemovedInAirflow3Warning, + stacklevel=2, + ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", RemovedInAirflow3Warning) + return self.get_accessible_dags([permissions.ACTION_CAN_EDIT], user) + + @provide_session + def get_accessible_dags( + self, + user_actions: Container[str] | None, + user, + session: Session = NEW_SESSION, + ) -> Iterable[DagModel]: + warnings.warn( + "`get_accessible_dags` has been deprecated. Please use `get_accessible_dag_ids` instead.", + RemovedInAirflow3Warning, + stacklevel=3, + ) + + dag_ids = self.get_accessible_dag_ids(user, user_actions, session) + return session.scalars(select(DagModel).where(DagModel.dag_id.in_(dag_ids))) + + @provide_session + def get_accessible_dag_ids( + self, + user, + user_actions: Container[str] | None = None, + session: Session = NEW_SESSION, + ) -> set[str]: + warnings.warn( + "`get_accessible_dag_ids` has been deprecated. Please use `get_permitted_dag_ids` instead.", + RemovedInAirflow3Warning, + stacklevel=3, + ) + if not user_actions: + user_actions = [permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ] + fab_action_name_to_method_name = {v: k for k, v in MAP_METHOD_NAME_TO_FAB_ACTION_NAME.items()} + user_methods: Container[ResourceMethod] = [ + fab_action_name_to_method_name[action] + for action in fab_action_name_to_method_name + if action in user_actions + ] + return self.get_permitted_dag_ids(user=user, user_methods=user_methods, session=session) + """ ----------- Role entity diff --git a/airflow/auth/managers/models/resource_details.py b/airflow/auth/managers/models/resource_details.py index 51cdc5979392d..9d424bead2aad 100644 --- a/airflow/auth/managers/models/resource_details.py +++ b/airflow/auth/managers/models/resource_details.py @@ -21,18 +21,46 @@ from enum import Enum +@dataclass +class ConfigurationDetails: + """Represents the details of a configuration.""" + + section: str | None = None + + @dataclass class ConnectionDetails: """Represents the details of a connection.""" - conn_id: str + conn_id: str | None = None @dataclass class DagDetails: """Represents the details of a DAG.""" - id: str + id: str | None = None + + +@dataclass +class DatasetDetails: + """Represents the details of a dataset.""" + + uri: str | None = None + + +@dataclass +class PoolDetails: + """Represents the details of a pool.""" + + name: str | None = None + + +@dataclass +class VariableDetails: + """Represents the details of a variable.""" + + key: str | None = None class DagAccessEntity(Enum): @@ -42,7 +70,9 @@ class DagAccessEntity(Enum): CODE = "CODE" DATASET = "DATASET" DEPENDENCIES = "DEPENDENCIES" + IMPORT_ERRORS = "IMPORT_ERRORS" RUN = "RUN" TASK_INSTANCE = "TASK_INSTANCE" TASK_LOGS = "TASK_LOGS" + WARNING = "WARNING" XCOM = "XCOM" diff --git a/airflow/www/auth.py b/airflow/www/auth.py index c943779ab09c5..03708a3bc8828 100644 --- a/airflow/www/auth.py +++ b/airflow/www/auth.py @@ -33,7 +33,7 @@ if TYPE_CHECKING: from airflow.auth.managers.base_auth_manager import ResourceMethod - from airflow.models import Connection + from airflow.models.connection import Connection T = TypeVar("T", bound=Callable) @@ -51,7 +51,7 @@ def _has_access_no_details(is_authorized_callback: Callable[[], bool]) -> Callab This works only for resources with no details. This function is used in some ``has_access_`` functions below. - :param is_authorized_callback: callback to execute to figure whether the user authorized to access + :param is_authorized_callback: callback to execute to figure whether the user is authorized to access the resource? """ @@ -116,9 +116,7 @@ def decorated(*args, **kwargs): ] is_authorized = all( [ - get_auth_manager().is_authorized_connection( - method=method, connection_details=connection_details - ) + get_auth_manager().is_authorized_connection(method=method, details=connection_details) for connection_details in connections_details ] ) @@ -167,8 +165,8 @@ def decorated(*args, **kwargs): is_authorized = get_auth_manager().is_authorized_dag( method=method, - dag_access_entity=access_entity, - dag_details=None if not dag_id else DagDetails(id=dag_id), + access_entity=access_entity, + details=None if not dag_id else DagDetails(id=dag_id), ) return _has_access( diff --git a/airflow/www/extensions/init_jinja_globals.py b/airflow/www/extensions/init_jinja_globals.py index ff5481dd468f6..0c3521882a882 100644 --- a/airflow/www/extensions/init_jinja_globals.py +++ b/airflow/www/extensions/init_jinja_globals.py @@ -21,6 +21,7 @@ import pendulum import airflow +from airflow.auth.managers.models.resource_details import DagDetails from airflow.configuration import conf from airflow.settings import IS_K8S_OR_K8SCELERY_EXECUTOR, STATE_COLORS from airflow.utils.net import get_hostname @@ -69,10 +70,17 @@ def prepare_jinja_globals(): "git_version": git_version, "k8s_or_k8scelery_executor": IS_K8S_OR_K8SCELERY_EXECUTOR, "rest_api_enabled": False, - "auth_manager": get_auth_manager(), "config_test_connection": conf.get("core", "test_connection", fallback="Disabled"), } + # Extra global specific to auth manager + extra_globals.update( + { + "auth_manager": get_auth_manager(), + "DagDetails": DagDetails, + } + ) + backends = conf.get("api", "auth_backends") if backends and backends[0] != "airflow.api.auth.backend.deny_all": extra_globals["rest_api_enabled"] = True diff --git a/airflow/www/security_manager.py b/airflow/www/security_manager.py index 172857098ac7d..ca2f455a91617 100644 --- a/airflow/www/security_manager.py +++ b/airflow/www/security_manager.py @@ -23,7 +23,7 @@ from sqlalchemy import or_, select from sqlalchemy.orm import joinedload -from airflow.auth.managers.fab.models import Permission, Resource, Role, User +from airflow.auth.managers.fab.models import Permission, Resource, Role from airflow.auth.managers.fab.views.permissions import ( ActionModelView, PermissionPairModelView, @@ -43,6 +43,7 @@ CustomUserInfoEditView, ) from airflow.auth.managers.fab.views.user_stats import CustomUserStatsChartView +from airflow.auth.managers.models.resource_details import DagDetails from airflow.exceptions import AirflowException, RemovedInAirflow3Warning from airflow.models import DagBag, DagModel from airflow.security import permissions @@ -63,6 +64,8 @@ if TYPE_CHECKING: from sqlalchemy.orm import Session + from airflow.auth.managers.base_auth_manager import ResourceMethod + class AirflowSecurityManagerV2(SecurityManager, LoggingMixin): """Custom security manager, which introduces a permission model adapted to Airflow. @@ -267,95 +270,47 @@ def get_user_roles(user=None): user = g.user return user.roles - def get_readable_dags(self, user) -> Iterable[DagModel]: - """Gets the DAGs readable by authenticated user.""" - warnings.warn( - "`get_readable_dags` has been deprecated. Please use `get_readable_dag_ids` instead.", - RemovedInAirflow3Warning, - stacklevel=2, - ) - with warnings.catch_warnings(): - warnings.simplefilter("ignore", RemovedInAirflow3Warning) - return self.get_accessible_dags([permissions.ACTION_CAN_READ], user) - - def get_editable_dags(self, user) -> Iterable[DagModel]: - """Gets the DAGs editable by authenticated user.""" - warnings.warn( - "`get_editable_dags` has been deprecated. Please use `get_editable_dag_ids` instead.", - RemovedInAirflow3Warning, - stacklevel=2, - ) - with warnings.catch_warnings(): - warnings.simplefilter("ignore", RemovedInAirflow3Warning) - return self.get_accessible_dags([permissions.ACTION_CAN_EDIT], user) - - @provide_session - def get_accessible_dags( - self, - user_actions: Container[str] | None, - user, - session: Session = NEW_SESSION, - ) -> Iterable[DagModel]: - warnings.warn( - "`get_accessible_dags` has been deprecated. Please use `get_accessible_dag_ids` instead.", - RemovedInAirflow3Warning, - stacklevel=3, - ) - dag_ids = self.get_accessible_dag_ids(user, user_actions, session) - return session.scalars(select(DagModel).where(DagModel.dag_id.in_(dag_ids))) - def get_readable_dag_ids(self, user) -> set[str]: """Gets the DAG IDs readable by authenticated user.""" - return self.get_accessible_dag_ids(user, [permissions.ACTION_CAN_READ]) + return self.get_permitted_dag_ids(user, ["GET"]) def get_editable_dag_ids(self, user) -> set[str]: """Gets the DAG IDs editable by authenticated user.""" - return self.get_accessible_dag_ids(user, [permissions.ACTION_CAN_EDIT]) + return self.get_permitted_dag_ids(user, ["PUT"]) @provide_session - def get_accessible_dag_ids( + def get_permitted_dag_ids( self, user, - user_actions: Container[str] | None = None, + user_methods: Container[ResourceMethod] | None = None, session: Session = NEW_SESSION, ) -> set[str]: """Generic function to get readable or writable DAGs for user.""" - if not user_actions: - user_actions = [permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ] + if not user_methods: + user_methods = ["PUT", "GET"] - if not get_auth_manager().is_logged_in(): - roles = user.roles - else: - if (permissions.ACTION_CAN_EDIT in user_actions and self.can_edit_all_dags(user)) or ( - permissions.ACTION_CAN_READ in user_actions and self.can_read_all_dags(user) - ): - return {dag.dag_id for dag in session.execute(select(DagModel.dag_id))} - user_query = session.scalar( - select(User) - .options( - joinedload(User.roles) - .subqueryload(Role.permissions) - .options(joinedload(Permission.action), joinedload(Permission.resource)) + dag_ids = {dag.dag_id for dag in session.execute(select(DagModel.dag_id))} + + if ("GET" in user_methods and get_auth_manager().is_authorized_dag(method="GET", user=user)) or ( + "PUT" in user_methods and get_auth_manager().is_authorized_dag(method="PUT", user=user) + ): + return dag_ids + + return { + dag_id + for dag_id in dag_ids + if ( + "GET" in user_methods + and get_auth_manager().is_authorized_dag( + method="GET", details=DagDetails(id=dag_id), user=user + ) + ) + or ( + "PUT" in user_methods + and get_auth_manager().is_authorized_dag( + method="PUT", details=DagDetails(id=dag_id), user=user ) - .where(User.id == user.id) ) - roles = user_query.roles - - resources = set() - for role in roles: - for permission in role.permissions: - action = permission.action.name - if action in user_actions: - resource = permission.resource.name - if resource == permissions.RESOURCE_DAG: - return {dag.dag_id for dag in session.execute(select(DagModel.dag_id))} - if resource.startswith(permissions.RESOURCE_DAG_PREFIX): - resources.add(resource[len(permissions.RESOURCE_DAG_PREFIX) :]) - else: - resources.add(resource) - return { - dag.dag_id - for dag in session.execute(select(DagModel.dag_id).where(DagModel.dag_id.in_(resources))) } def can_access_some_dags(self, action: str, dag_id: str | None = None) -> bool: @@ -369,24 +324,6 @@ def can_access_some_dags(self, action: str, dag_id: str | None = None) -> bool: return any(self.get_readable_dag_ids(user)) return any(self.get_editable_dag_ids(user)) - def can_read_dag(self, dag_id: str, user=None) -> bool: - """Determines whether a user has DAG read access.""" - root_dag_id = self._get_root_dag_id(dag_id) - dag_resource_name = permissions.resource_name_for_dag(root_dag_id) - return self.has_access(permissions.ACTION_CAN_READ, dag_resource_name, user=user) - - def can_edit_dag(self, dag_id: str, user=None) -> bool: - """Determines whether a user has DAG edit access.""" - root_dag_id = self._get_root_dag_id(dag_id) - dag_resource_name = permissions.resource_name_for_dag(root_dag_id) - return self.has_access(permissions.ACTION_CAN_EDIT, dag_resource_name, user=user) - - def can_delete_dag(self, dag_id: str, user=None) -> bool: - """Determines whether a user has DAG delete access.""" - root_dag_id = self._get_root_dag_id(dag_id) - dag_resource_name = permissions.resource_name_for_dag(root_dag_id) - return self.has_access(permissions.ACTION_CAN_DELETE, dag_resource_name, user=user) - def prefixed_dag_id(self, dag_id: str) -> str: """Returns the permission name for a DAG id.""" warnings.warn( @@ -428,36 +365,6 @@ def has_access(self, action_name: str, resource_name: str, user=None) -> bool: return False - def _has_role(self, role_name_or_list: Container, user) -> bool: - """Whether the user has this role name.""" - if not isinstance(role_name_or_list, list): - role_name_or_list = [role_name_or_list] - return any(r.name in role_name_or_list for r in user.roles) - - def has_all_dags_access(self, user) -> bool: - """ - Has all the dag access in any of the 3 cases. - - 1. Role needs to be in (Admin, Viewer, User, Op). - 2. Has can_read action on dags resource. - 3. Has can_edit action on dags resource. - """ - if not user: - user = g.user - return ( - self._has_role(["Admin", "Viewer", "Op", "User"], user) - or self.can_read_all_dags(user) - or self.can_edit_all_dags(user) - ) - - def can_edit_all_dags(self, user=None) -> bool: - """Has can_edit action on DAG resource.""" - return self.has_access(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG, user) - - def can_read_all_dags(self, user=None) -> bool: - """Has can_read action on DAG resource.""" - return self.has_access(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG, user) - def clean_perms(self) -> None: """FAB leaves faulty permissions that need to be cleaned up.""" self.log.debug("Cleaning faulty perms") diff --git a/airflow/www/templates/airflow/dag.html b/airflow/www/templates/airflow/dag.html index d324199d1fdef..884f0401858ef 100644 --- a/airflow/www/templates/airflow/dag.html +++ b/airflow/www/templates/airflow/dag.html @@ -110,8 +110,8 @@

{% if dag.parent_dag is defined and dag.parent_dag %} SUBDAG: {{ dag.dag_id }} {% else %} - {% set can_edit = appbuilder.sm.can_edit_dag(dag.dag_id) %} - {% if appbuilder.sm.can_edit_dag(dag.dag_id) %} + {% set can_edit = auth_manager.is_authorized_dag(method="PUT", details=DagDetails(id=dag.dag_id)) %} + {% if can_edit %} {% set switch_tooltip = 'Pause/Unpause DAG' %} {% else %} {% set switch_tooltip = 'DAG is Paused' if dag_is_paused else 'DAG is Active' %} diff --git a/airflow/www/views.py b/airflow/www/views.py index f3511fc41c633..40dec21f3a3bf 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -81,7 +81,7 @@ set_dag_run_state_to_success, set_state, ) -from airflow.auth.managers.models.resource_details import DagAccessEntity +from airflow.auth.managers.models.resource_details import DagAccessEntity, DagDetails from airflow.configuration import AIRFLOW_CONFIG, conf from airflow.datasets import Dataset from airflow.exceptions import ( @@ -775,7 +775,7 @@ def index(self): end = start + dags_per_page # Get all the dag id the user could access - filter_dag_ids = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user) + filter_dag_ids = get_airflow_app().appbuilder.sm.get_permitted_dag_ids(g.user) with create_session() as session: # read orm_dags from the db @@ -916,9 +916,13 @@ def index(self): dataset_triggered_next_run_info = {} for dag in dags: - dag.can_edit = get_airflow_app().appbuilder.sm.can_edit_dag(dag.dag_id, g.user) + dag.can_edit = get_auth_manager().is_authorized_dag( + method="PUT", details=DagDetails(id=dag.dag_id), user=g.user + ) dag.can_trigger = dag.can_edit and can_create_dag_run - dag.can_delete = get_airflow_app().appbuilder.sm.can_delete_dag(dag.dag_id, g.user) + dag.can_delete = get_auth_manager().is_authorized_dag( + method="DELETE", details=DagDetails(id=dag.dag_id), user=g.user + ) dagtags = session.execute(select(func.distinct(DagTag.name)).order_by(DagTag.name)).all() tags = [ @@ -1062,11 +1066,10 @@ def cluster_activity(self): ) @expose("/next_run_datasets_summary", methods=["POST"]) - @auth.has_access_dag("GET") @provide_session def next_run_datasets_summary(self, session: Session = NEW_SESSION): """Next run info for dataset triggered DAGs.""" - allowed_dag_ids = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user) + allowed_dag_ids = get_airflow_app().appbuilder.sm.get_permitted_dag_ids(g.user) if not allowed_dag_ids: return flask.json.jsonify({}) @@ -1101,7 +1104,7 @@ def next_run_datasets_summary(self, session: Session = NEW_SESSION): @provide_session def dag_stats(self, session: Session = NEW_SESSION): """Dag statistics.""" - allowed_dag_ids = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user) + allowed_dag_ids = get_airflow_app().appbuilder.sm.get_permitted_dag_ids(g.user) # Filter by post parameters selected_dag_ids = {unquote(dag_id) for dag_id in request.form.getlist("dag_ids") if dag_id} @@ -1133,7 +1136,7 @@ def dag_stats(self, session: Session = NEW_SESSION): @provide_session def task_stats(self, session: Session = NEW_SESSION): """Task Statistics.""" - allowed_dag_ids = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user) + allowed_dag_ids = get_airflow_app().appbuilder.sm.get_permitted_dag_ids(g.user) if not allowed_dag_ids: return flask.json.jsonify({}) @@ -1234,7 +1237,7 @@ def task_stats(self, session: Session = NEW_SESSION): @provide_session def last_dagruns(self, session: Session = NEW_SESSION): """Last DAG runs.""" - allowed_dag_ids = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user) + allowed_dag_ids = get_airflow_app().appbuilder.sm.get_permitted_dag_ids(g.user) # Filter by post parameters selected_dag_ids = {unquote(dag_id) for dag_id in request.form.getlist("dag_ids") if dag_id} @@ -2341,7 +2344,7 @@ def dagrun_clear(self, *, session: Session = NEW_SESSION): @provide_session def blocked(self, session: Session = NEW_SESSION): """Mark Dag Blocked.""" - allowed_dag_ids = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user) + allowed_dag_ids = get_airflow_app().appbuilder.sm.get_permitted_dag_ids(g.user) # Filter by post parameters selected_dag_ids = {unquote(dag_id) for dag_id in request.form.getlist("dag_ids") if dag_id} @@ -3908,9 +3911,11 @@ class DagFilter(BaseFilter): """Filter using DagIDs.""" def apply(self, query, func): - if get_airflow_app().appbuilder.sm.has_all_dags_access(g.user): + if get_auth_manager().is_authorized_dag( + method="GET", user=g.user + ) or get_auth_manager().is_authorized_dag(method="PUT", user=g.user): return query - filter_dag_ids = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user) + filter_dag_ids = get_airflow_app().appbuilder.sm.get_permitted_dag_ids(g.user) return query.where(self.model.dag_id.in_(filter_dag_ids)) @@ -3960,7 +3965,7 @@ class AirflowPrivilegeVerifierModelView(AirflowModelView): @staticmethod def validate_dag_edit_access(item: DagRun | TaskInstance): """Validates whether the user has 'can_edit' access for this specific DAG.""" - if not get_airflow_app().appbuilder.sm.can_edit_dag(item.dag_id): + if not get_auth_manager().is_authorized_dag(method="PUT", details=DagDetails(id=item.dag_id)): raise AirflowException(f"Access denied for dag_id {item.dag_id}") def pre_add(self, item: DagRun | TaskInstance): @@ -4006,7 +4011,7 @@ def check_dag_edit_acl_for_actions( ) for dag_id in dag_ids: - if not get_airflow_app().appbuilder.sm.can_edit_dag(dag_id): + if not get_auth_manager().is_authorized_dag(method="PUT", details=DagDetails(id=dag_id)): flash(f"Access denied for dag_id {dag_id}", "danger") logging.warning("User %s tried to modify %s without having access.", g.user.username, dag_id) return redirect(self.get_default_url()) @@ -5696,7 +5701,6 @@ def action_set_skipped(self, tis): class AutocompleteView(AirflowBaseView): """View to provide autocomplete results.""" - @auth.has_access_dag("GET") @provide_session @expose("/dagmodel/autocomplete") def autocomplete(self, session: Session = NEW_SESSION): @@ -5730,7 +5734,7 @@ def autocomplete(self, session: Session = NEW_SESSION): dag_ids_query = dag_ids_query.where(DagModel.is_paused) owners_query = owners_query.where(DagModel.is_paused) - filter_dag_ids = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user) + filter_dag_ids = get_airflow_app().appbuilder.sm.get_permitted_dag_ids(g.user) dag_ids_query = dag_ids_query.where(DagModel.dag_id.in_(filter_dag_ids)) owners_query = owners_query.where(DagModel.dag_id.in_(filter_dag_ids)) @@ -5815,9 +5819,9 @@ def add_user_permissions_to_dag(sender, template, context, **extra): permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN ) - dag.can_edit = get_airflow_app().appbuilder.sm.can_edit_dag(dag.dag_id) + dag.can_edit = get_auth_manager().is_authorized_dag(method="PUT", details=DagDetails(id=dag.dag_id)) dag.can_trigger = dag.can_edit and can_create_dag_run - dag.can_delete = get_airflow_app().appbuilder.sm.can_delete_dag(dag.dag_id) + dag.can_delete = get_auth_manager().is_authorized_dag(method="DELETE", details=DagDetails(id=dag.dag_id)) context["dag"] = dag diff --git a/tests/api_connexion/endpoints/test_dag_endpoint.py b/tests/api_connexion/endpoints/test_dag_endpoint.py index 06ac5bbef85fd..6dc59f2f8c8d2 100644 --- a/tests/api_connexion/endpoints/test_dag_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_endpoint.py @@ -815,12 +815,13 @@ def test_should_raises_401_unauthenticated(self): assert_401(response) - def test_should_respond_403_unauthorized(self): + def test_should_return_empty_list(self): self._create_dag_models(1) response = self.client.get("api/v1/dags", environ_overrides={"REMOTE_USER": "test_no_permissions"}) - assert response.status_code == 403 + assert response.status_code == 200 + assert {"dags": [], "total_entries": 0} == response.json def test_paused_true_returns_paused_dags(self, url_safe_serializer): self._create_dag_models(1, dag_id_prefix="TEST_DAG_PAUSED", is_paused=True) @@ -1539,7 +1540,7 @@ def test_should_raises_401_unauthenticated(self): assert_401(response) - def test_should_respond_403_unauthorized(self): + def test_should_return_empty_list(self): self._create_dag_models(1) response = self.client.patch( "api/v1/dags?dag_id_pattern=~", @@ -1549,7 +1550,8 @@ def test_should_respond_403_unauthorized(self): environ_overrides={"REMOTE_USER": "test_no_permissions"}, ) - assert response.status_code == 403 + assert response.status_code == 200 + assert {"dags": [], "total_entries": 0} == response.json def test_should_respond_200_and_pause_dags(self, url_safe_serializer): file_token = url_safe_serializer.dumps("/tmp/dag_1.py") diff --git a/tests/api_connexion/endpoints/test_log_endpoint.py b/tests/api_connexion/endpoints/test_log_endpoint.py index 422830eb6ed53..e4d61c50607da 100644 --- a/tests/api_connexion/endpoints/test_log_endpoint.py +++ b/tests/api_connexion/endpoints/test_log_endpoint.py @@ -46,9 +46,7 @@ def configured_app(minimal_app_for_api): username="test", role_name="Test", permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_LOG), ], ) create_user(app, username="test_no_permissions", role_name="TestNoPermissions") diff --git a/tests/auth/managers/fab/test_fab_auth_manager.py b/tests/auth/managers/fab/test_fab_auth_manager.py index 85a13ad17822c..c48ceccfa742f 100644 --- a/tests/auth/managers/fab/test_fab_auth_manager.py +++ b/tests/auth/managers/fab/test_fab_auth_manager.py @@ -293,7 +293,7 @@ def test_is_authorized_dag( user = Mock() user.perms = user_permissions result = auth_manager.is_authorized_dag( - method=method, dag_access_entity=dag_access_entity, dag_details=dag_details, user=user + method=method, access_entity=dag_access_entity, details=dag_details, user=user ) assert result == expected_result diff --git a/tests/auth/managers/test_base_auth_manager.py b/tests/auth/managers/test_base_auth_manager.py index 1ff8dcfbbf172..416fa75e2aaf5 100644 --- a/tests/auth/managers/test_base_auth_manager.py +++ b/tests/auth/managers/test_base_auth_manager.py @@ -27,7 +27,15 @@ if TYPE_CHECKING: from airflow.auth.managers.models.base_user import BaseUser - from airflow.auth.managers.models.resource_details import ConnectionDetails, DagAccessEntity, DagDetails + from airflow.auth.managers.models.resource_details import ( + ConfigurationDetails, + ConnectionDetails, + DagAccessEntity, + DagDetails, + DatasetDetails, + PoolDetails, + VariableDetails, + ) class EmptyAuthManager(BaseAuthManager): @@ -40,7 +48,13 @@ def get_user(self) -> BaseUser: def get_user_id(self) -> str: raise NotImplementedError() - def is_authorized_configuration(self, *, method: ResourceMethod, user: BaseUser | None = None) -> bool: + def is_authorized_configuration( + self, + *, + method: ResourceMethod, + details: ConfigurationDetails | None = None, + user: BaseUser | None = None, + ) -> bool: raise NotImplementedError() def is_authorized_cluster_activity(self, *, method: ResourceMethod, user: BaseUser | None = None) -> bool: @@ -50,7 +64,7 @@ def is_authorized_connection( self, *, method: ResourceMethod, - connection_details: ConnectionDetails | None = None, + details: ConnectionDetails | None = None, user: BaseUser | None = None, ) -> bool: raise NotImplementedError() @@ -59,16 +73,25 @@ def is_authorized_dag( self, *, method: ResourceMethod, - dag_access_entity: DagAccessEntity | None = None, - dag_details: DagDetails | None = None, + access_entity: DagAccessEntity | None = None, + details: DagDetails | None = None, user: BaseUser | None = None, ) -> bool: raise NotImplementedError() - def is_authorized_dataset(self, *, method: ResourceMethod, user: BaseUser | None = None) -> bool: + def is_authorized_dataset( + self, *, method: ResourceMethod, details: DatasetDetails | None = None, user: BaseUser | None = None + ) -> bool: raise NotImplementedError() - def is_authorized_variable(self, *, method: ResourceMethod, user: BaseUser | None = None) -> bool: + def is_authorized_pool( + self, *, method: ResourceMethod, details: PoolDetails | None = None, user: BaseUser | None = None + ) -> bool: + raise NotImplementedError() + + def is_authorized_variable( + self, *, method: ResourceMethod, details: VariableDetails | None = None, user: BaseUser | None = None + ) -> bool: raise NotImplementedError() def is_authorized_website(self, *, user: BaseUser | None = None) -> bool: diff --git a/tests/www/test_security.py b/tests/www/test_security.py index b70aad536bce4..60157303475c4 100644 --- a/tests/www/test_security.py +++ b/tests/www/test_security.py @@ -33,6 +33,7 @@ from airflow.auth.managers.fab.fab_auth_manager import FabAuthManager from airflow.auth.managers.fab.models import User, assoc_permission_role from airflow.auth.managers.fab.models.anonymous_user import AnonymousUser +from airflow.auth.managers.models.resource_details import DagDetails from airflow.configuration import initialize_config from airflow.exceptions import AirflowException from airflow.models import DagModel @@ -41,6 +42,7 @@ from airflow.security import permissions from airflow.www import app as application from airflow.www.auth import get_access_denied_message +from airflow.www.extensions.init_auth_manager import get_auth_manager from airflow.www.utils import CustomSQLAInterface from tests.test_utils.api_connexion_utils import ( create_user, @@ -118,6 +120,24 @@ def _delete_dag_model(dag_model, session, security_manager): _delete_dag_permissions(dag_model.dag_id, security_manager) +def _can_read_dag(dag_id: str, user) -> bool: + return get_auth_manager().is_authorized_dag(method="GET", details=DagDetails(id=dag_id), user=user) + + +def _can_edit_dag(dag_id: str, user) -> bool: + return get_auth_manager().is_authorized_dag(method="PUT", details=DagDetails(id=dag_id), user=user) + + +def _can_delete_dag(dag_id: str, user) -> bool: + return get_auth_manager().is_authorized_dag(method="DELETE", details=DagDetails(id=dag_id), user=user) + + +def _has_all_dags_access(user) -> bool: + return get_auth_manager().is_authorized_dag( + method="GET", user=user + ) or get_auth_manager().is_authorized_dag(method="PUT", user=user) + + @contextlib.contextmanager def _create_dag_model_context(dag_id, session, security_manager): dag = _create_dag_model(dag_id, session, security_manager) @@ -321,7 +341,7 @@ def test_verify_default_anon_user_has_no_accessible_dag_ids( with _create_dag_model_context("test_dag_id", session, security_manager): security_manager.sync_roles() - assert security_manager.get_accessible_dag_ids(user) == set() + assert security_manager.get_permitted_dag_ids(user) == set() def test_verify_default_anon_user_has_no_access_to_specific_dag(app, session, security_manager, has_dag_perm): @@ -334,8 +354,8 @@ def test_verify_default_anon_user_has_no_access_to_specific_dag(app, session, se with _create_dag_model_context(dag_id, session, security_manager): security_manager.sync_roles() - assert security_manager.can_read_dag(dag_id, user) is False - assert security_manager.can_edit_dag(dag_id, user) is False + assert _can_read_dag(dag_id, user) is False + assert _can_edit_dag(dag_id, user) is False assert has_dag_perm(permissions.ACTION_CAN_READ, dag_id, user) is False assert has_dag_perm(permissions.ACTION_CAN_EDIT, dag_id, user) is False @@ -359,7 +379,7 @@ def test_verify_anon_user_with_admin_role_has_all_dag_access( security_manager.sync_roles() - assert security_manager.get_accessible_dag_ids(user) == set(test_dag_ids) + assert security_manager.get_permitted_dag_ids(user) == set(test_dag_ids) def test_verify_anon_user_with_admin_role_has_access_to_each_dag( @@ -379,8 +399,8 @@ def test_verify_anon_user_with_admin_role_has_access_to_each_dag( with _create_dag_model_context(dag_id, session, security_manager): security_manager.sync_roles() - assert security_manager.can_read_dag(dag_id, user) is True - assert security_manager.can_edit_dag(dag_id, user) is True + assert _can_read_dag(dag_id, user) is True + assert _can_edit_dag(dag_id, user) is True assert has_dag_perm(permissions.ACTION_CAN_READ, dag_id, user) is True assert has_dag_perm(permissions.ACTION_CAN_EDIT, dag_id, user) is True @@ -487,7 +507,7 @@ def test_get_accessible_dag_ids(mock_is_logged_in, app, security_manager, sessio dag_id, access_control={role_name: permission_action} ) - assert security_manager.get_accessible_dag_ids(user) == {"dag_id"} + assert security_manager.get_permitted_dag_ids(user) == {"dag_id"} @patch.object(FabAuthManager, "is_logged_in") @@ -551,9 +571,9 @@ def test_sync_perm_for_dag_creates_permissions_for_specified_roles(app, security security_manager.sync_perm_for_dag( test_dag_id, access_control={test_role: {"can_read", "can_edit"}} ) - assert security_manager.can_read_dag(test_dag_id, user) - assert security_manager.can_edit_dag(test_dag_id, user) - assert not security_manager.can_delete_dag(test_dag_id, user) + assert _can_read_dag(test_dag_id, user) + assert _can_edit_dag(test_dag_id, user) + assert not _can_delete_dag(test_dag_id, user) def test_sync_perm_for_dag_removes_existing_permissions_if_empty(app, security_manager): @@ -581,18 +601,18 @@ def test_sync_perm_for_dag_removes_existing_permissions_if_empty(app, security_m ] ) - assert security_manager.can_read_dag(test_dag_id, user) - assert security_manager.can_edit_dag(test_dag_id, user) - assert security_manager.can_delete_dag(test_dag_id, user) + assert _can_read_dag(test_dag_id, user) + assert _can_edit_dag(test_dag_id, user) + assert _can_delete_dag(test_dag_id, user) # Need to clear cache on user perms user._perms = None security_manager.sync_perm_for_dag(test_dag_id, access_control={test_role: {}}) - assert not security_manager.can_read_dag(test_dag_id, user) - assert not security_manager.can_edit_dag(test_dag_id, user) - assert not security_manager.can_delete_dag(test_dag_id, user) + assert not _can_read_dag(test_dag_id, user) + assert not _can_edit_dag(test_dag_id, user) + assert not _can_delete_dag(test_dag_id, user) def test_sync_perm_for_dag_removes_permissions_from_other_roles(app, security_manager): @@ -621,18 +641,18 @@ def test_sync_perm_for_dag_removes_permissions_from_other_roles(app, security_ma ] ) - assert security_manager.can_read_dag(test_dag_id, user) - assert security_manager.can_edit_dag(test_dag_id, user) - assert security_manager.can_delete_dag(test_dag_id, user) + assert _can_read_dag(test_dag_id, user) + assert _can_edit_dag(test_dag_id, user) + assert _can_delete_dag(test_dag_id, user) # Need to clear cache on user perms user._perms = None security_manager.sync_perm_for_dag(test_dag_id, access_control={"other_role": {"can_read"}}) - assert not security_manager.can_read_dag(test_dag_id, user) - assert not security_manager.can_edit_dag(test_dag_id, user) - assert not security_manager.can_delete_dag(test_dag_id, user) + assert not _can_read_dag(test_dag_id, user) + assert not _can_edit_dag(test_dag_id, user) + assert not _can_delete_dag(test_dag_id, user) def test_sync_perm_for_dag_does_not_prune_roles_when_access_control_unset(app, security_manager): @@ -659,16 +679,16 @@ def test_sync_perm_for_dag_does_not_prune_roles_when_access_control_unset(app, s ] ) - assert security_manager.can_read_dag(test_dag_id, user) - assert security_manager.can_edit_dag(test_dag_id, user) + assert _can_read_dag(test_dag_id, user) + assert _can_edit_dag(test_dag_id, user) # Need to clear cache on user perms user._perms = None security_manager.sync_perm_for_dag(test_dag_id, access_control=None) - assert security_manager.can_read_dag(test_dag_id, user) - assert security_manager.can_edit_dag(test_dag_id, user) + assert _can_read_dag(test_dag_id, user) + assert _can_edit_dag(test_dag_id, user) def test_has_all_dag_access(app, security_manager): @@ -679,7 +699,7 @@ def test_has_all_dag_access(app, security_manager): username="user", role_name=role_name, ) as user: - assert security_manager.has_all_dags_access(user) + assert _has_all_dags_access(user) with app.app_context(): with create_user_scope( @@ -688,7 +708,7 @@ def test_has_all_dag_access(app, security_manager): role_name="read_all", permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG)], ) as user: - assert security_manager.has_all_dags_access(user) + assert _has_all_dags_access(user) with app.app_context(): with create_user_scope( @@ -697,7 +717,7 @@ def test_has_all_dag_access(app, security_manager): role_name="edit_all", permissions=[(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG)], ) as user: - assert security_manager.has_all_dags_access(user) + assert _has_all_dags_access(user) with app.app_context(): with create_user_scope( @@ -706,7 +726,7 @@ def test_has_all_dag_access(app, security_manager): role_name="nada", permissions=[], ) as user: - assert not security_manager.has_all_dags_access(user) + assert not _has_all_dags_access(user) def test_access_control_with_non_existent_role(security_manager): diff --git a/tests/www/views/conftest.py b/tests/www/views/conftest.py index 1bde030388707..b472290c79e65 100644 --- a/tests/www/views/conftest.py +++ b/tests/www/views/conftest.py @@ -173,6 +173,7 @@ def local_context(self): # airflow.www.views.AirflowBaseView.extra_args "macros", "auth_manager", + "DagDetails", ] for key in keys_to_delete: del result[key] From 54becdaadf957bfb188a29aa725d555e7604e5bd Mon Sep 17 00:00:00 2001 From: Vincent Beck Date: Mon, 25 Sep 2023 14:16:39 -0400 Subject: [PATCH 02/31] Fix get_dag_warnings --- airflow/api_connexion/endpoints/dag_warning_endpoint.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/airflow/api_connexion/endpoints/dag_warning_endpoint.py b/airflow/api_connexion/endpoints/dag_warning_endpoint.py index fa1e03ae64a9c..34cf2345e5454 100644 --- a/airflow/api_connexion/endpoints/dag_warning_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_warning_endpoint.py @@ -28,11 +28,12 @@ DagWarningCollection, dag_warning_collection_schema, ) -from airflow.auth.managers.models.resource_details import DagAccessEntity +from airflow.auth.managers.models.resource_details import DagAccessEntity, DagDetails from airflow.models.dagwarning import DagWarning as DagWarningModel from airflow.utils.airflow_flask_app import get_airflow_app from airflow.utils.db import get_query_count from airflow.utils.session import NEW_SESSION, provide_session +from airflow.www.extensions.init_auth_manager import get_auth_manager if TYPE_CHECKING: from sqlalchemy.orm import Session @@ -60,7 +61,7 @@ def get_dag_warnings( allowed_filter_attrs = ["dag_id", "warning_type", "message", "timestamp"] query = select(DagWarningModel) if dag_id: - if not get_airflow_app().appbuilder.sm.can_read_dag(dag_id, g.user): + if not get_auth_manager().is_authorized_dag(method="GET", details=DagDetails(id=dag_id), user=g.user): raise PermissionDenied(detail=f"User not allowed to access this DAG: {dag_id}") query = query.where(DagWarningModel.dag_id == dag_id) else: From ed8d27d3c0d7b6bb54266e4b0e72e657d27c7933 Mon Sep 17 00:00:00 2001 From: Vincent Beck Date: Mon, 25 Sep 2023 15:39:37 -0400 Subject: [PATCH 03/31] Remove test --- tests/www/views/test_views_decorators.py | 44 +----------------------- 1 file changed, 1 insertion(+), 43 deletions(-) diff --git a/tests/www/views/test_views_decorators.py b/tests/www/views/test_views_decorators.py index 80eb588f29916..227193aaf85a1 100644 --- a/tests/www/views/test_views_decorators.py +++ b/tests/www/views/test_views_decorators.py @@ -18,15 +18,13 @@ from __future__ import annotations import urllib.parse -from unittest import mock import pytest -from airflow.models import DagBag, DagRun, TaskInstance, Variable +from airflow.models import DagBag, Variable from airflow.utils import timezone from airflow.utils.state import State from airflow.utils.types import DagRunType -from airflow.www import app from airflow.www.views import action_has_dag_edit_access from tests.test_utils.db import clear_db_runs, clear_db_variables from tests.test_utils.www import _check_last_log, _check_last_log_masked_variable, check_content_in_response @@ -187,46 +185,6 @@ def test_calendar(admin_client, dagruns): check_content_in_response(expected, resp) -@pytest.mark.parametrize( - "class_type, no_instances, no_unique_dags", - [ - (None, 0, 0), - (TaskInstance, 0, 0), - (TaskInstance, 1, 1), - (TaskInstance, 10, 1), - (TaskInstance, 10, 5), - (DagRun, 0, 0), - (DagRun, 1, 1), - (DagRun, 10, 1), - (DagRun, 10, 9), - ], -) -def test_action_has_dag_edit_access(create_task_instance, class_type, no_instances, no_unique_dags): - unique_dag_ids = [f"test_dag_id_{nr}" for nr in range(no_unique_dags)] - tis: list[TaskInstance] = [ - create_task_instance( - task_id=f"test_task_instance_{nr}", - execution_date=timezone.datetime(2021, 1, 1 + nr), - dag_id=unique_dag_ids[nr % len(unique_dag_ids)], - run_id=f"test_run_id_{nr}", - ) - for nr in range(no_instances) - ] - if class_type is None: - test_items = None - else: - test_items = tis if class_type == TaskInstance else [ti.get_dagrun() for ti in tis] - test_items = test_items[0] if len(test_items) == 1 else test_items - application = app.create_app(testing=True) - with application.app_context(): - with mock.patch.object(application.appbuilder.sm, "can_edit_dag") as mocked_can_edit: - mocked_can_edit.return_value = True - assert not isinstance(test_items, list) or len(test_items) == no_instances - assert some_view_action_which_requires_dag_edit_access(None, test_items) is True - assert mocked_can_edit.call_count == no_unique_dags - clear_db_runs() - - def test_action_has_dag_edit_access_exception(): with pytest.raises(ValueError): some_view_action_which_requires_dag_edit_access(None, "some_incorrect_value") From 4d6acbfceccf5d810265ae9b837f5200d72a337d Mon Sep 17 00:00:00 2001 From: Vincent Beck Date: Mon, 25 Sep 2023 17:35:10 -0400 Subject: [PATCH 04/31] Fix `_get_root_dag_id` --- airflow/auth/managers/fab/fab_auth_manager.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/airflow/auth/managers/fab/fab_auth_manager.py b/airflow/auth/managers/fab/fab_auth_manager.py index fc2c8a7d4614a..dbf5a7312e64b 100644 --- a/airflow/auth/managers/fab/fab_auth_manager.py +++ b/airflow/auth/managers/fab/fab_auth_manager.py @@ -399,8 +399,7 @@ def _get_root_dag_id(self, dag_id: str) -> str: :meta private: """ if "." in dag_id: - dm = self.security_manager.appbuilder.get_session.scalar( + return self.security_manager.appbuilder.get_session.scalar( select(DagModel.dag_id, DagModel.root_dag_id).where(DagModel.dag_id == dag_id).limit(1) ) - return dm.root_dag_id or dm.dag_id return dag_id From 94f8585d547df6f18c85bcf52f18cf98b17fcdea Mon Sep 17 00:00:00 2001 From: Vincent Beck Date: Tue, 26 Sep 2023 10:09:57 -0400 Subject: [PATCH 05/31] Check GET permissions on DAG level as well --- airflow/auth/managers/fab/fab_auth_manager.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/airflow/auth/managers/fab/fab_auth_manager.py b/airflow/auth/managers/fab/fab_auth_manager.py index dbf5a7312e64b..5458314f1c7ce 100644 --- a/airflow/auth/managers/fab/fab_auth_manager.py +++ b/airflow/auth/managers/fab/fab_auth_manager.py @@ -193,7 +193,7 @@ def is_authorized_dag( entity (e.g. DAG runs). 2. ``dag_access`` is provided which means the user wants to access a sub entity of the DAG (e.g. DAG runs). - a. If ``method`` is GET, then check the user has READ permissions on the sub entity + a. If ``method`` is GET, then check the user has READ permissions on the DAG and the sub entity b. Else, check the user has EDIT permissions on the DAG and ``method`` on the sub entity :param method: The method to authorize. @@ -209,7 +209,9 @@ def is_authorized_dag( resource_type = self._get_fab_resource_type(access_entity) if method == "GET": - return self._is_authorized(method=method, resource_type=resource_type, user=user) + return self._is_authorized_dag( + method="GET", details=details, user=user + ) and self._is_authorized(method=method, resource_type=resource_type, user=user) else: return self._is_authorized_dag( method="PUT", details=details, user=user From 19b52891cc1d1e458af0a712322294ec2af8fb98 Mon Sep 17 00:00:00 2001 From: Vincent Beck Date: Tue, 26 Sep 2023 10:12:54 -0400 Subject: [PATCH 06/31] Add test case --- tests/auth/managers/fab/test_fab_auth_manager.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/auth/managers/fab/test_fab_auth_manager.py b/tests/auth/managers/fab/test_fab_auth_manager.py index c48ceccfa742f..91e3477e4eaeb 100644 --- a/tests/auth/managers/fab/test_fab_auth_manager.py +++ b/tests/auth/managers/fab/test_fab_auth_manager.py @@ -253,6 +253,14 @@ def test_is_authorized(self, api_name, method, user_permissions, expected_result [(ACTION_CAN_READ, RESOURCE_DAG), (ACTION_CAN_READ, RESOURCE_DAG_RUN)], True, ), + # Without read permissions on a specific DAG + ( + "GET", + DagAccessEntity.TASK_INSTANCE, + DagDetails(id="test_dag_id"), + [(ACTION_CAN_READ, RESOURCE_TASK_INSTANCE)], + False, + ), # With read permissions on a specific DAG ( "GET", From ccffa78d2fc34d7666c755f0f4755490e19340a0 Mon Sep 17 00:00:00 2001 From: Vincent Beck Date: Tue, 26 Sep 2023 11:47:42 -0400 Subject: [PATCH 07/31] Fix tests --- airflow/api_connexion/security.py | 2 +- airflow/auth/managers/fab/fab_auth_manager.py | 8 ++++---- airflow/www/views.py | 17 +++++++---------- .../endpoints/test_log_endpoint.py | 1 + .../endpoints/test_xcom_endpoint.py | 4 ---- 5 files changed, 13 insertions(+), 19 deletions(-) diff --git a/airflow/api_connexion/security.py b/airflow/api_connexion/security.py index 4784769c079ce..f549149731c24 100644 --- a/airflow/api_connexion/security.py +++ b/airflow/api_connexion/security.py @@ -143,7 +143,7 @@ def requires_access_dag( def requires_access_decorator(func: T): @wraps(func) def decorated(*args, **kwargs): - dag_id: str | None = kwargs.get("dag_id") + dag_id: str | None = kwargs.get("dag_id") if kwargs.get("dag_id") != "~" else None return _requires_access( is_authorized_callback=lambda: get_auth_manager().is_authorized_dag( method=method, diff --git a/airflow/auth/managers/fab/fab_auth_manager.py b/airflow/auth/managers/fab/fab_auth_manager.py index 5458314f1c7ce..7b335b71cb0a9 100644 --- a/airflow/auth/managers/fab/fab_auth_manager.py +++ b/airflow/auth/managers/fab/fab_auth_manager.py @@ -193,7 +193,8 @@ def is_authorized_dag( entity (e.g. DAG runs). 2. ``dag_access`` is provided which means the user wants to access a sub entity of the DAG (e.g. DAG runs). - a. If ``method`` is GET, then check the user has READ permissions on the DAG and the sub entity + a. If ``method`` is GET, then check the user has READ permissions on the DAG and the sub entity. + However, if not specific DAG is targeted, just check the sub entity b. Else, check the user has EDIT permissions on the DAG and ``method`` on the sub entity :param method: The method to authorize. @@ -209,9 +210,8 @@ def is_authorized_dag( resource_type = self._get_fab_resource_type(access_entity) if method == "GET": - return self._is_authorized_dag( - method="GET", details=details, user=user - ) and self._is_authorized(method=method, resource_type=resource_type, user=user) + dag_level_check = self._is_authorized_dag(method="GET", details=details, user=user) if details and details.id else True + return dag_level_check and self._is_authorized(method=method, resource_type=resource_type, user=user) else: return self._is_authorized_dag( method="PUT", details=details, user=user diff --git a/airflow/www/views.py b/airflow/www/views.py index 40dec21f3a3bf..1620407aec8da 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -901,11 +901,9 @@ def index(self): .unique() .all() ) - user_permissions = g.user.perms - can_create_dag_run = ( - permissions.ACTION_CAN_CREATE, - permissions.RESOURCE_DAG_RUN, - ) in user_permissions + can_create_dag_run = get_auth_manager().is_authorized_dag( + method="POST", access_entity=DagAccessEntity.RUN, user=g.user + ) dataset_triggered_dag_ids = {dag.dag_id for dag in dags if dag.schedule_interval == "Dataset"} if dataset_triggered_dag_ids: @@ -934,7 +932,7 @@ def index(self): import_errors = select(errors.ImportError).order_by(errors.ImportError.id) - if (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG) not in user_permissions: + if not get_auth_manager().is_authorized_dag(method="GET"): # if the user doesn't have access to all DAGs, only display errors from visible DAGs import_errors = import_errors.join( DagModel, DagModel.fileloc == errors.ImportError.filename @@ -977,10 +975,9 @@ def _iter_parsed_moved_data_table_names(): # Second segment is a version marker that we don't need to show. yield segments[-1], table_name - if ( - permissions.ACTION_CAN_ACCESS_MENU, - permissions.RESOURCE_ADMIN_MENU, - ) in user_permissions and conf.getboolean("webserver", "warn_deployment_exposure"): + if get_auth_manager().is_authorized_configuration(method="GET", user=g.user) and conf.getboolean( + "webserver", "warn_deployment_exposure" + ): robots_file_access_count = ( select(Log) .where(Log.event == "robots") diff --git a/tests/api_connexion/endpoints/test_log_endpoint.py b/tests/api_connexion/endpoints/test_log_endpoint.py index e4d61c50607da..08f6a3a8d1a58 100644 --- a/tests/api_connexion/endpoints/test_log_endpoint.py +++ b/tests/api_connexion/endpoints/test_log_endpoint.py @@ -46,6 +46,7 @@ def configured_app(minimal_app_for_api): username="test", role_name="Test", permissions=[ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_LOG), ], ) diff --git a/tests/api_connexion/endpoints/test_xcom_endpoint.py b/tests/api_connexion/endpoints/test_xcom_endpoint.py index cf7716d6448b9..9e175ab488d3e 100644 --- a/tests/api_connexion/endpoints/test_xcom_endpoint.py +++ b/tests/api_connexion/endpoints/test_xcom_endpoint.py @@ -55,8 +55,6 @@ def configured_app(minimal_app_for_api): role_name="Test", permissions=[ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), (permissions.ACTION_CAN_READ, permissions.RESOURCE_XCOM), ], ) @@ -65,8 +63,6 @@ def configured_app(minimal_app_for_api): username="test_granular_permissions", role_name="TestGranularDag", permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), (permissions.ACTION_CAN_READ, permissions.RESOURCE_XCOM), ], ) From 218a33697e5493671fb5fbea680d8ca834c44b40 Mon Sep 17 00:00:00 2001 From: Vincent Beck Date: Tue, 26 Sep 2023 11:50:31 -0400 Subject: [PATCH 08/31] Fix static checks --- airflow/auth/managers/fab/fab_auth_manager.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/airflow/auth/managers/fab/fab_auth_manager.py b/airflow/auth/managers/fab/fab_auth_manager.py index 7b335b71cb0a9..787b75799cb8b 100644 --- a/airflow/auth/managers/fab/fab_auth_manager.py +++ b/airflow/auth/managers/fab/fab_auth_manager.py @@ -210,8 +210,14 @@ def is_authorized_dag( resource_type = self._get_fab_resource_type(access_entity) if method == "GET": - dag_level_check = self._is_authorized_dag(method="GET", details=details, user=user) if details and details.id else True - return dag_level_check and self._is_authorized(method=method, resource_type=resource_type, user=user) + dag_level_check = ( + self._is_authorized_dag(method="GET", details=details, user=user) + if details and details.id + else True + ) + return dag_level_check and self._is_authorized( + method=method, resource_type=resource_type, user=user + ) else: return self._is_authorized_dag( method="PUT", details=details, user=user From 414009b004e122d23890ce61c1ce5be7c86b5b4a Mon Sep 17 00:00:00 2001 From: Vincent Beck Date: Tue, 26 Sep 2023 14:03:26 -0400 Subject: [PATCH 09/31] Refactoring --- airflow/auth/managers/fab/fab_auth_manager.py | 27 +++++++++---------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/airflow/auth/managers/fab/fab_auth_manager.py b/airflow/auth/managers/fab/fab_auth_manager.py index 787b75799cb8b..52262314a956f 100644 --- a/airflow/auth/managers/fab/fab_auth_manager.py +++ b/airflow/auth/managers/fab/fab_auth_manager.py @@ -194,8 +194,9 @@ def is_authorized_dag( 2. ``dag_access`` is provided which means the user wants to access a sub entity of the DAG (e.g. DAG runs). a. If ``method`` is GET, then check the user has READ permissions on the DAG and the sub entity. - However, if not specific DAG is targeted, just check the sub entity - b. Else, check the user has EDIT permissions on the DAG and ``method`` on the sub entity + b. Else, check the user has EDIT permissions on the DAG and ``method`` on the sub entity. + + However, if no specific DAG is targeted, just check the sub entity. :param method: The method to authorize. :param access_entity: The dag access entity. @@ -208,20 +209,16 @@ def is_authorized_dag( else: # Scenario 2 resource_type = self._get_fab_resource_type(access_entity) + dag_method: ResourceMethod = "GET" if method == "GET" else "PUT" + dag_level_check = ( + self._is_authorized_dag(method=dag_method, details=details, user=user) + if details and details.id + else True + ) - if method == "GET": - dag_level_check = ( - self._is_authorized_dag(method="GET", details=details, user=user) - if details and details.id - else True - ) - return dag_level_check and self._is_authorized( - method=method, resource_type=resource_type, user=user - ) - else: - return self._is_authorized_dag( - method="PUT", details=details, user=user - ) and self._is_authorized(method=method, resource_type=resource_type, user=user) + return dag_level_check and self._is_authorized( + method=method, resource_type=resource_type, user=user + ) def is_authorized_dataset( self, *, method: ResourceMethod, details: DatasetDetails | None = None, user: BaseUser | None = None From cee7e5fc49c4760de5ff803fe767d4726c381f30 Mon Sep 17 00:00:00 2001 From: Vincent Beck Date: Wed, 27 Sep 2023 10:51:43 -0400 Subject: [PATCH 10/31] Address feedbacks --- airflow/auth/managers/fab/fab_auth_manager.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/airflow/auth/managers/fab/fab_auth_manager.py b/airflow/auth/managers/fab/fab_auth_manager.py index 52262314a956f..4f7eb9814387e 100644 --- a/airflow/auth/managers/fab/fab_auth_manager.py +++ b/airflow/auth/managers/fab/fab_auth_manager.py @@ -210,15 +210,13 @@ def is_authorized_dag( # Scenario 2 resource_type = self._get_fab_resource_type(access_entity) dag_method: ResourceMethod = "GET" if method == "GET" else "PUT" - dag_level_check = ( - self._is_authorized_dag(method=dag_method, details=details, user=user) - if details and details.id - else True - ) - return dag_level_check and self._is_authorized( - method=method, resource_type=resource_type, user=user - ) + if (details and details.id) and not self._is_authorized_dag( + method=dag_method, details=details, user=user + ): + return False + + return self._is_authorized(method=method, resource_type=resource_type, user=user) def is_authorized_dataset( self, *, method: ResourceMethod, details: DatasetDetails | None = None, user: BaseUser | None = None From f04bb893bcc9949834f35c064190632b20572b7e Mon Sep 17 00:00:00 2001 From: Vincent Beck Date: Wed, 27 Sep 2023 13:25:04 -0400 Subject: [PATCH 11/31] Fix test --- .../endpoints/test_event_log_endpoint.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/tests/api_connexion/endpoints/test_event_log_endpoint.py b/tests/api_connexion/endpoints/test_event_log_endpoint.py index ad154e57fe8ee..d1e8af2facd51 100644 --- a/tests/api_connexion/endpoints/test_event_log_endpoint.py +++ b/tests/api_connexion/endpoints/test_event_log_endpoint.py @@ -36,11 +36,26 @@ def configured_app(minimal_app_for_api): role_name="Test", permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_AUDIT_LOG)], # type: ignore ) + create_user( + app, # type:ignore + username="test_granular", + role_name="TestGranular", + permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_AUDIT_LOG)], # type: ignore + ) + app.appbuilder.sm.sync_perm_for_dag( # type: ignore + "TEST_DAG_ID_1", + access_control={"TestGranular": [permissions.ACTION_CAN_READ]}, + ) + app.appbuilder.sm.sync_perm_for_dag( # type: ignore + "TEST_DAG_ID_2", + access_control={"TestGranular": [permissions.ACTION_CAN_READ]}, + ) create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore yield app delete_user(app, username="test") # type: ignore + delete_user(app, username="test_granular") # type: ignore delete_user(app, username="test_no_permissions") # type: ignore @@ -253,7 +268,7 @@ def test_should_filter_eventlogs_by_allowed_attributes(self, create_log_model, s for attr in ["dag_id", "task_id", "owner", "event"]: attr_value = f"TEST_{attr}_1".upper() response = self.client.get( - f"/api/v1/eventLogs?{attr}={attr_value}", environ_overrides={"REMOTE_USER": "test"} + f"/api/v1/eventLogs?{attr}={attr_value}", environ_overrides={"REMOTE_USER": "test_granular"} ) assert response.status_code == 200 assert {eventlog[attr] for eventlog in response.json["event_logs"]} == {attr_value} From 1b33953c17c0da7f940ae82c0a61b766d7f9441d Mon Sep 17 00:00:00 2001 From: Vincent Beck Date: Thu, 28 Sep 2023 11:13:08 -0400 Subject: [PATCH 12/31] Minor feedbacks --- airflow/api_connexion/security.py | 2 +- airflow/auth/managers/base_auth_manager.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/airflow/api_connexion/security.py b/airflow/api_connexion/security.py index f549149731c24..47c7050e3f27d 100644 --- a/airflow/api_connexion/security.py +++ b/airflow/api_connexion/security.py @@ -75,7 +75,7 @@ def _requires_access(*, is_authorized_callback: Callable[[], bool], func: Callab Define the behavior whether the user is authorized to access the resource. :param is_authorized_callback: callback to execute to figure whether the user is authorized to access - the resource? + the resource :param func: the function to call if the user is authorized :param args: the arguments of ``func`` :param kwargs: the keyword arguments ``func`` diff --git a/airflow/auth/managers/base_auth_manager.py b/airflow/auth/managers/base_auth_manager.py index 71a6a94fdc235..372125f51732b 100644 --- a/airflow/auth/managers/base_auth_manager.py +++ b/airflow/auth/managers/base_auth_manager.py @@ -93,7 +93,7 @@ def is_authorized_configuration( Return whether the user is authorized to perform a given action on configuration. :param method: the method to perform - :param details: optional details about the connection + :param details: optional details about the configuration :param user: the user to perform the action on. If not provided (or None), it uses the current user """ @@ -158,7 +158,7 @@ def is_authorized_dataset( Return whether the user is authorized to perform a given action on a dataset. :param method: the method to perform - :param details: optional details about the variable + :param details: optional details about the dataset :param user: the user to perform the action on. If not provided (or None), it uses the current user """ @@ -174,7 +174,7 @@ def is_authorized_pool( Return whether the user is authorized to perform a given action on a pool. :param method: the method to perform - :param details: optional details about the variable + :param details: optional details about the pool :param user: the user to perform the action on. If not provided (or None), it uses the current user """ From 21f92047cce224a6d87dc5bcec4b3d24bc84da21 Mon Sep 17 00:00:00 2001 From: Vincent Beck Date: Fri, 29 Sep 2023 11:25:52 -0400 Subject: [PATCH 13/31] When doing authorization check on DAGs in general, check if the user has permission to access at least one DAG --- .../api_connexion/endpoints/dag_endpoint.py | 7 ++- airflow/api_connexion/security.py | 33 +++++++++--- .../managers/fab/security_manager/override.py | 13 ++++- airflow/www/security_manager.py | 51 +++++++------------ airflow/www/views.py | 16 +++--- .../endpoints/test_dag_endpoint.py | 10 ++-- tests/www/test_security.py | 6 +-- 7 files changed, 74 insertions(+), 62 deletions(-) diff --git a/airflow/api_connexion/endpoints/dag_endpoint.py b/airflow/api_connexion/endpoints/dag_endpoint.py index a67699feb2b79..24ddb84ac6482 100644 --- a/airflow/api_connexion/endpoints/dag_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_endpoint.py @@ -34,7 +34,6 @@ dag_schema, dags_collection_schema, ) -from airflow.api_connexion.security import requires_authentication from airflow.exceptions import AirflowException, DagNotFound from airflow.models.dag import DagModel, DagTag from airflow.utils.airflow_flask_app import get_airflow_app @@ -69,7 +68,7 @@ def get_dag_details(*, dag_id: str) -> APIResponse: return dag_detail_schema.dump(dag) -@requires_authentication +@security.requires_access_dag("GET") @format_parameters({"limit": check_limit}) @provide_session def get_dags( @@ -96,7 +95,7 @@ def get_dags( if dag_id_pattern: dags_query = dags_query.where(DagModel.dag_id.ilike(f"%{dag_id_pattern}%")) - readable_dags = get_airflow_app().appbuilder.sm.get_permitted_dag_ids(g.user) + readable_dags = get_airflow_app().appbuilder.sm.get_permitted_dag_ids(user=g.user) dags_query = dags_query.where(DagModel.dag_id.in_(readable_dags)) if tags: @@ -132,7 +131,7 @@ def patch_dag(*, dag_id: str, update_mask: UpdateMask = None, session: Session = return dag_schema.dump(dag) -@security.requires_authentication +@security.requires_access_dag("PUT") @format_parameters({"limit": check_limit}) @provide_session def patch_dags(limit, session, offset=0, only_active=True, tags=None, dag_id_pattern=None, update_mask=None): diff --git a/airflow/api_connexion/security.py b/airflow/api_connexion/security.py index 47c7050e3f27d..7df8d4549129e 100644 --- a/airflow/api_connexion/security.py +++ b/airflow/api_connexion/security.py @@ -70,7 +70,7 @@ def decorated(*args, **kwargs): return requires_access_decorator -def _requires_access(*, is_authorized_callback: Callable[[], bool], func: Callable, args, kwargs): +def _requires_access(*, is_authorized_callback: Callable[[], bool], func: Callable, args, kwargs) -> bool: """ Define the behavior whether the user is authorized to access the resource. @@ -140,16 +140,37 @@ def decorated(*args, **kwargs): def requires_access_dag( method: ResourceMethod, access_entity: DagAccessEntity | None = None ) -> Callable[[T], T]: + appbuilder = get_airflow_app().appbuilder + + def _is_authorized_callback(dag_id: str): + def callback(): + access = get_auth_manager().is_authorized_dag( + method=method, + access_entity=access_entity, + details=DagDetails(id=dag_id), + ) + + # ``access`` means here: + # - if a DAG id is provided (``dag_id`` not None): is the user authorized to access this DAG + # - if no DAG id is provided: is the user authorized to access all DAGs + if dag_id or access: + return access + + # No DAG id is provided and the user is not authorized to access all DAGs + # If method is "GET", return whether the user has read access to any DAGs + # If method is "PUT", return whether the user has edit access to any DAGs + return (method == "GET" and any(appbuilder.sm.get_readable_dag_ids())) or ( + method == "PUT" and any(appbuilder.sm.get_editable_dag_ids()) + ) + + return callback + def requires_access_decorator(func: T): @wraps(func) def decorated(*args, **kwargs): dag_id: str | None = kwargs.get("dag_id") if kwargs.get("dag_id") != "~" else None return _requires_access( - is_authorized_callback=lambda: get_auth_manager().is_authorized_dag( - method=method, - access_entity=access_entity, - details=DagDetails(id=dag_id), - ), + is_authorized_callback=_is_authorized_callback(dag_id), func=func, args=args, kwargs=kwargs, diff --git a/airflow/auth/managers/fab/security_manager/override.py b/airflow/auth/managers/fab/security_manager/override.py index e4d9cd2af33c9..7bfaafc0945b7 100644 --- a/airflow/auth/managers/fab/security_manager/override.py +++ b/airflow/auth/managers/fab/security_manager/override.py @@ -567,7 +567,18 @@ def get_accessible_dag_ids( for action in fab_action_name_to_method_name if action in user_actions ] - return self.get_permitted_dag_ids(user=user, user_methods=user_methods, session=session) + return self.get_permitted_dag_ids(user=user, methods=user_methods, session=session) + + def can_access_some_dags(self, action: str, dag_id: str | None = None) -> bool: + """Checks if user has read or write access to some dags.""" + if dag_id and dag_id != "~": + root_dag_id = self._get_root_dag_id(dag_id) + return self.has_access(action, permissions.resource_name_for_dag(root_dag_id)) + + user = g.user + if action == permissions.ACTION_CAN_READ: + return any(self.get_readable_dag_ids(user)) + return any(self.get_editable_dag_ids(user)) """ ----------- diff --git a/airflow/www/security_manager.py b/airflow/www/security_manager.py index 0281e639cf6af..ba80780227587 100644 --- a/airflow/www/security_manager.py +++ b/airflow/www/security_manager.py @@ -271,29 +271,30 @@ def get_user_roles(user=None): user = g.user return user.roles - def get_readable_dag_ids(self, user) -> set[str]: + def get_readable_dag_ids(self, user=None) -> set[str]: """Gets the DAG IDs readable by authenticated user.""" - return self.get_permitted_dag_ids(user, ["GET"]) + return self.get_permitted_dag_ids(methods=["GET"], user=user) - def get_editable_dag_ids(self, user) -> set[str]: + def get_editable_dag_ids(self, user=None) -> set[str]: """Gets the DAG IDs editable by authenticated user.""" - return self.get_permitted_dag_ids(user, ["PUT"]) + return self.get_permitted_dag_ids(methods=["PUT"], user=user) @provide_session def get_permitted_dag_ids( self, - user, - user_methods: Container[ResourceMethod] | None = None, + *, + methods: Container[ResourceMethod] | None = None, + user=None, session: Session = NEW_SESSION, ) -> set[str]: """Generic function to get readable or writable DAGs for user.""" - if not user_methods: - user_methods = ["PUT", "GET"] + if not methods: + methods = ["PUT", "GET"] dag_ids = {dag.dag_id for dag in session.execute(select(DagModel.dag_id))} - if ("GET" in user_methods and get_auth_manager().is_authorized_dag(method="GET", user=user)) or ( - "PUT" in user_methods and get_auth_manager().is_authorized_dag(method="PUT", user=user) + if ("GET" in methods and get_auth_manager().is_authorized_dag(method="GET", user=user)) or ( + "PUT" in methods and get_auth_manager().is_authorized_dag(method="PUT", user=user) ): return dag_ids @@ -301,29 +302,22 @@ def get_permitted_dag_ids( dag_id for dag_id in dag_ids if ( - "GET" in user_methods + "GET" in methods and get_auth_manager().is_authorized_dag( method="GET", details=DagDetails(id=dag_id), user=user ) ) or ( - "PUT" in user_methods + "PUT" in methods and get_auth_manager().is_authorized_dag( method="PUT", details=DagDetails(id=dag_id), user=user ) ) } - def can_access_some_dags(self, action: str, dag_id: str | None = None) -> bool: - """Checks if user has read or write access to some dags.""" - if dag_id and dag_id != "~": - root_dag_id = self._get_root_dag_id(dag_id) - return self.has_access(action, permissions.resource_name_for_dag(root_dag_id)) - - user = g.user - if action == permissions.ACTION_CAN_READ: - return any(self.get_readable_dag_ids(user)) - return any(self.get_editable_dag_ids(user)) + def can_access_dags(self, user) -> bool: + """Checks if user has read access to some dags.""" + return any(self.get_readable_dag_ids(user)) def prefixed_dag_id(self, dag_id: str) -> str: """Returns the permission name for a DAG id.""" @@ -645,24 +639,13 @@ def create_perm_vm_for_all_dag(self) -> None: def check_authorization( self, perms: Sequence[tuple[str, str]] | None = None, - dag_id: str | None = None, ) -> bool: """Checks that the logged in user has the specified permissions.""" if not perms: return True for perm in perms: - if perm in ( - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG), - ): - can_access_all_dags = self.has_access(*perm) - if not can_access_all_dags: - action = perm[0] - if not self.can_access_some_dags(action, dag_id): - return False - elif not self.has_access(*perm): + if not self.has_access(*perm): return False return True diff --git a/airflow/www/views.py b/airflow/www/views.py index 41b0040f95c15..a0f484afccfa5 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -775,7 +775,7 @@ def index(self): end = start + dags_per_page # Get all the dag id the user could access - filter_dag_ids = get_airflow_app().appbuilder.sm.get_permitted_dag_ids(g.user) + filter_dag_ids = get_airflow_app().appbuilder.sm.get_permitted_dag_ids(user=g.user) with create_session() as session: # read orm_dags from the db @@ -1066,7 +1066,7 @@ def cluster_activity(self): @provide_session def next_run_datasets_summary(self, session: Session = NEW_SESSION): """Next run info for dataset triggered DAGs.""" - allowed_dag_ids = get_airflow_app().appbuilder.sm.get_permitted_dag_ids(g.user) + allowed_dag_ids = get_airflow_app().appbuilder.sm.get_permitted_dag_ids(user=g.user) if not allowed_dag_ids: return flask.json.jsonify({}) @@ -1101,7 +1101,7 @@ def next_run_datasets_summary(self, session: Session = NEW_SESSION): @provide_session def dag_stats(self, session: Session = NEW_SESSION): """Dag statistics.""" - allowed_dag_ids = get_airflow_app().appbuilder.sm.get_permitted_dag_ids(g.user) + allowed_dag_ids = get_airflow_app().appbuilder.sm.get_permitted_dag_ids(user=g.user) # Filter by post parameters selected_dag_ids = {unquote(dag_id) for dag_id in request.form.getlist("dag_ids") if dag_id} @@ -1133,7 +1133,7 @@ def dag_stats(self, session: Session = NEW_SESSION): @provide_session def task_stats(self, session: Session = NEW_SESSION): """Task Statistics.""" - allowed_dag_ids = get_airflow_app().appbuilder.sm.get_permitted_dag_ids(g.user) + allowed_dag_ids = get_airflow_app().appbuilder.sm.get_permitted_dag_ids(user=g.user) if not allowed_dag_ids: return flask.json.jsonify({}) @@ -1234,7 +1234,7 @@ def task_stats(self, session: Session = NEW_SESSION): @provide_session def last_dagruns(self, session: Session = NEW_SESSION): """Last DAG runs.""" - allowed_dag_ids = get_airflow_app().appbuilder.sm.get_permitted_dag_ids(g.user) + allowed_dag_ids = get_airflow_app().appbuilder.sm.get_permitted_dag_ids(user=g.user) # Filter by post parameters selected_dag_ids = {unquote(dag_id) for dag_id in request.form.getlist("dag_ids") if dag_id} @@ -2341,7 +2341,7 @@ def dagrun_clear(self, *, session: Session = NEW_SESSION): @provide_session def blocked(self, session: Session = NEW_SESSION): """Mark Dag Blocked.""" - allowed_dag_ids = get_airflow_app().appbuilder.sm.get_permitted_dag_ids(g.user) + allowed_dag_ids = get_airflow_app().appbuilder.sm.get_permitted_dag_ids(user=g.user) # Filter by post parameters selected_dag_ids = {unquote(dag_id) for dag_id in request.form.getlist("dag_ids") if dag_id} @@ -3912,7 +3912,7 @@ def apply(self, query, func): method="GET", user=g.user ) or get_auth_manager().is_authorized_dag(method="PUT", user=g.user): return query - filter_dag_ids = get_airflow_app().appbuilder.sm.get_permitted_dag_ids(g.user) + filter_dag_ids = get_airflow_app().appbuilder.sm.get_permitted_dag_ids(user=g.user) return query.where(self.model.dag_id.in_(filter_dag_ids)) @@ -5731,7 +5731,7 @@ def autocomplete(self, session: Session = NEW_SESSION): dag_ids_query = dag_ids_query.where(DagModel.is_paused) owners_query = owners_query.where(DagModel.is_paused) - filter_dag_ids = get_airflow_app().appbuilder.sm.get_permitted_dag_ids(g.user) + filter_dag_ids = get_airflow_app().appbuilder.sm.get_permitted_dag_ids(user=g.user) dag_ids_query = dag_ids_query.where(DagModel.dag_id.in_(filter_dag_ids)) owners_query = owners_query.where(DagModel.dag_id.in_(filter_dag_ids)) diff --git a/tests/api_connexion/endpoints/test_dag_endpoint.py b/tests/api_connexion/endpoints/test_dag_endpoint.py index 6dc59f2f8c8d2..06ac5bbef85fd 100644 --- a/tests/api_connexion/endpoints/test_dag_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_endpoint.py @@ -815,13 +815,12 @@ def test_should_raises_401_unauthenticated(self): assert_401(response) - def test_should_return_empty_list(self): + def test_should_respond_403_unauthorized(self): self._create_dag_models(1) response = self.client.get("api/v1/dags", environ_overrides={"REMOTE_USER": "test_no_permissions"}) - assert response.status_code == 200 - assert {"dags": [], "total_entries": 0} == response.json + assert response.status_code == 403 def test_paused_true_returns_paused_dags(self, url_safe_serializer): self._create_dag_models(1, dag_id_prefix="TEST_DAG_PAUSED", is_paused=True) @@ -1540,7 +1539,7 @@ def test_should_raises_401_unauthenticated(self): assert_401(response) - def test_should_return_empty_list(self): + def test_should_respond_403_unauthorized(self): self._create_dag_models(1) response = self.client.patch( "api/v1/dags?dag_id_pattern=~", @@ -1550,8 +1549,7 @@ def test_should_return_empty_list(self): environ_overrides={"REMOTE_USER": "test_no_permissions"}, ) - assert response.status_code == 200 - assert {"dags": [], "total_entries": 0} == response.json + assert response.status_code == 403 def test_should_respond_200_and_pause_dags(self, url_safe_serializer): file_token = url_safe_serializer.dumps("/tmp/dag_1.py") diff --git a/tests/www/test_security.py b/tests/www/test_security.py index 60157303475c4..e24d820f63c7c 100644 --- a/tests/www/test_security.py +++ b/tests/www/test_security.py @@ -341,7 +341,7 @@ def test_verify_default_anon_user_has_no_accessible_dag_ids( with _create_dag_model_context("test_dag_id", session, security_manager): security_manager.sync_roles() - assert security_manager.get_permitted_dag_ids(user) == set() + assert security_manager.get_permitted_dag_ids(user=user) == set() def test_verify_default_anon_user_has_no_access_to_specific_dag(app, session, security_manager, has_dag_perm): @@ -379,7 +379,7 @@ def test_verify_anon_user_with_admin_role_has_all_dag_access( security_manager.sync_roles() - assert security_manager.get_permitted_dag_ids(user) == set(test_dag_ids) + assert security_manager.get_permitted_dag_ids(user=user) == set(test_dag_ids) def test_verify_anon_user_with_admin_role_has_access_to_each_dag( @@ -507,7 +507,7 @@ def test_get_accessible_dag_ids(mock_is_logged_in, app, security_manager, sessio dag_id, access_control={role_name: permission_action} ) - assert security_manager.get_permitted_dag_ids(user) == {"dag_id"} + assert security_manager.get_permitted_dag_ids(user=user) == {"dag_id"} @patch.object(FabAuthManager, "is_logged_in") From 1e6b9af84502c3fbc9d3824c0a6dee5756e90472 Mon Sep 17 00:00:00 2001 From: Vincent Beck Date: Fri, 29 Sep 2023 13:01:56 -0400 Subject: [PATCH 14/31] Fix tests --- airflow/api_connexion/security.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/api_connexion/security.py b/airflow/api_connexion/security.py index 7df8d4549129e..b22e3a272ba97 100644 --- a/airflow/api_connexion/security.py +++ b/airflow/api_connexion/security.py @@ -61,7 +61,7 @@ def requires_access_decorator(func: T): @wraps(func) def decorated(*args, **kwargs): check_authentication() - if appbuilder.sm.check_authorization(permissions, kwargs.get("dag_id")): + if appbuilder.sm.check_authorization(permissions): return func(*args, **kwargs) raise PermissionDenied() From 4a7df9b0233d721019acdb1cab7eea07e52d00c8 Mon Sep 17 00:00:00 2001 From: Vincent Beck Date: Thu, 5 Oct 2023 14:57:14 -0400 Subject: [PATCH 15/31] Address feedbacks --- airflow/api_connexion/endpoints/dag_run_endpoint.py | 2 +- airflow/auth/managers/fab/fab_auth_manager.py | 1 - airflow/auth/managers/models/resource_details.py | 1 - airflow/www/views.py | 3 ++- 4 files changed, 3 insertions(+), 4 deletions(-) diff --git a/airflow/api_connexion/endpoints/dag_run_endpoint.py b/airflow/api_connexion/endpoints/dag_run_endpoint.py index 06b288754fc12..8b821dcd204a6 100644 --- a/airflow/api_connexion/endpoints/dag_run_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_run_endpoint.py @@ -103,7 +103,7 @@ def get_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION) @security.requires_access_dag("GET", DagAccessEntity.RUN) -@security.requires_access_dag("GET", DagAccessEntity.DATASET) +@security.requires_access_dataset("GET") @provide_session def get_upstream_dataset_events( *, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION diff --git a/airflow/auth/managers/fab/fab_auth_manager.py b/airflow/auth/managers/fab/fab_auth_manager.py index 4f7eb9814387e..2e19b6705f9ff 100644 --- a/airflow/auth/managers/fab/fab_auth_manager.py +++ b/airflow/auth/managers/fab/fab_auth_manager.py @@ -89,7 +89,6 @@ _MAP_DAG_ACCESS_ENTITY_TO_FAB_RESOURCE_TYPE = { DagAccessEntity.AUDIT_LOG: RESOURCE_AUDIT_LOG, DagAccessEntity.CODE: RESOURCE_DAG_CODE, - DagAccessEntity.DATASET: RESOURCE_DATASET, DagAccessEntity.DEPENDENCIES: RESOURCE_DAG_DEPENDENCIES, DagAccessEntity.IMPORT_ERRORS: RESOURCE_IMPORT_ERROR, DagAccessEntity.RUN: RESOURCE_DAG_RUN, diff --git a/airflow/auth/managers/models/resource_details.py b/airflow/auth/managers/models/resource_details.py index 9d424bead2aad..70e3c41457506 100644 --- a/airflow/auth/managers/models/resource_details.py +++ b/airflow/auth/managers/models/resource_details.py @@ -68,7 +68,6 @@ class DagAccessEntity(Enum): AUDIT_LOG = "AUDIT_LOG" CODE = "CODE" - DATASET = "DATASET" DEPENDENCIES = "DEPENDENCIES" IMPORT_ERRORS = "IMPORT_ERRORS" RUN = "RUN" diff --git a/airflow/www/views.py b/airflow/www/views.py index a0f484afccfa5..ab78348f44b6f 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -3565,7 +3565,8 @@ def historical_metrics_data(self): ) @expose("/object/next_run_datasets/") - @auth.has_access_dag("GET", DagAccessEntity.DATASET) + @auth.has_access_dag("GET", DagAccessEntity.RUN) + @auth.has_access_dataset("GET") def next_run_datasets(self, dag_id): """Returns datasets necessary, and their status, for the next dag run.""" dag = get_airflow_app().dag_bag.get_dag(dag_id) From fd2cb8adecb616afc86a301bb5408d9803107643 Mon Sep 17 00:00:00 2001 From: Vincent Beck Date: Thu, 5 Oct 2023 17:01:37 -0400 Subject: [PATCH 16/31] Apply suggestions --- airflow/www/extensions/init_jinja_globals.py | 8 ++------ airflow/www/security_manager.py | 6 +----- airflow/www/views.py | 6 +++--- 3 files changed, 6 insertions(+), 14 deletions(-) diff --git a/airflow/www/extensions/init_jinja_globals.py b/airflow/www/extensions/init_jinja_globals.py index 0c3521882a882..80e377378ca9c 100644 --- a/airflow/www/extensions/init_jinja_globals.py +++ b/airflow/www/extensions/init_jinja_globals.py @@ -74,12 +74,8 @@ def prepare_jinja_globals(): } # Extra global specific to auth manager - extra_globals.update( - { - "auth_manager": get_auth_manager(), - "DagDetails": DagDetails, - } - ) + extra_globals["auth_manager"] = get_auth_manager() + extra_globals["DagDetails"] = DagDetails backends = conf.get("api", "auth_backends") if backends and backends[0] != "airflow.api.auth.backend.deny_all": diff --git a/airflow/www/security_manager.py b/airflow/www/security_manager.py index ba80780227587..b7887ab767d9e 100644 --- a/airflow/www/security_manager.py +++ b/airflow/www/security_manager.py @@ -644,8 +644,4 @@ def check_authorization( if not perms: return True - for perm in perms: - if not self.has_access(*perm): - return False - - return True + return all(self.has_access(*perm) for perm in perms) diff --git a/airflow/www/views.py b/airflow/www/views.py index ab78348f44b6f..fb122e901c2cc 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -3909,9 +3909,9 @@ class DagFilter(BaseFilter): """Filter using DagIDs.""" def apply(self, query, func): - if get_auth_manager().is_authorized_dag( - method="GET", user=g.user - ) or get_auth_manager().is_authorized_dag(method="PUT", user=g.user): + if get_auth_manager().is_authorized_dag(method="GET", user=g.user): + return query + if get_auth_manager().is_authorized_dag(method="PUT", user=g.user): return query filter_dag_ids = get_airflow_app().appbuilder.sm.get_permitted_dag_ids(user=g.user) return query.where(self.model.dag_id.in_(filter_dag_ids)) From a670ceed8586220373409fd13296b76444fd9a44 Mon Sep 17 00:00:00 2001 From: Vincent Beck Date: Thu, 5 Oct 2023 17:13:16 -0400 Subject: [PATCH 17/31] Apply suggestions --- airflow/www/security_manager.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/airflow/www/security_manager.py b/airflow/www/security_manager.py index a4f27d6542cd4..a7fc690df30d2 100644 --- a/airflow/www/security_manager.py +++ b/airflow/www/security_manager.py @@ -299,21 +299,15 @@ def get_permitted_dag_ids( ): return dag_ids + def _is_permitted_dag_id(method: ResourceMethod, methods: Container[ResourceMethod], dag_id: str): + return method in methods and get_auth_manager().is_authorized_dag( + method=method, details=DagDetails(id=dag_id), user=user + ) + return { dag_id for dag_id in dag_ids - if ( - "GET" in methods - and get_auth_manager().is_authorized_dag( - method="GET", details=DagDetails(id=dag_id), user=user - ) - ) - or ( - "PUT" in methods - and get_auth_manager().is_authorized_dag( - method="PUT", details=DagDetails(id=dag_id), user=user - ) - ) + if _is_permitted_dag_id("GET", methods, dag_id) or _is_permitted_dag_id("PUT", methods, dag_id) } def can_access_dags(self, user) -> bool: From 9fff28e37152405ff77a701096b3ecc565f3f7db Mon Sep 17 00:00:00 2001 From: Vincent Beck Date: Wed, 11 Oct 2023 14:29:59 -0400 Subject: [PATCH 18/31] Add back check permissions on dag run for task_instance_endpoint --- .../api_connexion/endpoints/task_instance_endpoint.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/airflow/api_connexion/endpoints/task_instance_endpoint.py b/airflow/api_connexion/endpoints/task_instance_endpoint.py index 6aaddf084f3f1..eb37948c682cd 100644 --- a/airflow/api_connexion/endpoints/task_instance_endpoint.py +++ b/airflow/api_connexion/endpoints/task_instance_endpoint.py @@ -61,6 +61,7 @@ T = TypeVar("T") +@security.requires_access_dag("GET", DagAccessEntity.RUN) @security.requires_access_dag("GET", DagAccessEntity.TASK_INSTANCE) @provide_session def get_task_instance( @@ -103,6 +104,7 @@ def get_task_instance( return task_instance_schema.dump(task_instance) +@security.requires_access_dag("GET", DagAccessEntity.RUN) @security.requires_access_dag("GET", DagAccessEntity.TASK_INSTANCE) @provide_session def get_mapped_task_instance( @@ -149,6 +151,7 @@ def get_mapped_task_instance( "updated_at_lte": format_datetime, }, ) +@security.requires_access_dag("GET", DagAccessEntity.RUN) @security.requires_access_dag("GET", DagAccessEntity.TASK_INSTANCE) @provide_session def get_mapped_task_instances( @@ -287,6 +290,7 @@ def _apply_range_filter(query: Select, key: ClauseElement, value_range: tuple[T, "updated_at_lte": format_datetime, }, ) +@security.requires_access_dag("GET", DagAccessEntity.RUN) @security.requires_access_dag("GET", DagAccessEntity.TASK_INSTANCE) @provide_session def get_task_instances( @@ -362,6 +366,7 @@ def get_task_instances( ) +@security.requires_access_dag("GET", DagAccessEntity.RUN) @security.requires_access_dag("GET", DagAccessEntity.TASK_INSTANCE) @provide_session def get_task_instances_batch(session: Session = NEW_SESSION) -> APIResponse: @@ -431,6 +436,7 @@ def get_task_instances_batch(session: Session = NEW_SESSION) -> APIResponse: ) +@security.requires_access_dag("PUT", DagAccessEntity.RUN) @security.requires_access_dag("PUT", DagAccessEntity.TASK_INSTANCE) @provide_session def post_clear_task_instances(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse: @@ -491,6 +497,7 @@ def post_clear_task_instances(*, dag_id: str, session: Session = NEW_SESSION) -> ) +@security.requires_access_dag("PUT", DagAccessEntity.RUN) @security.requires_access_dag("PUT", DagAccessEntity.TASK_INSTANCE) @provide_session def post_set_task_instances_state(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse: @@ -558,6 +565,7 @@ def set_mapped_task_instance_note( return set_task_instance_note(dag_id=dag_id, dag_run_id=dag_run_id, task_id=task_id, map_index=map_index) +@security.requires_access_dag("PUT", DagAccessEntity.RUN) @security.requires_access_dag("PUT", DagAccessEntity.TASK_INSTANCE) @provide_session def patch_task_instance( @@ -598,6 +606,7 @@ def patch_task_instance( return task_instance_reference_schema.dump(ti) +@security.requires_access_dag("PUT", DagAccessEntity.RUN) @security.requires_access_dag("PUT", DagAccessEntity.TASK_INSTANCE) @provide_session def patch_mapped_task_instance( @@ -609,6 +618,7 @@ def patch_mapped_task_instance( ) +@security.requires_access_dag("PUT", DagAccessEntity.RUN) @security.requires_access_dag("PUT", DagAccessEntity.TASK_INSTANCE) @provide_session def set_task_instance_note( From 8f9f9efb471591ccd0f151470b94c528916bf90a Mon Sep 17 00:00:00 2001 From: Vincent Beck Date: Wed, 11 Oct 2023 17:39:21 -0400 Subject: [PATCH 19/31] Fix permissions --- .../api_connexion/endpoints/task_instance_endpoint.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/airflow/api_connexion/endpoints/task_instance_endpoint.py b/airflow/api_connexion/endpoints/task_instance_endpoint.py index eb37948c682cd..ad4cd4381a8eb 100644 --- a/airflow/api_connexion/endpoints/task_instance_endpoint.py +++ b/airflow/api_connexion/endpoints/task_instance_endpoint.py @@ -436,7 +436,7 @@ def get_task_instances_batch(session: Session = NEW_SESSION) -> APIResponse: ) -@security.requires_access_dag("PUT", DagAccessEntity.RUN) +@security.requires_access_dag("GET", DagAccessEntity.RUN) @security.requires_access_dag("PUT", DagAccessEntity.TASK_INSTANCE) @provide_session def post_clear_task_instances(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse: @@ -497,7 +497,7 @@ def post_clear_task_instances(*, dag_id: str, session: Session = NEW_SESSION) -> ) -@security.requires_access_dag("PUT", DagAccessEntity.RUN) +@security.requires_access_dag("GET", DagAccessEntity.RUN) @security.requires_access_dag("PUT", DagAccessEntity.TASK_INSTANCE) @provide_session def post_set_task_instances_state(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse: @@ -565,7 +565,7 @@ def set_mapped_task_instance_note( return set_task_instance_note(dag_id=dag_id, dag_run_id=dag_run_id, task_id=task_id, map_index=map_index) -@security.requires_access_dag("PUT", DagAccessEntity.RUN) +@security.requires_access_dag("GET", DagAccessEntity.RUN) @security.requires_access_dag("PUT", DagAccessEntity.TASK_INSTANCE) @provide_session def patch_task_instance( @@ -606,7 +606,7 @@ def patch_task_instance( return task_instance_reference_schema.dump(ti) -@security.requires_access_dag("PUT", DagAccessEntity.RUN) +@security.requires_access_dag("GET", DagAccessEntity.RUN) @security.requires_access_dag("PUT", DagAccessEntity.TASK_INSTANCE) @provide_session def patch_mapped_task_instance( @@ -618,7 +618,7 @@ def patch_mapped_task_instance( ) -@security.requires_access_dag("PUT", DagAccessEntity.RUN) +@security.requires_access_dag("GET", DagAccessEntity.RUN) @security.requires_access_dag("PUT", DagAccessEntity.TASK_INSTANCE) @provide_session def set_task_instance_note( From b6e1fbaf33ad42288ff28606923a70a1a4c8be28 Mon Sep 17 00:00:00 2001 From: Vincent Beck Date: Thu, 12 Oct 2023 11:25:34 -0400 Subject: [PATCH 20/31] Pass auth manager to views in `AirflowBaseView` --- airflow/www/extensions/init_jinja_globals.py | 6 ------ airflow/www/templates/airflow/dag.html | 3 +-- airflow/www/views.py | 7 +++++++ tests/www/views/conftest.py | 2 -- 4 files changed, 8 insertions(+), 10 deletions(-) diff --git a/airflow/www/extensions/init_jinja_globals.py b/airflow/www/extensions/init_jinja_globals.py index 80e377378ca9c..a6efa5bb681a4 100644 --- a/airflow/www/extensions/init_jinja_globals.py +++ b/airflow/www/extensions/init_jinja_globals.py @@ -21,12 +21,10 @@ import pendulum import airflow -from airflow.auth.managers.models.resource_details import DagDetails from airflow.configuration import conf from airflow.settings import IS_K8S_OR_K8SCELERY_EXECUTOR, STATE_COLORS from airflow.utils.net import get_hostname from airflow.utils.platform import get_airflow_git_version -from airflow.www.extensions.init_auth_manager import get_auth_manager def init_jinja_globals(app): @@ -73,10 +71,6 @@ def prepare_jinja_globals(): "config_test_connection": conf.get("core", "test_connection", fallback="Disabled"), } - # Extra global specific to auth manager - extra_globals["auth_manager"] = get_auth_manager() - extra_globals["DagDetails"] = DagDetails - backends = conf.get("api", "auth_backends") if backends and backends[0] != "airflow.api.auth.backend.deny_all": extra_globals["rest_api_enabled"] = True diff --git a/airflow/www/templates/airflow/dag.html b/airflow/www/templates/airflow/dag.html index 884f0401858ef..d19635292f9c2 100644 --- a/airflow/www/templates/airflow/dag.html +++ b/airflow/www/templates/airflow/dag.html @@ -110,8 +110,7 @@

{% if dag.parent_dag is defined and dag.parent_dag %} SUBDAG: {{ dag.dag_id }} {% else %} - {% set can_edit = auth_manager.is_authorized_dag(method="PUT", details=DagDetails(id=dag.dag_id)) %} - {% if can_edit %} + {% if can_edit_dag %} {% set switch_tooltip = 'Pause/Unpause DAG' %} {% else %} {% set switch_tooltip = 'DAG is Paused' if dag_is_paused else 'DAG is Active' %} diff --git a/airflow/www/views.py b/airflow/www/views.py index 9204de2314124..2060516b836dc 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -699,6 +699,13 @@ def render_template(self, *args, **kwargs): # Add triggerer_job only if we need it if TriggererJobRunner.is_needed(): kwargs["triggerer_job"] = lazy_object_proxy.Proxy(TriggererJobRunner.most_recent_job) + + kwargs["auth_manager"] = get_auth_manager() + if "dag" in kwargs: + kwargs["can_edit_dag"] = get_auth_manager().is_authorized_dag( + method="PUT", details=DagDetails(id=kwargs["dag"].dag_id) + ) + return super().render_template( *args, # Cache this at most once per request, not for the lifetime of the view instance diff --git a/tests/www/views/conftest.py b/tests/www/views/conftest.py index b472290c79e65..c304909dab437 100644 --- a/tests/www/views/conftest.py +++ b/tests/www/views/conftest.py @@ -172,8 +172,6 @@ def local_context(self): "scheduler_job", # airflow.www.views.AirflowBaseView.extra_args "macros", - "auth_manager", - "DagDetails", ] for key in keys_to_delete: del result[key] From 58621d279a79deebfe6e1c8be40b218be440d0b9 Mon Sep 17 00:00:00 2001 From: Vincent Beck Date: Fri, 13 Oct 2023 11:52:43 -0400 Subject: [PATCH 21/31] Add back auth manager to global jinja context --- airflow/www/extensions/init_jinja_globals.py | 4 ++++ airflow/www/templates/airflow/dag.html | 4 ++-- airflow/www/views.py | 1 - tests/www/views/conftest.py | 1 + 4 files changed, 7 insertions(+), 3 deletions(-) diff --git a/airflow/www/extensions/init_jinja_globals.py b/airflow/www/extensions/init_jinja_globals.py index a6efa5bb681a4..95cd9b8c26785 100644 --- a/airflow/www/extensions/init_jinja_globals.py +++ b/airflow/www/extensions/init_jinja_globals.py @@ -25,6 +25,7 @@ from airflow.settings import IS_K8S_OR_K8SCELERY_EXECUTOR, STATE_COLORS from airflow.utils.net import get_hostname from airflow.utils.platform import get_airflow_git_version +from airflow.www.extensions.init_auth_manager import get_auth_manager def init_jinja_globals(app): @@ -71,6 +72,9 @@ def prepare_jinja_globals(): "config_test_connection": conf.get("core", "test_connection", fallback="Disabled"), } + # Extra global specific to auth manager + extra_globals["auth_manager"] = get_auth_manager() + backends = conf.get("api", "auth_backends") if backends and backends[0] != "airflow.api.auth.backend.deny_all": extra_globals["rest_api_enabled"] = True diff --git a/airflow/www/templates/airflow/dag.html b/airflow/www/templates/airflow/dag.html index d19635292f9c2..40440d3fd6672 100644 --- a/airflow/www/templates/airflow/dag.html +++ b/airflow/www/templates/airflow/dag.html @@ -115,10 +115,10 @@

{% else %} {% set switch_tooltip = 'DAG is Paused' if dag_is_paused else 'DAG is Active' %} {% endif %} -