Skip to content

Commit

Permalink
Add back system test for AWS auth manager (#38044)
Browse files Browse the repository at this point in the history
  • Loading branch information
vincbeck committed Mar 12, 2024
1 parent 8fc9848 commit a192751
Show file tree
Hide file tree
Showing 6 changed files with 255 additions and 26 deletions.
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/auth_manager/views/auth.py
Expand Up @@ -93,7 +93,7 @@ def login_callback(self):
user_id=attributes["id"][0],
groups=attributes["groups"],
username=saml_auth.get_nameid(),
email=attributes["email"][0],
email=attributes["email"][0] if "email" in attributes else None,
)
session["aws_user"] = user

Expand Down
12 changes: 8 additions & 4 deletions tests/conftest.py
Expand Up @@ -1092,11 +1092,15 @@ def refuse_to_run_test_from_wrongly_named_files(request):
dirname: str = request.node.fspath.dirname
filename: str = request.node.fspath.basename
is_system_test: bool = "tests/system/" in dirname
if is_system_test and not request.node.fspath.basename.startswith("example_"):
if is_system_test and not (
request.node.fspath.basename.startswith("example_")
or request.node.fspath.basename.startswith("test_")
):
raise Exception(
f"All test method files in tests/system must start with 'example_'. Seems that {filename} "
f"contains {request.function} that looks like a test case. Please rename the file to "
f"follow the example_* pattern if you want to run the tests in it."
f"All test method files in tests/system must start with 'example_' or 'test_'. "
f"Seems that {filename} contains {request.function} that looks like a test case. "
f"Please rename the file to follow the example_* or test_* pattern if you want to run the tests "
f"in it."
)
if not is_system_test and not request.node.fspath.basename.startswith("test_"):
raise Exception(
Expand Down
33 changes: 15 additions & 18 deletions tests/providers/amazon/aws/auth_manager/views/test_auth.py
Expand Up @@ -48,24 +48,21 @@

@pytest.fixture
def aws_app():
def factory():
with conf_vars(
{
(
"core",
"auth_manager",
): "airflow.providers.amazon.aws.auth_manager.aws_auth_manager.AwsAuthManager",
("aws_auth_manager", "enable"): "True",
("aws_auth_manager", "saml_metadata_url"): SAML_METADATA_URL,
}
):
with patch(
"airflow.providers.amazon.aws.auth_manager.views.auth.OneLogin_Saml2_IdPMetadataParser"
) as mock_parser:
mock_parser.parse_remote.return_value = SAML_METADATA_PARSED
return application.create_app(testing=True)

return factory()
with conf_vars(
{
(
"core",
"auth_manager",
): "airflow.providers.amazon.aws.auth_manager.aws_auth_manager.AwsAuthManager",
("aws_auth_manager", "enable"): "True",
("aws_auth_manager", "saml_metadata_url"): SAML_METADATA_URL,
}
):
with patch(
"airflow.providers.amazon.aws.auth_manager.views.auth.OneLogin_Saml2_IdPMetadataParser"
) as mock_parser:
mock_parser.parse_remote.return_value = SAML_METADATA_PARSED
return application.create_app(testing=True)


@pytest.mark.db_test
Expand Down
16 changes: 16 additions & 0 deletions tests/system/providers/amazon/aws/tests/__init__.py
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
210 changes: 210 additions & 0 deletions tests/system/providers/amazon/aws/tests/test_aws_auth_manager.py
@@ -0,0 +1,210 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from pathlib import Path
from unittest.mock import Mock, patch

import boto3
import pytest

from airflow.www import app as application
from tests.system.providers.amazon.aws.utils import set_env_id
from tests.test_utils.config import conf_vars
from tests.test_utils.www import check_content_in_response

pytest.importorskip("onelogin")

SAML_METADATA_URL = "/saml/metadata"
SAML_METADATA_PARSED = {
"idp": {
"entityId": "https://portal.sso.us-east-1.amazonaws.com/saml/assertion/<assertion>",
"singleSignOnService": {
"url": "https://portal.sso.us-east-1.amazonaws.com/saml/assertion/<assertion>",
"binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect",
},
"singleLogoutService": {
"url": "https://portal.sso.us-east-1.amazonaws.com/saml/logout/<assertion>",
"binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect",
},
"x509cert": "<cert>",
},
"security": {"authnRequestsSigned": False},
"sp": {"NameIDFormat": "urn:oasis:names:tc:SAML:2.0:nameid-format:transient"},
}

AVP_POLICY_ADMIN = """
permit (
principal in Airflow::Role::"Admin",
action,
resource
);
"""

env_id_cache: str | None = None
policy_store_id_cache: str | None = None


def create_avp_policy_store(env_id):
description = f"Created by system test TestAwsAuthManager: {env_id}"
client = boto3.client("verifiedpermissions")
response = client.create_policy_store(
validationSettings={"mode": "OFF"},
description=description,
)
policy_store_id = response["policyStoreId"]

schema_path = (
Path(__file__)
.parents[6]
.joinpath("airflow", "providers", "amazon", "aws", "auth_manager", "cli", "schema.json")
.resolve()
)
with open(schema_path) as schema_file:
client.put_schema(
policyStoreId=policy_store_id,
definition={
"cedarJson": schema_file.read(),
},
)

client.update_policy_store(
policyStoreId=policy_store_id,
validationSettings={
"mode": "STRICT",
},
description=description,
)

client.create_policy(
policyStoreId=policy_store_id,
definition={
"static": {"description": "Admin permissions", "statement": AVP_POLICY_ADMIN},
},
)

return policy_store_id


@pytest.fixture
def env_id():
global env_id_cache
if not env_id_cache:
env_id_cache = set_env_id()
return env_id_cache


@pytest.fixture
def region_name():
return boto3.session.Session().region_name


@pytest.fixture
def avp_policy_store_id(env_id):
global policy_store_id_cache
if not policy_store_id_cache:
policy_store_id_cache = create_avp_policy_store(env_id)
return policy_store_id_cache


@pytest.fixture
def base_app(region_name, avp_policy_store_id):
with conf_vars(
{
(
"core",
"auth_manager",
): "airflow.providers.amazon.aws.auth_manager.aws_auth_manager.AwsAuthManager",
("aws_auth_manager", "enable"): "True",
("aws_auth_manager", "region_name"): region_name,
("aws_auth_manager", "saml_metadata_url"): SAML_METADATA_URL,
("aws_auth_manager", "avp_policy_store_id"): avp_policy_store_id,
}
):
with patch(
"airflow.providers.amazon.aws.auth_manager.views.auth.OneLogin_Saml2_IdPMetadataParser"
) as mock_parser, patch(
"airflow.providers.amazon.aws.auth_manager.views.auth.AwsAuthManagerAuthenticationViews._init_saml_auth"
) as mock_init_saml_auth:
mock_parser.parse_remote.return_value = SAML_METADATA_PARSED

yield mock_init_saml_auth


@pytest.fixture
def client_no_permissions(base_app):
auth = Mock()
auth.is_authenticated.return_value = True
auth.get_nameid.return_value = "user_no_permissions"
auth.get_attributes.return_value = {
"id": ["user_no_permissions"],
"groups": [],
"email": ["email"],
}
base_app.return_value = auth
return application.create_app(testing=True)


@pytest.fixture
def client_admin_permissions(base_app):
auth = Mock()
auth.is_authenticated.return_value = True
auth.get_nameid.return_value = "user_admin_permissions"
auth.get_attributes.return_value = {
"id": ["user_admin_permissions"],
"groups": ["Admin"],
}
base_app.return_value = auth
return application.create_app(testing=True)


@pytest.mark.system("amazon")
class TestAwsAuthManager:
"""
Run tests on Airflow using AWS auth manager with real credentials
"""

@classmethod
def teardown_class(cls):
cls.delete_avp_policy_store()

@classmethod
def delete_avp_policy_store(cls):
client = boto3.client("verifiedpermissions")

paginator = client.get_paginator("list_policy_stores")
pages = paginator.paginate()
policy_store_ids = [
store["policyStoreId"]
for page in pages
for store in page["policyStores"]
if "description" in store
and f"Created by system test TestAwsAuthManager: {env_id_cache}" in store["description"]
]

for policy_store_id in policy_store_ids:
client.delete_policy_store(policyStoreId=policy_store_id)

def test_login_no_permissions(self, client_no_permissions):
with client_no_permissions.test_client() as client:
response = client.get("/login_callback", follow_redirects=True)
check_content_in_response("Your user has no roles and/or permissions!", response, 403)

def test_login_admin(self, client_admin_permissions):
with client_admin_permissions.test_client() as client:
response = client.get("/login_callback", follow_redirects=True)
check_content_in_response("<h2>DAGs</h2>", response, 200)
8 changes: 5 additions & 3 deletions tests/system/providers/amazon/aws/utils/__init__.py
Expand Up @@ -43,8 +43,8 @@
DEFAULT_ENV_ID: str = f"{DEFAULT_ENV_ID_PREFIX}{uuid4()!s:.{DEFAULT_ENV_ID_LEN}}"
PURGE_LOGS_INTERVAL_PERIOD = 5

# All test file names will contain this string.
TEST_FILE_IDENTIFIER: str = "example"
# All test file names will contain one of these strings.
TEST_FILE_IDENTIFIERS: list[str] = ["example_", "test_"]

INVALID_ENV_ID_MSG: str = (
"In order to maximize compatibility, the SYSTEM_TESTS_ENV_ID must be an alphanumeric string "
Expand All @@ -68,7 +68,9 @@ def _get_test_name() -> str:
# The exact layer of the stack will depend on if this is called directly
# or from another helper, but the test will always contain the identifier.
test_filename: str = next(
frame.filename for frame in inspect.stack() if TEST_FILE_IDENTIFIER in frame.filename
frame.filename
for frame in inspect.stack()
if any(identifier in frame.filename for identifier in TEST_FILE_IDENTIFIERS)
)
return Path(test_filename).stem

Expand Down

0 comments on commit a192751

Please sign in to comment.