Skip to content

Commit

Permalink
Add check in AWS auth manager to check if the Amazon Verified Permiss…
Browse files Browse the repository at this point in the history
…ions schema is up to date (#38333)
  • Loading branch information
vincbeck committed Mar 21, 2024
1 parent bce63b2 commit ea951af
Show file tree
Hide file tree
Showing 9 changed files with 142 additions and 112 deletions.
15 changes: 15 additions & 0 deletions airflow/providers/amazon/aws/auth_manager/avp/facade.py
Expand Up @@ -16,7 +16,9 @@
# under the License.
from __future__ import annotations

import json
from functools import cached_property
from pathlib import Path
from typing import TYPE_CHECKING, Sequence, TypedDict

from airflow.configuration import conf
Expand Down Expand Up @@ -222,6 +224,19 @@ def get_batch_is_authorized_single_result(
)
raise AirflowException("Could not find the authorization result.")

def is_policy_store_schema_up_to_date(self) -> bool:
"""Return whether the policy store schema equals the latest version of the schema."""
resp = self.avp_client.get_schema(
policyStoreId=self.avp_policy_store_id,
)
policy_store_schema = json.loads(resp["schema"])

schema_path = Path(__file__).parents[0] / "schema.json"
with open(schema_path) as schema_file:
latest_schema = json.load(schema_file)

return policy_store_schema == latest_schema

@staticmethod
def _get_user_group_entities(user: AwsAuthManagerUser) -> list[dict]:
user_entity = {
Expand Down
10 changes: 10 additions & 0 deletions airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
Expand Up @@ -88,6 +88,7 @@ class AwsAuthManager(BaseAuthManager):
def __init__(self, appbuilder: AirflowAppBuilder) -> None:
super().__init__(appbuilder)
enable = conf.getboolean(CONF_SECTION_NAME, CONF_ENABLE_KEY)
self._check_avp_schema_version()
if not enable:
raise NotImplementedError(
"The AWS auth manager is currently being built. It is not finalized. It is not intended to be used yet."
Expand Down Expand Up @@ -430,6 +431,15 @@ def _get_menu_item_request(resource_name: str) -> IsAuthorizedRequest:
"entity_id": resource_name,
}

def _check_avp_schema_version(self):
if not self.avp_facade.is_policy_store_schema_up_to_date():
self.log.warning(
"The Amazon Verified Permissions policy store schema is different from the latest version "
"(https://github.com/apache/airflow/blob/main/airflow/providers/amazon/aws/auth_manager/avp/schema.json). "
"Please update it to its latest version. "
"See doc: https://airflow.apache.org/docs/apache-airflow-providers-amazon/stable/auth-manager/setup/amazon-verified-permissions.html#update-the-policy-store-schema."
)


def get_parser() -> argparse.ArgumentParser:
"""Generate documentation; used by Sphinx argparse."""
Expand Down
Expand Up @@ -138,7 +138,7 @@ def _set_schema(client: BaseClient, policy_store_id: str, args) -> None:
print(f"Dry run, not updating the schema of the policy store with ID '{policy_store_id}'.")
return

schema_path = Path(__file__).parents[0].joinpath("schema.json").resolve()
schema_path = Path(__file__).parents[1] / "avp" / "schema.json"
with open(schema_path) as schema_file:
response = client.put_schema(
policyStoreId=policy_store_id,
Expand Down
Expand Up @@ -69,7 +69,7 @@
),
ActionCommand(
name="update-avp-schema",
help="Update Amazon Verified permissions policy store schema to the latest version in 'airflow/providers/amazon/aws/auth_manager/cli/schema.json'",
help="Update Amazon Verified permissions policy store schema to the latest version in 'airflow/providers/amazon/aws/auth_manager/avp/schema.json'",
func=lazy_load_command("airflow.providers.amazon.aws.auth_manager.cli.avp_commands.update_schema"),
args=(ARG_POLICY_STORE_ID, ARG_DRY_RUN, ARG_VERBOSE),
),
Expand Down
197 changes: 94 additions & 103 deletions tests/providers/amazon/aws/auth_manager/avp/test_facade.py
Expand Up @@ -16,6 +16,8 @@
# under the License.
from __future__ import annotations

import json
from pathlib import Path
from typing import TYPE_CHECKING
from unittest.mock import Mock

Expand Down Expand Up @@ -43,6 +45,7 @@ def facade():
with conf_vars(
{
("aws_auth_manager", "region_name"): REGION_NAME,
("aws_auth_manager", "avp_policy_store_id"): AVP_POLICY_STORE_ID,
}
):
yield AwsAuthManagerAmazonVerifiedPermissionsFacade()
Expand All @@ -53,27 +56,17 @@ def test_avp_client(self, facade):
assert hasattr(facade, "avp_client")

def test_avp_policy_store_id(self, facade):
with conf_vars(
{
("aws_auth_manager", "avp_policy_store_id"): AVP_POLICY_STORE_ID,
}
):
assert hasattr(facade, "avp_policy_store_id")
assert hasattr(facade, "avp_policy_store_id")

def test_is_authorized_no_user(self, facade):
method: ResourceMethod = "GET"
entity_type = AvpEntities.VARIABLE

with conf_vars(
{
("aws_auth_manager", "avp_policy_store_id"): AVP_POLICY_STORE_ID,
}
):
result = facade.is_authorized(
method=method,
entity_type=entity_type,
user=None,
)
result = facade.is_authorized(
method=method,
entity_type=entity_type,
user=None,
)

assert result is False

Expand Down Expand Up @@ -178,18 +171,13 @@ def test_is_authorized_successful(
method: ResourceMethod = "GET"
entity_type = AvpEntities.VARIABLE

with conf_vars(
{
("aws_auth_manager", "avp_policy_store_id"): AVP_POLICY_STORE_ID,
}
):
result = facade.is_authorized(
method=method,
entity_type=entity_type,
entity_id=entity_id,
user=user,
context=context,
)
result = facade.is_authorized(
method=method,
entity_type=entity_type,
entity_id=entity_id,
user=user,
context=context,
)

params = prune_dict(
{
Expand All @@ -211,15 +199,8 @@ def test_is_authorized_unsuccessful(self, facade):
mock_is_authorized = Mock(return_value=avp_response)
facade.avp_client.is_authorized = mock_is_authorized

with conf_vars(
{
("aws_auth_manager", "avp_policy_store_id"): AVP_POLICY_STORE_ID,
}
):
with pytest.raises(
AirflowException, match="Error occurred while making an authorization decision."
):
facade.is_authorized(method="GET", entity_type=AvpEntities.VARIABLE, user=test_user)
with pytest.raises(AirflowException, match="Error occurred while making an authorization decision."):
facade.is_authorized(method="GET", entity_type=AvpEntities.VARIABLE, user=test_user)

@pytest.mark.parametrize(
"user, avp_response, expected",
Expand All @@ -245,18 +226,13 @@ def test_batch_is_authorized_successful(self, facade, user, avp_response, expect
mock_batch_is_authorized = Mock(return_value=avp_response)
facade.avp_client.batch_is_authorized = mock_batch_is_authorized

with conf_vars(
{
("aws_auth_manager", "avp_policy_store_id"): AVP_POLICY_STORE_ID,
}
):
result = facade.batch_is_authorized(
requests=[
{"method": "GET", "entity_type": AvpEntities.VARIABLE, "entity_id": "var1"},
{"method": "GET", "entity_type": AvpEntities.VARIABLE, "entity_id": "var1"},
],
user=user,
)
result = facade.batch_is_authorized(
requests=[
{"method": "GET", "entity_type": AvpEntities.VARIABLE, "entity_id": "var1"},
{"method": "GET", "entity_type": AvpEntities.VARIABLE, "entity_id": "var1"},
],
user=user,
)

assert result == expected

Expand All @@ -265,21 +241,16 @@ def test_batch_is_authorized_unsuccessful(self, facade):
mock_batch_is_authorized = Mock(return_value=avp_response)
facade.avp_client.batch_is_authorized = mock_batch_is_authorized

with conf_vars(
{
("aws_auth_manager", "avp_policy_store_id"): AVP_POLICY_STORE_ID,
}
with pytest.raises(
AirflowException, match="Error occurred while making a batch authorization decision."
):
with pytest.raises(
AirflowException, match="Error occurred while making a batch authorization decision."
):
facade.batch_is_authorized(
requests=[
{"method": "GET", "entity_type": AvpEntities.VARIABLE, "entity_id": "var1"},
{"method": "GET", "entity_type": AvpEntities.VARIABLE, "entity_id": "var1"},
],
user=test_user,
)
facade.batch_is_authorized(
requests=[
{"method": "GET", "entity_type": AvpEntities.VARIABLE, "entity_id": "var1"},
{"method": "GET", "entity_type": AvpEntities.VARIABLE, "entity_id": "var1"},
],
user=test_user,
)

def test_get_batch_is_authorized_single_result_successful(self, facade):
single_result = {
Expand All @@ -291,12 +262,30 @@ def test_get_batch_is_authorized_single_result_successful(self, facade):
"decision": "ALLOW",
}

with conf_vars(
{
("aws_auth_manager", "avp_policy_store_id"): AVP_POLICY_STORE_ID,
}
):
result = facade.get_batch_is_authorized_single_result(
result = facade.get_batch_is_authorized_single_result(
batch_is_authorized_results=[
{
"request": {
"principal": {"entityType": "Airflow::User", "entityId": "test_user"},
"action": {"actionType": "Airflow::Action", "actionId": "Variable.GET"},
"resource": {"entityType": "Airflow::Variable", "entityId": "*"},
},
"decision": "ALLOW",
},
single_result,
],
request={
"method": "GET",
"entity_type": AvpEntities.CONNECTION,
},
user=test_user,
)

assert result == single_result

def test_get_batch_is_authorized_single_result_unsuccessful(self, facade):
with pytest.raises(AirflowException, match="Could not find the authorization result."):
facade.get_batch_is_authorized_single_result(
batch_is_authorized_results=[
{
"request": {
Expand All @@ -306,7 +295,14 @@ def test_get_batch_is_authorized_single_result_successful(self, facade):
},
"decision": "ALLOW",
},
single_result,
{
"request": {
"principal": {"entityType": "Airflow::User", "entityId": "test_user"},
"action": {"actionType": "Airflow::Action", "actionId": "Variable.POST"},
"resource": {"entityType": "Airflow::Variable", "entityId": "*"},
},
"decision": "ALLOW",
},
],
request={
"method": "GET",
Expand All @@ -315,37 +311,32 @@ def test_get_batch_is_authorized_single_result_successful(self, facade):
user=test_user,
)

assert result == single_result
def test_is_policy_store_schema_up_to_date_when_schema_up_to_date(self, facade):
schema_path = (
Path(__file__)
.parents[6]
.joinpath("airflow", "providers", "amazon", "aws", "auth_manager", "avp", "schema.json")
.resolve()
)
with open(schema_path) as schema_file:
avp_response = {"schema": schema_file.read()}
mock_get_schema = Mock(return_value=avp_response)
facade.avp_client.get_schema = mock_get_schema

def test_get_batch_is_authorized_single_result_unsuccessful(self, facade):
with conf_vars(
{
("aws_auth_manager", "avp_policy_store_id"): AVP_POLICY_STORE_ID,
}
):
with pytest.raises(AirflowException, match="Could not find the authorization result."):
facade.get_batch_is_authorized_single_result(
batch_is_authorized_results=[
{
"request": {
"principal": {"entityType": "Airflow::User", "entityId": "test_user"},
"action": {"actionType": "Airflow::Action", "actionId": "Variable.GET"},
"resource": {"entityType": "Airflow::Variable", "entityId": "*"},
},
"decision": "ALLOW",
},
{
"request": {
"principal": {"entityType": "Airflow::User", "entityId": "test_user"},
"action": {"actionType": "Airflow::Action", "actionId": "Variable.POST"},
"resource": {"entityType": "Airflow::Variable", "entityId": "*"},
},
"decision": "ALLOW",
},
],
request={
"method": "GET",
"entity_type": AvpEntities.CONNECTION,
},
user=test_user,
)
assert facade.is_policy_store_schema_up_to_date()

def test_is_policy_store_schema_up_to_date_when_schema_is_modified(self, facade):
schema_path = (
Path(__file__)
.parents[6]
.joinpath("airflow", "providers", "amazon", "aws", "auth_manager", "avp", "schema.json")
.resolve()
)
with open(schema_path) as schema_file:
schema = json.loads(schema_file.read())
schema["new_field"] = "new_value"
avp_response = {"schema": json.dumps(schema)}
mock_get_schema = Mock(return_value=avp_response)
facade.avp_client.get_schema = mock_get_schema

assert not facade.is_policy_store_schema_up_to_date()
11 changes: 8 additions & 3 deletions tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py
Expand Up @@ -86,7 +86,8 @@ def auth_manager():
("aws_auth_manager", "enable"): "True",
}
):
return AwsAuthManager(None)
with patch.object(AwsAuthManager, "_check_avp_schema_version"):
return AwsAuthManager(None)


@pytest.fixture
Expand All @@ -102,7 +103,8 @@ def auth_manager_with_appbuilder():
("aws_auth_manager", "enable"): "True",
}
):
return AwsAuthManager(appbuilder)
with patch.object(AwsAuthManager, "_check_avp_schema_version"):
return AwsAuthManager(appbuilder)


@pytest.fixture
Expand All @@ -128,8 +130,11 @@ def client_admin():
"airflow.providers.amazon.aws.auth_manager.views.auth.OneLogin_Saml2_IdPMetadataParser"
) as mock_parser, patch(
"airflow.providers.amazon.aws.auth_manager.views.auth.AwsAuthManagerAuthenticationViews._init_saml_auth"
) as mock_init_saml_auth:
) as mock_init_saml_auth, patch(
"airflow.providers.amazon.aws.auth_manager.avp.facade.AwsAuthManagerAmazonVerifiedPermissionsFacade.is_policy_store_schema_up_to_date"
) as mock_is_policy_store_schema_up_to_date:
mock_parser.parse_remote.return_value = SAML_METADATA_PARSED
mock_is_policy_store_schema_up_to_date.return_value = True

auth = Mock()
auth.is_authenticated.return_value = True
Expand Down

0 comments on commit ea951af

Please sign in to comment.