diff --git a/src/MessagePack.GeneratorCore/CodeAnalysis/Definitions.cs b/src/MessagePack.GeneratorCore/CodeAnalysis/Definitions.cs index 9e3b86899..5ccf8c5b3 100644 --- a/src/MessagePack.GeneratorCore/CodeAnalysis/Definitions.cs +++ b/src/MessagePack.GeneratorCore/CodeAnalysis/Definitions.cs @@ -9,6 +9,11 @@ namespace MessagePackCompiler.CodeAnalysis { + public interface INamespaceInfo + { + string? Namespace { get; } + } + public interface IResolverRegisterInfo { string FullName { get; } @@ -16,41 +21,36 @@ public interface IResolverRegisterInfo string FormatterName { get; } } - public class ObjectSerializationInfo : IResolverRegisterInfo + public class ObjectSerializationInfo : IResolverRegisterInfo, INamespaceInfo { - public string Name { get; set; } + public string Name { get; } - public string FullName { get; set; } + public string FullName { get; } - public string Namespace { get; set; } + public string? Namespace { get; } - public GenericTypeParameterInfo[] GenericTypeParameters { get; set; } + public GenericTypeParameterInfo[] GenericTypeParameters { get; } - public bool IsOpenGenericType { get; set; } + public bool IsOpenGenericType { get; } - public bool IsIntKey { get; set; } + public bool IsIntKey { get; } public bool IsStringKey { get { return !this.IsIntKey; } } - public bool IsClass { get; set; } - - public bool IsStruct - { - get { return !this.IsClass; } - } + public bool IsClass { get; } - public MemberSerializationInfo[] ConstructorParameters { get; set; } + public MemberSerializationInfo[] ConstructorParameters { get; } - public MemberSerializationInfo[] Members { get; set; } + public MemberSerializationInfo[] Members { get; } - public bool HasIMessagePackSerializationCallbackReceiver { get; set; } + public bool HasIMessagePackSerializationCallbackReceiver { get; } - public bool NeedsCastOnBefore { get; set; } + public bool NeedsCastOnBefore { get; } - public bool NeedsCastOnAfter { get; set; } + public bool NeedsCastOnAfter { get; } public string FormatterName => this.Namespace == null ? FormatterNameWithoutNameSpace : this.Namespace + "." + FormatterNameWithoutNameSpace; @@ -79,7 +79,7 @@ public int MaxKey } } - public MemberSerializationInfo GetMember(int index) + public MemberSerializationInfo? GetMember(int index) { return this.Members.FirstOrDefault(x => x.IntKey == index); } @@ -89,6 +89,22 @@ public string GetConstructorString() var args = string.Join(", ", this.ConstructorParameters.Select(x => "__" + x.Name + "__")); return $"{this.FullName}({args})"; } + + public ObjectSerializationInfo(bool isClass, bool isOpenGenericType, GenericTypeParameterInfo[] genericTypeParameterInfos, MemberSerializationInfo[] constructorParameters, bool isIntKey, MemberSerializationInfo[] members, string name, string fullName, string? @namespace, bool hasSerializationConstructor, bool needsCastOnAfter, bool needsCastOnBefore) + { + IsClass = isClass; + IsOpenGenericType = isOpenGenericType; + GenericTypeParameters = genericTypeParameterInfos; + ConstructorParameters = constructorParameters; + IsIntKey = isIntKey; + Members = members; + Name = name; + FullName = fullName; + Namespace = @namespace; + HasIMessagePackSerializationCallbackReceiver = hasSerializationConstructor; + NeedsCastOnAfter = needsCastOnAfter; + NeedsCastOnBefore = needsCastOnBefore; + } } public class GenericTypeParameterInfo @@ -103,33 +119,44 @@ public GenericTypeParameterInfo(string name, string constraints) { Name = name ?? throw new ArgumentNullException(nameof(name)); Constraints = constraints ?? throw new ArgumentNullException(nameof(name)); - HasConstraints = !string.IsNullOrEmpty(constraints); + HasConstraints = constraints != string.Empty; } } public class MemberSerializationInfo { - public bool IsProperty { get; set; } + public bool IsProperty { get; } - public bool IsField { get; set; } + public bool IsWritable { get; } - public bool IsWritable { get; set; } + public bool IsReadable { get; } - public bool IsReadable { get; set; } + public int IntKey { get; } - public int IntKey { get; set; } + public string StringKey { get; } - public string StringKey { get; set; } + public string Type { get; } - public string Type { get; set; } + public string Name { get; } - public string Name { get; set; } + public string ShortTypeName { get; } - public string ShortTypeName { get; set; } + public string? CustomFormatterTypeName { get; } - public string CustomFormatterTypeName { get; set; } + private readonly HashSet primitiveTypes = new (Generator.ShouldUseFormatterResolverHelper.PrimitiveTypes); - private readonly HashSet primitiveTypes = new HashSet(Generator.ShouldUseFormatterResolverHelper.PrimitiveTypes); + public MemberSerializationInfo(bool isProperty, bool isWritable, bool isReadable, int intKey, string stringKey, string name, string type, string shortTypeName, string? customFormatterTypeName) + { + IsProperty = isProperty; + IsWritable = isWritable; + IsReadable = isReadable; + IntKey = intKey; + StringKey = stringKey; + Type = type; + Name = name; + ShortTypeName = shortTypeName; + CustomFormatterTypeName = customFormatterTypeName; + } public string GetSerializeMethodString() { @@ -139,7 +166,7 @@ public string GetSerializeMethodString() } else if (this.primitiveTypes.Contains(this.Type)) { - return $"writer.Write(value.{this.Name})"; + return "writer.Write(value." + this.Name + ")"; } else { @@ -156,7 +183,7 @@ public string GetDeserializeMethodString() else if (this.primitiveTypes.Contains(this.Type)) { string suffix = this.Type == "byte[]" ? "?.ToArray()" : string.Empty; - return $"reader.Read{this.ShortTypeName.Replace("[]", "s")}()" + suffix; + return $"reader.Read{this.ShortTypeName!.Replace("[]", "s")}()" + suffix; } else { @@ -165,55 +192,84 @@ public string GetDeserializeMethodString() } } - public class EnumSerializationInfo : IResolverRegisterInfo + public class EnumSerializationInfo : IResolverRegisterInfo, INamespaceInfo { - public string Namespace { get; set; } + public EnumSerializationInfo(string? @namespace, string name, string fullName, string underlyingType) + { + Namespace = @namespace; + Name = name; + FullName = fullName; + UnderlyingType = underlyingType; + } + + public string? Namespace { get; } - public string Name { get; set; } + public string Name { get; } - public string FullName { get; set; } + public string FullName { get; } - public string UnderlyingType { get; set; } + public string UnderlyingType { get; } public string FormatterName => (this.Namespace == null ? this.Name : this.Namespace + "." + this.Name) + "Formatter"; } public class GenericSerializationInfo : IResolverRegisterInfo, IEquatable { - public string FullName { get; set; } + public string FullName { get; } - public string FormatterName { get; set; } + public string FormatterName { get; } - public bool IsOpenGenericType { get; set; } + public bool IsOpenGenericType { get; } - public bool Equals(GenericSerializationInfo other) + public bool Equals(GenericSerializationInfo? other) { - return this.FullName.Equals(other.FullName); + return this.FullName.Equals(other?.FullName); } public override int GetHashCode() { return this.FullName.GetHashCode(); } + + public GenericSerializationInfo(string fullName, string formatterName, bool isOpenGenericType) + { + FullName = fullName; + FormatterName = formatterName; + IsOpenGenericType = isOpenGenericType; + } } - public class UnionSerializationInfo : IResolverRegisterInfo + public class UnionSerializationInfo : IResolverRegisterInfo, INamespaceInfo { - public string Namespace { get; set; } + public string? Namespace { get; } - public string Name { get; set; } + public string Name { get; } - public string FullName { get; set; } + public string FullName { get; } public string FormatterName => (this.Namespace == null ? this.Name : this.Namespace + "." + this.Name) + "Formatter"; - public UnionSubTypeInfo[] SubTypes { get; set; } + public UnionSubTypeInfo[] SubTypes { get; } + + public UnionSerializationInfo(string? @namespace, string name, string fullName, UnionSubTypeInfo[] subTypes) + { + Namespace = @namespace; + Name = name; + FullName = fullName; + SubTypes = subTypes; + } } public class UnionSubTypeInfo { - public string Type { get; set; } + public UnionSubTypeInfo(int key, string type) + { + Key = key; + Type = type; + } + + public int Key { get; } - public int Key { get; set; } + public string Type { get; } } } diff --git a/src/MessagePack.GeneratorCore/CodeAnalysis/TypeCollector.cs b/src/MessagePack.GeneratorCore/CodeAnalysis/TypeCollector.cs index b65a9f7d3..93cac37b2 100644 --- a/src/MessagePack.GeneratorCore/CodeAnalysis/TypeCollector.cs +++ b/src/MessagePack.GeneratorCore/CodeAnalysis/TypeCollector.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Generic; +using System.Collections.Immutable; using System.Linq; using System.Text; using System.Text.RegularExpressions; @@ -23,14 +24,14 @@ public MessagePackGeneratorResolveFailedException(string message) internal class ReferenceSymbols { #pragma warning disable SA1401 // Fields should be private - internal readonly INamedTypeSymbol Task; - internal readonly INamedTypeSymbol TaskOfT; + internal readonly INamedTypeSymbol? Task; + internal readonly INamedTypeSymbol? TaskOfT; internal readonly INamedTypeSymbol MessagePackObjectAttribute; internal readonly INamedTypeSymbol UnionAttribute; internal readonly INamedTypeSymbol SerializationConstructorAttribute; internal readonly INamedTypeSymbol KeyAttribute; internal readonly INamedTypeSymbol IgnoreAttribute; - internal readonly INamedTypeSymbol IgnoreDataMemberAttribute; + internal readonly INamedTypeSymbol? IgnoreDataMemberAttribute; internal readonly INamedTypeSymbol IMessagePackSerializationCallbackReceiver; internal readonly INamedTypeSymbol MessagePackFormatterAttribute; #pragma warning restore SA1401 // Fields should be private @@ -49,35 +50,20 @@ public ReferenceSymbols(Compilation compilation, Action logger) logger("failed to get metadata of System.Threading.Tasks.Task"); } - MessagePackObjectAttribute = compilation.GetTypeByMetadataName("MessagePack.MessagePackObjectAttribute"); - if (MessagePackObjectAttribute == null) - { - throw new InvalidOperationException("failed to get metadata of MessagePack.MessagePackObjectAttribute"); - } + MessagePackObjectAttribute = compilation.GetTypeByMetadataName("MessagePack.MessagePackObjectAttribute") + ?? throw new InvalidOperationException("failed to get metadata of MessagePack.MessagePackObjectAttribute"); - UnionAttribute = compilation.GetTypeByMetadataName("MessagePack.UnionAttribute"); - if (UnionAttribute == null) - { - throw new InvalidOperationException("failed to get metadata of MessagePack.UnionAttribute"); - } + UnionAttribute = compilation.GetTypeByMetadataName("MessagePack.UnionAttribute") + ?? throw new InvalidOperationException("failed to get metadata of MessagePack.UnionAttribute"); - SerializationConstructorAttribute = compilation.GetTypeByMetadataName("MessagePack.SerializationConstructorAttribute"); - if (SerializationConstructorAttribute == null) - { - throw new InvalidOperationException("failed to get metadata of MessagePack.SerializationConstructorAttribute"); - } + SerializationConstructorAttribute = compilation.GetTypeByMetadataName("MessagePack.SerializationConstructorAttribute") + ?? throw new InvalidOperationException("failed to get metadata of MessagePack.SerializationConstructorAttribute"); - KeyAttribute = compilation.GetTypeByMetadataName("MessagePack.KeyAttribute"); - if (KeyAttribute == null) - { - throw new InvalidOperationException("failed to get metadata of MessagePack.KeyAttribute"); - } + KeyAttribute = compilation.GetTypeByMetadataName("MessagePack.KeyAttribute") + ?? throw new InvalidOperationException("failed to get metadata of MessagePack.KeyAttribute"); - IgnoreAttribute = compilation.GetTypeByMetadataName("MessagePack.IgnoreMemberAttribute"); - if (IgnoreAttribute == null) - { - throw new InvalidOperationException("failed to get metadata of MessagePack.IgnoreMemberAttribute"); - } + IgnoreAttribute = compilation.GetTypeByMetadataName("MessagePack.IgnoreMemberAttribute") + ?? throw new InvalidOperationException("failed to get metadata of MessagePack.IgnoreMemberAttribute"); IgnoreDataMemberAttribute = compilation.GetTypeByMetadataName("System.Runtime.Serialization.IgnoreDataMemberAttribute"); if (IgnoreDataMemberAttribute == null) @@ -85,24 +71,16 @@ public ReferenceSymbols(Compilation compilation, Action logger) logger("failed to get metadata of System.Runtime.Serialization.IgnoreDataMemberAttribute"); } - IMessagePackSerializationCallbackReceiver = compilation.GetTypeByMetadataName("MessagePack.IMessagePackSerializationCallbackReceiver"); - if (IMessagePackSerializationCallbackReceiver == null) - { - throw new InvalidOperationException("failed to get metadata of MessagePack.IMessagePackSerializationCallbackReceiver"); - } + IMessagePackSerializationCallbackReceiver = compilation.GetTypeByMetadataName("MessagePack.IMessagePackSerializationCallbackReceiver") + ?? throw new InvalidOperationException("failed to get metadata of MessagePack.IMessagePackSerializationCallbackReceiver"); - MessagePackFormatterAttribute = compilation.GetTypeByMetadataName("MessagePack.MessagePackFormatterAttribute"); - if (MessagePackFormatterAttribute == null) - { - throw new InvalidOperationException("failed to get metadata of MessagePack.MessagePackFormatterAttribute"); - } + MessagePackFormatterAttribute = compilation.GetTypeByMetadataName("MessagePack.MessagePackFormatterAttribute") + ?? throw new InvalidOperationException("failed to get metadata of MessagePack.MessagePackFormatterAttribute"); } } public class TypeCollector { - private const string CodegeneratorOnlyPreprocessorSymbol = "INCLUDE_ONLY_CODE_GENERATION"; - private static readonly SymbolDisplayFormat BinaryWriteFormat = new SymbolDisplayFormat( genericsOptions: SymbolDisplayGenericsOptions.IncludeTypeParameters, miscellaneousOptions: SymbolDisplayMiscellaneousOptions.ExpandNullable, @@ -114,7 +92,7 @@ public class TypeCollector private readonly bool isForceUseMap; private readonly ReferenceSymbols typeReferences; private readonly INamedTypeSymbol[] targetTypes; - private readonly HashSet embeddedTypes = new HashSet(new string[] + private readonly HashSet embeddedTypes = new (new[] { "short", "int", @@ -185,7 +163,7 @@ public class TypeCollector "System.Reactive.Unit", }); - private readonly Dictionary knownGenericTypes = new Dictionary + private readonly Dictionary knownGenericTypes = new () { #pragma warning disable SA1509 // Opening braces should not be preceded by blank line { "System.Collections.Generic.List<>", "global::MessagePack.Formatters.ListFormatter" }, @@ -261,22 +239,19 @@ public class TypeCollector #pragma warning restore SA1509 // Opening braces should not be preceded by blank line }; - private readonly Action logger; - private readonly bool disallowInternal; - private HashSet externalIgnoreTypeNames; + private readonly HashSet externalIgnoreTypeNames; // visitor workspace: - private HashSet alreadyCollected; - private List collectedObjectInfo; - private List collectedEnumInfo; - private List collectedGenericInfo; - private List collectedUnionInfo; + private readonly HashSet alreadyCollected = new (); + private readonly List collectedObjectInfo = new (); + private readonly List collectedEnumInfo = new (); + private readonly List collectedGenericInfo = new (); + private readonly List collectedUnionInfo = new (); - public TypeCollector(Compilation compilation, bool disallowInternal, bool isForceUseMap, string[] ignoreTypeNames, Action logger) + public TypeCollector(Compilation compilation, bool disallowInternal, bool isForceUseMap, string[]? ignoreTypeNames, Action logger) { - this.logger = logger; this.typeReferences = new ReferenceSymbols(compilation, logger); this.disallowInternal = disallowInternal; this.isForceUseMap = isForceUseMap; @@ -307,11 +282,11 @@ public TypeCollector(Compilation compilation, bool disallowInternal, bool isForc private void ResetWorkspace() { - this.alreadyCollected = new HashSet(); - this.collectedObjectInfo = new List(); - this.collectedEnumInfo = new List(); - this.collectedGenericInfo = new List(); - this.collectedUnionInfo = new List(); + this.alreadyCollected.Clear(); + this.collectedObjectInfo.Clear(); + this.collectedEnumInfo.Clear(); + this.collectedGenericInfo.Clear(); + this.collectedUnionInfo.Clear(); } // EntryPoint @@ -339,19 +314,20 @@ private void CollectCore(ITypeSymbol typeSymbol) return; } - if (this.embeddedTypes.Contains(typeSymbol.ToString())) + var typeSymbolString = typeSymbol.ToString() ?? throw new InvalidOperationException(); + if (this.embeddedTypes.Contains(typeSymbolString)) { return; } - if (this.externalIgnoreTypeNames.Contains(typeSymbol.ToString())) + if (this.externalIgnoreTypeNames.Contains(typeSymbolString)) { return; } - if (typeSymbol.TypeKind == TypeKind.Array) + if (typeSymbol is IArrayTypeSymbol arrayTypeSymbol) { - this.CollectArray(typeSymbol as IArrayTypeSymbol); + this.CollectArray(arrayTypeSymbol); return; } @@ -371,9 +347,9 @@ private void CollectCore(ITypeSymbol typeSymbol) return; } - if (typeSymbol.TypeKind == TypeKind.Enum) + if (type.EnumUnderlyingType != null) { - this.CollectEnum(type); + this.CollectEnum(type, type.EnumUnderlyingType); return; } @@ -401,62 +377,57 @@ private void CollectCore(ITypeSymbol typeSymbol) } this.CollectObject(type); - return; } - private void CollectEnum(INamedTypeSymbol type) + private void CollectEnum(INamedTypeSymbol type, ISymbol enumUnderlyingType) { - var info = new EnumSerializationInfo - { - Name = type.ToDisplayString(ShortTypeNameFormat).Replace(".", "_"), - Namespace = type.ContainingNamespace.IsGlobalNamespace ? null : type.ContainingNamespace.ToDisplayString(), - FullName = type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), - UnderlyingType = type.EnumUnderlyingType.ToDisplayString(BinaryWriteFormat), - }; - + var info = new EnumSerializationInfo(type.ContainingNamespace.IsGlobalNamespace ? null : type.ContainingNamespace.ToDisplayString(), type.ToDisplayString(ShortTypeNameFormat).Replace(".", "_"), type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), enumUnderlyingType.ToDisplayString(BinaryWriteFormat)); this.collectedEnumInfo.Add(info); } private void CollectUnion(INamedTypeSymbol type) { - System.Collections.Immutable.ImmutableArray[] unionAttrs = type.GetAttributes().Where(x => x.AttributeClass.ApproximatelyEqual(this.typeReferences.UnionAttribute)).Select(x => x.ConstructorArguments).ToArray(); + ImmutableArray[] unionAttrs = type.GetAttributes().Where(x => x.AttributeClass.ApproximatelyEqual(this.typeReferences.UnionAttribute)).Select(x => x.ConstructorArguments).ToArray(); if (unionAttrs.Length == 0) { throw new MessagePackGeneratorResolveFailedException("Serialization Type must mark UnionAttribute." + " type: " + type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); } // 0, Int 1, SubType - var info = new UnionSerializationInfo + UnionSubTypeInfo UnionSubTypeInfoSelector(ImmutableArray x) { - Name = type.Name, - Namespace = type.ContainingNamespace.IsGlobalNamespace ? null : type.ContainingNamespace.ToDisplayString(), - FullName = type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), - SubTypes = unionAttrs.Select(x => new UnionSubTypeInfo + if (!(x[0] is { Value: int key }) || !(x[1] is { Value: ITypeSymbol typeSymbol })) { - Key = (int)x[0].Value, - Type = x[1].Value is ITypeSymbol typeSymbol ? typeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) : throw new NotSupportedException($"AOT code generation only supports UnionAttribute that uses a Type parameter, but the {type.ToDisplayString(SymbolDisplayFormat.CSharpErrorMessageFormat)} type uses an unsupported parameter."), - }).OrderBy(x => x.Key).ToArray(), - }; + throw new NotSupportedException("AOT code generation only supports UnionAttribute that uses a Type parameter, but the " + type.ToDisplayString(SymbolDisplayFormat.CSharpErrorMessageFormat) + " type uses an unsupported parameter."); + } + + var typeName = typeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + return new UnionSubTypeInfo(key, typeName); + } + + var info = new UnionSerializationInfo(type.ContainingNamespace.IsGlobalNamespace ? null : type.ContainingNamespace.ToDisplayString(), type.Name, type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), unionAttrs.Select(UnionSubTypeInfoSelector).OrderBy(x => x.Key).ToArray()); this.collectedUnionInfo.Add(info); } private void CollectGenericUnion(INamedTypeSymbol type) { - System.Collections.Immutable.ImmutableArray[] unionAttrs = type.GetAttributes().Where(x => x.AttributeClass.ApproximatelyEqual(this.typeReferences.UnionAttribute)).Select(x => x.ConstructorArguments).ToArray(); - if (unionAttrs.Length == 0) + var unionAttrs = type.GetAttributes().Where(x => x.AttributeClass.ApproximatelyEqual(this.typeReferences.UnionAttribute)).Select(x => x.ConstructorArguments); + using var enumerator = unionAttrs.GetEnumerator(); + if (!enumerator.MoveNext()) { return; } - var subTypes = unionAttrs.Select(x => x[1].Value).OfType().ToArray(); - foreach (var unionType in subTypes) + do { - if (alreadyCollected.Contains(unionType) == false) + var x = enumerator.Current; + if (x[1] is { Value: INamedTypeSymbol unionType } && alreadyCollected.Contains(unionType) == false) { CollectCore(unionType); } } + while (enumerator.MoveNext()); } private void CollectArray(IArrayTypeSymbol array) @@ -464,36 +435,26 @@ private void CollectArray(IArrayTypeSymbol array) ITypeSymbol elemType = array.ElementType; this.CollectCore(elemType); - var info = new GenericSerializationInfo - { - FullName = array.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), - IsOpenGenericType = elemType is ITypeParameterSymbol, - }; - + var fullName = array.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + var elementTypeDisplayName = elemType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + string formatterName; if (array.IsSZArray) { - info.FormatterName = $"global::MessagePack.Formatters.ArrayFormatter<{elemType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}>"; - } - else if (array.Rank == 2) - { - info.FormatterName = $"global::MessagePack.Formatters.TwoDimensionalArrayFormatter<{elemType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}>"; - } - else if (array.Rank == 3) - { - info.FormatterName = $"global::MessagePack.Formatters.ThreeDimensionalArrayFormatter<{elemType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}>"; - } - else if (array.Rank == 4) - { - info.FormatterName = $"global::MessagePack.Formatters.FourDimensionalArrayFormatter<{elemType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}>"; + formatterName = "global::MessagePack.Formatters.ArrayFormatter<" + elementTypeDisplayName + ">"; } else { - throw new InvalidOperationException("does not supports array dimension, " + info.FullName); + formatterName = array.Rank switch + { + 2 => "global::MessagePack.Formatters.TwoDimensionalArrayFormatter<" + elementTypeDisplayName + ">", + 3 => "global::MessagePack.Formatters.ThreeDimensionalArrayFormatter<" + elementTypeDisplayName + ">", + 4 => "global::MessagePack.Formatters.FourDimensionalArrayFormatter<" + elementTypeDisplayName + ">", + _ => throw new InvalidOperationException("does not supports array dimension, " + fullName), + }; } + var info = new GenericSerializationInfo(fullName, formatterName, elemType is ITypeParameterSymbol); this.collectedGenericInfo.Add(info); - - return; } private void CollectGeneric(INamedTypeSymbol type) @@ -512,20 +473,16 @@ private void CollectGeneric(INamedTypeSymbol type) // nullable if (genericTypeString == "T?") { - this.CollectCore(type.TypeArguments[0]); + var firstTypeArgument = type.TypeArguments[0]; + this.CollectCore(firstTypeArgument); - if (!this.embeddedTypes.Contains(type.TypeArguments[0].ToString())) + if (this.embeddedTypes.Contains(firstTypeArgument.ToString()!)) { - var info = new GenericSerializationInfo - { - FormatterName = $"global::MessagePack.Formatters.NullableFormatter<{type.TypeArguments[0].ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}>", - FullName = type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), - IsOpenGenericType = isOpenGenericType, - }; - - this.collectedGenericInfo.Add(info); + return; } + var info = new GenericSerializationInfo(type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), "global::MessagePack.Formatters.NullableFormatter<" + firstTypeArgument.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) + ">", isOpenGenericType); + this.collectedGenericInfo.Add(info); return; } @@ -540,43 +497,27 @@ private void CollectGeneric(INamedTypeSymbol type) var typeArgs = string.Join(", ", type.TypeArguments.Select(x => x.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat))); var f = formatter.Replace("TREPLACE", typeArgs); - var info = new GenericSerializationInfo - { - FormatterName = f, - FullName = type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), - IsOpenGenericType = isOpenGenericType, - }; + var info = new GenericSerializationInfo(type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), f, isOpenGenericType); this.collectedGenericInfo.Add(info); - if (genericTypeString == "System.Linq.ILookup<,>") + if (genericTypeString != "System.Linq.ILookup<,>") { - formatter = this.knownGenericTypes["System.Linq.IGrouping<,>"]; - f = formatter.Replace("TREPLACE", typeArgs); - - var groupingInfo = new GenericSerializationInfo - { - FormatterName = f, - FullName = $"global::System.Linq.IGrouping<{typeArgs}>", - IsOpenGenericType = isOpenGenericType, - }; - - this.collectedGenericInfo.Add(groupingInfo); + return; + } - formatter = this.knownGenericTypes["System.Collections.Generic.IEnumerable<>"]; - typeArgs = type.TypeArguments[1].ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - f = formatter.Replace("TREPLACE", typeArgs); + formatter = this.knownGenericTypes["System.Linq.IGrouping<,>"]; + f = formatter.Replace("TREPLACE", typeArgs); - var enumerableInfo = new GenericSerializationInfo - { - FormatterName = f, - FullName = $"global::System.Collections.Generic.IEnumerable<{typeArgs}>", - IsOpenGenericType = isOpenGenericType, - }; + var groupingInfo = new GenericSerializationInfo("global::System.Linq.IGrouping<" + typeArgs + ">", f, isOpenGenericType); + this.collectedGenericInfo.Add(groupingInfo); - this.collectedGenericInfo.Add(enumerableInfo); - } + formatter = this.knownGenericTypes["System.Collections.Generic.IEnumerable<>"]; + typeArgs = type.TypeArguments[1].ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + f = formatter.Replace("TREPLACE", typeArgs); + var enumerableInfo = new GenericSerializationInfo("global::System.Collections.Generic.IEnumerable<" + typeArgs + ">", f, isOpenGenericType); + this.collectedGenericInfo.Add(enumerableInfo); return; } @@ -609,16 +550,23 @@ private void CollectGeneric(INamedTypeSymbol type) formatterBuilder.Append(type.Name); formatterBuilder.Append("Formatter<"); - formatterBuilder.Append(string.Join(", ", type.TypeArguments.Select(x => x.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)))); - formatterBuilder.Append(">"); - - var genericSerializationInfo = new GenericSerializationInfo + var typeArgumentIterator = type.TypeArguments.GetEnumerator(); { - FullName = type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), - FormatterName = formatterBuilder.ToString(), - IsOpenGenericType = isOpenGenericType, - }; + if (typeArgumentIterator.MoveNext()) + { + formatterBuilder.Append(typeArgumentIterator.Current.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); + } + + while (typeArgumentIterator.MoveNext()) + { + formatterBuilder.Append(", "); + formatterBuilder.Append(typeArgumentIterator.Current.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); + } + } + formatterBuilder.Append('>'); + + var genericSerializationInfo = new GenericSerializationInfo(type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), formatterBuilder.ToString(), isOpenGenericType); this.collectedGenericInfo.Add(genericSerializationInfo); } @@ -633,17 +581,14 @@ private ObjectSerializationInfo GetObjectInfo(INamedTypeSymbol type) var isClass = !type.IsValueType; var isOpenGenericType = type.IsGenericType; - AttributeData contractAttr = type.GetAttributes().FirstOrDefault(x => x.AttributeClass.ApproximatelyEqual(this.typeReferences.MessagePackObjectAttribute)); - if (contractAttr == null) - { - throw new MessagePackGeneratorResolveFailedException("Serialization Object must mark MessagePackObjectAttribute." + " type: " + type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); - } + AttributeData contractAttr = type.GetAttributes().FirstOrDefault(x => x.AttributeClass.ApproximatelyEqual(this.typeReferences.MessagePackObjectAttribute)) + ?? throw new MessagePackGeneratorResolveFailedException("Serialization Object must mark MessagePackObjectAttribute." + " type: " + type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); var isIntKey = true; var intMembers = new Dictionary(); var stringMembers = new Dictionary(); - if (this.isForceUseMap || (bool)contractAttr.ConstructorArguments[0].Value) + if (this.isForceUseMap || (contractAttr.ConstructorArguments[0] is { Value: bool firstConstructorArgument } && firstConstructorArgument)) { // All public members are serialize target except [Ignore] member. isIntKey = false; @@ -652,31 +597,20 @@ private ObjectSerializationInfo GetObjectInfo(INamedTypeSymbol type) foreach (IPropertySymbol item in type.GetAllMembers().OfType().Where(x => !x.IsOverride)) { - if (item.GetAttributes().Any(x => x.AttributeClass.ApproximatelyEqual(this.typeReferences.IgnoreAttribute) || x.AttributeClass.Name == this.typeReferences.IgnoreDataMemberAttribute.Name)) + if (item.GetAttributes().Any(x => (x.AttributeClass.ApproximatelyEqual(this.typeReferences.IgnoreAttribute) || x.AttributeClass?.Name == this.typeReferences.IgnoreDataMemberAttribute?.Name))) { continue; } - var customFormatterAttr = item.GetAttributes().FirstOrDefault(x => x.AttributeClass.ApproximatelyEqual(this.typeReferences.MessagePackFormatterAttribute))?.ConstructorArguments[0].Value as INamedTypeSymbol; - - var member = new MemberSerializationInfo - { - IsReadable = (item.GetMethod != null) && item.GetMethod.DeclaredAccessibility == Accessibility.Public && !item.IsStatic, - IsWritable = (item.SetMethod != null) && item.SetMethod.DeclaredAccessibility == Accessibility.Public && !item.IsStatic, - StringKey = item.Name, - IsProperty = true, - IsField = false, - Name = item.Name, - Type = item.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), - ShortTypeName = item.Type.ToDisplayString(BinaryWriteFormat), - CustomFormatterTypeName = customFormatterAttr?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), - }; - if (!member.IsReadable && !member.IsWritable) + var isReadable = item.GetMethod != null && item.GetMethod.DeclaredAccessibility == Accessibility.Public && !item.IsStatic; + var isWritable = item.SetMethod != null && item.SetMethod.DeclaredAccessibility == Accessibility.Public && !item.IsStatic; + if (!isReadable && !isWritable) { continue; } - member.IntKey = hiddenIntKey++; + var customFormatterAttr = item.GetAttributes().FirstOrDefault(x => x.AttributeClass.ApproximatelyEqual(this.typeReferences.MessagePackFormatterAttribute))?.ConstructorArguments[0].Value as INamedTypeSymbol; + var member = new MemberSerializationInfo(true, isWritable, isReadable, hiddenIntKey++, item.Name, item.Name, item.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), item.Type.ToDisplayString(BinaryWriteFormat), customFormatterAttr?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); stringMembers.Add(member.StringKey, member); this.CollectCore(item.Type); // recursive collect @@ -684,7 +618,7 @@ private ObjectSerializationInfo GetObjectInfo(INamedTypeSymbol type) foreach (IFieldSymbol item in type.GetAllMembers().OfType()) { - if (item.GetAttributes().Any(x => x.AttributeClass.ApproximatelyEqual(this.typeReferences.IgnoreAttribute) || x.AttributeClass.Name == this.typeReferences.IgnoreDataMemberAttribute.Name)) + if (item.GetAttributes().Any(x => (x.AttributeClass.ApproximatelyEqual(this.typeReferences.IgnoreAttribute) || x.AttributeClass?.Name == this.typeReferences.IgnoreDataMemberAttribute?.Name))) { continue; } @@ -694,26 +628,15 @@ private ObjectSerializationInfo GetObjectInfo(INamedTypeSymbol type) continue; } - var customFormatterAttr = item.GetAttributes().FirstOrDefault(x => x.AttributeClass.ApproximatelyEqual(this.typeReferences.MessagePackFormatterAttribute))?.ConstructorArguments[0].Value as INamedTypeSymbol; - - var member = new MemberSerializationInfo - { - IsReadable = item.DeclaredAccessibility == Accessibility.Public && !item.IsStatic, - IsWritable = item.DeclaredAccessibility == Accessibility.Public && !item.IsReadOnly && !item.IsStatic, - StringKey = item.Name, - IsProperty = false, - IsField = true, - Name = item.Name, - Type = item.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), - ShortTypeName = item.Type.ToDisplayString(BinaryWriteFormat), - CustomFormatterTypeName = customFormatterAttr?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), - }; - if (!member.IsReadable && !member.IsWritable) + var isReadable = item.DeclaredAccessibility == Accessibility.Public && !item.IsStatic; + var isWritable = item.DeclaredAccessibility == Accessibility.Public && !item.IsReadOnly && !item.IsStatic; + if (!isReadable && !isWritable) { continue; } - member.IntKey = hiddenIntKey++; + var customFormatterAttr = item.GetAttributes().FirstOrDefault(x => x.AttributeClass.ApproximatelyEqual(this.typeReferences.MessagePackFormatterAttribute))?.ConstructorArguments[0].Value as INamedTypeSymbol; + var member = new MemberSerializationInfo(false, isWritable, isReadable, hiddenIntKey++, item.Name, item.Name, item.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), item.Type.ToDisplayString(BinaryWriteFormat), customFormatterAttr?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); stringMembers.Add(member.StringKey, member); this.CollectCore(item.Type); // recursive collect } @@ -731,37 +654,28 @@ private ObjectSerializationInfo GetObjectInfo(INamedTypeSymbol type) continue; // .tt files don't generate good code for this yet: https://github.com/neuecc/MessagePack-CSharp/issues/390 } - if (item.GetAttributes().Any(x => x.AttributeClass.ApproximatelyEqual(this.typeReferences.IgnoreAttribute) || x.AttributeClass.ApproximatelyEqual(this.typeReferences.IgnoreDataMemberAttribute))) + if (item.GetAttributes().Any(x => + { + var typeReferencesIgnoreDataMemberAttribute = this.typeReferences.IgnoreDataMemberAttribute; + return typeReferencesIgnoreDataMemberAttribute != null && (x.AttributeClass.ApproximatelyEqual(this.typeReferences.IgnoreAttribute) || x.AttributeClass.ApproximatelyEqual(typeReferencesIgnoreDataMemberAttribute)); + })) { continue; } - var customFormatterAttr = item.GetAttributes().FirstOrDefault(x => x.AttributeClass.ApproximatelyEqual(this.typeReferences.MessagePackFormatterAttribute))?.ConstructorArguments[0].Value as INamedTypeSymbol; - - var member = new MemberSerializationInfo - { - IsReadable = (item.GetMethod != null) && item.GetMethod.DeclaredAccessibility == Accessibility.Public && !item.IsStatic, - IsWritable = (item.SetMethod != null) && item.SetMethod.DeclaredAccessibility == Accessibility.Public && !item.IsStatic, - IsProperty = true, - IsField = false, - Name = item.Name, - Type = item.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), - ShortTypeName = item.Type.ToDisplayString(BinaryWriteFormat), - CustomFormatterTypeName = customFormatterAttr?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), - }; - if (!member.IsReadable && !member.IsWritable) + var isReadable = item.GetMethod != null && item.GetMethod.DeclaredAccessibility == Accessibility.Public && !item.IsStatic; + var isWritable = item.SetMethod != null && item.SetMethod.DeclaredAccessibility == Accessibility.Public && !item.IsStatic; + if (!isReadable && !isWritable) { continue; } - TypedConstant? key = item.GetAttributes().FirstOrDefault(x => x.AttributeClass.ApproximatelyEqual(this.typeReferences.KeyAttribute))?.ConstructorArguments[0]; - if (key == null) - { - throw new MessagePackGeneratorResolveFailedException("all public members must mark KeyAttribute or IgnoreMemberAttribute." + " type: " + type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) + " member:" + item.Name); - } + var customFormatterAttr = item.GetAttributes().FirstOrDefault(x => x.AttributeClass.ApproximatelyEqual(this.typeReferences.MessagePackFormatterAttribute))?.ConstructorArguments[0].Value as INamedTypeSymbol; + var key = item.GetAttributes().FirstOrDefault(x => x.AttributeClass.ApproximatelyEqual(this.typeReferences.KeyAttribute))?.ConstructorArguments[0] + ?? throw new MessagePackGeneratorResolveFailedException("all public members must mark KeyAttribute or IgnoreMemberAttribute." + " type: " + type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) + " member:" + item.Name); - var intKey = (key.Value.Value is int) ? (int)key.Value.Value : (int?)null; - var stringKey = (key.Value.Value is string) ? (string)key.Value.Value : (string)null; + var intKey = key is { Value: int intKeyValue } ? intKeyValue : default(int?); + var stringKey = key is { Value: string stringKeyValue } ? stringKeyValue : default; if (intKey == null && stringKey == null) { throw new MessagePackGeneratorResolveFailedException("both IntKey and StringKey are null." + " type: " + type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) + " member:" + item.Name); @@ -782,23 +696,22 @@ private ObjectSerializationInfo GetObjectInfo(INamedTypeSymbol type) if (isIntKey) { - member.IntKey = (int)intKey; - if (intMembers.ContainsKey(member.IntKey)) + if (intMembers.ContainsKey(intKey!.Value)) { throw new MessagePackGeneratorResolveFailedException("key is duplicated, all members key must be unique." + " type: " + type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) + " member:" + item.Name); } + var member = new MemberSerializationInfo(true, isWritable, isReadable, intKey!.Value, item.Name, item.Name, item.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), item.Type.ToDisplayString(BinaryWriteFormat), customFormatterAttr?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); intMembers.Add(member.IntKey, member); } else { - member.StringKey = (string)stringKey; - if (stringMembers.ContainsKey(member.StringKey)) + if (stringMembers.ContainsKey(stringKey!)) { throw new MessagePackGeneratorResolveFailedException("key is duplicated, all members key must be unique." + " type: " + type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) + " member:" + item.Name); } - member.IntKey = hiddenIntKey++; + var member = new MemberSerializationInfo(true, isWritable, isReadable, hiddenIntKey++, stringKey!, item.Name, item.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), item.Type.ToDisplayString(BinaryWriteFormat), customFormatterAttr?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); stringMembers.Add(member.StringKey, member); } @@ -817,32 +730,19 @@ private ObjectSerializationInfo GetObjectInfo(INamedTypeSymbol type) continue; } - var customFormatterAttr = item.GetAttributes().FirstOrDefault(x => x.AttributeClass.ApproximatelyEqual(this.typeReferences.MessagePackFormatterAttribute))?.ConstructorArguments[0].Value as INamedTypeSymbol; - - var member = new MemberSerializationInfo - { - IsReadable = item.DeclaredAccessibility == Accessibility.Public && !item.IsStatic, - IsWritable = item.DeclaredAccessibility == Accessibility.Public && !item.IsReadOnly && !item.IsStatic, - IsProperty = true, - IsField = false, - Name = item.Name, - Type = item.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), - ShortTypeName = item.Type.ToDisplayString(BinaryWriteFormat), - CustomFormatterTypeName = customFormatterAttr?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), - }; - if (!member.IsReadable && !member.IsWritable) + var isReadable = item.DeclaredAccessibility == Accessibility.Public && !item.IsStatic; + var isWritable = item.DeclaredAccessibility == Accessibility.Public && !item.IsReadOnly && !item.IsStatic; + if (!isReadable && !isWritable) { continue; } - TypedConstant? key = item.GetAttributes().FirstOrDefault(x => x.AttributeClass.ApproximatelyEqual(this.typeReferences.KeyAttribute))?.ConstructorArguments[0]; - if (key == null) - { - throw new MessagePackGeneratorResolveFailedException("all public members must mark KeyAttribute or IgnoreMemberAttribute." + " type: " + type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) + " member:" + item.Name); - } + var customFormatterAttr = item.GetAttributes().FirstOrDefault(x => x.AttributeClass.ApproximatelyEqual(this.typeReferences.MessagePackFormatterAttribute))?.ConstructorArguments[0].Value as INamedTypeSymbol; + var key = item.GetAttributes().FirstOrDefault(x => x.AttributeClass.ApproximatelyEqual(this.typeReferences.KeyAttribute))?.ConstructorArguments[0] + ?? throw new MessagePackGeneratorResolveFailedException("all public members must mark KeyAttribute or IgnoreMemberAttribute." + " type: " + type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) + " member:" + item.Name); - var intKey = (key.Value.Value is int) ? (int)key.Value.Value : (int?)null; - var stringKey = (key.Value.Value is string) ? (string)key.Value.Value : (string)null; + var intKey = key is { Value: int intKeyValue } ? intKeyValue : default(int?); + var stringKey = key is { Value: string stringKeyValue } ? stringKeyValue : default; if (intKey == null && stringKey == null) { throw new MessagePackGeneratorResolveFailedException("both IntKey and StringKey are null." + " type: " + type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) + " member:" + item.Name); @@ -863,23 +763,22 @@ private ObjectSerializationInfo GetObjectInfo(INamedTypeSymbol type) if (isIntKey) { - member.IntKey = (int)intKey; - if (intMembers.ContainsKey(member.IntKey)) + if (intMembers.ContainsKey(intKey!.Value)) { throw new MessagePackGeneratorResolveFailedException("key is duplicated, all members key must be unique." + " type: " + type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) + " member:" + item.Name); } + var member = new MemberSerializationInfo(true, isWritable, isReadable, intKey!.Value, item.Name, item.Name, item.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), item.Type.ToDisplayString(BinaryWriteFormat), customFormatterAttr?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); intMembers.Add(member.IntKey, member); } else { - member.StringKey = (string)stringKey; - if (stringMembers.ContainsKey(member.StringKey)) + if (stringMembers.ContainsKey(stringKey!)) { throw new MessagePackGeneratorResolveFailedException("key is duplicated, all members key must be unique." + " type: " + type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) + " member:" + item.Name); } - member.IntKey = hiddenIntKey++; + var member = new MemberSerializationInfo(true, isWritable, isReadable, hiddenIntKey++, stringKey!, item.Name, item.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), item.Type.ToDisplayString(BinaryWriteFormat), customFormatterAttr?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); stringMembers.Add(member.StringKey, member); } @@ -888,8 +787,8 @@ private ObjectSerializationInfo GetObjectInfo(INamedTypeSymbol type) } // GetConstructor - IEnumerator ctorEnumerator = null; - IMethodSymbol ctor = type.Constructors.Where(x => x.DeclaredAccessibility == Accessibility.Public).SingleOrDefault(x => x.GetAttributes().Any(y => y.AttributeClass.ApproximatelyEqual(this.typeReferences.SerializationConstructorAttribute))); + var ctorEnumerator = default(IEnumerator); + var ctor = type.Constructors.Where(x => x.DeclaredAccessibility == Accessibility.Public).SingleOrDefault(x => x.GetAttributes().Any(y => y.AttributeClass != null && y.AttributeClass.ApproximatelyEqual(this.typeReferences.SerializationConstructorAttribute))); if (ctor == null) { ctorEnumerator = type.Constructors.Where(x => x.DeclaredAccessibility == Accessibility.Public).OrderByDescending(x => x.Parameters.Length).GetEnumerator(); @@ -909,17 +808,17 @@ private ObjectSerializationInfo GetObjectInfo(INamedTypeSymbol type) var constructorParameters = new List(); if (ctor != null) { - ILookup> constructorLookupDictionary = stringMembers.ToLookup(x => x.Key, x => x, StringComparer.OrdinalIgnoreCase); + var constructorLookupDictionary = stringMembers.ToLookup(x => x.Key, x => x, StringComparer.OrdinalIgnoreCase); do { constructorParameters.Clear(); var ctorParamIndex = 0; - foreach (IParameterSymbol item in ctor.Parameters) + foreach (IParameterSymbol item in ctor!.Parameters) { MemberSerializationInfo paramMember; if (isIntKey) { - if (intMembers.TryGetValue(ctorParamIndex, out paramMember)) + if (intMembers.TryGetValue(ctorParamIndex, out paramMember!)) { if (item.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) == paramMember.Type && paramMember.IsReadable) { @@ -954,51 +853,46 @@ private ObjectSerializationInfo GetObjectInfo(INamedTypeSymbol type) else { IEnumerable> hasKey = constructorLookupDictionary[item.Name]; - var len = hasKey.Count(); - if (len != 0) + using var enumerator = hasKey.GetEnumerator(); + // hasKey.Count() == 0 + if (!enumerator.MoveNext()) { - if (len != 1) + if (ctorEnumerator == null) { - if (ctorEnumerator != null) - { - ctor = null; - continue; - } - else - { - throw new MessagePackGeneratorResolveFailedException("duplicate matched constructor parameter name:" + type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) + " parameterName:" + item.Name + " paramterType:" + item.Type.Name); - } + throw new MessagePackGeneratorResolveFailedException("can't find matched constructor parameter, index not found. type:" + type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) + " parameterName:" + item.Name); } - paramMember = hasKey.First().Value; - if (item.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) == paramMember.Type && paramMember.IsReadable) - { - constructorParameters.Add(paramMember); - } - else + ctor = null; + continue; + } + + var first = enumerator.Current.Value; + // hasKey.Count() != 1 + if (enumerator.MoveNext()) + { + if (ctorEnumerator == null) { - if (ctorEnumerator != null) - { - ctor = null; - continue; - } - else - { - throw new MessagePackGeneratorResolveFailedException("can't find matched constructor parameter, parameterType mismatch. type:" + type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) + " parameterName:" + item.Name + " paramterType:" + item.Type.Name); - } + throw new MessagePackGeneratorResolveFailedException("duplicate matched constructor parameter name:" + type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) + " parameterName:" + item.Name + " paramterType:" + item.Type.Name); } + + ctor = null; + continue; + } + + paramMember = first; + if (item.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) == paramMember.Type && paramMember.IsReadable) + { + constructorParameters.Add(paramMember); } else { - if (ctorEnumerator != null) - { - ctor = null; - continue; - } - else + if (ctorEnumerator == null) { - throw new MessagePackGeneratorResolveFailedException("can't find matched constructor parameter, index not found. type:" + type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) + " parameterName:" + item.Name); + throw new MessagePackGeneratorResolveFailedException("can't find matched constructor parameter, parameterType mismatch. type:" + type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) + " parameterName:" + item.Name + " paramterType:" + item.Type.Name); } + + ctor = null; + continue; } } @@ -1022,23 +916,7 @@ private ObjectSerializationInfo GetObjectInfo(INamedTypeSymbol type) needsCastOnAfter = !type.GetMembers("OnAfterDeserialize").Any(); } - var info = new ObjectSerializationInfo - { - IsClass = isClass, - IsOpenGenericType = isOpenGenericType, - GenericTypeParameters = isOpenGenericType - ? type.TypeParameters.Select(ToGenericTypeParameterInfo).ToArray() - : Array.Empty(), - ConstructorParameters = constructorParameters.ToArray(), - IsIntKey = isIntKey, - Members = isIntKey ? intMembers.Values.ToArray() : stringMembers.Values.ToArray(), - Name = isOpenGenericType ? GetGenericFormatterClassName(type) : GetMinimallyQualifiedClassName(type), - FullName = type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), - Namespace = type.ContainingNamespace.IsGlobalNamespace ? null : type.ContainingNamespace.ToDisplayString(), - HasIMessagePackSerializationCallbackReceiver = hasSerializationConstructor, - NeedsCastOnAfter = needsCastOnAfter, - NeedsCastOnBefore = needsCastOnBefore, - }; + var info = new ObjectSerializationInfo(isClass, isOpenGenericType, isOpenGenericType ? type.TypeParameters.Select(ToGenericTypeParameterInfo).ToArray() : Array.Empty(), constructorParameters.ToArray(), isIntKey, isIntKey ? intMembers.Values.ToArray() : stringMembers.Values.ToArray(), isOpenGenericType ? GetGenericFormatterClassName(type) : GetMinimallyQualifiedClassName(type), type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), type.ContainingNamespace.IsGlobalNamespace ? null : type.ContainingNamespace.ToDisplayString(), hasSerializationConstructor, needsCastOnAfter, needsCastOnBefore); return info; } @@ -1055,26 +933,12 @@ private static GenericTypeParameterInfo ToGenericTypeParameterInfo(ITypeParamete if (typeParameter.HasReferenceTypeConstraint) { - if (typeParameter.ReferenceTypeConstraintNullableAnnotation == NullableAnnotation.Annotated) - { - constraints.Add("class?"); - } - else - { - constraints.Add("class"); - } + constraints.Add(typeParameter.ReferenceTypeConstraintNullableAnnotation == NullableAnnotation.Annotated ? "class?" : "class"); } if (typeParameter.HasValueTypeConstraint) { - if (typeParameter.HasUnmanagedTypeConstraint) - { - constraints.Add("unmanaged"); - } - else - { - constraints.Add("struct"); - } + constraints.Add(typeParameter.HasUnmanagedTypeConstraint ? "unmanaged" : "struct"); } // constraint types (IDisposable, IEnumerable ...) @@ -1109,7 +973,7 @@ private static string GetMinimallyQualifiedClassName(INamedTypeSymbol type) return name; } - private static bool TryGetNextConstructor(IEnumerator ctorEnumerator, ref IMethodSymbol ctor) + private static bool TryGetNextConstructor(IEnumerator? ctorEnumerator, ref IMethodSymbol? ctor) { if (ctorEnumerator == null || ctor != null) { diff --git a/src/MessagePack.GeneratorCore/CodeGenerator.cs b/src/MessagePack.GeneratorCore/CodeGenerator.cs index a48528a04..f96546937 100644 --- a/src/MessagePack.GeneratorCore/CodeGenerator.cs +++ b/src/MessagePack.GeneratorCore/CodeGenerator.cs @@ -18,13 +18,11 @@ public class CodeGenerator { private static readonly Encoding NoBomUtf8 = new UTF8Encoding(false); - private Action logger; - private CancellationToken cancellationToken; + private readonly Action logger; public CodeGenerator(Action logger, CancellationToken cancellationToken) { this.logger = logger; - this.cancellationToken = cancellationToken; } /// @@ -42,21 +40,21 @@ public CodeGenerator(Action logger, CancellationToken cancellationToken) Compilation compilation, string output, string resolverName, - string @namespace, + string? @namespace, bool useMapMode, - string multipleIfDirectiveOutputSymbols, - string[] externalIgnoreTypeNames) + string? multipleIfDirectiveOutputSymbols, + string[]? externalIgnoreTypeNames) { var namespaceDot = string.IsNullOrWhiteSpace(@namespace) ? string.Empty : @namespace + "."; var multipleOutputSymbols = multipleIfDirectiveOutputSymbols?.Split(',') ?? Array.Empty(); var sw = Stopwatch.StartNew(); - foreach (var multioutSymbol in multipleOutputSymbols.Length == 0 ? new[] { string.Empty } : multipleOutputSymbols) + foreach (var multiOutputSymbol in multipleOutputSymbols.Length == 0 ? new[] { string.Empty } : multipleOutputSymbols) { logger("Project Compilation Start:" + compilation.AssemblyName); - var collector = new TypeCollector(compilation, true, useMapMode, externalIgnoreTypeNames, x => Console.WriteLine(x)); + var collector = new TypeCollector(compilation, true, useMapMode, externalIgnoreTypeNames, Console.WriteLine); logger("Project Compilation Complete:" + sw.Elapsed.ToString()); @@ -74,21 +72,21 @@ public CodeGenerator(Action logger, CancellationToken cancellationToken) { // SingleFile Output var fullGeneratedProgramText = GenerateSingleFileSync(resolverName, namespaceDot, objectInfo, enumInfo, unionInfo, genericInfo); - if (multioutSymbol == string.Empty) + if (multiOutputSymbol == string.Empty) { - await OutputAsync(output, fullGeneratedProgramText, cancellationToken); + await OutputAsync(output, fullGeneratedProgramText); } else { - var fname = Path.GetFileNameWithoutExtension(output) + "." + MultiSymbolToSafeFilePath(multioutSymbol) + ".cs"; - var text = $"#if {multioutSymbol}" + Environment.NewLine + fullGeneratedProgramText + Environment.NewLine + "#endif"; - await OutputAsync(Path.Combine(Path.GetDirectoryName(output), fname), text, cancellationToken); + var fname = Path.GetFileNameWithoutExtension(output) + "." + MultiSymbolToSafeFilePath(multiOutputSymbol) + ".cs"; + var text = $"#if {multiOutputSymbol}" + Environment.NewLine + fullGeneratedProgramText + Environment.NewLine + "#endif"; + await OutputAsync(Path.Combine(Path.GetDirectoryName(output) ?? string.Empty, fname), text); } } else { // Multiple File output - await GenerateMultipleFileAsync(output, resolverName, objectInfo, enumInfo, unionInfo, namespaceDot, multioutSymbol, genericInfo); + await GenerateMultipleFileAsync(output, resolverName, objectInfo, enumInfo, unionInfo, namespaceDot, multiOutputSymbol, genericInfo); } if (objectInfo.Length == 0 && enumInfo.Length == 0 && genericInfo.Length == 0 && unionInfo.Length == 0) @@ -117,40 +115,33 @@ public static string GenerateSingleFileSync(string resolverName, string namespac { var (nameSpace, isStringKey) = x.Key; var objectSerializationInfos = x.ToArray(); - var template = isStringKey ? new StringKeyFormatterTemplate() : (IFormatterTemplate)new FormatterTemplate(); - - template.Namespace = namespaceDot + "Formatters" + (nameSpace is null ? string.Empty : "." + nameSpace); - template.ObjectSerializationInfos = objectSerializationInfos; - + var ns = namespaceDot + "Formatters" + (nameSpace is null ? string.Empty : "." + nameSpace); + var template = isStringKey ? new StringKeyFormatterTemplate(ns, objectSerializationInfos) : (IFormatterTemplate)new FormatterTemplate(ns, objectSerializationInfos); return template; }) .ToArray(); + string GetNamespace(IGrouping x) + { + if (x.Key == null) + { + return namespaceDot + "Formatters"; + } + + return namespaceDot + "Formatters." + x.Key; + } + var enumFormatterTemplates = enumInfo .GroupBy(x => x.Namespace) - .Select(x => new EnumTemplate() - { - Namespace = namespaceDot + "Formatters" + ((x.Key == null) ? string.Empty : "." + x.Key), - EnumSerializationInfos = x.ToArray(), - }) + .Select(x => new EnumTemplate(GetNamespace(x), x.ToArray())) .ToArray(); var unionFormatterTemplates = unionInfo .GroupBy(x => x.Namespace) - .Select(x => new UnionTemplate() - { - Namespace = namespaceDot + "Formatters" + ((x.Key == null) ? string.Empty : "." + x.Key), - UnionSerializationInfos = x.ToArray(), - }) + .Select(x => new UnionTemplate(GetNamespace(x), x.ToArray())) .ToArray(); - var resolverTemplate = new ResolverTemplate() - { - Namespace = namespaceDot + "Resolvers", - FormatterNamespace = namespaceDot + "Formatters", - ResolverName = resolverName, - RegisterInfos = genericInfo.Where(x => !x.IsOpenGenericType).Cast().Concat(enumInfo).Concat(unionInfo).Concat(objectInfo.Where(x => !x.IsOpenGenericType)).ToArray(), - }; + var resolverTemplate = new ResolverTemplate(namespaceDot + "Resolvers", namespaceDot + "Formatters", resolverName, genericInfo.Where(x => !x.IsOpenGenericType).Cast().Concat(enumInfo).Concat(unionInfo).Concat(objectInfo.Where(x => !x.IsOpenGenericType)).ToArray()); var sb = new StringBuilder(); sb.AppendLine(resolverTemplate.TransformText()); @@ -180,68 +171,59 @@ public static string GenerateSingleFileSync(string resolverName, string namespac private Task GenerateMultipleFileAsync(string output, string resolverName, ObjectSerializationInfo[] objectInfo, EnumSerializationInfo[] enumInfo, UnionSerializationInfo[] unionInfo, string namespaceDot, string multioutSymbol, GenericSerializationInfo[] genericInfo) { + string GetNamespace(INamespaceInfo x) + { + if (x.Namespace == null) + { + return namespaceDot + "Formatters"; + } + + return namespaceDot + "Formatters." + x.Namespace; + } + var waitingTasks = new Task[objectInfo.Length + enumInfo.Length + unionInfo.Length + 1]; var waitingIndex = 0; foreach (var x in objectInfo) { - var template = x.IsStringKey ? new StringKeyFormatterTemplate() : (IFormatterTemplate)new FormatterTemplate(); - template.Namespace = namespaceDot + "Formatters" + (x.Namespace is null ? string.Empty : "." + x.Namespace); - template.ObjectSerializationInfos = new[] { x }; - + var ns = namespaceDot + "Formatters" + (x.Namespace is null ? string.Empty : "." + x.Namespace); + var template = x.IsStringKey ? new StringKeyFormatterTemplate(ns, new[] { x }) : (IFormatterTemplate)new FormatterTemplate(ns, new[] { x }); var text = template.TransformText(); - waitingTasks[waitingIndex++] = OutputToDirAsync(output, template.Namespace, x.Name + "Formatter", multioutSymbol, text, cancellationToken); + waitingTasks[waitingIndex++] = OutputToDirAsync(output, template.Namespace, x.Name + "Formatter", multioutSymbol, text); } foreach (var x in enumInfo) { - var template = new EnumTemplate() - { - Namespace = namespaceDot + "Formatters" + ((x.Namespace == null) ? string.Empty : "." + x.Namespace), - EnumSerializationInfos = new[] { x }, - }; - + var template = new EnumTemplate(GetNamespace(x), new[] { x }); var text = template.TransformText(); - waitingTasks[waitingIndex++] = OutputToDirAsync(output, template.Namespace, x.Name + "Formatter", multioutSymbol, text, cancellationToken); + waitingTasks[waitingIndex++] = OutputToDirAsync(output, template.Namespace, x.Name + "Formatter", multioutSymbol, text); } foreach (var x in unionInfo) { - var template = new UnionTemplate() - { - Namespace = namespaceDot + "Formatters" + ((x.Namespace == null) ? string.Empty : "." + x.Namespace), - UnionSerializationInfos = new[] { x }, - }; - + var template = new UnionTemplate(GetNamespace(x), new[] { x }); var text = template.TransformText(); - waitingTasks[waitingIndex++] = OutputToDirAsync(output, template.Namespace, x.Name + "Formatter", multioutSymbol, text, cancellationToken); + waitingTasks[waitingIndex++] = OutputToDirAsync(output, template.Namespace, x.Name + "Formatter", multioutSymbol, text); } - var resolverTemplate = new ResolverTemplate() - { - Namespace = namespaceDot + "Resolvers", - FormatterNamespace = namespaceDot + "Formatters", - ResolverName = resolverName, - RegisterInfos = genericInfo.Where(x => !x.IsOpenGenericType).Cast().Concat(enumInfo).Concat(unionInfo).Concat(objectInfo.Where(x => !x.IsOpenGenericType)).ToArray(), - }; - - waitingTasks[waitingIndex] = OutputToDirAsync(output, resolverTemplate.Namespace, resolverTemplate.ResolverName, multioutSymbol, resolverTemplate.TransformText(), cancellationToken); + var resolverTemplate = new ResolverTemplate(namespaceDot + "Resolvers", namespaceDot + "Formatters", resolverName, genericInfo.Where(x => !x.IsOpenGenericType).Cast().Concat(enumInfo).Concat(unionInfo).Concat(objectInfo.Where(x => !x.IsOpenGenericType)).ToArray()); + waitingTasks[waitingIndex] = OutputToDirAsync(output, resolverTemplate.Namespace, resolverTemplate.ResolverName, multioutSymbol, resolverTemplate.TransformText()); return Task.WhenAll(waitingTasks); } - private Task OutputToDirAsync(string dir, string ns, string name, string multipleOutSymbol, string text, CancellationToken cancellationToken) + private Task OutputToDirAsync(string dir, string ns, string name, string multipleOutSymbol, string text) { if (multipleOutSymbol == string.Empty) { - return OutputAsync(Path.Combine(dir, $"{ns}_{name}".Replace(".", "_").Replace("global::", string.Empty) + ".cs"), text, cancellationToken); + return OutputAsync(Path.Combine(dir, $"{ns}_{name}".Replace(".", "_").Replace("global::", string.Empty) + ".cs"), text); } else { text = $"#if {multipleOutSymbol}" + Environment.NewLine + text + Environment.NewLine + "#endif"; - return OutputAsync(Path.Combine(dir, MultiSymbolToSafeFilePath(multipleOutSymbol), $"{ns}_{name}".Replace(".", "_").Replace("global::", string.Empty) + ".cs"), text, cancellationToken); + return OutputAsync(Path.Combine(dir, MultiSymbolToSafeFilePath(multipleOutSymbol), $"{ns}_{name}".Replace(".", "_").Replace("global::", string.Empty) + ".cs"), text); } } - private Task OutputAsync(string path, string text, CancellationToken cancellationToken) + private Task OutputAsync(string path, string text) { path = path.Replace("global::", string.Empty); @@ -249,12 +231,12 @@ private Task OutputAsync(string path, string text, CancellationToken cancellatio logger(prefix + path); var fi = new FileInfo(path); - if (!fi.Directory.Exists) + if (fi.Directory != null && !fi.Directory.Exists) { fi.Directory.Create(); } - System.IO.File.WriteAllText(path, NormalizeNewLines(text), NoBomUtf8); + File.WriteAllText(path, NormalizeNewLines(text), NoBomUtf8); return Task.CompletedTask; } diff --git a/src/MessagePack.GeneratorCore/Generator/IFormatterTemplate.cs b/src/MessagePack.GeneratorCore/Generator/IFormatterTemplate.cs index 188f79954..ea3363a6e 100644 --- a/src/MessagePack.GeneratorCore/Generator/IFormatterTemplate.cs +++ b/src/MessagePack.GeneratorCore/Generator/IFormatterTemplate.cs @@ -7,9 +7,9 @@ namespace MessagePackCompiler.Generator { public interface IFormatterTemplate { - string Namespace { get; set; } + string Namespace { get; } - ObjectSerializationInfo[] ObjectSerializationInfos { get; set; } + ObjectSerializationInfo[] ObjectSerializationInfos { get; } string TransformText(); } diff --git a/src/MessagePack.GeneratorCore/Generator/TemplatePartials.cs b/src/MessagePack.GeneratorCore/Generator/TemplatePartials.cs index 6ab80bf5a..bca5c5876 100644 --- a/src/MessagePack.GeneratorCore/Generator/TemplatePartials.cs +++ b/src/MessagePack.GeneratorCore/Generator/TemplatePartials.cs @@ -7,40 +7,72 @@ namespace MessagePackCompiler.Generator { public partial class FormatterTemplate : IFormatterTemplate { - public string Namespace { get; set; } + public FormatterTemplate(string @namespace, ObjectSerializationInfo[] objectSerializationInfos) + { + Namespace = @namespace; + ObjectSerializationInfos = objectSerializationInfos; + } - public ObjectSerializationInfo[] ObjectSerializationInfos { get; set; } + public string Namespace { get; } + + public ObjectSerializationInfo[] ObjectSerializationInfos { get; } } public partial class StringKeyFormatterTemplate : IFormatterTemplate { - public string Namespace { get; set; } + public StringKeyFormatterTemplate(string @namespace, ObjectSerializationInfo[] objectSerializationInfos) + { + Namespace = @namespace; + ObjectSerializationInfos = objectSerializationInfos; + } + + public string Namespace { get; } - public ObjectSerializationInfo[] ObjectSerializationInfos { get; set; } + public ObjectSerializationInfo[] ObjectSerializationInfos { get; } } public partial class ResolverTemplate { - public string Namespace { get; set; } + public ResolverTemplate(string @namespace, string formatterNamespace, string resolverName, IResolverRegisterInfo[] registerInfos) + { + Namespace = @namespace; + FormatterNamespace = formatterNamespace; + ResolverName = resolverName; + RegisterInfos = registerInfos; + } + + public string Namespace { get; } - public string FormatterNamespace { get; set; } + public string FormatterNamespace { get; } - public string ResolverName { get; set; } = "GeneratedResolver"; + public string ResolverName { get; } - public IResolverRegisterInfo[] RegisterInfos { get; set; } + public IResolverRegisterInfo[] RegisterInfos { get; } } public partial class EnumTemplate { - public string Namespace { get; set; } + public EnumTemplate(string @namespace, EnumSerializationInfo[] enumSerializationInfos) + { + Namespace = @namespace; + EnumSerializationInfos = enumSerializationInfos; + } - public EnumSerializationInfo[] EnumSerializationInfos { get; set; } + public string Namespace { get; } + + public EnumSerializationInfo[] EnumSerializationInfos { get; } } public partial class UnionTemplate { - public string Namespace { get; set; } + public UnionTemplate(string @namespace, UnionSerializationInfo[] unionSerializationInfos) + { + Namespace = @namespace; + UnionSerializationInfos = unionSerializationInfos; + } + + public string Namespace { get; } - public UnionSerializationInfo[] UnionSerializationInfos { get; set; } + public UnionSerializationInfo[] UnionSerializationInfos { get; } } } diff --git a/src/MessagePack.GeneratorCore/MessagePack.GeneratorCore.csproj b/src/MessagePack.GeneratorCore/MessagePack.GeneratorCore.csproj index 4be9ff608..d2d86f281 100644 --- a/src/MessagePack.GeneratorCore/MessagePack.GeneratorCore.csproj +++ b/src/MessagePack.GeneratorCore/MessagePack.GeneratorCore.csproj @@ -6,7 +6,8 @@ True ..\..\opensource.snk - 8 + 9 + enable diff --git a/src/MessagePack.GeneratorCore/Utils/RoslynExtensions.cs b/src/MessagePack.GeneratorCore/Utils/RoslynExtensions.cs index a02a282c0..b8645e11b 100644 --- a/src/MessagePack.GeneratorCore/Utils/RoslynExtensions.cs +++ b/src/MessagePack.GeneratorCore/Utils/RoslynExtensions.cs @@ -1,12 +1,9 @@ // Copyright (c) All contributors. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. -using System; using System.Collections.Generic; using System.Linq; using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp; -using Microsoft.CodeAnalysis.CSharp.Syntax; namespace MessagePackCompiler { @@ -15,111 +12,14 @@ internal static class RoslynExtensions { public static IEnumerable GetNamedTypeSymbols(this Compilation compilation) { - foreach (var syntaxTree in compilation.SyntaxTrees) + return compilation.SyntaxTrees.SelectMany(syntaxTree => { var semModel = compilation.GetSemanticModel(syntaxTree); - - foreach (var item in syntaxTree.GetRoot() + return syntaxTree.GetRoot() .DescendantNodes() .Select(x => semModel.GetDeclaredSymbol(x)) - .Where(x => x != null)) - { - var namedType = item as INamedTypeSymbol; - if (namedType != null) - { - yield return namedType; - } - } - } - } - - public static IEnumerable EnumerateBaseType(this ITypeSymbol symbol) - { - var t = symbol.BaseType; - while (t != null) - { - yield return t; - t = t.BaseType; - } - } - - public static AttributeData FindAttribute(this IEnumerable attributeDataList, string typeName) - { - return attributeDataList - .Where(x => x.AttributeClass.ToDisplayString() == typeName) - .FirstOrDefault(); - } - - public static AttributeData FindAttributeShortName( - this IEnumerable attributeDataList, - string typeName) - { - return attributeDataList - .Where(x => x.AttributeClass.Name == typeName) - .FirstOrDefault(); - } - - public static AttributeData FindAttributeIncludeBasePropertyShortName( - this IPropertySymbol property, - string typeName) - { - do - { - var data = FindAttributeShortName(property.GetAttributes(), typeName); - if (data != null) - { - return data; - } - - property = property.OverriddenProperty; - } - while (property != null); - - return null; - } - - public static AttributeSyntax FindAttribute( - this BaseTypeDeclarationSyntax typeDeclaration, - SemanticModel model, - string typeName) - { - return typeDeclaration.AttributeLists - .SelectMany(x => x.Attributes) - .Where(x => model.GetTypeInfo(x).Type?.ToDisplayString() == typeName) - .FirstOrDefault(); - } - - public static INamedTypeSymbol FindBaseTargetType(this ITypeSymbol symbol, string typeName) - { - return symbol.EnumerateBaseType() - .Where(x => x.OriginalDefinition?.ToDisplayString() == typeName) - .FirstOrDefault(); - } - - public static object GetSingleNamedArgumentValue(this AttributeData attribute, string key) - { - foreach (var item in attribute.NamedArguments) - { - if (item.Key == key) - { - return item.Value.Value; - } - } - - return null; - } - - public static bool IsNullable(this INamedTypeSymbol symbol) - { - if (symbol.IsGenericType) - { - if (symbol.ConstructUnboundGenericType().ToDisplayString() == "T?") - { - return true; - } - } - - return false; + .OfType(); + }); } public static IEnumerable GetAllMembers(this ITypeSymbol symbol) @@ -136,17 +36,11 @@ public static IEnumerable GetAllMembers(this ITypeSymbol symbol) } } - public static IEnumerable GetAllInterfaceMembers(this ITypeSymbol symbol) - { - return symbol.GetMembers() - .Concat(symbol.AllInterfaces.SelectMany(x => x.GetMembers())); - } - - public static bool ApproximatelyEqual(this INamedTypeSymbol left, INamedTypeSymbol right) + public static bool ApproximatelyEqual(this INamedTypeSymbol? left, INamedTypeSymbol? right) { if (left is IErrorTypeSymbol || right is IErrorTypeSymbol) { - return left.ToDisplayString() == right.ToDisplayString(); + return left?.ToDisplayString() == right?.ToDisplayString(); } else {