diff --git a/providers/fab/tests/unit/fab/auth_manager/security_manager/test_override.py b/providers/fab/tests/unit/fab/auth_manager/security_manager/test_override.py index 7b9e9696e5776..fe5117ccd66db 100644 --- a/providers/fab/tests/unit/fab/auth_manager/security_manager/test_override.py +++ b/providers/fab/tests/unit/fab/auth_manager/security_manager/test_override.py @@ -247,3 +247,67 @@ def test_get_oauth_user_info(self, provider, resp, user_info): sm._decode_and_validate_azure_jwt = Mock(return_value=resp) sm._get_authentik_token_info = Mock(return_value=resp) assert sm.get_oauth_user_info(provider, {"id_token": None}) == user_info + + def test_get_oauth_user_info_azure_with_groups_config(self): + from flask import Flask + + app = Flask(__name__) + app.config["AUTH_OAUTH_ROLE_KEYS"] = {"azure": "groups"} + + azure_response = { + "oid": "user-123", + "given_name": "Jane", + "family_name": "Smith", + "email": "jane.smith@example.com", + "groups": ["admin-group", "viewer-group"], + } + + with app.app_context(): + sm = EmptySecurityManager() + sm.appbuilder = Mock(sm=sm) + sm.oauth_remotes = {} + sm._decode_and_validate_azure_jwt = Mock(return_value=azure_response) + + user_info = sm.get_oauth_user_info("azure", {"id_token": "test-token"}) + + assert user_info["username"] == "user-123" + assert user_info["email"] == "jane.smith@example.com" + assert user_info["role_keys"] == ["admin-group", "viewer-group"] + + @mock.patch("airflow.providers.fab.auth_manager.security_manager.override.flash") + @mock.patch("airflow.providers.fab.auth_manager.security_manager.override.has_request_context") + def test_cli_safe_flash_escapes_html_username( + self, + mock_has_request_context, + mock_flash, + ): + mock_has_request_context.return_value = True + + malicious_username = "" + safe_username = "<script>alert('xss')</script>" + + EmptySecurityManager._cli_safe_flash( + f"User {safe_username} already exists", + "warning", + ) + + flash_arg = mock_flash.call_args[0][0] + + assert malicious_username not in str(flash_arg) + assert safe_username in str(flash_arg) + + @mock.patch("airflow.providers.fab.auth_manager.security_manager.override.log") + @mock.patch("airflow.providers.fab.auth_manager.security_manager.override.has_request_context") + def test_cli_safe_flash_formats_cli_output( + self, + mock_has_request_context, + mock_log, + ): + mock_has_request_context.return_value = False + + EmptySecurityManager._cli_safe_flash( + "Hello
World", + "warning", + ) + + mock_log.warning.assert_called_once_with("Hello\n*World*")