diff --git a/aws_lambda_powertools/utilities/data_classes/cognito_user_pool_event.py b/aws_lambda_powertools/utilities/data_classes/cognito_user_pool_event.py index 79c43a8b701..c85515a31d5 100644 --- a/aws_lambda_powertools/utilities/data_classes/cognito_user_pool_event.py +++ b/aws_lambda_powertools/utilities/data_classes/cognito_user_pool_event.py @@ -575,13 +575,10 @@ def scopes_to_suppress(self, value: list[str]): class ClaimsAndScopeOverrideDetails(GroupConfigurationBase): @property - def id_token_generation(self) -> TokenClaimsAndScopeOverrideDetails | None: - id_token_generation_details = self._data.get("idTokenGeneration") - return ( - None - if id_token_generation_details is None - else TokenClaimsAndScopeOverrideDetails(id_token_generation_details) - ) + def id_token_generation(self) -> TokenClaimsAndScopeOverrideDetails: + if self._data.get("idTokenGeneration") is None: + self._data["idTokenGeneration"] = {} + return TokenClaimsAndScopeOverrideDetails(self._data["idTokenGeneration"]) @id_token_generation.setter def id_token_generation(self, value: dict[str, Any]): @@ -597,13 +594,10 @@ def id_token_generation(self, value: dict[str, Any]): self._data["idTokenGeneration"] = value @property - def access_token_generation(self) -> TokenClaimsAndScopeOverrideDetails | None: - access_token_generation_details = self._data.get("accessTokenGeneration") - return ( - None - if access_token_generation_details is None - else TokenClaimsAndScopeOverrideDetails(access_token_generation_details) - ) + def access_token_generation(self) -> TokenClaimsAndScopeOverrideDetails: + if self._data.get("accessTokenGeneration") is None: + self._data["accessTokenGeneration"] = {} + return TokenClaimsAndScopeOverrideDetails(self._data["accessTokenGeneration"]) @access_token_generation.setter def access_token_generation(self, value: dict[str, Any]): @@ -622,13 +616,17 @@ def access_token_generation(self, value: dict[str, Any]): class PreTokenGenerationTriggerEventResponse(DictWrapper): @property def claims_override_details(self) -> ClaimsOverrideDetails: - return ClaimsOverrideDetails(self.get("claimsOverrideDetails") or {}) + if self._data.get("claimsOverrideDetails") is None: + self._data["claimsOverrideDetails"] = {} + return ClaimsOverrideDetails(self._data["claimsOverrideDetails"]) class PreTokenGenerationTriggerV2EventResponse(DictWrapper): @property def claims_scope_override_details(self) -> ClaimsAndScopeOverrideDetails: - return ClaimsAndScopeOverrideDetails(self.get("claimsAndScopeOverrideDetails") or {}) + if self._data.get("claimsAndScopeOverrideDetails") is None: + self._data["claimsAndScopeOverrideDetails"] = {} + return ClaimsAndScopeOverrideDetails(self._data["claimsAndScopeOverrideDetails"]) class PreTokenGenerationTriggerEvent(BaseTriggerEvent): diff --git a/tests/unit/data_classes/required_dependencies/test_cognito_user_pool_event.py b/tests/unit/data_classes/required_dependencies/test_cognito_user_pool_event.py index 41ee52d915e..af84b3f8982 100644 --- a/tests/unit/data_classes/required_dependencies/test_cognito_user_pool_event.py +++ b/tests/unit/data_classes/required_dependencies/test_cognito_user_pool_event.py @@ -187,7 +187,7 @@ def test_cognito_pre_token_generation_trigger_event(): expected_claims = {"test": "value"} claims_override_details.claims_to_add_or_override = expected_claims - assert claims_override_details.claims_to_add_or_override["test"] == "value" + assert parsed_event.response.claims_override_details.claims_to_add_or_override["test"] == "value" claims_override_details.claims_to_suppress = ["email"] assert claims_override_details.claims_to_suppress[0] == "email" @@ -229,6 +229,10 @@ def test_cognito_pre_token_v2_generation_trigger_event(): assert parsed_event.request.scopes == raw_event["request"]["scopes"] claims_scope_override_details = parsed_event.response.claims_scope_override_details + # Test that accessing id_token_generation and access_token_generation properties initialize empty dicts + assert claims_scope_override_details.id_token_generation.claims_to_add_or_override == {} + assert claims_scope_override_details.access_token_generation.claims_to_add_or_override == {} + claims_scope_override_details.id_token_generation = claims_scope_override_details.access_token_generation = {} assert claims_scope_override_details.id_token_generation.claims_to_add_or_override == {} assert claims_scope_override_details.id_token_generation.claims_to_suppress == [] @@ -246,8 +250,14 @@ def test_cognito_pre_token_v2_generation_trigger_event(): expected_claims = {"test": "value"} claims_scope_override_details.id_token_generation.claims_to_add_or_override = expected_claims claims_scope_override_details.access_token_generation.claims_to_add_or_override = expected_claims - assert claims_scope_override_details.id_token_generation.claims_to_add_or_override["test"] == "value" - assert claims_scope_override_details.access_token_generation.claims_to_add_or_override["test"] == "value" + assert ( + parsed_event.response.claims_scope_override_details.id_token_generation.claims_to_add_or_override["test"] + == "value" + ) + assert ( + parsed_event.response.claims_scope_override_details.access_token_generation.claims_to_add_or_override["test"] + == "value" + ) claims_scope_override_details.id_token_generation.claims_to_suppress = ( claims_scope_override_details.access_token_generation.claims_to_suppress