From fa6ef5b5f82eb861fdef6b7f5ebbd3313495f433 Mon Sep 17 00:00:00 2001 From: Sergio Pedri Date: Tue, 16 May 2023 21:19:55 +0200 Subject: [PATCH] Check for cancellation more often in generators --- .../ObservablePropertyGenerator.Execute.cs | 24 +++++++++++++++++ .../ObservablePropertyGenerator.cs | 4 +++ ...rValidateAllPropertiesGenerator.Execute.cs | 8 +++++- ...ValidatorValidateAllPropertiesGenerator.cs | 6 ++++- .../TransitiveMembersGenerator.cs | 7 +++++ .../Input/RelayCommandGenerator.Execute.cs | 27 ++++++++++++++++++- .../Input/RelayCommandGenerator.cs | 4 +++ .../IMessengerRegisterAllGenerator.cs | 8 +++++- 8 files changed, 84 insertions(+), 4 deletions(-) diff --git a/src/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservablePropertyGenerator.Execute.cs b/src/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservablePropertyGenerator.Execute.cs index 6f590b80..a1f1c145 100644 --- a/src/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservablePropertyGenerator.Execute.cs +++ b/src/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservablePropertyGenerator.Execute.cs @@ -64,6 +64,8 @@ internal static class Execute return false; } + token.ThrowIfCancellationRequested(); + // Get the property type and name string typeNameWithNullabilityAnnotations = fieldSymbol.Type.GetFullyQualifiedNameWithNullabilityAnnotations(); string fieldName = fieldSymbol.Name; @@ -87,6 +89,8 @@ internal static class Execute return false; } + token.ThrowIfCancellationRequested(); + // Check for special cases that are explicitly not allowed if (IsGeneratedPropertyInvalid(propertyName, fieldSymbol.Type)) { @@ -102,6 +106,8 @@ internal static class Execute return false; } + token.ThrowIfCancellationRequested(); + using ImmutableArrayBuilder propertyChangedNames = ImmutableArrayBuilder.Rent(); using ImmutableArrayBuilder propertyChangingNames = ImmutableArrayBuilder.Rent(); using ImmutableArrayBuilder notifiedCommandNames = ImmutableArrayBuilder.Rent(); @@ -114,6 +120,8 @@ internal static class Execute bool hasAnyValidationAttributes = false; bool isOldPropertyValueDirectlyReferenced = IsOldPropertyValueDirectlyReferenced(fieldSymbol, propertyName); + token.ThrowIfCancellationRequested(); + // Get the nullability info for the property GetNullabilityInfo( fieldSymbol, @@ -121,6 +129,8 @@ internal static class Execute out bool isReferenceTypeOrUnconstraindTypeParameter, out bool includeMemberNotNullOnSetAccessor); + token.ThrowIfCancellationRequested(); + // Track the property changing event for the property, if the type supports it if (shouldInvokeOnPropertyChanging) { @@ -137,6 +147,8 @@ internal static class Execute hasOrInheritsClassLevelNotifyPropertyChangedRecipients = true; } + token.ThrowIfCancellationRequested(); + // Get the class-level [NotifyDataErrorInfo] setting, if any if (TryGetNotifyDataErrorInfo(fieldSymbol, out bool isValidationTargetValid)) { @@ -144,9 +156,13 @@ internal static class Execute hasOrInheritsClassLevelNotifyDataErrorInfo = true; } + token.ThrowIfCancellationRequested(); + // Gather attributes info foreach (AttributeData attributeData in fieldSymbol.GetAttributes()) { + token.ThrowIfCancellationRequested(); + // Gather dependent property and command names if (TryGatherDependentPropertyChangedNames(fieldSymbol, attributeData, in propertyChangedNames, in builder) || TryGatherDependentCommandNames(fieldSymbol, attributeData, in notifiedCommandNames, in builder)) @@ -194,6 +210,8 @@ internal static class Execute } } + token.ThrowIfCancellationRequested(); + // Gather explicit forwarded attributes info foreach (AttributeListSyntax attributeList in fieldSyntax.AttributeLists) { @@ -205,6 +223,8 @@ internal static class Execute continue; } + token.ThrowIfCancellationRequested(); + foreach (AttributeSyntax attribute in attributeList.Attributes) { // Roslyn ignores attributes in an attribute list with an invalid target, so we can't get the AttributeData as usual. @@ -250,6 +270,8 @@ internal static class Execute } } + token.ThrowIfCancellationRequested(); + // Log the diagnostic for missing ObservableValidator, if needed if (hasAnyValidationAttributes && !fieldSymbol.ContainingType.InheritsFromFullyQualifiedMetadataName("CommunityToolkit.Mvvm.ComponentModel.ObservableValidator")) @@ -272,6 +294,8 @@ internal static class Execute fieldSymbol.Name); } + token.ThrowIfCancellationRequested(); + propertyInfo = new PropertyInfo( typeNameWithNullabilityAnnotations, fieldName, diff --git a/src/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservablePropertyGenerator.cs b/src/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservablePropertyGenerator.cs index a1f4e942..ad7e5fe0 100644 --- a/src/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservablePropertyGenerator.cs +++ b/src/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservablePropertyGenerator.cs @@ -42,8 +42,12 @@ public void Initialize(IncrementalGeneratorInitializationContext context) // Get the hierarchy info for the target symbol, and try to gather the property info HierarchyInfo hierarchy = HierarchyInfo.From(fieldSymbol.ContainingType); + token.ThrowIfCancellationRequested(); + _ = Execute.TryGetInfo(fieldDeclaration, fieldSymbol, context.SemanticModel, token, out PropertyInfo? propertyInfo, out ImmutableArray diagnostics); + token.ThrowIfCancellationRequested(); + return (Hierarchy: hierarchy, new Result(propertyInfo, diagnostics)); }) .Where(static item => item.Hierarchy is not null); diff --git a/src/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservableValidatorValidateAllPropertiesGenerator.Execute.cs b/src/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservableValidatorValidateAllPropertiesGenerator.Execute.cs index c8e09713..684430db 100644 --- a/src/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservableValidatorValidateAllPropertiesGenerator.Execute.cs +++ b/src/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservableValidatorValidateAllPropertiesGenerator.Execute.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Immutable; using System.Linq; +using System.Threading; using CommunityToolkit.Mvvm.SourceGenerators.ComponentModel.Models; using CommunityToolkit.Mvvm.SourceGenerators.Extensions; using CommunityToolkit.Mvvm.SourceGenerators.Helpers; @@ -37,8 +38,9 @@ public static bool IsObservableValidator(INamedTypeSymbol typeSymbol) /// Gets the instance from an input symbol. /// /// The input instance to inspect. + /// The cancellation token for the current operation. /// The resulting instance for . - public static ValidationInfo GetInfo(INamedTypeSymbol typeSymbol) + public static ValidationInfo GetInfo(INamedTypeSymbol typeSymbol, CancellationToken token) { using ImmutableArrayBuilder propertyNames = ImmutableArrayBuilder.Rent(); @@ -49,6 +51,8 @@ public static ValidationInfo GetInfo(INamedTypeSymbol typeSymbol) continue; } + token.ThrowIfCancellationRequested(); + ImmutableArray attributes = memberSymbol.GetAttributes(); // Also include fields that are annotated with [ObservableProperty]. This is necessary because @@ -79,6 +83,8 @@ public static ValidationInfo GetInfo(INamedTypeSymbol typeSymbol) propertyNames.Add(propertyName); } + token.ThrowIfCancellationRequested(); + return new( typeSymbol.GetFullyQualifiedMetadataName(), typeSymbol.GetFullyQualifiedName(), diff --git a/src/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservableValidatorValidateAllPropertiesGenerator.cs b/src/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservableValidatorValidateAllPropertiesGenerator.cs index b4a5b153..e3a00ae0 100644 --- a/src/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservableValidatorValidateAllPropertiesGenerator.cs +++ b/src/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservableValidatorValidateAllPropertiesGenerator.cs @@ -48,13 +48,17 @@ public void Initialize(IncrementalGeneratorInitializationContext context) return default; } + token.ThrowIfCancellationRequested(); + // Only select types inheriting from ObservableValidator if (!Execute.IsObservableValidator(typeSymbol)) { return default; } - return Execute.GetInfo(typeSymbol); + token.ThrowIfCancellationRequested(); + + return Execute.GetInfo(typeSymbol, token); }) .Where(static item => item is not null)!; diff --git a/src/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/TransitiveMembersGenerator.cs b/src/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/TransitiveMembersGenerator.cs index bd1d3749..7375080a 100644 --- a/src/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/TransitiveMembersGenerator.cs +++ b/src/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/TransitiveMembersGenerator.cs @@ -77,6 +77,8 @@ public void Initialize(IncrementalGeneratorInitializationContext context) // Gather all generation info, and any diagnostics TInfo? info = ValidateTargetTypeAndGetInfo(typeSymbol, context.Attributes[0], context.SemanticModel.Compilation, out ImmutableArray diagnostics); + token.ThrowIfCancellationRequested(); + // If there are any diagnostics, there's no need to compute the hierarchy info at all, just return them if (diagnostics.Length > 0) { @@ -84,8 +86,13 @@ public void Initialize(IncrementalGeneratorInitializationContext context) } HierarchyInfo hierarchy = HierarchyInfo.From(typeSymbol); + + token.ThrowIfCancellationRequested(); + MetadataInfo metadataInfo = new(typeSymbol.IsSealed, Execute.IsNullabilitySupported(context.SemanticModel.Compilation)); + token.ThrowIfCancellationRequested(); + return new Result<(HierarchyInfo, MetadataInfo?, TInfo?)>((hierarchy, metadataInfo, info), diagnostics); }) .Where(static item => item is not null)!; diff --git a/src/CommunityToolkit.Mvvm.SourceGenerators/Input/RelayCommandGenerator.Execute.cs b/src/CommunityToolkit.Mvvm.SourceGenerators/Input/RelayCommandGenerator.Execute.cs index c2bb7cca..e5b75316 100644 --- a/src/CommunityToolkit.Mvvm.SourceGenerators/Input/RelayCommandGenerator.Execute.cs +++ b/src/CommunityToolkit.Mvvm.SourceGenerators/Input/RelayCommandGenerator.Execute.cs @@ -56,9 +56,13 @@ internal static class Execute goto Failure; } + token.ThrowIfCancellationRequested(); + // Get the command field and property names (string fieldName, string propertyName) = GetGeneratedFieldAndPropertyNames(methodSymbol); + token.ThrowIfCancellationRequested(); + // Get the command type symbols if (!TryMapCommandTypesFromMethod( methodSymbol, @@ -74,6 +78,8 @@ internal static class Execute goto Failure; } + token.ThrowIfCancellationRequested(); + // Check the switch to allow concurrent executions if (!TryGetAllowConcurrentExecutionsSwitch( methodSymbol, @@ -85,6 +91,8 @@ internal static class Execute goto Failure; } + token.ThrowIfCancellationRequested(); + // Check the switch to control exception flow if (!TryGetFlowExceptionsToTaskSchedulerSwitch( methodSymbol, @@ -96,11 +104,14 @@ internal static class Execute goto Failure; } + token.ThrowIfCancellationRequested(); + // Get the CanExecute expression type, if any if (!TryGetCanExecuteExpressionType( methodSymbol, attributeData, commandTypeArguments, + token, in builder, out string? canExecuteMemberName, out CanExecuteExpressionType? canExecuteExpressionType)) @@ -108,6 +119,8 @@ internal static class Execute goto Failure; } + token.ThrowIfCancellationRequested(); + // Get the option to include a cancel command, if any if (!TryGetIncludeCancelCommandSwitch( methodSymbol, @@ -120,6 +133,8 @@ internal static class Execute goto Failure; } + token.ThrowIfCancellationRequested(); + // Get all forwarded attributes (don't stop in case of errors, just ignore faulting attributes) GatherForwardedAttributes( methodSymbol, @@ -129,6 +144,8 @@ internal static class Execute out ImmutableArray fieldAttributes, out ImmutableArray propertyAttributes); + token.ThrowIfCancellationRequested(); + commandInfo = new CommandInfo( methodSymbol.Name, fieldName, @@ -749,6 +766,7 @@ public static (string FieldName, string PropertyName) GetGeneratedFieldAndProper /// The input instance to process. /// The instance for . /// The command type arguments, if any. + /// The cancellation token for the current operation. /// The current collection of gathered diagnostics. /// The resulting can execute member name, if available. /// The resulting expression type, if available. @@ -757,6 +775,7 @@ public static (string FieldName, string PropertyName) GetGeneratedFieldAndProper IMethodSymbol methodSymbol, AttributeData attributeData, ImmutableArray commandTypeArguments, + CancellationToken token, in ImmutableArrayBuilder diagnostics, out string? canExecuteMemberName, out CanExecuteExpressionType? canExecuteExpressionType) @@ -782,7 +801,7 @@ public static (string FieldName, string PropertyName) GetGeneratedFieldAndProper if (canExecuteSymbols.IsEmpty) { // Special case for when the target member is a generated property from [ObservableProperty] - if (TryGetCanExecuteMemberFromGeneratedProperty(memberName, methodSymbol.ContainingType, commandTypeArguments, out canExecuteExpressionType)) + if (TryGetCanExecuteMemberFromGeneratedProperty(memberName, methodSymbol.ContainingType, commandTypeArguments, token, out canExecuteExpressionType)) { canExecuteMemberName = memberName; @@ -892,12 +911,14 @@ public static (string FieldName, string PropertyName) GetGeneratedFieldAndProper /// The member name passed to [RelayCommand(CanExecute = ...)]. /// The containing type for the method annotated with [RelayCommand]. /// The type arguments for the command interface, if any. + /// The cancellation token for the current operation. /// The resulting can execute expression type, if available. /// Whether or not was set and the input symbol was valid. private static bool TryGetCanExecuteMemberFromGeneratedProperty( string memberName, INamedTypeSymbol containingType, ImmutableArray commandTypeArguments, + CancellationToken token, [NotNullWhen(true)] out CanExecuteExpressionType? canExecuteExpressionType) { foreach (ISymbol memberSymbol in containingType.GetAllMembers()) @@ -908,6 +929,8 @@ public static (string FieldName, string PropertyName) GetGeneratedFieldAndProper continue; } + token.ThrowIfCancellationRequested(); + ImmutableArray attributes = memberSymbol.GetAttributes(); // Only filter fields with the [ObservableProperty] attribute @@ -991,6 +1014,8 @@ public static (string FieldName, string PropertyName) GetGeneratedFieldAndProper continue; } + token.ThrowIfCancellationRequested(); + foreach (AttributeSyntax attribute in attributeList.Attributes) { // Get the symbol info for the attribute (once again just like in the [ObservableProperty] generator) diff --git a/src/CommunityToolkit.Mvvm.SourceGenerators/Input/RelayCommandGenerator.cs b/src/CommunityToolkit.Mvvm.SourceGenerators/Input/RelayCommandGenerator.cs index ddf24bf0..6e938d1c 100644 --- a/src/CommunityToolkit.Mvvm.SourceGenerators/Input/RelayCommandGenerator.cs +++ b/src/CommunityToolkit.Mvvm.SourceGenerators/Input/RelayCommandGenerator.cs @@ -40,6 +40,8 @@ public void Initialize(IncrementalGeneratorInitializationContext context) // Get the hierarchy info for the target symbol, and try to gather the command info HierarchyInfo? hierarchy = HierarchyInfo.From(methodSymbol.ContainingType); + token.ThrowIfCancellationRequested(); + _ = Execute.TryGetInfo( methodSymbol, context.Attributes[0], @@ -48,6 +50,8 @@ public void Initialize(IncrementalGeneratorInitializationContext context) out CommandInfo? commandInfo, out ImmutableArray diagnostics); + token.ThrowIfCancellationRequested(); + return (Hierarchy: hierarchy, new Result(commandInfo, diagnostics)); }) .Where(static item => item.Hierarchy is not null)!; diff --git a/src/CommunityToolkit.Mvvm.SourceGenerators/Messaging/IMessengerRegisterAllGenerator.cs b/src/CommunityToolkit.Mvvm.SourceGenerators/Messaging/IMessengerRegisterAllGenerator.cs index a9ee5b06..0a08ddd6 100644 --- a/src/CommunityToolkit.Mvvm.SourceGenerators/Messaging/IMessengerRegisterAllGenerator.cs +++ b/src/CommunityToolkit.Mvvm.SourceGenerators/Messaging/IMessengerRegisterAllGenerator.cs @@ -52,13 +52,19 @@ public void Initialize(IncrementalGeneratorInitializationContext context) ImmutableArray interfaceSymbols = Execute.GetInterfaces(typeSymbol); + token.ThrowIfCancellationRequested(); + // Check that the type implements at least one IRecipient interface if (interfaceSymbols.IsEmpty) { return default; } - return Execute.GetInfo(typeSymbol, interfaceSymbols); + RecipientInfo info = Execute.GetInfo(typeSymbol, interfaceSymbols); + + token.ThrowIfCancellationRequested(); + + return info; }) .Where(static item => item is not null)!;