Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public DeclarationSyntaxRewriter(SemanticModel semanticModel)
visitedNode = visitedParameterSyntax.WithModifiers(node.Modifiers.RemoveAt(thisKeywordIndex));
}

// Strip the parans keyword of any parameter
// Strip the params keyword of any parameter
var paramsKeywordIndex = ((ParameterSyntax)visitedNode).Modifiers.IndexOf(SyntaxKind.ParamsKeyword);
if (paramsKeywordIndex != -1)
{
Expand Down Expand Up @@ -73,6 +73,19 @@ public DeclarationSyntaxRewriter(SemanticModel semanticModel)
return base.VisitIdentifierName(node);
}

public override SyntaxNode? VisitGenericName(GenericNameSyntax node)
{
var typeInfo = _semanticModel.GetTypeInfo(node);
if (typeInfo.Type is not null)
{
return SyntaxFactory.ParseTypeName(
typeInfo.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)
).WithTriviaFrom(node);
}

return base.VisitGenericName(node);
}

public override SyntaxNode? VisitQualifiedName(QualifiedNameSyntax node)
{
var typeInfo = _semanticModel.GetTypeInfo(node);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFramework>netstandard2.0</TargetFramework>
<TargetFrameworks>netstandard2.0</TargetFrameworks>
<NoWarn>$(NoWarn);NU5128</NoWarn>
<IsPackable>false</IsPackable>
</PropertyGroup>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ namespace EntityFrameworkCore.Projectables.Generator
{
public class ProjectableDescriptor
{
public IEnumerable<string>? UsingDirectives { get; set; }

public string? ClassNamespace { get; set; }

public IEnumerable<string>? NestedInClassNames { get; set; }
Expand All @@ -23,6 +21,10 @@ public class ProjectableDescriptor

public string? ClassName { get; set; }

public TypeParameterListSyntax? ClassTypeParameterList { get; set; }

public SyntaxList<TypeParameterConstraintClauseSyntax>? ClassConstraintClauses { get; set; }

public string? MemberName { get; set; }

public string? ReturnTypeName { get; set; }
Expand All @@ -31,8 +33,8 @@ public class ProjectableDescriptor

public TypeParameterListSyntax? TypeParameterList { get; set; }

public IEnumerable<TypeParameterConstraintClauseSyntax>? ConstraintClauses { get; set; }
public SyntaxList<TypeParameterConstraintClauseSyntax>? ConstraintClauses { get; set; }

public SyntaxNode? Body { get; set; }
public ExpressionSyntax? ExpressionBody { get; set; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -123,23 +123,52 @@ x is IPropertySymbol xProperty &&
ClassNamespace = memberSymbol.ContainingType.ContainingNamespace.IsGlobalNamespace ? null : memberSymbol.ContainingType.ContainingNamespace.ToDisplayString(),
MemberName = memberSymbol.Name,
NestedInClassNames = GetNestedInClassPath(memberSymbol.ContainingType),
ParametersList = SyntaxFactory.ParameterList(),
TypeParameterList = SyntaxFactory.TypeParameterList()
ParametersList = SyntaxFactory.ParameterList()
};

if (memberSymbol.ContainingType is INamedTypeSymbol { IsGenericType: true } containingNamedType)
{
descriptor.ClassTypeParameterList = SyntaxFactory.TypeParameterList();

foreach (var additionalClassTypeParameter in containingNamedType.TypeParameters)
{
descriptor.ClassTypeParameterList = descriptor.ClassTypeParameterList.AddParameters(
SyntaxFactory.TypeParameter(additionalClassTypeParameter.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat))
);

if (!additionalClassTypeParameter.ConstraintTypes.IsDefaultOrEmpty)
{
descriptor.ClassConstraintClauses ??= SyntaxFactory.List<TypeParameterConstraintClauseSyntax>();

descriptor.ClassConstraintClauses = descriptor.ClassConstraintClauses.Value.Add(
SyntaxFactory.TypeParameterConstraintClause(
SyntaxFactory.IdentifierName(additionalClassTypeParameter.Name),
SyntaxFactory.SeparatedList<TypeParameterConstraintSyntax>(
additionalClassTypeParameter
.ConstraintTypes
.Select(c => SyntaxFactory.TypeConstraint(
SyntaxFactory.IdentifierName(c.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat))
))
)
)
);
}

// todo: add additional type constraints
}
}

if (!member.Modifiers.Any(SyntaxKind.StaticKeyword))
{
descriptor.ParametersList = descriptor.ParametersList.AddParameters(
SyntaxFactory.Parameter(
SyntaxFactory.Identifier("@this")
).WithType(
SyntaxFactory.ParseTypeName(
memberSymbol.ContainingType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)
)
.WithTrailingTrivia(
SyntaxFactory.SyntaxTrivia(SyntaxKind.WhitespaceTrivia, " ")
)
SyntaxFactory.Identifier("@this")
)
.WithType(
SyntaxFactory.ParseTypeName(
memberSymbol.ContainingType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)
)
)
);
}

