From 080c81c9c62954a64afb7918dff682d18cfd3c6b Mon Sep 17 00:00:00 2001 From: vincbeck Date: Thu, 21 May 2026 13:54:46 -0400 Subject: [PATCH] Check team permissions on bulk APIs for Connections, Variables and Pools --- .../airflow/api_fastapi/core_api/security.py | 112 ++++++++---- .../api_fastapi/core_api/test_security.py | 166 +++++++++++++++++- 2 files changed, 242 insertions(+), 36 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/core_api/security.py b/airflow-core/src/airflow/api_fastapi/core_api/security.py index a54425326d3fa..74975a27d42a5 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/security.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/security.py @@ -445,6 +445,7 @@ def inner( for action in request.actions for entity in action.entities if action.action != BulkAction.CREATE + or action.action_on_existence == BulkActionOnExistence.OVERWRITE ] # For each pool, find its associated team (if it exists) pool_name_to_team = Pool.get_name_to_team_name_mapping(existing_pool_names) @@ -460,14 +461,22 @@ def inner( # The list of `IsAuthorizedPoolRequest` will then be sent using `batch_is_authorized_pool` # Each `IsAuthorizedPoolRequest` is similar to calling `is_authorized_pool` for method in methods: - req: IsAuthorizedPoolRequest = { - "method": method, - "details": PoolDetails( - name=pool_name, - team_name=pool_name_to_team.get(pool_name), - ), - } - requests.append(req) + teams_to_check = _collect_teams_for_bulk_entity( + action=action, + entity_team_name=cast("PoolBody", pool).team_name + if action.action != BulkAction.DELETE + else None, + existing_team_name=pool_name_to_team.get(pool_name), + ) + for team_name in teams_to_check: + req: IsAuthorizedPoolRequest = { + "method": method, + "details": PoolDetails( + name=pool_name, + team_name=team_name, + ), + } + requests.append(req) _requires_access( # By calling `batch_is_authorized_pool`, we check the user has access to all pools provided in the request @@ -548,6 +557,7 @@ def inner( for action in request.actions for entity in action.entities if action.action != BulkAction.CREATE + or action.action_on_existence == BulkActionOnExistence.OVERWRITE ] # For each connection, find its associated team (if it exists) conn_id_to_team = Connection.get_conn_id_to_team_name_mapping(existing_connection_ids) @@ -561,18 +571,23 @@ def inner( if action.action == BulkAction.DELETE else cast("ConnectionBody", connection).connection_id ) - # For each pool, build a `IsAuthorizedConnectionRequest` - # The list of `IsAuthorizedConnectionRequest` will then be sent using `batch_is_authorized_connection` - # Each `IsAuthorizedConnectionRequest` is similar to calling `is_authorized_connection` for method in methods: - req: IsAuthorizedConnectionRequest = { - "method": method, - "details": ConnectionDetails( - conn_id=connection_id, - team_name=conn_id_to_team.get(connection_id), - ), - } - requests.append(req) + teams_to_check = _collect_teams_for_bulk_entity( + action=action, + entity_team_name=cast("ConnectionBody", connection).team_name + if action.action != BulkAction.DELETE + else None, + existing_team_name=conn_id_to_team.get(connection_id), + ) + for team_name in teams_to_check: + req: IsAuthorizedConnectionRequest = { + "method": method, + "details": ConnectionDetails( + conn_id=connection_id, + team_name=team_name, + ), + } + requests.append(req) _requires_access( # By calling `batch_is_authorized_connection`, we check the user has access to all connections provided in the request @@ -688,6 +703,7 @@ def inner( for action in request.actions for entity in action.entities if action.action != BulkAction.CREATE + or action.action_on_existence == BulkActionOnExistence.OVERWRITE ] # For each variable, find its associated team (if it exists) var_key_to_team = Variable.get_key_to_team_name_mapping(existing_variable_keys) @@ -701,18 +717,23 @@ def inner( if action.action == BulkAction.DELETE else cast("VariableBody", variable).key ) - # For each variable, build a `IsAuthorizedVariableRequest` - # The list of `IsAuthorizedVariableRequest` will then be sent using `batch_is_authorized_variable` - # Each `IsAuthorizedVariableRequest` is similar to calling `is_authorized_variable` for method in methods: - req: IsAuthorizedVariableRequest = { - "method": method, - "details": VariableDetails( - key=variable_key, - team_name=var_key_to_team.get(variable_key), - ), - } - requests.append(req) + teams_to_check = _collect_teams_for_bulk_entity( + action=action, + entity_team_name=cast("VariableBody", variable).team_name + if action.action != BulkAction.DELETE + else None, + existing_team_name=var_key_to_team.get(variable_key), + ) + for team_name in teams_to_check: + req: IsAuthorizedVariableRequest = { + "method": method, + "details": VariableDetails( + key=variable_key, + team_name=team_name, + ), + } + requests.append(req) _requires_access( # By calling `batch_is_authorized_variable`, we check the user has access to all variables provided in the request @@ -884,3 +905,34 @@ def _get_resource_methods_from_bulk_request( if action.action == BulkAction.CREATE and action.action_on_existence == BulkActionOnExistence.OVERWRITE: resource_methods.append("PUT") return resource_methods + + +def _collect_teams_for_bulk_entity( + action: BulkCreateAction | BulkUpdateAction | BulkDeleteAction, + entity_team_name: str | None, + existing_team_name: str | None, +) -> set[str | None]: + """ + Collect the set of team names to authorize for a single entity in a bulk request. + + For CREATE/UPDATE actions the caller must be authorized for the team_name specified in the + request body (if any). For UPDATE actions (or CREATE with overwrite) on existing resources, + the existing team is also checked so that a user cannot move a resource out of a team they + don't have access to. + """ + if not conf.getboolean("core", "multi_team"): + return {None} + + teams: set[str | None] = set() + + if action.action == BulkAction.DELETE: + teams.add(existing_team_name) + elif action.action == BulkAction.UPDATE: + teams.add(existing_team_name) + teams.add(entity_team_name) + elif action.action == BulkAction.CREATE: + teams.add(entity_team_name) + if action.action_on_existence == BulkActionOnExistence.OVERWRITE: + teams.add(existing_team_name) + + return teams diff --git a/airflow-core/tests/unit/api_fastapi/core_api/test_security.py b/airflow-core/tests/unit/api_fastapi/core_api/test_security.py index ef8a3972909ca..f451a02c24646 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/test_security.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/test_security.py @@ -811,7 +811,7 @@ def test_requires_access_connection_bulk( auth_manager = Mock() auth_manager.batch_is_authorized_connection.return_value = True mock_get_auth_manager.return_value = auth_manager - mock_get_conn_id_to_team_name_mapping.return_value = {"test1": "team1"} + mock_get_conn_id_to_team_name_mapping.return_value = {} request = BulkBody[ConnectionBody].model_validate( { @@ -844,7 +844,7 @@ def test_requires_access_connection_bulk( requests=[ { "method": "POST", - "details": ConnectionDetails(conn_id="test1", team_name="team1"), + "details": ConnectionDetails(conn_id="test1"), }, { "method": "POST", @@ -866,6 +866,58 @@ def test_requires_access_connection_bulk( user=user, ) + @patch.object(Connection, "get_conn_id_to_team_name_mapping") + @patch("airflow.api_fastapi.core_api.security.get_auth_manager") + def test_requires_access_connection_bulk_multi_team( + self, mock_get_auth_manager, mock_get_conn_id_to_team_name_mapping + ): + """Bulk connection authz checks team_name from both existing resources and request body.""" + auth_manager = Mock() + auth_manager.batch_is_authorized_connection.return_value = True + mock_get_auth_manager.return_value = auth_manager + mock_get_conn_id_to_team_name_mapping.return_value = {"test3": "team1", "test4": "team1"} + + user = Mock() + with conf_vars({("core", "multi_team"): "True"}): + request = BulkBody[ConnectionBody].model_validate( + { + "actions": [ + { + "action": "create", + "entities": [ + {"connection_id": "test1", "conn_type": "test1", "team_name": "team2"}, + {"connection_id": "test2", "conn_type": "test2"}, + ], + }, + { + "action": "delete", + "entities": ["test3"], + }, + { + "action": "create", + "entities": [ + {"connection_id": "test4", "conn_type": "test4", "team_name": "team2"}, + ], + "action_on_existence": "overwrite", + }, + ] + } + ) + requires_access_connection_bulk()(request, user) + + auth_manager.batch_is_authorized_connection.assert_called_once() + actual_requests = auth_manager.batch_is_authorized_connection.call_args.kwargs["requests"] + expected_requests = [ + {"method": "POST", "details": ConnectionDetails(conn_id="test1", team_name="team2")}, + {"method": "POST", "details": ConnectionDetails(conn_id="test2")}, + {"method": "DELETE", "details": ConnectionDetails(conn_id="test3", team_name="team1")}, + {"method": "POST", "details": ConnectionDetails(conn_id="test4", team_name="team1")}, + {"method": "POST", "details": ConnectionDetails(conn_id="test4", team_name="team2")}, + {"method": "PUT", "details": ConnectionDetails(conn_id="test4", team_name="team1")}, + {"method": "PUT", "details": ConnectionDetails(conn_id="test4", team_name="team2")}, + ] + assert sorted(actual_requests, key=str) == sorted(expected_requests, key=str) + @pytest.mark.db_test @pytest.mark.parametrize( "team_name", @@ -982,7 +1034,7 @@ def test_requires_access_variable_bulk(self, mock_get_auth_manager, mock_get_key auth_manager = Mock() auth_manager.batch_is_authorized_variable.return_value = True mock_get_auth_manager.return_value = auth_manager - mock_get_key_to_team_name_mapping.return_value = {"var1": "team1", "dummy": "team2"} + mock_get_key_to_team_name_mapping.return_value = {} request = BulkBody[VariableBody].model_validate( { "actions": [ @@ -1014,7 +1066,7 @@ def test_requires_access_variable_bulk(self, mock_get_auth_manager, mock_get_key requests=[ { "method": "POST", - "details": VariableDetails(key="var1", team_name="team1"), + "details": VariableDetails(key="var1"), }, { "method": "POST", @@ -1036,6 +1088,57 @@ def test_requires_access_variable_bulk(self, mock_get_auth_manager, mock_get_key user=user, ) + @patch.object(Variable, "get_key_to_team_name_mapping") + @patch("airflow.api_fastapi.core_api.security.get_auth_manager") + def test_requires_access_variable_bulk_multi_team( + self, mock_get_auth_manager, mock_get_key_to_team_name_mapping + ): + """Bulk variable authz checks team_name from both existing resources and request body.""" + auth_manager = Mock() + auth_manager.batch_is_authorized_variable.return_value = True + mock_get_auth_manager.return_value = auth_manager + mock_get_key_to_team_name_mapping.return_value = {"var3": "team1", "var4": "team1"} + user = Mock() + with conf_vars({("core", "multi_team"): "True"}): + request = BulkBody[VariableBody].model_validate( + { + "actions": [ + { + "action": "create", + "entities": [ + {"key": "var1", "value": "value1", "team_name": "team2"}, + {"key": "var2", "value": "value2"}, + ], + }, + { + "action": "delete", + "entities": ["var3"], + }, + { + "action": "create", + "entities": [ + {"key": "var4", "value": "value4", "team_name": "team2"}, + ], + "action_on_existence": "overwrite", + }, + ] + } + ) + requires_access_variable_bulk()(request, user) + + auth_manager.batch_is_authorized_variable.assert_called_once() + actual_requests = auth_manager.batch_is_authorized_variable.call_args.kwargs["requests"] + expected_requests = [ + {"method": "POST", "details": VariableDetails(key="var1", team_name="team2")}, + {"method": "POST", "details": VariableDetails(key="var2")}, + {"method": "DELETE", "details": VariableDetails(key="var3", team_name="team1")}, + {"method": "POST", "details": VariableDetails(key="var4", team_name="team1")}, + {"method": "POST", "details": VariableDetails(key="var4", team_name="team2")}, + {"method": "PUT", "details": VariableDetails(key="var4", team_name="team1")}, + {"method": "PUT", "details": VariableDetails(key="var4", team_name="team2")}, + ] + assert sorted(actual_requests, key=str) == sorted(expected_requests, key=str) + @pytest.mark.db_test @pytest.mark.parametrize( "team_name", @@ -1150,7 +1253,7 @@ def test_requires_access_pool_bulk(self, mock_get_auth_manager, mock_get_name_to auth_manager = Mock() auth_manager.batch_is_authorized_pool.return_value = True mock_get_auth_manager.return_value = auth_manager - mock_get_name_to_team_name_mapping.return_value = {"pool1": "team1"} + mock_get_name_to_team_name_mapping.return_value = {} request = BulkBody[PoolBody].model_validate( { "actions": [ @@ -1182,7 +1285,7 @@ def test_requires_access_pool_bulk(self, mock_get_auth_manager, mock_get_name_to requests=[ { "method": "POST", - "details": PoolDetails(name="pool1", team_name="team1"), + "details": PoolDetails(name="pool1"), }, { "method": "POST", @@ -1204,6 +1307,57 @@ def test_requires_access_pool_bulk(self, mock_get_auth_manager, mock_get_name_to user=user, ) + @patch.object(Pool, "get_name_to_team_name_mapping") + @patch("airflow.api_fastapi.core_api.security.get_auth_manager") + def test_requires_access_pool_bulk_multi_team( + self, mock_get_auth_manager, mock_get_name_to_team_name_mapping + ): + """Bulk pool authz checks team_name from both existing resources and request body.""" + auth_manager = Mock() + auth_manager.batch_is_authorized_pool.return_value = True + mock_get_auth_manager.return_value = auth_manager + mock_get_name_to_team_name_mapping.return_value = {"pool3": "team1", "pool4": "team1"} + user = Mock() + with conf_vars({("core", "multi_team"): "True"}): + request = BulkBody[PoolBody].model_validate( + { + "actions": [ + { + "action": "create", + "entities": [ + {"pool": "pool1", "slots": 1, "team_name": "team2"}, + {"pool": "pool2", "slots": 1}, + ], + }, + { + "action": "delete", + "entities": ["pool3"], + }, + { + "action": "create", + "entities": [ + {"pool": "pool4", "slots": 1, "team_name": "team2"}, + ], + "action_on_existence": "overwrite", + }, + ] + } + ) + requires_access_pool_bulk()(request, user) + + auth_manager.batch_is_authorized_pool.assert_called_once() + actual_requests = auth_manager.batch_is_authorized_pool.call_args.kwargs["requests"] + expected_requests = [ + {"method": "POST", "details": PoolDetails(name="pool1", team_name="team2")}, + {"method": "POST", "details": PoolDetails(name="pool2")}, + {"method": "DELETE", "details": PoolDetails(name="pool3", team_name="team1")}, + {"method": "POST", "details": PoolDetails(name="pool4", team_name="team1")}, + {"method": "POST", "details": PoolDetails(name="pool4", team_name="team2")}, + {"method": "PUT", "details": PoolDetails(name="pool4", team_name="team1")}, + {"method": "PUT", "details": PoolDetails(name="pool4", team_name="team2")}, + ] + assert sorted(actual_requests, key=str) == sorted(expected_requests, key=str) + class TestAuthManagerDependency: """Test the auth_manager_from_app dependency function."""