Skip to content

Commit

Permalink
Handle generator dependencies without ILRepack
Browse files Browse the repository at this point in the history
  • Loading branch information
aalmada committed Jul 22, 2021
1 parent da367fb commit 430e623
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ public class GenerateSourceTests
// "TestData/Results/Count.Span.cs"
//},
{
new[] { "TestData/Source/Count.TestEnumerableWithValueTypeEnumerator.cs" },
"TestData/Results/Count.TestEnumerableWithValueTypeEnumerator.cs"
new[] { "TestData/Source/AsEnumerable.TestEnumerableWithValueTypeEnumerator.cs" },
"TestData/Results/AsEnumerable.TestEnumerableWithValueTypeEnumerator.cs"
},
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@

<ItemGroup>
<ProjectReference Include="..\NetFabric.Hyperlinq.SourceGenerator\NetFabric.Hyperlinq.SourceGenerator.csproj" />
<ProjectReference Include="..\NetFabric.Hyperlinq\NetFabric.Hyperlinq.csproj" />
</ItemGroup>

<ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,7 @@ namespace NetFabric.Hyperlinq
{
static partial class GeneratedExtensionMethods
{
public static int Count(this ArrayExtensions.ArraySegmentValueEnumerable<int> source)
=> source.Count();
}
}
44 changes: 32 additions & 12 deletions NetFabric.Hyperlinq.SourceGenerator/Generator.AsValueEnumerable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@ namespace NetFabric.Hyperlinq.SourceGenerator
{
public partial class Generator
{
static ValueEnumerableType? GenerateAsValueEnumerable(Compilation compilation, SemanticModel semanticModel, TypeSymbolsCache typeSymbolsCache, MemberAccessExpressionSyntax expressionSyntax, CodeBuilder builder, HashSet<MethodSignature> generatedMethods, CancellationToken cancellationToken, bool isUnitTest)
static ValueEnumerableType? GenerateAsValueEnumerable(Compilation compilation, SemanticModel semanticModel, TypeSymbolsCache typeSymbolsCache, MemberAccessExpressionSyntax expressionSyntax, CodeBuilder builder, Dictionary<MethodSignature, ValueEnumerableType> generatedMethods, CancellationToken cancellationToken, bool isUnitTest)
{
// Check if the method is already defined in the project source
if (semanticModel.GetSymbolInfo(expressionSyntax, cancellationToken).Symbol is IMethodSymbol methodSymbol and not null)
return new ValueEnumerableType(Name: methodSymbol.ReturnType.Name);
if (semanticModel.GetSymbolInfo(expressionSyntax, cancellationToken).Symbol is IMethodSymbol { } methodSymbol)
return new ValueEnumerableType(
Name: methodSymbol.ReturnType.Name,
IsCollection: methodSymbol.ReturnType.ImplementsInterface(typeSymbolsCache["NetFabric.Hyperlinq.IValueReadOnlyCollection`2"]!, out var _),
IsList: methodSymbol.ReturnType.ImplementsInterface(typeSymbolsCache["NetFabric.Hyperlinq.IValueReadOnlyList`2"]!, out var _));

// Get the type this operator is applied to
var receiverTypeSymbol = semanticModel.GetTypeInfo(expressionSyntax.Expression, cancellationToken).Type;
Expand All @@ -32,28 +35,41 @@ public partial class Generator
|| SymbolEqualityComparer.Default.Equals(receiverTypeSymbol.OriginalDefinition, typeSymbolsCache[typeof(List<>)])
|| SymbolEqualityComparer.Default.Equals(receiverTypeSymbol.OriginalDefinition, typeSymbolsCache[typeof(ImmutableArray<>)])
)
return null; // no need to generate an implementation

var receiverTypeString = receiverTypeSymbol.ToDisplayString();
return null; // Do generate an implementation. The 'using NetFabric.Hyperlinq;' statement should be added instead.

// Receiver type implements IValueEnumerable<,>

if (IsValueEnumerable(receiverTypeSymbol, typeSymbolsCache))
var valueEnumerableType = AsValueEnumerable(expressionSyntax, compilation, semanticModel, typeSymbolsCache, expressionSyntax, builder, generatedMethods, cancellationToken, isUnitTest);
if (valueEnumerableType is not null)
{
// Check if the method is already defined by this generator
var methodSignature = new MethodSignature("AsValueEnumerable", valueEnumerableType.Name);
if (generatedMethods.TryGetValue(methodSignature, out var returnType))
return returnType;

// Receiver instance returns itself
_ = builder
.AppendLine()
.AppendLine("[MethodImpl(MethodImplOptions.AggressiveInlining)]")
.AppendLine($"public static {receiverTypeString} AsValueEnumerable(this {receiverTypeString} source)")
.AppendLine($"public static {valueEnumerableType.Name} AsValueEnumerable(this {valueEnumerableType.Name} source)")
.AppendIdentation().AppendLine($"=> source;");

return new ValueEnumerableType(Name: receiverTypeString);
// A new AsValueEnumerable() method has been generated
generatedMethods.Add(methodSignature, valueEnumerableType);
return valueEnumerableType;
}

// Receiver type is an enumerable

if (receiverTypeSymbol.IsEnumerable(compilation, out var enumerableSymbols))
{
var receiverTypeString = receiverTypeSymbol.ToDisplayString();

// Check if the method is already defined by this generator
var methodSignature = new MethodSignature("AsValueEnumerable", receiverTypeString);
if (generatedMethods.TryGetValue(methodSignature, out var returnType))
return returnType;

// Use an unique identifier to avoid name clashing
var uniqueIdString = isUnitTest
? receiverTypeString.Replace('.', '_').Replace(',', '_').Replace('<', '_').Replace('>', '_').Replace('`', '_')
Expand Down Expand Up @@ -371,9 +387,13 @@ public partial class Generator
}
}

// A new AsValueEnumerable method has been generated
_ = generatedMethods.Add(new MethodSignature("AsValueEnumerable", receiverTypeString));
return new ValueEnumerableType(Name: valueEnumerableTypeName);
// A new AsValueEnumerable() method has been generated
valueEnumerableType = new ValueEnumerableType(
Name: valueEnumerableTypeName,
IsCollection: enumerableImplementsICollection || enumerableImplementsIReadOnlyCollection,
IsList: enumerableImplementsIList || enumerableImplementsIReadOnlyList);
generatedMethods.Add(methodSignature, valueEnumerableType);
return valueEnumerableType;
}
}

