diff --git a/src/ComputeSharp.D2D1.CodeFixers/MissingPixelShaderDescriptorOnPixelShaderCodeFixer.cs b/src/ComputeSharp.D2D1.CodeFixers/MissingPixelShaderDescriptorOnPixelShaderCodeFixer.cs index 7a0a7f9e5..ebe6ca1db 100644 --- a/src/ComputeSharp.D2D1.CodeFixers/MissingPixelShaderDescriptorOnPixelShaderCodeFixer.cs +++ b/src/ComputeSharp.D2D1.CodeFixers/MissingPixelShaderDescriptorOnPixelShaderCodeFixer.cs @@ -1,14 +1,15 @@ using System.Collections.Immutable; using System.Composition; +using System.Threading; using System.Threading.Tasks; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CodeActions; using Microsoft.CodeAnalysis.CodeFixes; using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Editing; +using Microsoft.CodeAnalysis.Simplification; using Microsoft.CodeAnalysis.Text; using static ComputeSharp.SourceGeneration.Diagnostics.DiagnosticDescriptors; -using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; namespace CommunityToolkit.Mvvm.CodeFixers; @@ -43,7 +44,7 @@ public override async Task RegisterCodeFixesAsync(CodeFixContext context) context.RegisterCodeFix( CodeAction.Create( title: "Add [D2DGeneratedPixelShaderDescriptor] attribute", - createChangedDocument: token => ChangeReturnType(context.Document, root, structDeclaration), + createChangedDocument: token => ChangeReturnType(context.Document, root, structDeclaration, token), equivalenceKey: "Add [D2DGeneratedPixelShaderDescriptor] attribute"), diagnostic); } @@ -55,9 +56,22 @@ public override async Task RegisterCodeFixesAsync(CodeFixContext context) /// The original document being fixed. /// The original tree root belonging to the current document. /// The to update. + /// The cancellation token for the operation. /// An updated document with the applied code fix, and the return type of the method being . - private static Task ChangeReturnType(Document document, SyntaxNode root, StructDeclarationSyntax structDeclaration) + private static async Task ChangeReturnType(Document document, SyntaxNode root, StructDeclarationSyntax structDeclaration, CancellationToken cancellationToken) { + // Get the semantic model (bail if it's not available) + if (await document.GetSemanticModelAsync(cancellationToken).ConfigureAwait(false) is not SemanticModel semanticModel) + { + return document; + } + + // Also bail if we can't resolve the [D2DGeneratedPixelShaderDescriptor] attribute symbol (this should really never happen) + if (semanticModel.Compilation.GetTypeByMetadataName("ComputeSharp.D2D1.D2DGeneratedPixelShaderDescriptorAttribute") is not INamedTypeSymbol attributeSymbol) + { + return document; + } + int index = 0; // Find the index to use to insert the attribute. We want to make it so that if the struct declaration @@ -89,15 +103,16 @@ private static Task ChangeReturnType(Document document, SyntaxNode roo } } - // Create the attribute syntax for the new attribute - AttributeListSyntax newAttributeList = AttributeList(SingletonSeparatedList(Attribute(IdentifierName("D2DGeneratedPixelShaderDescriptor")))); + SyntaxGenerator syntaxGenerator = SyntaxGenerator.GetGenerator(document); - // Create a new syntax node with the new attribute - SyntaxNode typeSyntax = SyntaxGenerator.GetGenerator(document).InsertAttributes(structDeclaration, index, newAttributeList); + // Create the attribute syntax for the new attribute. Also annotate it + // to automatically add using directives to the document, if needed. + // Then create the attribute syntax and insert it at the right position. + SyntaxNode attributeTypeSyntax = syntaxGenerator.TypeExpression(attributeSymbol).WithAdditionalAnnotations(Simplifier.AddImportsAnnotation); + SyntaxNode attributeSyntax = syntaxGenerator.Attribute(attributeTypeSyntax); + SyntaxNode updatedStructDeclarationSyntax = syntaxGenerator.InsertAttributes(structDeclaration, index, attributeSyntax); // Replace the node in the document tree - Document updatedDocument = document.WithSyntaxRoot(root.ReplaceNode(structDeclaration, typeSyntax)); - - return Task.FromResult(updatedDocument); + return document.WithSyntaxRoot(root.ReplaceNode(structDeclaration, updatedStructDeclarationSyntax)); } }