From 909669f4053f485ad94223fa51e2b8be2f61a76f Mon Sep 17 00:00:00 2001 From: Peter <34331512+pmaytak@users.noreply.github.com> Date: Tue, 14 May 2024 23:58:05 -0700 Subject: [PATCH] Update JsonWebToken to enable extensibility (#2582) * Extract token reading code into a separate ReadPropertyValue. * Rename. Use IDictionary. * Add a test. Add JsonWebTokens InternalsVisibleTo for tests. --- .../InternalsVisibleTo.cs | 4 + .../Json/JsonWebToken.PayloadClaimSet.cs | 137 +++++++++--------- .../CustomJsonWebToken.cs | 42 ++++++ .../JsonWebTokenTests.cs | 29 +++- 4 files changed, 141 insertions(+), 71 deletions(-) create mode 100644 src/Microsoft.IdentityModel.JsonWebTokens/InternalsVisibleTo.cs create mode 100644 test/Microsoft.IdentityModel.JsonWebTokens.Tests/CustomJsonWebToken.cs diff --git a/src/Microsoft.IdentityModel.JsonWebTokens/InternalsVisibleTo.cs b/src/Microsoft.IdentityModel.JsonWebTokens/InternalsVisibleTo.cs new file mode 100644 index 0000000000..05a39737b8 --- /dev/null +++ b/src/Microsoft.IdentityModel.JsonWebTokens/InternalsVisibleTo.cs @@ -0,0 +1,4 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +[assembly: System.Runtime.CompilerServices.InternalsVisibleTo("Microsoft.IdentityModel.JsonWebTokens.Tests, PublicKey=0024000004800000940000000602000000240000525341310004000001000100b5fc90e7027f67871e773a8fde8938c81dd402ba65b9201d60593e96c492651e889cc13f1415ebb53fac1131ae0bd333c5ee6021672d9718ea31a8aebd0da0072f25d87dba6fc90ffd598ed4da35e44c398c454307e8e33b8426143daec9f596836f97c8f74750e5975c64e2189f45def46b2a2b1247adc3652bf5c308055da9")] diff --git a/src/Microsoft.IdentityModel.JsonWebTokens/Json/JsonWebToken.PayloadClaimSet.cs b/src/Microsoft.IdentityModel.JsonWebTokens/Json/JsonWebToken.PayloadClaimSet.cs index d98a916982..1fb5a7410d 100644 --- a/src/Microsoft.IdentityModel.JsonWebTokens/Json/JsonWebToken.PayloadClaimSet.cs +++ b/src/Microsoft.IdentityModel.JsonWebTokens/Json/JsonWebToken.PayloadClaimSet.cs @@ -18,7 +18,7 @@ internal JsonClaimSet CreatePayloadClaimSet(byte[] bytes, int length) } internal JsonClaimSet CreatePayloadClaimSet(ReadOnlySpan byteSpan) - { + { Utf8JsonReader reader = new(byteSpan); if (!JsonSerializerPrimitives.IsReaderAtTokenType(ref reader, JsonTokenType.StartObject, true)) throw LogHelper.LogExceptionMessage( @@ -37,71 +37,7 @@ internal JsonClaimSet CreatePayloadClaimSet(ReadOnlySpan byteSpan) { if (reader.TokenType == JsonTokenType.PropertyName) { - if (reader.ValueTextEquals(JwtPayloadUtf8Bytes.Aud)) - { - _audiences = []; - reader.Read(); - if (reader.TokenType == JsonTokenType.StartArray) - { - JsonSerializerPrimitives.ReadStringsSkipNulls(ref reader, _audiences, JwtRegisteredClaimNames.Aud, ClassName); - claims[JwtRegisteredClaimNames.Aud] = _audiences; - } - else - { - if (reader.TokenType != JsonTokenType.Null) - { - _audiences.Add(JsonSerializerPrimitives.ReadString(ref reader, JwtRegisteredClaimNames.Aud, ClassName)); - claims[JwtRegisteredClaimNames.Aud] = _audiences[0]; - } - else - { - claims[JwtRegisteredClaimNames.Aud] = _audiences; - } - } - } - else if (reader.ValueTextEquals(JwtPayloadUtf8Bytes.Azp)) - { - _azp = JsonSerializerPrimitives.ReadString(ref reader, JwtRegisteredClaimNames.Azp, ClassName, true); - claims[JwtRegisteredClaimNames.Azp] = _azp; - } - else if (reader.ValueTextEquals(JwtPayloadUtf8Bytes.Exp)) - { - _exp = JsonSerializerPrimitives.ReadLong(ref reader, JwtRegisteredClaimNames.Exp, ClassName, true); - _expDateTime = EpochTime.DateTime(_exp.Value); - claims[JwtRegisteredClaimNames.Exp] = _exp; - } - else if (reader.ValueTextEquals(JwtPayloadUtf8Bytes.Iat)) - { - _iat = JsonSerializerPrimitives.ReadLong(ref reader, JwtRegisteredClaimNames.Iat, ClassName, true); - _iatDateTime = EpochTime.DateTime(_iat.Value); - claims[JwtRegisteredClaimNames.Iat] = _iat; - } - else if (reader.ValueTextEquals(JwtPayloadUtf8Bytes.Iss)) - { - _iss = JsonSerializerPrimitives.ReadString(ref reader, JwtRegisteredClaimNames.Iss, ClassName, true); - claims[JwtRegisteredClaimNames.Iss] = _iss; - } - else if (reader.ValueTextEquals(JwtPayloadUtf8Bytes.Jti)) - { - _jti = JsonSerializerPrimitives.ReadString(ref reader, JwtRegisteredClaimNames.Jti, ClassName, true); - claims[JwtRegisteredClaimNames.Jti] = _jti; - } - else if (reader.ValueTextEquals(JwtPayloadUtf8Bytes.Nbf)) - { - _nbf = JsonSerializerPrimitives.ReadLong(ref reader, JwtRegisteredClaimNames.Nbf, ClassName, true); - _nbfDateTime = EpochTime.DateTime(_nbf.Value); - claims[JwtRegisteredClaimNames.Nbf] = _nbf; - } - else if (reader.ValueTextEquals(JwtPayloadUtf8Bytes.Sub)) - { - _sub = JsonSerializerPrimitives.ReadStringOrNumberAsString(ref reader, JwtRegisteredClaimNames.Sub, ClassName, true); - claims[JwtRegisteredClaimNames.Sub] = _sub; - } - else - { - string propertyName = reader.GetString(); - claims[propertyName] = JsonSerializerPrimitives.ReadPropertyValueAsObject(ref reader, propertyName, JsonClaimSet.ClassName, true); - } + ReadPayloadValue(ref reader, claims); } // We read a JsonTokenType.StartObject above, exiting and positioning reader at next token. else if (JsonSerializerPrimitives.IsReaderAtTokenType(ref reader, JsonTokenType.EndObject, false)) @@ -112,5 +48,74 @@ internal JsonClaimSet CreatePayloadClaimSet(ReadOnlySpan byteSpan) return new JsonClaimSet(claims); } + + private protected virtual void ReadPayloadValue(ref Utf8JsonReader reader, IDictionary claims) + { + if (reader.ValueTextEquals(JwtPayloadUtf8Bytes.Aud)) + { + _audiences = []; + reader.Read(); + if (reader.TokenType == JsonTokenType.StartArray) + { + JsonSerializerPrimitives.ReadStringsSkipNulls(ref reader, _audiences, JwtRegisteredClaimNames.Aud, ClassName); + claims[JwtRegisteredClaimNames.Aud] = _audiences; + } + else + { + if (reader.TokenType != JsonTokenType.Null) + { + _audiences.Add(JsonSerializerPrimitives.ReadString(ref reader, JwtRegisteredClaimNames.Aud, ClassName)); + claims[JwtRegisteredClaimNames.Aud] = _audiences[0]; + } + else + { + claims[JwtRegisteredClaimNames.Aud] = _audiences; + } + } + } + else if (reader.ValueTextEquals(JwtPayloadUtf8Bytes.Azp)) + { + _azp = JsonSerializerPrimitives.ReadString(ref reader, JwtRegisteredClaimNames.Azp, ClassName, true); + claims[JwtRegisteredClaimNames.Azp] = _azp; + } + else if (reader.ValueTextEquals(JwtPayloadUtf8Bytes.Exp)) + { + _exp = JsonSerializerPrimitives.ReadLong(ref reader, JwtRegisteredClaimNames.Exp, ClassName, true); + _expDateTime = EpochTime.DateTime(_exp.Value); + claims[JwtRegisteredClaimNames.Exp] = _exp; + } + else if (reader.ValueTextEquals(JwtPayloadUtf8Bytes.Iat)) + { + _iat = JsonSerializerPrimitives.ReadLong(ref reader, JwtRegisteredClaimNames.Iat, ClassName, true); + _iatDateTime = EpochTime.DateTime(_iat.Value); + claims[JwtRegisteredClaimNames.Iat] = _iat; + } + else if (reader.ValueTextEquals(JwtPayloadUtf8Bytes.Iss)) + { + _iss = JsonSerializerPrimitives.ReadString(ref reader, JwtRegisteredClaimNames.Iss, ClassName, true); + claims[JwtRegisteredClaimNames.Iss] = _iss; + } + else if (reader.ValueTextEquals(JwtPayloadUtf8Bytes.Jti)) + { + _jti = JsonSerializerPrimitives.ReadString(ref reader, JwtRegisteredClaimNames.Jti, ClassName, true); + claims[JwtRegisteredClaimNames.Jti] = _jti; + } + else if (reader.ValueTextEquals(JwtPayloadUtf8Bytes.Nbf)) + { + _nbf = JsonSerializerPrimitives.ReadLong(ref reader, JwtRegisteredClaimNames.Nbf, ClassName, true); + _nbfDateTime = EpochTime.DateTime(_nbf.Value); + claims[JwtRegisteredClaimNames.Nbf] = _nbf; + } + else if (reader.ValueTextEquals(JwtPayloadUtf8Bytes.Sub)) + { + _sub = JsonSerializerPrimitives.ReadStringOrNumberAsString(ref reader, JwtRegisteredClaimNames.Sub, ClassName, true); + claims[JwtRegisteredClaimNames.Sub] = _sub; + } + else + { + string propertyName = reader.GetString(); + claims[propertyName] = JsonSerializerPrimitives.ReadPropertyValueAsObject(ref reader, propertyName, JsonClaimSet.ClassName, true); + } + } } } diff --git a/test/Microsoft.IdentityModel.JsonWebTokens.Tests/CustomJsonWebToken.cs b/test/Microsoft.IdentityModel.JsonWebTokens.Tests/CustomJsonWebToken.cs new file mode 100644 index 0000000000..813b82e4b3 --- /dev/null +++ b/test/Microsoft.IdentityModel.JsonWebTokens.Tests/CustomJsonWebToken.cs @@ -0,0 +1,42 @@ +using System; +using System.Collections.Generic; +using System.Text.Json; +using Microsoft.IdentityModel.Tokens.Json; + +namespace Microsoft.IdentityModel.JsonWebTokens.Tests +{ + public class CustomJsonWebToken : JsonWebToken + { + private const string CustomClaimName = "CustomClaim"; + + public CustomJsonWebToken(string jwtEncodedString) : base(jwtEncodedString) { } + + public CustomJsonWebToken(ReadOnlyMemory encodedTokenMemory) : base(encodedTokenMemory) { } + + public CustomJsonWebToken(string header, string payload) : base(header, payload) { } + + private protected override void ReadPayloadValue(ref Utf8JsonReader reader, IDictionary claims) + { + if (reader.ValueTextEquals(CustomClaimName)) + { + _customClaim = JsonSerializerPrimitives.ReadString(ref reader, CustomClaimName, ClassName, true); + claims[CustomClaimName] = _customClaim; + } + else + { + base.ReadPayloadValue(ref reader, claims); + } + } + + private string _customClaim; + + public string CustomClaim + { + get + { + _customClaim ??= Payload.GetStringValue(CustomClaimName); + return _customClaim; + } + } + } +} diff --git a/test/Microsoft.IdentityModel.JsonWebTokens.Tests/JsonWebTokenTests.cs b/test/Microsoft.IdentityModel.JsonWebTokens.Tests/JsonWebTokenTests.cs index d672f6cac7..08644135d1 100644 --- a/test/Microsoft.IdentityModel.JsonWebTokens.Tests/JsonWebTokenTests.cs +++ b/test/Microsoft.IdentityModel.JsonWebTokens.Tests/JsonWebTokenTests.cs @@ -363,14 +363,14 @@ public static TheoryData CheckAudienceValuesTheoryDat PropertyName = "aud", PropertyValue = new List(), PropertyType = typeof(List), - Json = JsonUtilities.CreateUnsignedToken("aud", new List{ null, null }) + Json = JsonUtilities.CreateUnsignedToken("aud", new List { null, null }) }); theoryData.Add(new GetPayloadValueTheoryData("singleNonNull") { ClaimValue = new List { "audience" }, PropertyName = "aud", - PropertyValue = new List { "audience"}, + PropertyValue = new List { "audience" }, PropertyType = typeof(List), Json = JsonUtilities.CreateUnsignedToken("aud", "audience") }); @@ -379,7 +379,7 @@ public static TheoryData CheckAudienceValuesTheoryDat { ClaimValue = new List { "audience1" }, PropertyName = "aud", - PropertyValue = new List { "audience1"}, + PropertyValue = new List { "audience1" }, PropertyType = typeof(List), Json = JsonUtilities.CreateUnsignedToken("aud", new List { null, "audience1", null }) }); @@ -726,7 +726,7 @@ public static TheoryData GetPayloadSubClaimValueTheor return theoryData; } - + } // This test ensures that accessing claims from the payload works as expected. @@ -981,7 +981,7 @@ public static TheoryData GetPayloadValueTheoryData { PropertyName = "dateTime", PropertyType = typeof(string[]), - PropertyValue = new string[] {dateTime.ToString("o", CultureInfo.InvariantCulture)}, + PropertyValue = new string[] { dateTime.ToString("o", CultureInfo.InvariantCulture) }, Json = JsonUtilities.CreateUnsignedToken("dateTime", dateTime) }); @@ -1718,6 +1718,25 @@ public void StringAndMemoryConstructors_CreateEquivalentTokens(JwtTheoryData the } TestUtilities.AssertFailIfErrors(context); } + + [Fact] + public void DerivedJsonWebToken_IsCreatedCorrectly() + { + var expectedCustomClaim = "customclaim"; + var tokenStr = new JsonWebTokenHandler().CreateToken(new SecurityTokenDescriptor + { + Issuer = Default.Issuer, + Claims = new Dictionary + { + { "CustomClaim", expectedCustomClaim }, + } + }); + + var derivedToken = new CustomJsonWebToken(tokenStr); + + Assert.Equal(expectedCustomClaim, derivedToken.CustomClaim); + Assert.Equal(Default.Issuer, derivedToken.Issuer); + } } public class ParseTimeValuesTheoryData : TheoryDataBase