Expand Down
90 changes: 78 additions & 12 deletions NetFabric.Hyperlinq.SourceGenerator/Generator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ public void Execute(GeneratorExecutionContext context)

internal static void GenerateSource(Compilation compilation, TypeSymbolsCache typeSymbolsCache, List<MemberAccessExpressionSyntax> memberAccessExpressions, CodeBuilder builder, CancellationToken cancellationToken, bool isUnitTest = false)
{
var generatedMethods = new HashSet<MethodSignature>();
var generatedMethods = new Dictionary<MethodSignature, ValueEnumerableType>();

_ = builder
.AppendLine("#nullable enable")
Expand All @@ -143,29 +143,95 @@ internal static void GenerateSource(Compilation compilation, TypeSymbolsCache ty
{
foreach (var expressionSyntax in memberAccessExpressions)
{
cancellationToken.ThrowIfCancellationRequested();

var semanticModel = compilation.GetSemanticModel(expressionSyntax.SyntaxTree);

_ = GenerateSource(compilation, semanticModel, typeSymbolsCache, expressionSyntax, builder, generatedMethods, cancellationToken, isUnitTest);
}
}
}

static ValueEnumerableType? GenerateSource(Compilation compilation, SemanticModel semanticModel, TypeSymbolsCache typeSymbolsCache, MemberAccessExpressionSyntax expressionSyntax, CodeBuilder builder, HashSet<MethodSignature> generatedMethods, CancellationToken cancellationToken, bool isUnitTest)
=> expressionSyntax.Name.ToString() switch
static ValueEnumerableType? AsValueEnumerable(MemberAccessExpressionSyntax memberAccessExpressionSyntax, Compilation compilation, SemanticModel semanticModel, TypeSymbolsCache typeSymbolsCache, MemberAccessExpressionSyntax expressionSyntax, CodeBuilder builder, Dictionary<MethodSignature, ValueEnumerableType> generatedMethods, CancellationToken cancellationToken, bool isUnitTest)
{
var typeSymbol = semanticModel.GetTypeInfo(memberAccessExpressionSyntax.Expression, cancellationToken).Type;
if (typeSymbol is null)
return null;

// Check if the receiver type implements IValueEnumerable<,>
if (typeSymbol.ImplementsInterface(typeSymbolsCache["NetFabric.Hyperlinq.IValueEnumerable`2"]!, out var _))
return new ValueEnumerableType(
Name: typeSymbol.ToDisplayString(),
IsCollection: typeSymbol.ImplementsInterface(typeSymbolsCache["NetFabric.Hyperlinq.IValueReadOnlyCollection`2"]!, out var _),
IsList: typeSymbol.ImplementsInterface(typeSymbolsCache["NetFabric.Hyperlinq.IValueReadOnlyList`2"]!, out var _));

// Go up one layer. Generate method is required.
if (expressionSyntax.Expression is InvocationExpressionSyntax { Expression: MemberAccessExpressionSyntax receiverSyntax })
{
"AsValueEnumerable" => GenerateAsValueEnumerable(compilation, semanticModel, typeSymbolsCache, expressionSyntax, builder, generatedMethods, cancellationToken, isUnitTest),
_ => GenerateOperationSource(compilation, semanticModel, expressionSyntax, builder, generatedMethods, cancellationToken, isUnitTest),
};
if (GenerateSource(compilation, semanticModel, typeSymbolsCache, receiverSyntax, builder, generatedMethods, cancellationToken, isUnitTest) is { } valueEnumerableType)
return valueEnumerableType; // Receiver type implements IValueEnumerable<,>
}

// Receiver type does not implement IValueEnumerable<,> so nothing else needs to be done
return null;
}

