Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -247,3 +247,67 @@ def test_get_oauth_user_info(self, provider, resp, user_info):
sm._decode_and_validate_azure_jwt = Mock(return_value=resp)
sm._get_authentik_token_info = Mock(return_value=resp)
assert sm.get_oauth_user_info(provider, {"id_token": None}) == user_info

def test_get_oauth_user_info_azure_with_groups_config(self):
from flask import Flask

app = Flask(__name__)
app.config["AUTH_OAUTH_ROLE_KEYS"] = {"azure": "groups"}

azure_response = {
"oid": "user-123",
"given_name": "Jane",
"family_name": "Smith",
"email": "jane.smith@example.com",
"groups": ["admin-group", "viewer-group"],
}

with app.app_context():
sm = EmptySecurityManager()
sm.appbuilder = Mock(sm=sm)
sm.oauth_remotes = {}
sm._decode_and_validate_azure_jwt = Mock(return_value=azure_response)

user_info = sm.get_oauth_user_info("azure", {"id_token": "test-token"})

assert user_info["username"] == "user-123"
assert user_info["email"] == "jane.smith@example.com"
assert user_info["role_keys"] == ["admin-group", "viewer-group"]

@mock.patch("airflow.providers.fab.auth_manager.security_manager.override.flash")
@mock.patch("airflow.providers.fab.auth_manager.security_manager.override.has_request_context")
def test_cli_safe_flash_escapes_html_username(
self,
mock_has_request_context,
mock_flash,
):
mock_has_request_context.return_value = True

malicious_username = "<script>alert('xss')</script>"
safe_username = "&lt;script&gt;alert(&#39;xss&#39;)&lt;/script&gt;"

EmptySecurityManager._cli_safe_flash(
f"User {safe_username} already exists",
"warning",
)

flash_arg = mock_flash.call_args[0][0]

assert malicious_username not in str(flash_arg)
assert safe_username in str(flash_arg)

@mock.patch("airflow.providers.fab.auth_manager.security_manager.override.log")
@mock.patch("airflow.providers.fab.auth_manager.security_manager.override.has_request_context")
def test_cli_safe_flash_formats_cli_output(
self,
mock_has_request_context,
mock_log,
):
mock_has_request_context.return_value = False

EmptySecurityManager._cli_safe_flash(
"Hello<br><b>World</b>",
"warning",
)

mock_log.warning.assert_called_once_with("Hello\n*World*")
Loading