Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Source/Mockolate.SourceGenerators/Entities/Event.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ internal record Event
public Event(IEventSymbol eventSymbol, IMethodSymbol delegateInvokeMethod, List<Event>? alreadyDefinedEvents, IAssemblySymbol? sourceAssembly = null)
{
Accessibility = eventSymbol.DeclaredAccessibility;
OverrideAccessibility = Helpers.ResolveOverrideVisibility(
Accessibility, eventSymbol.ContainingAssembly, sourceAssembly);
UseOverride = eventSymbol.IsVirtual || eventSymbol.IsAbstract;
IsAbstract = eventSymbol.IsAbstract;
Name = Helpers.EscapeIfKeyword(eventSymbol.ExplicitInterfaceImplementations.Length > 0 ? eventSymbol.ExplicitInterfaceImplementations[0].Name : eventSymbol.Name);
Expand Down Expand Up @@ -53,6 +55,7 @@ public Event(IEventSymbol eventSymbol, IMethodSymbol delegateInvokeMethod, List<
};

public Accessibility Accessibility { get; }
public string OverrideAccessibility { get; }
public string Name { get; }
public string? ExplicitImplementation { get; }

Expand Down
3 changes: 3 additions & 0 deletions Source/Mockolate.SourceGenerators/Entities/Method.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ internal record Method
public Method(IMethodSymbol methodSymbol, List<Method>? alreadyDefinedMethods, IAssemblySymbol? sourceAssembly = null)
{
Accessibility = methodSymbol.DeclaredAccessibility;
OverrideAccessibility = Helpers.ResolveOverrideVisibility(
Accessibility, methodSymbol.ContainingAssembly, sourceAssembly);
UseOverride = methodSymbol.IsVirtual || methodSymbol.IsAbstract;
IsAbstract = methodSymbol.IsAbstract;
IsStatic = methodSymbol.IsStatic;
Expand Down Expand Up @@ -67,6 +69,7 @@ public Method(IMethodSymbol methodSymbol, List<Method>? alreadyDefinedMethods, I
};

public Accessibility Accessibility { get; }
public string OverrideAccessibility { get; }
public Type ReturnType { get; }
public string Name { get; }
public string ContainingType { get; }
Expand Down
3 changes: 3 additions & 0 deletions Source/Mockolate.SourceGenerators/Entities/Property.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ internal record Property
public Property(IPropertySymbol propertySymbol, List<Property>? alreadyDefinedProperties, IAssemblySymbol? sourceAssembly = null)
{
Accessibility = propertySymbol.DeclaredAccessibility;
OverrideAccessibility = Helpers.ResolveOverrideVisibility(
Accessibility, propertySymbol.ContainingAssembly, sourceAssembly);
UseOverride = propertySymbol.IsVirtual || propertySymbol.IsAbstract;
string rawName = propertySymbol.ExplicitInterfaceImplementations.Length > 0 ? propertySymbol.ExplicitInterfaceImplementations[0].Name : propertySymbol.Name;
Name = propertySymbol.IsIndexer ? rawName : Helpers.EscapeIfKeyword(rawName);
Expand Down Expand Up @@ -72,6 +74,7 @@ public Property(IPropertySymbol propertySymbol, List<Property>? alreadyDefinedPr
public EquatableArray<Attribute>? Attributes { get; }

public Accessibility Accessibility { get; }
public string OverrideAccessibility { get; }
public string Name { get; }
public string? ExplicitImplementation { get; }

Expand Down
49 changes: 43 additions & 6 deletions Source/Mockolate.SourceGenerators/Helpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,12 @@ static bool HasIndexedConflict(string @base, string parameterName, int count)
}
}

// A member (or accessor) declared in another assembly is overridable only if the overriding
// assembly can actually see it. `internal` and `private protected` are invisible across assembly
// boundaries unless the declaring assembly grants InternalsVisibleTo. `protected internal`
// (= protected OR internal) is always reachable via the protected half from a derived class.
/// <summary>
/// A member (or accessor) declared in another assembly is overridable only if the overriding
/// assembly can actually see it. `internal` and `private protected` are invisible across assembly
/// boundaries unless the declaring assembly grants InternalsVisibleTo. `protected internal`
/// (= protected OR internal) is always reachable via the protected half from a derived class.
/// </summary>
public static bool IsOverridableFrom(ISymbol member, IAssemblySymbol? sourceAssembly)
{
if (sourceAssembly is null ||
Expand All @@ -99,6 +101,41 @@ member.DeclaredAccessibility is not (Accessibility.Internal or Accessibility.Pro
return containingAssembly.GivesAccessTo(sourceAssembly);
}

/// <summary>
/// C# requires an override to match the base member's declared accessibility, with one
/// exception: when overriding a `protected internal` member from an assembly that cannot see
/// `internal` (i.e. neither the same assembly nor an InternalsVisibleTo target), the override
/// must drop the internal half and use plain `protected`.
/// </summary>
public static string ResolveOverrideVisibility(Accessibility accessibility,
IAssemblySymbol? containingAssembly, IAssemblySymbol? sourceAssembly)
=> accessibility switch
{
Accessibility.Public => "public",
Accessibility.Protected => "protected",
Accessibility.Internal => "internal",
Accessibility.ProtectedAndInternal => "private protected",
Accessibility.ProtectedOrInternal => HasInternalAccess(containingAssembly, sourceAssembly)
? "protected internal"
: "protected",
_ => "private",
};

private static bool HasInternalAccess(IAssemblySymbol? containingAssembly, IAssemblySymbol? sourceAssembly)
{
if (sourceAssembly is null || containingAssembly is null)
{
return false;
}

if (SymbolEqualityComparer.Default.Equals(containingAssembly, sourceAssembly))
{
return true;
}

return containingAssembly.GivesAccessTo(sourceAssembly);
}

public static string ToTypeOrWrapper(this Type type)
{
if (type.SpecialGenericType == SpecialGenericType.Span)
Expand Down Expand Up @@ -146,8 +183,8 @@ public static bool NeedsRefStructPipeline(this MethodParameter parameter)
}

return parameter.RefKind == RefKind.RefReadOnlyParameter
&& parameter.Type.IsRefStruct
&& parameter.Type.SpecialGenericType is (SpecialGenericType.Span or SpecialGenericType.ReadOnlySpan);
&& parameter.Type.IsRefStruct
&& parameter.Type.SpecialGenericType is SpecialGenericType.Span or SpecialGenericType.ReadOnlySpan;
}

extension(ITypeSymbol typeSymbol)
Expand Down
10 changes: 5 additions & 5 deletions Source/Mockolate.SourceGenerators/Sources/Sources.MockClass.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1486,7 +1486,7 @@ private static void AppendMockSubject_ImplementClass_AddEvent(StringBuilder sb,
{
if (@event.ExplicitImplementation is null)
{
sb.Append("\t\t").Append(@event.Accessibility.ToVisibilityString()).Append(' ');
sb.Append("\t\t").Append(@event.OverrideAccessibility).Append(' ');
if (@event.IsStatic)
{
sb.Append("static ");
Expand Down Expand Up @@ -1625,7 +1625,7 @@ property.IndexerParameters is not null
{
if (property.ExplicitImplementation is null)
{
sb.Append("\t\t").Append(property.Accessibility.ToVisibilityString()).Append(' ');
sb.Append("\t\t").Append(property.OverrideAccessibility).Append(' ');
if (property.IsStatic)
{
sb.Append("static ");
Expand Down Expand Up @@ -1671,7 +1671,7 @@ property.IndexerParameters is not null
sb.Append("\t\t\t");
if (property.Getter.Accessibility != property.Accessibility)
{
sb.Append(property.Getter.Accessibility.ToVisibilityString()).Append(' ');
sb.Append(property.Getter.OverrideAccessibility).Append(' ');
}

sb.AppendLine("get");
Expand Down Expand Up @@ -1906,7 +1906,7 @@ property.IndexerParameters is not null
sb.Append("\t\t\t");
if (property.Setter.Accessibility != property.Accessibility)
{
sb.Append(property.Setter.Accessibility.ToVisibilityString()).Append(' ');
sb.Append(property.Setter.OverrideAccessibility).Append(' ');
}

sb.AppendLine(property.Setter.IsInitOnly ? "init" : "set");
Expand Down Expand Up @@ -2143,7 +2143,7 @@ private static void AppendMockSubject_ImplementClass_AddMethod(StringBuilder sb,
sb.Append("\t\t");
if (method.ExplicitImplementation is null)
{
sb.Append(method.Accessibility.ToVisibilityString()).Append(' ');
sb.Append(method.OverrideAccessibility).Append(' ');
if (method.IsStatic)
{
sb.Append("static ");
Expand Down
14 changes: 0 additions & 14 deletions Source/Mockolate.SourceGenerators/Sources/Sources.cs
Original file line number Diff line number Diff line change
Expand Up @@ -671,20 +671,6 @@ static bool IsIdentifierPart(char c)
}
}

extension(Accessibility accessibility)
{
internal string ToVisibilityString()
=> accessibility switch
{
Accessibility.Protected => "protected",
Accessibility.Internal => "internal",
Accessibility.ProtectedOrInternal => "protected",
Accessibility.Public => "public",
Accessibility.ProtectedAndInternal => "private protected",
_ => "private",
};
}

extension(RefKind refKind)
{
internal string GetString(bool replaceRefReadonlyWithIn = false)
Expand Down
36 changes: 36 additions & 0 deletions Tests/Mockolate.SourceGenerators.Tests/MockTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,42 @@ await That(result.Sources["Mock.MyDerivedClass.g.cs"])
.DoesNotContain("override bool Equals");
}

[Fact]
public async Task ShouldPreserveProtectedInternalAccessibilityOnOverriddenMembers()
{
GeneratorResult result = Generator
.Run("""
using Mockolate;
using System;

namespace MyCode;

public class Program
{
public static void Main(string[] args) => _ = MyClass.CreateMock();
}

public class MyClass
{
protected internal virtual void ProtectedInternalMethod() { }
protected internal virtual int ProtectedInternalProperty { get; set; }
protected internal virtual event EventHandler? ProtectedInternalEvent;
public virtual int MixedAccessorProperty { get; protected internal set; }
}
""");

await That(result.Sources).ContainsKey("Mock.MyClass.g.cs");
string generated = result.Sources["Mock.MyClass.g.cs"];
await That(generated)
.Contains("protected internal override void ProtectedInternalMethod()").And
.Contains("protected internal override int ProtectedInternalProperty").And
.Contains("protected internal override event global::System.EventHandler? ProtectedInternalEvent").And
.Contains("protected internal set").And
.DoesNotContain("protected override void ProtectedInternalMethod").And
.DoesNotContain("protected override int ProtectedInternalProperty").And
.DoesNotContain("protected override event");
}

[Fact]
public async Task ShouldSupportSpecialTypes()
{
Expand Down
Loading