diff --git a/src/ComputeSharp.D2D1.CodeFixers/MissingPixelShaderDescriptorOnPixelShaderCodeFixer.cs b/src/ComputeSharp.D2D1.CodeFixers/MissingPixelShaderDescriptorOnPixelShaderCodeFixer.cs
index 7a0a7f9e5..ebe6ca1db 100644
--- a/src/ComputeSharp.D2D1.CodeFixers/MissingPixelShaderDescriptorOnPixelShaderCodeFixer.cs
+++ b/src/ComputeSharp.D2D1.CodeFixers/MissingPixelShaderDescriptorOnPixelShaderCodeFixer.cs
@@ -1,14 +1,15 @@
using System.Collections.Immutable;
using System.Composition;
+using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CodeFixes;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Editing;
+using Microsoft.CodeAnalysis.Simplification;
using Microsoft.CodeAnalysis.Text;
using static ComputeSharp.SourceGeneration.Diagnostics.DiagnosticDescriptors;
-using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
namespace CommunityToolkit.Mvvm.CodeFixers;
@@ -43,7 +44,7 @@ public override async Task RegisterCodeFixesAsync(CodeFixContext context)
context.RegisterCodeFix(
CodeAction.Create(
title: "Add [D2DGeneratedPixelShaderDescriptor] attribute",
- createChangedDocument: token => ChangeReturnType(context.Document, root, structDeclaration),
+ createChangedDocument: token => ChangeReturnType(context.Document, root, structDeclaration, token),
equivalenceKey: "Add [D2DGeneratedPixelShaderDescriptor] attribute"),
diagnostic);
}
@@ -55,9 +56,22 @@ public override async Task RegisterCodeFixesAsync(CodeFixContext context)
/// The original document being fixed.
/// The original tree root belonging to the current document.
/// The to update.
+ /// The cancellation token for the operation.
/// An updated document with the applied code fix, and the return type of the method being .
- private static Task ChangeReturnType(Document document, SyntaxNode root, StructDeclarationSyntax structDeclaration)
+ private static async Task ChangeReturnType(Document document, SyntaxNode root, StructDeclarationSyntax structDeclaration, CancellationToken cancellationToken)
{
+ // Get the semantic model (bail if it's not available)
+ if (await document.GetSemanticModelAsync(cancellationToken).ConfigureAwait(false) is not SemanticModel semanticModel)
+ {
+ return document;
+ }
+
+ // Also bail if we can't resolve the [D2DGeneratedPixelShaderDescriptor] attribute symbol (this should really never happen)
+ if (semanticModel.Compilation.GetTypeByMetadataName("ComputeSharp.D2D1.D2DGeneratedPixelShaderDescriptorAttribute") is not INamedTypeSymbol attributeSymbol)
+ {
+ return document;
+ }
+
int index = 0;
// Find the index to use to insert the attribute. We want to make it so that if the struct declaration
@@ -89,15 +103,16 @@ private static Task ChangeReturnType(Document document, SyntaxNode roo
}
}
- // Create the attribute syntax for the new attribute
- AttributeListSyntax newAttributeList = AttributeList(SingletonSeparatedList(Attribute(IdentifierName("D2DGeneratedPixelShaderDescriptor"))));
+ SyntaxGenerator syntaxGenerator = SyntaxGenerator.GetGenerator(document);
- // Create a new syntax node with the new attribute
- SyntaxNode typeSyntax = SyntaxGenerator.GetGenerator(document).InsertAttributes(structDeclaration, index, newAttributeList);
+ // Create the attribute syntax for the new attribute. Also annotate it
+ // to automatically add using directives to the document, if needed.
+ // Then create the attribute syntax and insert it at the right position.
+ SyntaxNode attributeTypeSyntax = syntaxGenerator.TypeExpression(attributeSymbol).WithAdditionalAnnotations(Simplifier.AddImportsAnnotation);
+ SyntaxNode attributeSyntax = syntaxGenerator.Attribute(attributeTypeSyntax);
+ SyntaxNode updatedStructDeclarationSyntax = syntaxGenerator.InsertAttributes(structDeclaration, index, attributeSyntax);
// Replace the node in the document tree
- Document updatedDocument = document.WithSyntaxRoot(root.ReplaceNode(structDeclaration, typeSyntax));
-
- return Task.FromResult(updatedDocument);
+ return document.WithSyntaxRoot(root.ReplaceNode(structDeclaration, updatedStructDeclarationSyntax));
}
}