diff --git a/Source/Mockolate.SourceGenerators/Entities/Event.cs b/Source/Mockolate.SourceGenerators/Entities/Event.cs index 0a63ddc5..39ee92bf 100644 --- a/Source/Mockolate.SourceGenerators/Entities/Event.cs +++ b/Source/Mockolate.SourceGenerators/Entities/Event.cs @@ -9,6 +9,8 @@ internal record Event public Event(IEventSymbol eventSymbol, IMethodSymbol delegateInvokeMethod, List? alreadyDefinedEvents, IAssemblySymbol? sourceAssembly = null) { Accessibility = eventSymbol.DeclaredAccessibility; + OverrideAccessibility = Helpers.ResolveOverrideVisibility( + Accessibility, eventSymbol.ContainingAssembly, sourceAssembly); UseOverride = eventSymbol.IsVirtual || eventSymbol.IsAbstract; IsAbstract = eventSymbol.IsAbstract; Name = Helpers.EscapeIfKeyword(eventSymbol.ExplicitInterfaceImplementations.Length > 0 ? eventSymbol.ExplicitInterfaceImplementations[0].Name : eventSymbol.Name); @@ -53,6 +55,7 @@ public Event(IEventSymbol eventSymbol, IMethodSymbol delegateInvokeMethod, List< }; public Accessibility Accessibility { get; } + public string OverrideAccessibility { get; } public string Name { get; } public string? ExplicitImplementation { get; } diff --git a/Source/Mockolate.SourceGenerators/Entities/Method.cs b/Source/Mockolate.SourceGenerators/Entities/Method.cs index 164f57c5..417d58a1 100644 --- a/Source/Mockolate.SourceGenerators/Entities/Method.cs +++ b/Source/Mockolate.SourceGenerators/Entities/Method.cs @@ -9,6 +9,8 @@ internal record Method public Method(IMethodSymbol methodSymbol, List? alreadyDefinedMethods, IAssemblySymbol? sourceAssembly = null) { Accessibility = methodSymbol.DeclaredAccessibility; + OverrideAccessibility = Helpers.ResolveOverrideVisibility( + Accessibility, methodSymbol.ContainingAssembly, sourceAssembly); UseOverride = methodSymbol.IsVirtual || methodSymbol.IsAbstract; IsAbstract = methodSymbol.IsAbstract; IsStatic = methodSymbol.IsStatic; @@ -67,6 +69,7 @@ public Method(IMethodSymbol methodSymbol, List? alreadyDefinedMethods, I }; public Accessibility Accessibility { get; } + public string OverrideAccessibility { get; } public Type ReturnType { get; } public string Name { get; } public string ContainingType { get; } diff --git a/Source/Mockolate.SourceGenerators/Entities/Property.cs b/Source/Mockolate.SourceGenerators/Entities/Property.cs index a16d3400..f92d6dd4 100644 --- a/Source/Mockolate.SourceGenerators/Entities/Property.cs +++ b/Source/Mockolate.SourceGenerators/Entities/Property.cs @@ -9,6 +9,8 @@ internal record Property public Property(IPropertySymbol propertySymbol, List? alreadyDefinedProperties, IAssemblySymbol? sourceAssembly = null) { Accessibility = propertySymbol.DeclaredAccessibility; + OverrideAccessibility = Helpers.ResolveOverrideVisibility( + Accessibility, propertySymbol.ContainingAssembly, sourceAssembly); UseOverride = propertySymbol.IsVirtual || propertySymbol.IsAbstract; string rawName = propertySymbol.ExplicitInterfaceImplementations.Length > 0 ? propertySymbol.ExplicitInterfaceImplementations[0].Name : propertySymbol.Name; Name = propertySymbol.IsIndexer ? rawName : Helpers.EscapeIfKeyword(rawName); @@ -72,6 +74,7 @@ public Property(IPropertySymbol propertySymbol, List? alreadyDefinedPr public EquatableArray? Attributes { get; } public Accessibility Accessibility { get; } + public string OverrideAccessibility { get; } public string Name { get; } public string? ExplicitImplementation { get; } diff --git a/Source/Mockolate.SourceGenerators/Helpers.cs b/Source/Mockolate.SourceGenerators/Helpers.cs index d157d52d..3bfff6ce 100644 --- a/Source/Mockolate.SourceGenerators/Helpers.cs +++ b/Source/Mockolate.SourceGenerators/Helpers.cs @@ -78,10 +78,12 @@ static bool HasIndexedConflict(string @base, string parameterName, int count) } } - // A member (or accessor) declared in another assembly is overridable only if the overriding - // assembly can actually see it. `internal` and `private protected` are invisible across assembly - // boundaries unless the declaring assembly grants InternalsVisibleTo. `protected internal` - // (= protected OR internal) is always reachable via the protected half from a derived class. + /// + /// A member (or accessor) declared in another assembly is overridable only if the overriding + /// assembly can actually see it. `internal` and `private protected` are invisible across assembly + /// boundaries unless the declaring assembly grants InternalsVisibleTo. `protected internal` + /// (= protected OR internal) is always reachable via the protected half from a derived class. + /// public static bool IsOverridableFrom(ISymbol member, IAssemblySymbol? sourceAssembly) { if (sourceAssembly is null || @@ -99,6 +101,41 @@ member.DeclaredAccessibility is not (Accessibility.Internal or Accessibility.Pro return containingAssembly.GivesAccessTo(sourceAssembly); } + /// + /// C# requires an override to match the base member's declared accessibility, with one + /// exception: when overriding a `protected internal` member from an assembly that cannot see + /// `internal` (i.e. neither the same assembly nor an InternalsVisibleTo target), the override + /// must drop the internal half and use plain `protected`. + /// + public static string ResolveOverrideVisibility(Accessibility accessibility, + IAssemblySymbol? containingAssembly, IAssemblySymbol? sourceAssembly) + => accessibility switch + { + Accessibility.Public => "public", + Accessibility.Protected => "protected", + Accessibility.Internal => "internal", + Accessibility.ProtectedAndInternal => "private protected", + Accessibility.ProtectedOrInternal => HasInternalAccess(containingAssembly, sourceAssembly) + ? "protected internal" + : "protected", + _ => "private", + }; + + private static bool HasInternalAccess(IAssemblySymbol? containingAssembly, IAssemblySymbol? sourceAssembly) + { + if (sourceAssembly is null || containingAssembly is null) + { + return false; + } + + if (SymbolEqualityComparer.Default.Equals(containingAssembly, sourceAssembly)) + { + return true; + } + + return containingAssembly.GivesAccessTo(sourceAssembly); + } + public static string ToTypeOrWrapper(this Type type) { if (type.SpecialGenericType == SpecialGenericType.Span) @@ -146,8 +183,8 @@ public static bool NeedsRefStructPipeline(this MethodParameter parameter) } return parameter.RefKind == RefKind.RefReadOnlyParameter - && parameter.Type.IsRefStruct - && parameter.Type.SpecialGenericType is (SpecialGenericType.Span or SpecialGenericType.ReadOnlySpan); + && parameter.Type.IsRefStruct + && parameter.Type.SpecialGenericType is SpecialGenericType.Span or SpecialGenericType.ReadOnlySpan; } extension(ITypeSymbol typeSymbol) diff --git a/Source/Mockolate.SourceGenerators/Sources/Sources.MockClass.cs b/Source/Mockolate.SourceGenerators/Sources/Sources.MockClass.cs index 4383f3b4..9f29ca8a 100644 --- a/Source/Mockolate.SourceGenerators/Sources/Sources.MockClass.cs +++ b/Source/Mockolate.SourceGenerators/Sources/Sources.MockClass.cs @@ -1486,7 +1486,7 @@ private static void AppendMockSubject_ImplementClass_AddEvent(StringBuilder sb, { if (@event.ExplicitImplementation is null) { - sb.Append("\t\t").Append(@event.Accessibility.ToVisibilityString()).Append(' '); + sb.Append("\t\t").Append(@event.OverrideAccessibility).Append(' '); if (@event.IsStatic) { sb.Append("static "); @@ -1625,7 +1625,7 @@ property.IndexerParameters is not null { if (property.ExplicitImplementation is null) { - sb.Append("\t\t").Append(property.Accessibility.ToVisibilityString()).Append(' '); + sb.Append("\t\t").Append(property.OverrideAccessibility).Append(' '); if (property.IsStatic) { sb.Append("static "); @@ -1671,7 +1671,7 @@ property.IndexerParameters is not null sb.Append("\t\t\t"); if (property.Getter.Accessibility != property.Accessibility) { - sb.Append(property.Getter.Accessibility.ToVisibilityString()).Append(' '); + sb.Append(property.Getter.OverrideAccessibility).Append(' '); } sb.AppendLine("get"); @@ -1906,7 +1906,7 @@ property.IndexerParameters is not null sb.Append("\t\t\t"); if (property.Setter.Accessibility != property.Accessibility) { - sb.Append(property.Setter.Accessibility.ToVisibilityString()).Append(' '); + sb.Append(property.Setter.OverrideAccessibility).Append(' '); } sb.AppendLine(property.Setter.IsInitOnly ? "init" : "set"); @@ -2143,7 +2143,7 @@ private static void AppendMockSubject_ImplementClass_AddMethod(StringBuilder sb, sb.Append("\t\t"); if (method.ExplicitImplementation is null) { - sb.Append(method.Accessibility.ToVisibilityString()).Append(' '); + sb.Append(method.OverrideAccessibility).Append(' '); if (method.IsStatic) { sb.Append("static "); diff --git a/Source/Mockolate.SourceGenerators/Sources/Sources.cs b/Source/Mockolate.SourceGenerators/Sources/Sources.cs index 6c891623..1d9fabb3 100644 --- a/Source/Mockolate.SourceGenerators/Sources/Sources.cs +++ b/Source/Mockolate.SourceGenerators/Sources/Sources.cs @@ -671,20 +671,6 @@ static bool IsIdentifierPart(char c) } } - extension(Accessibility accessibility) - { - internal string ToVisibilityString() - => accessibility switch - { - Accessibility.Protected => "protected", - Accessibility.Internal => "internal", - Accessibility.ProtectedOrInternal => "protected", - Accessibility.Public => "public", - Accessibility.ProtectedAndInternal => "private protected", - _ => "private", - }; - } - extension(RefKind refKind) { internal string GetString(bool replaceRefReadonlyWithIn = false) diff --git a/Tests/Mockolate.SourceGenerators.Tests/MockTests.cs b/Tests/Mockolate.SourceGenerators.Tests/MockTests.cs index 28cd65c4..8475e0ef 100644 --- a/Tests/Mockolate.SourceGenerators.Tests/MockTests.cs +++ b/Tests/Mockolate.SourceGenerators.Tests/MockTests.cs @@ -699,6 +699,42 @@ await That(result.Sources["Mock.MyDerivedClass.g.cs"]) .DoesNotContain("override bool Equals"); } + [Fact] + public async Task ShouldPreserveProtectedInternalAccessibilityOnOverriddenMembers() + { + GeneratorResult result = Generator + .Run(""" + using Mockolate; + using System; + + namespace MyCode; + + public class Program + { + public static void Main(string[] args) => _ = MyClass.CreateMock(); + } + + public class MyClass + { + protected internal virtual void ProtectedInternalMethod() { } + protected internal virtual int ProtectedInternalProperty { get; set; } + protected internal virtual event EventHandler? ProtectedInternalEvent; + public virtual int MixedAccessorProperty { get; protected internal set; } + } + """); + + await That(result.Sources).ContainsKey("Mock.MyClass.g.cs"); + string generated = result.Sources["Mock.MyClass.g.cs"]; + await That(generated) + .Contains("protected internal override void ProtectedInternalMethod()").And + .Contains("protected internal override int ProtectedInternalProperty").And + .Contains("protected internal override event global::System.EventHandler? ProtectedInternalEvent").And + .Contains("protected internal set").And + .DoesNotContain("protected override void ProtectedInternalMethod").And + .DoesNotContain("protected override int ProtectedInternalProperty").And + .DoesNotContain("protected override event"); + } + [Fact] public async Task ShouldSupportSpecialTypes() {