Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use auth manager is_authorized_ APIs to check user permissions in Rest API #34317

Merged
merged 36 commits into from
Oct 17, 2023
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
d843ac1
Use auth manager APIs `is_authorized_` to check user permissions in R…
vincbeck Sep 25, 2023
54becda
Fix get_dag_warnings
vincbeck Sep 25, 2023
ed8d27d
Remove test
vincbeck Sep 25, 2023
4d6acbf
Fix `_get_root_dag_id`
vincbeck Sep 25, 2023
94f8585
Check GET permissions on DAG level as well
vincbeck Sep 26, 2023
19b5289
Add test case
vincbeck Sep 26, 2023
ccffa78
Fix tests
vincbeck Sep 26, 2023
218a336
Fix static checks
vincbeck Sep 26, 2023
414009b
Refactoring
vincbeck Sep 26, 2023
cee7e5f
Address feedbacks
vincbeck Sep 27, 2023
c13d45e
Merge branch 'main' into vincbeck/use_is_authorized
vincbeck Sep 27, 2023
f04bb89
Fix test
vincbeck Sep 27, 2023
1b33953
Minor feedbacks
vincbeck Sep 28, 2023
21f9204
When doing authorization check on DAGs in general, check if the user …
vincbeck Sep 29, 2023
1e6b9af
Fix tests
vincbeck Sep 29, 2023
4a7df9b
Address feedbacks
vincbeck Oct 5, 2023
fd2cb8a
Apply suggestions
vincbeck Oct 5, 2023
6cd6c65
Merge branch 'main' into vincbeck/use_is_authorized
vincbeck Oct 5, 2023
07da6fa
Merge branch 'main' into vincbeck/use_is_authorized
vincbeck Oct 5, 2023
a670cee
Apply suggestions
vincbeck Oct 5, 2023
9fff28e
Add back check permissions on dag run for task_instance_endpoint
vincbeck Oct 11, 2023
8f9f9ef
Fix permissions
vincbeck Oct 11, 2023
fc9ef42
Merge branch 'main' into vincbeck/use_is_authorized
vincbeck Oct 12, 2023
b6e1fba
Pass auth manager to views in `AirflowBaseView`
vincbeck Oct 12, 2023
58621d2
Add back auth manager to global jinja context
vincbeck Oct 13, 2023
04b1d54
Introduce DagAccessEntity.TASK
vincbeck Oct 12, 2023
c4260c5
Deprecate `requires_access`
vincbeck Oct 12, 2023
7ba129b
Raise exception in `check_authorization`
vincbeck Oct 13, 2023
1f1709b
Fix tests
vincbeck Oct 13, 2023
9334b2c
Merge branch 'main' into vincbeck/use_is_authorized
vincbeck Oct 16, 2023
04812ee
Move `get_permitted_dag_ids` to auth manager
vincbeck Oct 16, 2023
26b05be
Add comment
vincbeck Oct 16, 2023
2f115aa
Fix import
vincbeck Oct 16, 2023
0512542
Fix `get_permitted_dag_ids` in FAB auth manager
vincbeck Oct 16, 2023
bdcab82
Fix get_permitted_dag_ids
vincbeck Oct 17, 2023
dafeb54
Fix tests
vincbeck Oct 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions airflow/api_connexion/endpoints/config_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from airflow.api_connexion.exceptions import NotFound, PermissionDenied
from airflow.api_connexion.schemas.config_schema import Config, ConfigOption, ConfigSection, config_schema
from airflow.configuration import conf
from airflow.security import permissions
from airflow.settings import json

LINE_SEP = "\n" # `\n` cannot appear in f-strings
Expand Down Expand Up @@ -66,7 +65,7 @@ def _config_to_json(config: Config) -> str:
return json.dumps(config_schema.dump(config), indent=4)


@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_CONFIG)])
@security.requires_access_configuration("GET")
def get_config(*, section: str | None = None) -> Response:
"""Get current configuration."""
serializer = {
Expand Down Expand Up @@ -103,8 +102,8 @@ def get_config(*, section: str | None = None) -> Response:
)


