Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,20 @@ class QueuedEventCollectionResponse(BaseModel):
total_entries: int


class AssetEventAccessControl(StrictBaseModel):
"""Access control settings for asset event consumer team filtering."""

consumer_teams: list[str] | None = None
allow_global: bool = True


class CreateAssetEventsBody(StrictBaseModel):
"""Create asset events request."""

asset_id: int
partition_key: str | None = None
extra: dict = Field(default_factory=dict)
access_control: AssetEventAccessControl | None = None

@field_validator("extra", mode="after")
def set_from_rest_api(cls, v: dict) -> dict:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11434,6 +11434,23 @@ components:
- total_entries
title: AssetCollectionResponse
description: Asset collection response.
AssetEventAccessControl:
properties:
consumer_teams:
anyOf:
- items:
type: string
type: array
- type: 'null'
title: Consumer Teams
allow_global:
type: boolean
title: Allow Global
default: true
additionalProperties: false
type: object
title: AssetEventAccessControl
description: Access control settings for asset event consumer team filtering.
AssetEventCollectionResponse:
properties:
asset_events:
Expand Down Expand Up @@ -12951,6 +12968,10 @@ components:
additionalProperties: true
type: object
title: Extra
access_control:
anyOf:
- $ref: '#/components/schemas/AssetEventAccessControl'
- type: 'null'
additionalProperties: false
type: object
required:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -375,8 +375,13 @@ def create_asset_event(
timestamp = timezone.utcnow()

api_user_teams: set[str] = set()
api_allow_consumer_teams: list[str] | None = None
api_allow_global_consumers: bool = True
if conf.getboolean("core", "multi_team"):
api_user_teams = get_auth_manager().get_authorized_teams(user=user)
if body.access_control:
api_allow_consumer_teams = body.access_control.consumer_teams or None
api_allow_global_consumers = body.access_control.allow_global

assets_event = asset_manager.register_asset_change(
asset=asset_model,
Expand All @@ -385,6 +390,8 @@ def create_asset_event(
partition_key=body.partition_key,
source_is_api=True,
api_user_teams=api_user_teams,
api_allow_consumer_teams=api_allow_consumer_teams,
api_allow_global_consumers=api_allow_global_consumers,
session=session,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from airflow.utils.types import DagRunType

from tests_common.test_utils.asserts import assert_queries_count
from tests_common.test_utils.config import conf_vars
from tests_common.test_utils.db import (
clear_db_assets,
clear_db_dag_bundles,
Expand Down Expand Up @@ -1395,8 +1396,6 @@ def test_should_update_assets_endpoint(self, test_client, session):
class TestPostAssetEventsTeamResolution(TestAssets):
"""Tests for team-based filtering in create_asset_event."""

_ROUTE = "airflow.api_fastapi.core_api.routes.public.assets"

def _make_mock_event(self, asset):
m = mock.MagicMock(
spec=AssetEvent,
Expand All @@ -1421,34 +1420,83 @@ def _make_mock_event(self, asset):
@pytest.mark.parametrize(
("multi_team", "expected_teams"),
[
pytest.param(True, {"team_a", "team_b"}, id="enabled"),
pytest.param(False, set(), id="disabled"),
pytest.param("True", {"team_a", "team_b"}, id="enabled"),
pytest.param("False", set(), id="disabled"),
],
)
def test_team_resolution(self, test_client, session, multi_team, expected_teams):
@mock.patch("airflow.api_fastapi.core_api.routes.public.assets.asset_manager.register_asset_change")
@mock.patch("airflow.api_fastapi.core_api.routes.public.assets.get_auth_manager")
def test_team_resolution(
self, mock_get_auth_manager, mock_register, test_client, session, multi_team, expected_teams
):
(asset,) = self.create_assets(num=1, session=session)
mock_auth_mgr = mock.MagicMock()
mock_auth_mgr.get_authorized_teams.return_value = {"team_a", "team_b"}
mock_get_auth_manager.return_value.get_authorized_teams.return_value = {"team_a", "team_b"}
mock_register.return_value = self._make_mock_event(asset)

with (
mock.patch(
f"{self._ROUTE}.conf.getboolean",
side_effect=lambda s, k, **kw: multi_team if k == "multi_team" else kw.get("fallback"),
),
mock.patch(f"{self._ROUTE}.get_auth_manager", return_value=mock_auth_mgr),
mock.patch(
f"{self._ROUTE}.asset_manager.register_asset_change",
spec=True,
return_value=self._make_mock_event(asset),
) as mock_register,
):
with conf_vars({("core", "multi_team"): multi_team}):
response = test_client.post("/assets/events", json={"asset_id": asset.id, "extra": {}})

assert response.status_code == 200
call_kwargs = mock_register.call_args.kwargs
assert call_kwargs["source_is_api"] is True
assert call_kwargs["api_user_teams"] == expected_teams

@pytest.mark.usefixtures("time_freezer")
@pytest.mark.parametrize(
("multi_team", "access_control", "expected_consumer_teams", "expected_allow_global"),
[
pytest.param(
"True",
{"consumer_teams": ["team_ml", "team_data"], "allow_global": False},
["team_ml", "team_data"],
False,
id="multi_team_enabled_with_consumer_teams",
),
pytest.param(
"True",
None,
None,
True,
id="multi_team_enabled_no_access_control",
),
pytest.param(
"False",
{"consumer_teams": ["team_ml"], "allow_global": False},
None,
True,
id="multi_team_disabled_access_control_ignored",
),
],
)
@mock.patch("airflow.api_fastapi.core_api.routes.public.assets.asset_manager.register_asset_change")
@mock.patch("airflow.api_fastapi.core_api.routes.public.assets.get_auth_manager")
def test_access_control_consumer_teams(
self,
mock_get_auth_manager,
mock_register,
test_client,
session,
multi_team,
access_control,
expected_consumer_teams,
expected_allow_global,
):
(asset,) = self.create_assets(num=1, session=session)
mock_get_auth_manager.return_value.get_authorized_teams.return_value = {"team_a"}
mock_register.return_value = self._make_mock_event(asset)

payload = {"asset_id": asset.id, "extra": {}}
if access_control is not None:
payload["access_control"] = access_control

with conf_vars({("core", "multi_team"): multi_team}):
response = test_client.post("/assets/events", json=payload)

assert response.status_code == 200
call_kwargs = mock_register.call_args.kwargs
assert call_kwargs["api_allow_consumer_teams"] == expected_consumer_teams
assert call_kwargs["api_allow_global_consumers"] == expected_allow_global


@pytest.mark.need_serialized_dag
class TestPostAssetMaterialize(TestAssets):
Expand Down
13 changes: 13 additions & 0 deletions airflow-ctl/src/airflowctl/api/datamodels/generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,18 @@ class AssetAliasResponse(BaseModel):
group: Annotated[str, Field(title="Group")]


class AssetEventAccessControl(BaseModel):
"""
Access control settings for asset event consumer team filtering.
"""

model_config = ConfigDict(
extra="forbid",
)
consumer_teams: Annotated[list[str] | None, Field(title="Consumer Teams")] = None
allow_global: Annotated[bool | None, Field(title="Allow Global")] = True


class AssetStoreWriterKind(str, Enum):
"""
Identifies what kind of writer last updated an asset store entry.
Expand Down Expand Up @@ -415,6 +427,7 @@ class CreateAssetEventsBody(BaseModel):
asset_id: Annotated[int, Field(title="Asset Id")]
partition_key: Annotated[str | None, Field(title="Partition Key")] = None
extra: Annotated[dict[str, Any] | None, Field(title="Extra")] = None
access_control: AssetEventAccessControl | None = None


class DAGPatchBody(BaseModel):
Expand Down
Loading