static ValueEnumerableType? GenerateOperationSource(Compilation compilation, SemanticModel semanticModel, MemberAccessExpressionSyntax expressionSyntax, CodeBuilder builder, HashSet<MethodSignature> generatedMethods, CancellationToken cancellationToken, bool isUnitTest)
static ValueEnumerableType? GenerateSource(Compilation compilation, SemanticModel semanticModel, TypeSymbolsCache typeSymbolsCache, MemberAccessExpressionSyntax expressionSyntax, CodeBuilder builder, Dictionary<MethodSignature, ValueEnumerableType> generatedMethods, CancellationToken cancellationToken, bool isUnitTest)
=> expressionSyntax.Name.ToString() switch
{
"AsValueEnumerable" => GenerateAsValueEnumerable(compilation, semanticModel, typeSymbolsCache, expressionSyntax, builder, generatedMethods, cancellationToken, isUnitTest),
_ => GenerateOperationSource(compilation, semanticModel, typeSymbolsCache, expressionSyntax, builder, generatedMethods, cancellationToken, isUnitTest),
};

static ValueEnumerableType? GenerateOperationSource(Compilation compilation, SemanticModel semanticModel, TypeSymbolsCache typeSymbolsCache, MemberAccessExpressionSyntax expressionSyntax, CodeBuilder builder, Dictionary<MethodSignature, ValueEnumerableType> generatedMethods, CancellationToken cancellationToken, bool isUnitTest)
{
// Get the type this operator is applied to
var receiverTypeSymbol = semanticModel.GetTypeInfo(expressionSyntax.Expression, cancellationToken).Type;
var valueEnumerableType = AsValueEnumerable(expressionSyntax, compilation, semanticModel, typeSymbolsCache, expressionSyntax, builder, generatedMethods, cancellationToken, isUnitTest);
if (valueEnumerableType is null)
return null;

var symbol = semanticModel.GetSymbolInfo(expressionSyntax, cancellationToken).Symbol;
if (symbol is IMethodSymbol methodSymbol)
{
// Check if the generator already generated this method
var parameters = new string[] { valueEnumerableType.Name }.Concat(methodSymbol.Parameters.Select(parameter => parameter.Type.ToDisplayString())).ToArray();
var methodSignature = new MethodSignature(expressionSyntax.Name.ToString(), parameters);
if (generatedMethods.TryGetValue(methodSignature, out var returnType))
return returnType;

// Generate the extension method
_ = builder
.AppendLine()
.AppendLine("[MethodImpl(MethodImplOptions.AggressiveInlining)]")
.AppendLine($"public static {methodSymbol.ReturnType.ToDisplayString()} {methodSymbol.Name}(this {valueEnumerableType.Name} source)")
.AppendIdentation().AppendLine($"=> source.{methodSymbol.Name}();");

generatedMethods.Add(methodSignature, returnType);
return returnType;
}


// TODO: when 'using System.Linq;' is not used...
if (expressionSyntax.Parent is InvocationExpressionSyntax invocation)
{

var type = semanticModel.GetTypeInfo(invocation.ArgumentList.Arguments[0], cancellationToken).Type;
var symbol2 = semanticModel.GetSymbolInfo(invocation, cancellationToken).Symbol;




// Check if the source already provides the implementation as an instance method

// Check if the source already provides the implementation as an extension method

// Generate the extension method

_ = builder
.AppendLine()
.AppendLine("// TODO");
}

return null;
}

