diff --git a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/assets.py b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/assets.py index f463f3ba5e3e5..753ea94f1d9a2 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/assets.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/assets.py @@ -179,12 +179,20 @@ class QueuedEventCollectionResponse(BaseModel): total_entries: int +class AssetEventAccessControl(StrictBaseModel): + """Access control settings for asset event consumer team filtering.""" + + consumer_teams: list[str] | None = None + allow_global: bool = True + + class CreateAssetEventsBody(StrictBaseModel): """Create asset events request.""" asset_id: int partition_key: str | None = None extra: dict = Field(default_factory=dict) + access_control: AssetEventAccessControl | None = None @field_validator("extra", mode="after") def set_from_rest_api(cls, v: dict) -> dict: diff --git a/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml b/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml index 6c0ec00832237..42fc01244b31d 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml +++ b/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml @@ -11434,6 +11434,23 @@ components: - total_entries title: AssetCollectionResponse description: Asset collection response. + AssetEventAccessControl: + properties: + consumer_teams: + anyOf: + - items: + type: string + type: array + - type: 'null' + title: Consumer Teams + allow_global: + type: boolean + title: Allow Global + default: true + additionalProperties: false + type: object + title: AssetEventAccessControl + description: Access control settings for asset event consumer team filtering. AssetEventCollectionResponse: properties: asset_events: @@ -12951,6 +12968,10 @@ components: additionalProperties: true type: object title: Extra + access_control: + anyOf: + - $ref: '#/components/schemas/AssetEventAccessControl' + - type: 'null' additionalProperties: false type: object required: diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/assets.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/assets.py index 5e3afa8b3c00a..9525b1f796dbd 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/assets.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/assets.py @@ -375,8 +375,13 @@ def create_asset_event( timestamp = timezone.utcnow() api_user_teams: set[str] = set() + api_allow_consumer_teams: list[str] | None = None + api_allow_global_consumers: bool = True if conf.getboolean("core", "multi_team"): api_user_teams = get_auth_manager().get_authorized_teams(user=user) + if body.access_control: + api_allow_consumer_teams = body.access_control.consumer_teams or None + api_allow_global_consumers = body.access_control.allow_global assets_event = asset_manager.register_asset_change( asset=asset_model, @@ -385,6 +390,8 @@ def create_asset_event( partition_key=body.partition_key, source_is_api=True, api_user_teams=api_user_teams, + api_allow_consumer_teams=api_allow_consumer_teams, + api_allow_global_consumers=api_allow_global_consumers, session=session, ) diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_assets.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_assets.py index 6de32ee53c263..7222444103143 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_assets.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_assets.py @@ -47,6 +47,7 @@ from airflow.utils.types import DagRunType from tests_common.test_utils.asserts import assert_queries_count +from tests_common.test_utils.config import conf_vars from tests_common.test_utils.db import ( clear_db_assets, clear_db_dag_bundles, @@ -1395,8 +1396,6 @@ def test_should_update_assets_endpoint(self, test_client, session): class TestPostAssetEventsTeamResolution(TestAssets): """Tests for team-based filtering in create_asset_event.""" - _ROUTE = "airflow.api_fastapi.core_api.routes.public.assets" - def _make_mock_event(self, asset): m = mock.MagicMock( spec=AssetEvent, @@ -1421,27 +1420,20 @@ def _make_mock_event(self, asset): @pytest.mark.parametrize( ("multi_team", "expected_teams"), [ - pytest.param(True, {"team_a", "team_b"}, id="enabled"), - pytest.param(False, set(), id="disabled"), + pytest.param("True", {"team_a", "team_b"}, id="enabled"), + pytest.param("False", set(), id="disabled"), ], ) - def test_team_resolution(self, test_client, session, multi_team, expected_teams): + @mock.patch("airflow.api_fastapi.core_api.routes.public.assets.asset_manager.register_asset_change") + @mock.patch("airflow.api_fastapi.core_api.routes.public.assets.get_auth_manager") + def test_team_resolution( + self, mock_get_auth_manager, mock_register, test_client, session, multi_team, expected_teams + ): (asset,) = self.create_assets(num=1, session=session) - mock_auth_mgr = mock.MagicMock() - mock_auth_mgr.get_authorized_teams.return_value = {"team_a", "team_b"} + mock_get_auth_manager.return_value.get_authorized_teams.return_value = {"team_a", "team_b"} + mock_register.return_value = self._make_mock_event(asset) - with ( - mock.patch( - f"{self._ROUTE}.conf.getboolean", - side_effect=lambda s, k, **kw: multi_team if k == "multi_team" else kw.get("fallback"), - ), - mock.patch(f"{self._ROUTE}.get_auth_manager", return_value=mock_auth_mgr), - mock.patch( - f"{self._ROUTE}.asset_manager.register_asset_change", - spec=True, - return_value=self._make_mock_event(asset), - ) as mock_register, - ): + with conf_vars({("core", "multi_team"): multi_team}): response = test_client.post("/assets/events", json={"asset_id": asset.id, "extra": {}}) assert response.status_code == 200 @@ -1449,6 +1441,62 @@ def test_team_resolution(self, test_client, session, multi_team, expected_teams) assert call_kwargs["source_is_api"] is True assert call_kwargs["api_user_teams"] == expected_teams + @pytest.mark.usefixtures("time_freezer") + @pytest.mark.parametrize( + ("multi_team", "access_control", "expected_consumer_teams", "expected_allow_global"), + [ + pytest.param( + "True", + {"consumer_teams": ["team_ml", "team_data"], "allow_global": False}, + ["team_ml", "team_data"], + False, + id="multi_team_enabled_with_consumer_teams", + ), + pytest.param( + "True", + None, + None, + True, + id="multi_team_enabled_no_access_control", + ), + pytest.param( + "False", + {"consumer_teams": ["team_ml"], "allow_global": False}, + None, + True, + id="multi_team_disabled_access_control_ignored", + ), + ], + ) + @mock.patch("airflow.api_fastapi.core_api.routes.public.assets.asset_manager.register_asset_change") + @mock.patch("airflow.api_fastapi.core_api.routes.public.assets.get_auth_manager") + def test_access_control_consumer_teams( + self, + mock_get_auth_manager, + mock_register, + test_client, + session, + multi_team, + access_control, + expected_consumer_teams, + expected_allow_global, + ): + (asset,) = self.create_assets(num=1, session=session) + mock_get_auth_manager.return_value.get_authorized_teams.return_value = {"team_a"} + mock_register.return_value = self._make_mock_event(asset) + + payload = {"asset_id": asset.id, "extra": {}} + if access_control is not None: + payload["access_control"] = access_control + + with conf_vars({("core", "multi_team"): multi_team}): + response = test_client.post("/assets/events", json=payload) + + assert response.status_code == 200 + call_kwargs = mock_register.call_args.kwargs + assert call_kwargs["api_allow_consumer_teams"] == expected_consumer_teams + assert call_kwargs["api_allow_global_consumers"] == expected_allow_global + @pytest.mark.need_serialized_dag class TestPostAssetMaterialize(TestAssets): diff --git a/airflow-ctl/src/airflowctl/api/datamodels/generated.py b/airflow-ctl/src/airflowctl/api/datamodels/generated.py index fff609bcaea17..b438fc57a5fe5 100644 --- a/airflow-ctl/src/airflowctl/api/datamodels/generated.py +++ b/airflow-ctl/src/airflowctl/api/datamodels/generated.py @@ -49,6 +49,18 @@ class AssetAliasResponse(BaseModel): group: Annotated[str, Field(title="Group")] +class AssetEventAccessControl(BaseModel): + """ + Access control settings for asset event consumer team filtering. + """ + + model_config = ConfigDict( + extra="forbid", + ) + consumer_teams: Annotated[list[str] | None, Field(title="Consumer Teams")] = None + allow_global: Annotated[bool | None, Field(title="Allow Global")] = True + + class AssetStoreWriterKind(str, Enum): """ Identifies what kind of writer last updated an asset store entry. @@ -415,6 +427,7 @@ class CreateAssetEventsBody(BaseModel): asset_id: Annotated[int, Field(title="Asset Id")] partition_key: Annotated[str | None, Field(title="Partition Key")] = None extra: Annotated[dict[str, Any] | None, Field(title="Extra")] = None + access_control: AssetEventAccessControl | None = None class DAGPatchBody(BaseModel):