Expand Down Expand Up @@ -169,14 +198,15 @@ x is IPropertySymbol xProperty &&
var returnType = declarationSyntaxRewriter.Visit(methodDeclarationSyntax.ReturnType);

descriptor.ReturnTypeName = returnType.ToString();
descriptor.Body = expressionSyntaxRewriter.Visit(methodDeclarationSyntax.ExpressionBody.Expression);
descriptor.ExpressionBody = (ExpressionSyntax)expressionSyntaxRewriter.Visit(methodDeclarationSyntax.ExpressionBody.Expression);
foreach (var additionalParameter in ((ParameterListSyntax)declarationSyntaxRewriter.Visit(methodDeclarationSyntax.ParameterList)).Parameters)
{
descriptor.ParametersList = descriptor.ParametersList.AddParameters(additionalParameter);
}

if (methodDeclarationSyntax.TypeParameterList is not null)
{
descriptor.TypeParameterList = SyntaxFactory.TypeParameterList();
foreach (var additionalTypeParameter in ((TypeParameterListSyntax)declarationSyntaxRewriter.Visit(methodDeclarationSyntax.TypeParameterList)).Parameters)
{
descriptor.TypeParameterList = descriptor.TypeParameterList.AddParameters(additionalTypeParameter);
Expand All @@ -185,8 +215,11 @@ x is IPropertySymbol xProperty &&

if (methodDeclarationSyntax.ConstraintClauses.Any())
{
descriptor.ConstraintClauses = methodDeclarationSyntax.ConstraintClauses
.Select(x => (TypeParameterConstraintClauseSyntax)declarationSyntaxRewriter.Visit(x));
descriptor.ConstraintClauses = SyntaxFactory.List(
methodDeclarationSyntax
.ConstraintClauses
.Select(x => (TypeParameterConstraintClauseSyntax)declarationSyntaxRewriter.Visit(x))
);
}
}
else if (memberBody is PropertyDeclarationSyntax propertyDeclarationSyntax)
Expand All @@ -201,20 +234,13 @@ x is IPropertySymbol xProperty &&
var returnType = declarationSyntaxRewriter.Visit(propertyDeclarationSyntax.Type);

descriptor.ReturnTypeName = returnType.ToString();
descriptor.Body = expressionSyntaxRewriter.Visit(propertyDeclarationSyntax.ExpressionBody.Expression);
descriptor.ExpressionBody = (ExpressionSyntax)expressionSyntaxRewriter.Visit(propertyDeclarationSyntax.ExpressionBody.Expression);
}
else
{
return null;
}

descriptor.UsingDirectives =
member.SyntaxTree
.GetRoot()
.DescendantNodes()
.OfType<UsingDirectiveSyntax>()
.Select(x => x.ToString());

return descriptor;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
using System.Collections.Immutable;
using System.Diagnostics;
using System.Linq;
using System.Security.Cryptography.X509Certificates;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;