static bool IsValueEnumerable(ITypeSymbol symbol, TypeSymbolsCache typeSymbolsCache)
=> symbol.ImplementsInterface(typeSymbolsCache["NetFabric.Hyperlinq.IValueEnumerable`2"]!, out var _);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,23 @@

<PropertyGroup>
<TargetFramework>netstandard2.0</TargetFramework>
<PackageId>NetFabric.Hyperlinq.SourceGenerator</PackageId>
<Title>NetFabric.Hyperlinq.SourceGenerator</Title>
<Description> High performance LINQ implementation with minimal heap allocations. Supports enumerables, async enumerables, Memory, and Span.</Description>
<Version>3.0.0-beta45</Version>
<PackageIcon>Icon.png</PackageIcon>
<PackageLicenseFile>LICENSE</PackageLicenseFile>
<PackageTags>netfabric, hyperlinq, linq, enumeration, extensions, performance</PackageTags>
<PublishRepositoryUrl>true</PublishRepositoryUrl>
<GeneratePackageOnBuild>true</GeneratePackageOnBuild> <!-- Generates a package at build -->
<IncludeBuildOutput>false</IncludeBuildOutput> <!-- Do not include the generator as a lib dependency -->
</PropertyGroup>


<ItemGroup>
<!-- Take a public dependency on NetFabric.Hyperlinq. Consumers of this generator will get a reference to this package -->
<ProjectReference Include="..\NetFabric.Hyperlinq\NetFabric.Hyperlinq.csproj" />
</ItemGroup>

<ItemGroup>
<!-- Take a private dependency on Ben.TypeDictionary (PrivateAssets=all) Consumers of this generator will not reference it.
Set GeneratePathProperty=true so we can reference the binaries via the PKGBen_TypeDictionary property -->
Expand All @@ -17,6 +30,14 @@

<!-- Package the generator in the analyzer directory of the nuget package -->
<None Include="$(OutputPath)\$(AssemblyName).dll" Pack="true" PackagePath="analyzers/dotnet/cs" Visible="false" />
<None Include="..\Icon.png" Link="Icon.png">
<PackagePath></PackagePath>
<Pack>true</Pack>
</None>
<None Include="..\LICENSE" Link="LICENSE">
<PackagePath></PackagePath>
<Pack>true</Pack>
</None>

<!-- Package the Ben.TypeDictionary dependency alongside the generator assembly -->
<None Include="$(PKGBen_TypeDictionary)\lib\netstandard2.0\*.dll" Pack="true" PackagePath="analyzers/dotnet/cs" Visible="false" />
Expand All @@ -29,6 +50,10 @@
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers</IncludeAssets>
</PackageReference>
<PackageReference Include="Microsoft.SourceLink.GitHub" Version="1.0.0">
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers</IncludeAssets>
</PackageReference>
<PackageReference Include="IsExternalInit" Version="1.0.1">
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
Expand Down
2 changes: 1 addition & 1 deletion NetFabric.Hyperlinq.SourceGenerator/ValueEnumerableType.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@

namespace NetFabric.Hyperlinq.SourceGenerator
{
record ValueEnumerableType(string Name);
record ValueEnumerableType(string Name, bool IsCollection, bool IsList);
}

0 comments on commit 430e623

Please sign in to comment.