/
AssemblyCapabilityKeyAttributeGenerator.cs
114 lines (103 loc) · 6.13 KB
/
AssemblyCapabilityKeyAttributeGenerator.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using OmniSharp.Extensions.JsonRpc.Generators.Cache;
namespace OmniSharp.Extensions.JsonRpc.Generators
{
[Generator]
public class AssemblyCapabilityKeyAttributeGenerator : CachedSourceGenerator<AssemblyCapabilityKeyAttributeGenerator.SyntaxReceiver, TypeDeclarationSyntax>
{
protected override void Execute(
GeneratorExecutionContext context, SyntaxReceiver syntaxReceiver, AddCacheSource<TypeDeclarationSyntax> addCacheSource,
ReportCacheDiagnostic<TypeDeclarationSyntax> cacheDiagnostic
)
{
var namespaces = new HashSet<string>() { "OmniSharp.Extensions.LanguageServer.Protocol" };
var types = syntaxReceiver.FoundNodes
.Concat(syntaxReceiver.Handlers)
.Select(
options => {
var semanticModel = context.Compilation.GetSemanticModel(options.SyntaxTree);
foreach (var item in options.SyntaxTree.GetCompilationUnitRoot()
.Usings
.Where(z => z.Alias == null)
.Select(z => z.Name.ToFullString()))
{
namespaces.Add(item);
}
var typeSymbol = semanticModel.GetDeclaredSymbol(options)!;
return SyntaxFactory.Attribute(
SyntaxFactory.IdentifierName("AssemblyCapabilityKey"), SyntaxFactory.AttributeArgumentList(
SyntaxFactory.SeparatedList(
new[] {
SyntaxFactory.AttributeArgument(
SyntaxFactory.TypeOfExpression(SyntaxFactory.ParseName(typeSymbol.ToDisplayString()))
),
}.Concat(options.AttributeLists.GetAttribute("CapabilityKey")!.ArgumentList!.Arguments)
)
)
);
}
)
.ToArray();
if (types.Any())
{
var cu = SyntaxFactory.CompilationUnit()
.WithUsings(SyntaxFactory.List(namespaces.OrderBy(z => z).Select(z => SyntaxFactory.UsingDirective(SyntaxFactory.ParseName(z)))))
.AddAttributeLists(
SyntaxFactory.AttributeList(
target: SyntaxFactory.AttributeTargetSpecifier(SyntaxFactory.Token(SyntaxKind.AssemblyKeyword)), SyntaxFactory.SeparatedList(types)
)
)
.WithLeadingTrivia(SyntaxFactory.Comment(Preamble.GeneratedByATool))
.WithTrailingTrivia(SyntaxFactory.CarriageReturnLineFeed);
context.AddSource("AssemblyCapabilityKeys.cs", cu.NormalizeWhitespace().GetText(Encoding.UTF8));
}
}
public AssemblyCapabilityKeyAttributeGenerator() : base(() => new SyntaxReceiver(Cache))
{
}
public static CacheContainer<TypeDeclarationSyntax> Cache = new();
public class SyntaxReceiver : SyntaxReceiverCache<TypeDeclarationSyntax>
{
public List<TypeDeclarationSyntax> Handlers { get; } = new();
public override string? GetKey(TypeDeclarationSyntax syntax)
{
var hasher = new CacheKeyHasher();
hasher.Append(syntax.SyntaxTree.FilePath);
hasher.Append(syntax.Keyword.Text);
hasher.Append(syntax.Identifier.Text);
hasher.Append(syntax.TypeParameterList);
hasher.Append(syntax.AttributeLists);
hasher.Append(syntax.BaseList);
return hasher;
}
/// <summary>
/// Called for every syntax node in the compilation, we can inspect the nodes and save any information useful for generation
/// </summary>
public override void OnVisitNode(TypeDeclarationSyntax syntaxNode)
{
if (syntaxNode.Parent is TypeDeclarationSyntax) return;
if (syntaxNode is ClassDeclarationSyntax or RecordDeclarationSyntax
&& syntaxNode.Arity == 0
&& !syntaxNode.Modifiers.Any(SyntaxKind.AbstractKeyword)
&& syntaxNode.AttributeLists.ContainsAttribute("CapabilityKey")
&& syntaxNode.BaseList is { } bl && bl.Types.Any(
z => z.Type switch {
SimpleNameSyntax { Identifier: { Text: "ICapability" or "DynamicCapability" or "IDynamicCapability" or "LinkSupportCapability" }, Arity: 0 } => true,
_ => false
}
))
{
Handlers.Add(syntaxNode);
}
}
public SyntaxReceiver(CacheContainer<TypeDeclarationSyntax> cache) : base(cache)
{
}
}
}
}