namespace EntityFrameworkCore.Projectables.Generator
{
Expand All @@ -19,6 +21,22 @@ public class ProjectionExpressionGenerator : IIncrementalGenerator
{
private const string ProjectablesAttributeName = "EntityFrameworkCore.Projectables.ProjectableAttribute";

static readonly AttributeSyntax _editorBrowsableAttribute =
Attribute(
ParseName("global::System.ComponentModel.EditorBrowsable"),
AttributeArgumentList(
SingletonSeparatedList(
AttributeArgument(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
IdentifierName("global::System.ComponentModel.EditorBrowsableState"),
IdentifierName("Never")
)
)
)
)
);

public void Initialize(IncrementalGeneratorInitializationContext context)
{
// Do a simple filter for members
Expand Down Expand Up @@ -91,76 +109,84 @@ static void Execute(Compilation compilation, ImmutableArray<MemberDeclarationSyn
throw new InvalidOperationException("Expected a memberName here");
}

resultBuilder.Clear();

resultBuilder.AppendLine("// <auto-generated/>");

if (projectable.UsingDirectives is not null)
{
foreach (var usingDirective in projectable.UsingDirectives.Distinct())
{
resultBuilder.AppendLine(usingDirective);
}
}

if (projectable.TargetClassNamespace is not null)
{
var targetClassUsingDirective = $"using {projectable.TargetClassNamespace};";

if (!projectable.UsingDirectives.Contains(targetClassUsingDirective))
{
resultBuilder.AppendLine(targetClassUsingDirective);
}
}
var generatedClassName = ProjectionExpressionClassNameGenerator.GenerateName(projectable.ClassNamespace, projectable.NestedInClassNames, projectable.MemberName);
var generatedFileName = projectable.ClassTypeParameterList is not null ? $"{generatedClassName}-{projectable.ClassTypeParameterList.ChildNodes().Count()}.g.cs" : $"{generatedClassName}.g.cs";

var classSyntax = ClassDeclaration(generatedClassName)
.WithModifiers(TokenList(Token(SyntaxKind.StaticKeyword)))
.WithTypeParameterList(projectable.ClassTypeParameterList)
.WithConstraintClauses(projectable.ClassConstraintClauses ?? List<TypeParameterConstraintClauseSyntax>())
.AddAttributeLists(
AttributeList()
.AddAttributes(_editorBrowsableAttribute)
)
.AddMembers(
MethodDeclaration(
GenericName(
Identifier("global::System.Linq.Expressions.Expression"),
TypeArgumentList(
SingletonSeparatedList(
(TypeSyntax)GenericName(
Identifier("global::System.Func"),
GetLambdaTypeArgumentListSyntax(projectable)
)
)
)
),
"Expression"
)
.WithModifiers(TokenList(Token(SyntaxKind.StaticKeyword)))
.WithTypeParameterList(projectable.TypeParameterList)
.WithConstraintClauses(projectable.ConstraintClauses ?? List<TypeParameterConstraintClauseSyntax>())
.WithBody(
Block(
ReturnStatement(
ParenthesizedLambdaExpression(
projectable.ParametersList ?? ParameterList(),
null,
projectable.ExpressionBody
)
)
)
)
);

if (projectable.ClassNamespace is not null && projectable.ClassNamespace != projectable.TargetClassNamespace)
{
var classUsingDirective = $"using {projectable.ClassNamespace};";
#nullable disable

if (!projectable.UsingDirectives.Contains(classUsingDirective))
{
resultBuilder.AppendLine(classUsingDirective);
}
}
var compilationUnit = CompilationUnit()
.AddMembers(
NamespaceDeclaration(
ParseName("EntityFrameworkCore.Projectables.Generated")
).AddMembers(classSyntax)
)
.WithLeadingTrivia(
TriviaList(
Comment("// <auto-generated/>"),
Trivia(NullableDirectiveTrivia(Token(SyntaxKind.DisableKeyword), true))
)
);

var generatedClassName = ProjectionExpressionClassNameGenerator.GenerateName(projectable.ClassNamespace, projectable.NestedInClassNames, projectable.MemberName);

var lambdaTypeArguments = SyntaxFactory.TypeArgumentList(
SyntaxFactory.SeparatedList(
projectable.ParametersList?.Parameters.Where(p => p.Type is not null).Select(p => p.Type!)
)
);
context.AddSource(generatedFileName, SourceText.From(compilationUnit.NormalizeWhitespace().ToFullString(), Encoding.UTF8));

resultBuilder.Append($@"
namespace EntityFrameworkCore.Projectables.Generated
#nullable disable
{{
[System.ComponentModel.EditorBrowsable(System.ComponentModel.EditorBrowsableState.Never)]
public static class {generatedClassName}
{{
public static System.Linq.Expressions.Expression<System.Func<{(lambdaTypeArguments.Arguments.Any() ? $"{lambdaTypeArguments.Arguments}, " : "")}{projectable.ReturnTypeName}>> Expression{(projectable.TypeParameterList?.Parameters.Any() == true ? projectable.TypeParameterList.ToString() : string.Empty)}()");

if (projectable.ConstraintClauses is not null)
static TypeArgumentListSyntax GetLambdaTypeArgumentListSyntax(ProjectableDescriptor projectable)
{
foreach (var constraintClause in projectable.ConstraintClauses)
var lambdaTypeArguments = TypeArgumentList(
SeparatedList(
// TODO: Document where clause
projectable.ParametersList?.Parameters.Where(p => p.Type is not null).Select(p => p.Type!)
)
);

if (projectable.ReturnTypeName is not null)
{
resultBuilder.Append($@"
{constraintClause}");
lambdaTypeArguments = lambdaTypeArguments.AddArguments(ParseTypeName(projectable.ReturnTypeName));
}
}

resultBuilder.Append($@"
{{
return {projectable.ParametersList} =>
{projectable.Body};
}}
}}
}}");


context.AddSource($"{generatedClassName}.g.cs", SourceText.From(resultBuilder.ToString(), Encoding.UTF8));
return lambdaTypeArguments;
}
}
}

}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFramework>$(TargetFrameworkVersion)</TargetFramework>
<TargetFramework>net6.0</TargetFramework>
<PackageReadmeFile>README.md</PackageReadmeFile>
</PropertyGroup>

Expand Down
13 changes: 13 additions & 0 deletions src/EntityFrameworkCore.Projectables/Extensions/TypeExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,19 @@ namespace EntityFrameworkCore.Projectables.Extensions
{
public static class TypeExtensions
{
public static string GetSimplifiedTypeName(this Type type)
{
var name = type.Name;

var backtickIndex = name.IndexOf("`");
if (backtickIndex != -1)
{
name = name.Substring(0, backtickIndex);
}

return name;
}

public static IEnumerable<Type> GetNestedTypePath(this Type type)
{
if (type.IsNested && type.DeclaringType is not null)
Expand Down
Loading