Skip to content
Open
Original file line number Diff line number Diff line change
Expand Up @@ -575,13 +575,10 @@ def scopes_to_suppress(self, value: list[str]):

class ClaimsAndScopeOverrideDetails(GroupConfigurationBase):
@property
def id_token_generation(self) -> TokenClaimsAndScopeOverrideDetails | None:
id_token_generation_details = self._data.get("idTokenGeneration")
return (
None
if id_token_generation_details is None
else TokenClaimsAndScopeOverrideDetails(id_token_generation_details)
)
def id_token_generation(self) -> TokenClaimsAndScopeOverrideDetails:
if self._data.get("idTokenGeneration") is None:
self._data["idTokenGeneration"] = {}
return TokenClaimsAndScopeOverrideDetails(self._data["idTokenGeneration"])

@id_token_generation.setter
def id_token_generation(self, value: dict[str, Any]):
Expand All @@ -597,13 +594,10 @@ def id_token_generation(self, value: dict[str, Any]):
self._data["idTokenGeneration"] = value

@property
def access_token_generation(self) -> TokenClaimsAndScopeOverrideDetails | None:
access_token_generation_details = self._data.get("accessTokenGeneration")
return (
None
if access_token_generation_details is None
else TokenClaimsAndScopeOverrideDetails(access_token_generation_details)
)
def access_token_generation(self) -> TokenClaimsAndScopeOverrideDetails:
if self._data.get("accessTokenGeneration") is None:
self._data["accessTokenGeneration"] = {}
return TokenClaimsAndScopeOverrideDetails(self._data["accessTokenGeneration"])

@access_token_generation.setter
def access_token_generation(self, value: dict[str, Any]):
Expand All @@ -622,13 +616,17 @@ def access_token_generation(self, value: dict[str, Any]):
class PreTokenGenerationTriggerEventResponse(DictWrapper):
@property
def claims_override_details(self) -> ClaimsOverrideDetails:
return ClaimsOverrideDetails(self.get("claimsOverrideDetails") or {})
if self._data.get("claimsOverrideDetails") is None:
self._data["claimsOverrideDetails"] = {}
return ClaimsOverrideDetails(self._data["claimsOverrideDetails"])


class PreTokenGenerationTriggerV2EventResponse(DictWrapper):
@property
def claims_scope_override_details(self) -> ClaimsAndScopeOverrideDetails:
return ClaimsAndScopeOverrideDetails(self.get("claimsAndScopeOverrideDetails") or {})
if self._data.get("claimsAndScopeOverrideDetails") is None:
self._data["claimsAndScopeOverrideDetails"] = {}
return ClaimsAndScopeOverrideDetails(self._data["claimsAndScopeOverrideDetails"])


class PreTokenGenerationTriggerEvent(BaseTriggerEvent):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def test_cognito_pre_token_generation_trigger_event():

expected_claims = {"test": "value"}
claims_override_details.claims_to_add_or_override = expected_claims
assert claims_override_details.claims_to_add_or_override["test"] == "value"
assert parsed_event.response.claims_override_details.claims_to_add_or_override["test"] == "value"

claims_override_details.claims_to_suppress = ["email"]
assert claims_override_details.claims_to_suppress[0] == "email"
Expand Down Expand Up @@ -229,6 +229,10 @@ def test_cognito_pre_token_v2_generation_trigger_event():
assert parsed_event.request.scopes == raw_event["request"]["scopes"]

claims_scope_override_details = parsed_event.response.claims_scope_override_details
# Test that accessing id_token_generation and access_token_generation properties initialize empty dicts
assert claims_scope_override_details.id_token_generation.claims_to_add_or_override == {}
assert claims_scope_override_details.access_token_generation.claims_to_add_or_override == {}

claims_scope_override_details.id_token_generation = claims_scope_override_details.access_token_generation = {}
assert claims_scope_override_details.id_token_generation.claims_to_add_or_override == {}
assert claims_scope_override_details.id_token_generation.claims_to_suppress == []
Expand All @@ -246,8 +250,14 @@ def test_cognito_pre_token_v2_generation_trigger_event():
expected_claims = {"test": "value"}
claims_scope_override_details.id_token_generation.claims_to_add_or_override = expected_claims
claims_scope_override_details.access_token_generation.claims_to_add_or_override = expected_claims
assert claims_scope_override_details.id_token_generation.claims_to_add_or_override["test"] == "value"
assert claims_scope_override_details.access_token_generation.claims_to_add_or_override["test"] == "value"
assert (
parsed_event.response.claims_scope_override_details.id_token_generation.claims_to_add_or_override["test"]
== "value"
)
assert (
parsed_event.response.claims_scope_override_details.access_token_generation.claims_to_add_or_override["test"]
== "value"
)

claims_scope_override_details.id_token_generation.claims_to_suppress = (
claims_scope_override_details.access_token_generation.claims_to_suppress
Expand Down