@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_CONFIG)])
def get_value(section: str, option: str) -> Response:
@security.requires_access_configuration("GET")
def get_value(*, section: str, option: str) -> Response:
serializer = {
"text/plain": _config_to_text,
"application/json": _config_to_json,
Expand Down
12 changes: 6 additions & 6 deletions airflow/api_connexion/endpoints/connection_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
RESOURCE_EVENT_PREFIX = "connection"


@security.requires_access([(permissions.ACTION_CAN_DELETE, permissions.RESOURCE_CONNECTION)])
@security.requires_access_connection("DELETE")
@provide_session
@action_logging(
event=action_event_from_permission(
Expand All @@ -73,7 +73,7 @@ def delete_connection(*, connection_id: str, session: Session = NEW_SESSION) ->
return NoContent, HTTPStatus.NO_CONTENT


@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_CONNECTION)])
@security.requires_access_connection("GET")
@provide_session
def get_connection(*, connection_id: str, session: Session = NEW_SESSION) -> APIResponse:
"""Get a connection entry."""
Expand All @@ -86,7 +86,7 @@ def get_connection(*, connection_id: str, session: Session = NEW_SESSION) -> API
return connection_schema.dump(connection)


@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_CONNECTION)])
@security.requires_access_connection("GET")
@format_parameters({"limit": check_limit})
@provide_session
def get_connections(
Expand All @@ -109,7 +109,7 @@ def get_connections(
)


@security.requires_access([(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_CONNECTION)])
@security.requires_access_connection("PUT")
@provide_session
@action_logging(
event=action_event_from_permission(
Expand Down Expand Up @@ -147,7 +147,7 @@ def patch_connection(
return connection_schema.dump(connection)


@security.requires_access([(permissions.ACTION_CAN_CREATE, permissions.RESOURCE_CONNECTION)])
@security.requires_access_connection("POST")
@provide_session
@action_logging(
event=action_event_from_permission(
Expand Down Expand Up @@ -176,7 +176,7 @@ def post_connection(*, session: Session = NEW_SESSION) -> APIResponse:
raise AlreadyExists(detail=f"Connection already exist. ID: {conn_id}")


@security.requires_access([(permissions.ACTION_CAN_CREATE, permissions.RESOURCE_CONNECTION)])
@security.requires_access_connection("POST")
def test_connection() -> APIResponse:
"""
Test an API connection.
Expand Down
15 changes: 7 additions & 8 deletions airflow/api_connexion/endpoints/dag_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
)
from airflow.exceptions import AirflowException, DagNotFound
from airflow.models.dag import DagModel, DagTag
from airflow.security import permissions
from airflow.utils.airflow_flask_app import get_airflow_app
from airflow.utils.db import get_query_count
from airflow.utils.session import NEW_SESSION, provide_session
Expand All @@ -48,7 +47,7 @@
from airflow.api_connexion.types import APIResponse, UpdateMask


@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG)])
@security.requires_access_dag("GET")
@provide_session
def get_dag(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse:
"""Get basic information about a DAG."""
Expand All @@ -60,7 +59,7 @@ def get_dag(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse:
return dag_schema.dump(dag)


@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG)])
@security.requires_access_dag("GET")
def get_dag_details(*, dag_id: str) -> APIResponse:
"""Get details of DAG."""
dag: DAG = get_airflow_app().dag_bag.get_dag(dag_id)
Expand All @@ -69,7 +68,7 @@ def get_dag_details(*, dag_id: str) -> APIResponse:
return dag_detail_schema.dump(dag)


@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG)])
@security.requires_access_dag("GET")
@format_parameters({"limit": check_limit})
@provide_session
def get_dags(
Expand All @@ -96,7 +95,7 @@ def get_dags(
if dag_id_pattern:
dags_query = dags_query.where(DagModel.dag_id.ilike(f"%{dag_id_pattern}%"))

readable_dags = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user)
readable_dags = get_airflow_app().appbuilder.sm.get_permitted_dag_ids(user=g.user)

dags_query = dags_query.where(DagModel.dag_id.in_(readable_dags))
if tags:
Expand All @@ -110,7 +109,7 @@ def get_dags(
return dags_collection_schema.dump(DAGCollection(dags=dags, total_entries=total_entries))


@security.requires_access([(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG)])
@security.requires_access_dag("PUT")
@provide_session
def patch_dag(*, dag_id: str, update_mask: UpdateMask = None, session: Session = NEW_SESSION) -> APIResponse:
"""Update the specific DAG."""
Expand All @@ -132,7 +131,7 @@ def patch_dag(*, dag_id: str, update_mask: UpdateMask = None, session: Session =
return dag_schema.dump(dag)


@security.requires_access([(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG)])
@security.requires_access_dag("PUT")
@format_parameters({"limit": check_limit})
@provide_session
def patch_dags(limit, session, offset=0, only_active=True, tags=None, dag_id_pattern=None, update_mask=None):
Expand Down Expand Up @@ -180,7 +179,7 @@ def patch_dags(limit, session, offset=0, only_active=True, tags=None, dag_id_pat
return dags_collection_schema.dump(DAGCollection(dags=dags, total_entries=total_entries))


@security.requires_access([(permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG)])
@security.requires_access_dag("DELETE")
vincbeck marked this conversation as resolved.
Show resolved Hide resolved
@provide_session
def delete_dag(dag_id: str, session: Session = NEW_SESSION) -> APIResponse:
"""Delete the specific DAG."""
Expand Down
66 changes: 11 additions & 55 deletions airflow/api_connexion/endpoints/dag_run_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
TaskInstanceReferenceCollection,
task_instance_reference_collection_schema,
)
from airflow.auth.managers.models.resource_details import DagAccessEntity
from airflow.models import DagModel, DagRun
from airflow.security import permissions
from airflow.utils.airflow_flask_app import get_airflow_app
Expand All @@ -76,12 +77,7 @@
RESOURCE_EVENT_PREFIX = "dag_run"


@security.requires_access(
[
(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG),
(permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG_RUN),
],
)
@security.requires_access_dag("DELETE", DagAccessEntity.RUN)
vincbeck marked this conversation as resolved.
Show resolved Hide resolved
@provide_session
def delete_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION) -> APIResponse:
"""Delete a DAG Run."""
Expand All @@ -93,12 +89,7 @@ def delete_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSI
return NoContent, HTTPStatus.NO_CONTENT


@security.requires_access(
[
(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN),
],
)
@security.requires_access_dag("GET", DagAccessEntity.RUN)
@provide_session
def get_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION) -> APIResponse:
"""Get a DAG Run."""
Expand All @@ -111,13 +102,8 @@ def get_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION)
return dagrun_schema.dump(dag_run)


@security.requires_access(
[
(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN),
(permissions.ACTION_CAN_READ, permissions.RESOURCE_DATASET),
],
)
@security.requires_access_dag("GET", DagAccessEntity.RUN)
@security.requires_access_dataset("GET")
@provide_session
def get_upstream_dataset_events(
*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION
Expand Down Expand Up @@ -194,12 +180,7 @@ def _fetch_dag_runs(
return session.scalars(query.offset(offset).limit(limit)).all(), total_entries


@security.requires_access(
[
(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN),
],
)
@security.requires_access_dag("GET", DagAccessEntity.RUN)
@format_parameters(
{
"start_date_gte": format_datetime,
Expand Down Expand Up @@ -262,12 +243,7 @@ def get_dag_runs(
return dagrun_collection_schema.dump(DAGRunCollection(dag_runs=dag_run, total_entries=total_entries))


@security.requires_access(
[
(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN),
],
)
@security.requires_access_dag("GET", DagAccessEntity.RUN)
@provide_session
def get_dag_runs_batch(*, session: Session = NEW_SESSION) -> APIResponse:
"""Get list of DAG Runs."""
Expand Down Expand Up @@ -307,12 +283,7 @@ def get_dag_runs_batch(*, session: Session = NEW_SESSION) -> APIResponse:
return dagrun_collection_schema.dump(DAGRunCollection(dag_runs=dag_runs, total_entries=total_entries))


@security.requires_access(
[
(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG),
(permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN),
],
)
@security.requires_access_dag("POST", DagAccessEntity.RUN)
@provide_session
@action_logging(
event=action_event_from_permission(
Expand Down Expand Up @@ -378,12 +349,7 @@ def post_dag_run(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse:
raise AlreadyExists(detail=f"DAGRun with DAG ID: '{dag_id}' and DAGRun ID: '{run_id}' already exists")


@security.requires_access(
[
(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN),
],
)
@security.requires_access_dag("PUT", DagAccessEntity.RUN)
@provide_session
def update_dag_run_state(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION) -> APIResponse:
"""Set a state of a dag run."""
Expand All @@ -410,12 +376,7 @@ def update_dag_run_state(*, dag_id: str, dag_run_id: str, session: Session = NEW
return dagrun_schema.dump(dag_run)


@security.requires_access(
[
(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN),
],
)
@security.requires_access_dag("PUT", DagAccessEntity.RUN)
@provide_session
def clear_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION) -> APIResponse:
"""Clear a dag run."""
Expand Down Expand Up @@ -461,12 +422,7 @@ def clear_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSIO
return dagrun_schema.dump(dag_run)


@security.requires_access(
[
(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN),
],
)
@security.requires_access_dag("PUT", DagAccessEntity.RUN)
@provide_session
def set_dag_run_note(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION) -> APIResponse:
"""Set the note for a dag run."""
Expand Down
4 changes: 2 additions & 2 deletions airflow/api_connexion/endpoints/dag_source_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@
from airflow.api_connexion import security
from airflow.api_connexion.exceptions import NotFound
from airflow.api_connexion.schemas.dag_source_schema import dag_source_schema
from airflow.auth.managers.models.resource_details import DagAccessEntity
from airflow.models.dagcode import DagCode
from airflow.security import permissions


@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_CODE)])
@security.requires_access_dag("GET", DagAccessEntity.CODE)
def get_dag_source(*, file_token: str) -> Response:
"""Get source code using file token."""
secret_key = current_app.config["SECRET_KEY"]
Expand Down
7 changes: 4 additions & 3 deletions airflow/api_connexion/endpoints/dag_warning_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,20 @@
DagWarningCollection,
dag_warning_collection_schema,
)
from airflow.auth.managers.models.resource_details import DagAccessEntity, DagDetails
from airflow.models.dagwarning import DagWarning as DagWarningModel
from airflow.security import permissions
from airflow.utils.airflow_flask_app import get_airflow_app
from airflow.utils.db import get_query_count
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.www.extensions.init_auth_manager import get_auth_manager

if TYPE_CHECKING:
from sqlalchemy.orm import Session

from airflow.api_connexion.types import APIResponse


@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_WARNING)])
@security.requires_access_dag("GET", DagAccessEntity.WARNING)
@format_parameters({"limit": check_limit})
@provide_session
def get_dag_warnings(
Expand All @@ -60,7 +61,7 @@ def get_dag_warnings(
allowed_filter_attrs = ["dag_id", "warning_type", "message", "timestamp"]
query = select(DagWarningModel)
if dag_id:
if not get_airflow_app().appbuilder.sm.can_read_dag(dag_id, g.user):
if not get_auth_manager().is_authorized_dag(method="GET", details=DagDetails(id=dag_id), user=g.user):
raise PermissionDenied(detail=f"User not allowed to access this DAG: {dag_id}")
query = query.where(DagWarningModel.dag_id == dag_id)
else:
Expand Down
9 changes: 4 additions & 5 deletions airflow/api_connexion/endpoints/dataset_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
dataset_schema,
)
from airflow.models.dataset import DatasetEvent, DatasetModel
from airflow.security import permissions
from airflow.utils.db import get_query_count
from airflow.utils.session import NEW_SESSION, provide_session

Expand All @@ -42,9 +41,9 @@
from airflow.api_connexion.types import APIResponse


@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DATASET)])
@security.requires_access_dataset("GET")
@provide_session
def get_dataset(uri: str, session: Session = NEW_SESSION) -> APIResponse:
def get_dataset(*, uri: str, session: Session = NEW_SESSION) -> APIResponse:
"""Get a Dataset."""
dataset = session.scalar(
select(DatasetModel)
Expand All @@ -59,7 +58,7 @@ def get_dataset(uri: str, session: Session = NEW_SESSION) -> APIResponse:
return dataset_schema.dump(dataset)


@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DATASET)])
@security.requires_access_dataset("GET")
@format_parameters({"limit": check_limit})
@provide_session
def get_datasets(
Expand All @@ -86,7 +85,7 @@ def get_datasets(
return dataset_collection_schema.dump(DatasetCollection(datasets=datasets, total_entries=total_entries))


@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DATASET)])
@security.requires_access_dataset("GET")
@provide_session
@format_parameters({"limit": check_limit})
def get_dataset_events(
Expand Down
Loading
Loading