Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 82 additions & 30 deletions airflow-core/src/airflow/api_fastapi/core_api/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,7 @@ def inner(
for action in request.actions
for entity in action.entities
if action.action != BulkAction.CREATE
or action.action_on_existence == BulkActionOnExistence.OVERWRITE
]
# For each pool, find its associated team (if it exists)
pool_name_to_team = Pool.get_name_to_team_name_mapping(existing_pool_names)
Expand All @@ -460,14 +461,22 @@ def inner(
# The list of `IsAuthorizedPoolRequest` will then be sent using `batch_is_authorized_pool`
# Each `IsAuthorizedPoolRequest` is similar to calling `is_authorized_pool`
for method in methods:
req: IsAuthorizedPoolRequest = {
"method": method,
"details": PoolDetails(
name=pool_name,
team_name=pool_name_to_team.get(pool_name),
),
}
requests.append(req)
teams_to_check = _collect_teams_for_bulk_entity(
action=action,
entity_team_name=cast("PoolBody", pool).team_name
if action.action != BulkAction.DELETE
else None,
existing_team_name=pool_name_to_team.get(pool_name),
)
for team_name in teams_to_check:
req: IsAuthorizedPoolRequest = {
"method": method,
"details": PoolDetails(
name=pool_name,
team_name=team_name,
),
}
requests.append(req)

_requires_access(
# By calling `batch_is_authorized_pool`, we check the user has access to all pools provided in the request
Expand Down Expand Up @@ -548,6 +557,7 @@ def inner(
for action in request.actions
for entity in action.entities
if action.action != BulkAction.CREATE
or action.action_on_existence == BulkActionOnExistence.OVERWRITE
]
# For each connection, find its associated team (if it exists)
conn_id_to_team = Connection.get_conn_id_to_team_name_mapping(existing_connection_ids)
Expand All @@ -561,18 +571,23 @@ def inner(
if action.action == BulkAction.DELETE
else cast("ConnectionBody", connection).connection_id
)
# For each pool, build a `IsAuthorizedConnectionRequest`
# The list of `IsAuthorizedConnectionRequest` will then be sent using `batch_is_authorized_connection`
# Each `IsAuthorizedConnectionRequest` is similar to calling `is_authorized_connection`
for method in methods:
req: IsAuthorizedConnectionRequest = {
"method": method,
"details": ConnectionDetails(
conn_id=connection_id,
team_name=conn_id_to_team.get(connection_id),
),
}
requests.append(req)
teams_to_check = _collect_teams_for_bulk_entity(
action=action,
entity_team_name=cast("ConnectionBody", connection).team_name
if action.action != BulkAction.DELETE
else None,
existing_team_name=conn_id_to_team.get(connection_id),
)
for team_name in teams_to_check:
req: IsAuthorizedConnectionRequest = {
"method": method,
"details": ConnectionDetails(
conn_id=connection_id,
team_name=team_name,
),
}
requests.append(req)

_requires_access(
# By calling `batch_is_authorized_connection`, we check the user has access to all connections provided in the request
Expand Down Expand Up @@ -688,6 +703,7 @@ def inner(
for action in request.actions
for entity in action.entities
if action.action != BulkAction.CREATE
or action.action_on_existence == BulkActionOnExistence.OVERWRITE
]
# For each variable, find its associated team (if it exists)
var_key_to_team = Variable.get_key_to_team_name_mapping(existing_variable_keys)
Expand All @@ -701,18 +717,23 @@ def inner(
if action.action == BulkAction.DELETE
else cast("VariableBody", variable).key
)
# For each variable, build a `IsAuthorizedVariableRequest`
# The list of `IsAuthorizedVariableRequest` will then be sent using `batch_is_authorized_variable`
# Each `IsAuthorizedVariableRequest` is similar to calling `is_authorized_variable`
for method in methods:
req: IsAuthorizedVariableRequest = {
"method": method,
"details": VariableDetails(
key=variable_key,
team_name=var_key_to_team.get(variable_key),
),
}
requests.append(req)
teams_to_check = _collect_teams_for_bulk_entity(
action=action,
entity_team_name=cast("VariableBody", variable).team_name
if action.action != BulkAction.DELETE
else None,
existing_team_name=var_key_to_team.get(variable_key),
)
for team_name in teams_to_check:
req: IsAuthorizedVariableRequest = {
"method": method,
"details": VariableDetails(
key=variable_key,
team_name=team_name,
),
}
requests.append(req)

_requires_access(
# By calling `batch_is_authorized_variable`, we check the user has access to all variables provided in the request
Expand Down Expand Up @@ -884,3 +905,34 @@ def _get_resource_methods_from_bulk_request(
if action.action == BulkAction.CREATE and action.action_on_existence == BulkActionOnExistence.OVERWRITE:
resource_methods.append("PUT")
return resource_methods


def _collect_teams_for_bulk_entity(
action: BulkCreateAction | BulkUpdateAction | BulkDeleteAction,
entity_team_name: str | None,
existing_team_name: str | None,
) -> set[str | None]:
"""
Collect the set of team names to authorize for a single entity in a bulk request.

For CREATE/UPDATE actions the caller must be authorized for the team_name specified in the
request body (if any). For UPDATE actions (or CREATE with overwrite) on existing resources,
the existing team is also checked so that a user cannot move a resource out of a team they
don't have access to.
"""
if not conf.getboolean("core", "multi_team"):
return {None}

teams: set[str | None] = set()

if action.action == BulkAction.DELETE:
teams.add(existing_team_name)
elif action.action == BulkAction.UPDATE:
teams.add(existing_team_name)
teams.add(entity_team_name)
elif action.action == BulkAction.CREATE:
teams.add(entity_team_name)
if action.action_on_existence == BulkActionOnExistence.OVERWRITE:
teams.add(existing_team_name)
Comment thread
vincbeck marked this conversation as resolved.

return teams
166 changes: 160 additions & 6 deletions airflow-core/tests/unit/api_fastapi/core_api/test_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,7 +811,7 @@ def test_requires_access_connection_bulk(
auth_manager = Mock()
auth_manager.batch_is_authorized_connection.return_value = True
mock_get_auth_manager.return_value = auth_manager
mock_get_conn_id_to_team_name_mapping.return_value = {"test1": "team1"}
mock_get_conn_id_to_team_name_mapping.return_value = {}

request = BulkBody[ConnectionBody].model_validate(
{
Expand Down Expand Up @@ -844,7 +844,7 @@ def test_requires_access_connection_bulk(
requests=[
{
"method": "POST",
"details": ConnectionDetails(conn_id="test1", team_name="team1"),
"details": ConnectionDetails(conn_id="test1"),
},
{
"method": "POST",
Expand All @@ -866,6 +866,58 @@ def test_requires_access_connection_bulk(
user=user,
)

@patch.object(Connection, "get_conn_id_to_team_name_mapping")
@patch("airflow.api_fastapi.core_api.security.get_auth_manager")
def test_requires_access_connection_bulk_multi_team(
self, mock_get_auth_manager, mock_get_conn_id_to_team_name_mapping
):
"""Bulk connection authz checks team_name from both existing resources and request body."""
auth_manager = Mock()
auth_manager.batch_is_authorized_connection.return_value = True
mock_get_auth_manager.return_value = auth_manager
mock_get_conn_id_to_team_name_mapping.return_value = {"test3": "team1", "test4": "team1"}

user = Mock()
with conf_vars({("core", "multi_team"): "True"}):
request = BulkBody[ConnectionBody].model_validate(
{
"actions": [
{
"action": "create",
"entities": [
{"connection_id": "test1", "conn_type": "test1", "team_name": "team2"},
{"connection_id": "test2", "conn_type": "test2"},
],
},
{
"action": "delete",
"entities": ["test3"],
},
{
"action": "create",
"entities": [
{"connection_id": "test4", "conn_type": "test4", "team_name": "team2"},
],
"action_on_existence": "overwrite",
},
]
}
)
requires_access_connection_bulk()(request, user)

auth_manager.batch_is_authorized_connection.assert_called_once()
actual_requests = auth_manager.batch_is_authorized_connection.call_args.kwargs["requests"]
expected_requests = [
{"method": "POST", "details": ConnectionDetails(conn_id="test1", team_name="team2")},
{"method": "POST", "details": ConnectionDetails(conn_id="test2")},
{"method": "DELETE", "details": ConnectionDetails(conn_id="test3", team_name="team1")},
{"method": "POST", "details": ConnectionDetails(conn_id="test4", team_name="team1")},
{"method": "POST", "details": ConnectionDetails(conn_id="test4", team_name="team2")},
{"method": "PUT", "details": ConnectionDetails(conn_id="test4", team_name="team1")},
{"method": "PUT", "details": ConnectionDetails(conn_id="test4", team_name="team2")},
]
assert sorted(actual_requests, key=str) == sorted(expected_requests, key=str)

@pytest.mark.db_test
@pytest.mark.parametrize(
"team_name",
Expand Down Expand Up @@ -982,7 +1034,7 @@ def test_requires_access_variable_bulk(self, mock_get_auth_manager, mock_get_key
auth_manager = Mock()
auth_manager.batch_is_authorized_variable.return_value = True
mock_get_auth_manager.return_value = auth_manager
mock_get_key_to_team_name_mapping.return_value = {"var1": "team1", "dummy": "team2"}
mock_get_key_to_team_name_mapping.return_value = {}
request = BulkBody[VariableBody].model_validate(
{
"actions": [
Expand Down Expand Up @@ -1014,7 +1066,7 @@ def test_requires_access_variable_bulk(self, mock_get_auth_manager, mock_get_key
requests=[
{
"method": "POST",
"details": VariableDetails(key="var1", team_name="team1"),
"details": VariableDetails(key="var1"),
},
{
"method": "POST",
Expand All @@ -1036,6 +1088,57 @@ def test_requires_access_variable_bulk(self, mock_get_auth_manager, mock_get_key
user=user,
)

@patch.object(Variable, "get_key_to_team_name_mapping")
@patch("airflow.api_fastapi.core_api.security.get_auth_manager")
def test_requires_access_variable_bulk_multi_team(
self, mock_get_auth_manager, mock_get_key_to_team_name_mapping
):
"""Bulk variable authz checks team_name from both existing resources and request body."""
auth_manager = Mock()
auth_manager.batch_is_authorized_variable.return_value = True
mock_get_auth_manager.return_value = auth_manager
mock_get_key_to_team_name_mapping.return_value = {"var3": "team1", "var4": "team1"}
user = Mock()
with conf_vars({("core", "multi_team"): "True"}):
request = BulkBody[VariableBody].model_validate(
{
"actions": [
{
"action": "create",
"entities": [
{"key": "var1", "value": "value1", "team_name": "team2"},
{"key": "var2", "value": "value2"},
],
},
{
"action": "delete",
"entities": ["var3"],
},
{
"action": "create",
"entities": [
{"key": "var4", "value": "value4", "team_name": "team2"},
],
"action_on_existence": "overwrite",
},
]
}
)
requires_access_variable_bulk()(request, user)

auth_manager.batch_is_authorized_variable.assert_called_once()
actual_requests = auth_manager.batch_is_authorized_variable.call_args.kwargs["requests"]
expected_requests = [
{"method": "POST", "details": VariableDetails(key="var1", team_name="team2")},
{"method": "POST", "details": VariableDetails(key="var2")},
{"method": "DELETE", "details": VariableDetails(key="var3", team_name="team1")},
{"method": "POST", "details": VariableDetails(key="var4", team_name="team1")},
{"method": "POST", "details": VariableDetails(key="var4", team_name="team2")},
{"method": "PUT", "details": VariableDetails(key="var4", team_name="team1")},
{"method": "PUT", "details": VariableDetails(key="var4", team_name="team2")},
]
assert sorted(actual_requests, key=str) == sorted(expected_requests, key=str)

@pytest.mark.db_test
@pytest.mark.parametrize(
"team_name",
Expand Down Expand Up @@ -1150,7 +1253,7 @@ def test_requires_access_pool_bulk(self, mock_get_auth_manager, mock_get_name_to
auth_manager = Mock()
auth_manager.batch_is_authorized_pool.return_value = True
mock_get_auth_manager.return_value = auth_manager
mock_get_name_to_team_name_mapping.return_value = {"pool1": "team1"}
mock_get_name_to_team_name_mapping.return_value = {}
request = BulkBody[PoolBody].model_validate(
{
"actions": [
Expand Down Expand Up @@ -1182,7 +1285,7 @@ def test_requires_access_pool_bulk(self, mock_get_auth_manager, mock_get_name_to
requests=[
{
"method": "POST",
"details": PoolDetails(name="pool1", team_name="team1"),
"details": PoolDetails(name="pool1"),
},
{
"method": "POST",
Expand All @@ -1204,6 +1307,57 @@ def test_requires_access_pool_bulk(self, mock_get_auth_manager, mock_get_name_to
user=user,
)

@patch.object(Pool, "get_name_to_team_name_mapping")
@patch("airflow.api_fastapi.core_api.security.get_auth_manager")
def test_requires_access_pool_bulk_multi_team(
self, mock_get_auth_manager, mock_get_name_to_team_name_mapping
):
"""Bulk pool authz checks team_name from both existing resources and request body."""
auth_manager = Mock()
auth_manager.batch_is_authorized_pool.return_value = True
mock_get_auth_manager.return_value = auth_manager
mock_get_name_to_team_name_mapping.return_value = {"pool3": "team1", "pool4": "team1"}
user = Mock()
with conf_vars({("core", "multi_team"): "True"}):
request = BulkBody[PoolBody].model_validate(
{
"actions": [
{
"action": "create",
"entities": [
{"pool": "pool1", "slots": 1, "team_name": "team2"},
{"pool": "pool2", "slots": 1},
],
},
{
"action": "delete",
"entities": ["pool3"],
},
{
"action": "create",
"entities": [
{"pool": "pool4", "slots": 1, "team_name": "team2"},
],
"action_on_existence": "overwrite",
},
]
}
)
requires_access_pool_bulk()(request, user)

auth_manager.batch_is_authorized_pool.assert_called_once()
actual_requests = auth_manager.batch_is_authorized_pool.call_args.kwargs["requests"]
expected_requests = [
{"method": "POST", "details": PoolDetails(name="pool1", team_name="team2")},
{"method": "POST", "details": PoolDetails(name="pool2")},
{"method": "DELETE", "details": PoolDetails(name="pool3", team_name="team1")},
{"method": "POST", "details": PoolDetails(name="pool4", team_name="team1")},
{"method": "POST", "details": PoolDetails(name="pool4", team_name="team2")},
{"method": "PUT", "details": PoolDetails(name="pool4", team_name="team1")},
{"method": "PUT", "details": PoolDetails(name="pool4", team_name="team2")},
]
assert sorted(actual_requests, key=str) == sorted(expected_requests, key=str)


class TestAuthManagerDependency:
"""Test the auth_manager_from_app dependency function."""
Expand Down
Loading