Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add nullability support for NRules #320

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/NRules/NRules.Fluent/Dsl/ContextExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ public static class ContextExtensions
/// <param name="fact">Existing fact to update.</param>
/// <param name="updateAction">Action to apply to the fact.</param>
public static void Update<T>(this IContext context, T fact, Action<T> updateAction)
where T : notnull
{
updateAction(fact);
context.Update(fact);
Expand Down
2 changes: 1 addition & 1 deletion src/NRules/NRules.Fluent/Dsl/IFilterExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ public interface IFilterExpression
/// </summary>
/// <param name="keySelectors">Key selector expressions.</param>
/// <returns>Filters expression builder.</returns>
IFilterExpression OnChange(params Expression<Func<object>>[] keySelectors);
IFilterExpression OnChange(params Expression<Func<object?>>[] keySelectors);

/// <summary>
/// Configures the engine to filter rule's matches given a set of predicates.
Expand Down
2 changes: 1 addition & 1 deletion src/NRules/NRules.Fluent/Expressions/FilterExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ public FilterExpression(FilterGroupBuilder builder, SymbolStack symbolStack)
_symbolStack = symbolStack;
}

public IFilterExpression OnChange(params Expression<Func<object>>[] keySelectors)
public IFilterExpression OnChange(params Expression<Func<object?>>[] keySelectors)
{
foreach (var keySelector in keySelectors)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ internal class LeftHandSideExpression : ILeftHandSideExpression
{
private readonly GroupBuilder _builder;
private readonly SymbolStack _symbolStack;
private PatternBuilder _currentPatternBuilder;
private PatternBuilder? _currentPatternBuilder;

public LeftHandSideExpression(GroupBuilder builder, SymbolStack symbolStack)
{
Expand All @@ -29,7 +29,7 @@ public ILeftHandSideExpression Match<TFact>(Expression<Func<TFact>> alias, param

public ILeftHandSideExpression Match<TFact>(params Expression<Func<TFact, bool>>[] conditions)
{
var symbol = Expression.Parameter(typeof (TFact));
var symbol = Expression.Parameter(typeof(TFact));
var patternBuilder = _builder.Pattern(symbol.Type, symbol.Name);
patternBuilder.DslConditions(_symbolStack.Scope.Declarations, conditions);
_symbolStack.Scope.Add(patternBuilder.Declaration);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ namespace NRules.Fluent.Expressions;
internal class PatternExpressionRewriter : ExpressionRewriter
{
private readonly Declaration _patternDeclaration;
private ParameterExpression _originalParameter;
private ParameterExpression _normalizedParameter;
private ParameterExpression? _originalParameter;
private ParameterExpression? _normalizedParameter;

public PatternExpressionRewriter(Declaration patternDeclaration, IEnumerable<Declaration> declarations)
: base(declarations)
Expand Down
41 changes: 27 additions & 14 deletions src/NRules/NRules.Fluent/Expressions/QueryExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ internal class QueryExpression : IQuery, IQueryBuilder
private readonly ParameterExpression _symbol;
private readonly SymbolStack _symbolStack;
private readonly GroupBuilder _groupBuilder;
private Func<string, Type, BuildResult> _buildAction;
private Func<string?, Type?, BuildResult>? _buildAction;

public QueryExpression(ParameterExpression symbol, SymbolStack symbolStack, GroupBuilder groupBuilder)
{
Expand Down Expand Up @@ -52,7 +52,7 @@ public void From<TSource>(Expression<Func<TSource>> source)

public void Where<TSource>(Expression<Func<TSource, bool>>[] predicates)
{
var previousBuildAction = _buildAction;
var previousBuildAction = EnsureBuildAction();
_buildAction = (name, type) =>
{
var result = previousBuildAction(name, type);
Expand All @@ -63,7 +63,7 @@ public void Where<TSource>(Expression<Func<TSource, bool>>[] predicates)

public void Select<TSource, TResult>(Expression<Func<TSource, TResult>> selector)
{
var previousBuildAction = _buildAction;
var previousBuildAction = EnsureBuildAction();
_buildAction = (name, _) =>
{
var patternBuilder = new PatternBuilder(typeof(TResult), name);
Expand All @@ -87,7 +87,7 @@ public void Where<TSource>(Expression<Func<TSource, bool>>[] predicates)

public void SelectMany<TSource, TResult>(Expression<Func<TSource, IEnumerable<TResult>>> selector)
{
var previousBuildAction = _buildAction;
var previousBuildAction = EnsureBuildAction();
_buildAction = (name, _) =>
{
var patternBuilder = new PatternBuilder(typeof(TResult), name);
Expand All @@ -111,7 +111,7 @@ public void Where<TSource>(Expression<Func<TSource, bool>>[] predicates)

public void GroupBy<TSource, TKey, TElement>(Expression<Func<TSource, TKey>> keySelector, Expression<Func<TSource, TElement>> elementSelector)
{
var previousBuildAction = _buildAction;
var previousBuildAction = EnsureBuildAction();
_buildAction = (name, _) =>
{
var patternBuilder = new PatternBuilder(typeof(IGrouping<TKey, TElement>), name);
Expand All @@ -136,7 +136,7 @@ public void Where<TSource>(Expression<Func<TSource, bool>>[] predicates)

public void Collect<TSource>()
{
var previousBuildAction = _buildAction;
var previousBuildAction = EnsureBuildAction();
_buildAction = (name, type) =>
{
var patternBuilder = new PatternBuilder(type ?? typeof(IEnumerable<TSource>), name);
Expand All @@ -160,10 +160,14 @@ public void Collect<TSource>()

public void ToLookup<TSource, TKey, TElement>(Expression<Func<TSource, TKey>> keySelector, Expression<Func<TSource, TElement>> elementSelector)
{
var previousBuildAction = _buildAction;
var previousBuildAction = EnsureBuildAction();
_buildAction = (name, _) =>
{
var result = previousBuildAction(name, typeof(IKeyedLookup<TKey, TElement>));
if (result.Source is null || result.Aggregate is null)
{
throw new InvalidOperationException($"{nameof(ToLookup)} cannot be called directly after {nameof(From)} or {nameof(FactQuery)}");
}
var keySelectorExpression = result.Source.DslPatternExpression(_symbolStack.Scope.Declarations, keySelector);
var elementSelectorExpression = result.Source.DslPatternExpression(_symbolStack.Scope.Declarations, elementSelector);
result.Aggregate.ToLookup(keySelectorExpression, elementSelectorExpression);
Expand All @@ -173,10 +177,14 @@ public void Collect<TSource>()

public void OrderBy<TSource, TKey>(Expression<Func<TSource, TKey>> keySelector, SortDirection sortDirection)
{
var previousBuildAction = _buildAction;
var previousBuildAction = EnsureBuildAction();
_buildAction = (name, type) =>
{
var result = previousBuildAction(name, type);
if (result.Source is null || result.Aggregate is null)
{
throw new InvalidOperationException($"{nameof(OrderBy)} cannot be called directly after {nameof(From)} or {nameof(FactQuery)}");
}
var keySelectorExpression = result.Source.DslPatternExpression(_symbolStack.Scope.Declarations, keySelector);
result.Aggregate.OrderBy(keySelectorExpression, sortDirection);
return result;
Expand All @@ -188,9 +196,9 @@ public void Collect<TSource>()
Aggregate<TSource, TResult>(aggregateName, expressions, null);
}

public void Aggregate<TSource, TResult>(string aggregateName, IEnumerable<KeyValuePair<string, LambdaExpression>> expressions, Type customFactoryType)
public void Aggregate<TSource, TResult>(string aggregateName, IEnumerable<KeyValuePair<string, LambdaExpression>> expressions, Type? customFactoryType)
{
var previousBuildAction = _buildAction;
var previousBuildAction = EnsureBuildAction();
_buildAction = (name, _) =>
{
var patternBuilder = new PatternBuilder(typeof(TResult), name);
Expand Down Expand Up @@ -222,11 +230,16 @@ public void Collect<TSource>()

public PatternBuilder Build()
{
var patternBuilder = _buildAction(_symbol.Name, null);
var patternBuilder = EnsureBuildAction()(_symbol.Name, null);
_groupBuilder.Pattern(patternBuilder.Pattern);
return patternBuilder.Pattern;
}


private Func<string?, Type?, BuildResult> EnsureBuildAction()
{
return _buildAction ?? throw new InvalidOperationException($"{nameof(From)} or {nameof(FactQuery)} was not called");
}

private class BuildResult
{
public BuildResult(PatternBuilder pattern, AggregateBuilder aggregate, PatternBuilder source)
Expand All @@ -242,7 +255,7 @@ public BuildResult(PatternBuilder pattern)
}

public PatternBuilder Pattern { get; }
public AggregateBuilder Aggregate { get; }
public PatternBuilder Source { get; }
public AggregateBuilder? Aggregate { get; }
public PatternBuilder? Source { get; }
}
}
2 changes: 1 addition & 1 deletion src/NRules/NRules.Fluent/Expressions/SymbolTable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace NRules.Fluent.Expressions;
internal class SymbolTable
{
private readonly HashSet<Declaration> _declarations;
private readonly SymbolTable _parentScope;
private readonly SymbolTable? _parentScope;

internal SymbolTable()
{
Expand Down
1 change: 1 addition & 0 deletions src/NRules/NRules.Fluent/NRules.Fluent.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
<SignAssembly>True</SignAssembly>
<AssemblyOriginatorKeyFile>..\..\..\SigningKey.snk</AssemblyOriginatorKeyFile>
<LangVersion>latest</LangVersion>
<Nullable>enable</Nullable>
</PropertyGroup>

<ItemGroup>
Expand Down
19 changes: 7 additions & 12 deletions src/NRules/NRules.Fluent/RuleLoadSpec.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public interface IRuleLoadSpec
/// <param name="assemblies">Assemblies to load from.</param>
/// <returns>Spec to continue fluent configuration.</returns>
IRuleLoadSpec From(params Assembly[] assemblies);

/// <summary>
/// Specifies to load all rule definitions from a given collection of assemblies.
/// </summary>
Expand Down Expand Up @@ -83,14 +83,14 @@ internal class RuleLoadSpec : IRuleLoadSpec
{
private readonly IRuleActivator _activator;
private readonly RuleTypeScanner _typeScanner = new();
private Func<IRuleMetadata, bool> _filter;
private Func<IRuleMetadata, bool>? _filter;

public RuleLoadSpec(IRuleActivator activator)
{
_activator = activator;
}

public string RuleSetName { get; private set; }
public string? RuleSetName { get; private set; }

public IRuleLoadSpec PrivateTypes(bool include = true)
{
Expand Down Expand Up @@ -136,9 +136,9 @@ public IRuleLoadSpec From(Action<IRuleTypeScanner> scanAction)

public IRuleLoadSpec Where(Func<IRuleMetadata, bool> filter)
{
if (IsFilterSet())
if (_filter != null)
throw new InvalidOperationException("Rule load specification can only have a single 'Where' clause");

_filter = filter;
return this;
}
Expand All @@ -147,7 +147,7 @@ public IRuleLoadSpec To(string ruleSetName)
{
if (RuleSetName != null)
throw new InvalidOperationException("Rule load specification can only have a single 'To' clause");

RuleSetName = ruleSetName;
return this;
}
Expand All @@ -164,7 +164,7 @@ public IEnumerable<IRuleDefinition> Load()
private IEnumerable<Type> GetRuleTypes()
{
var ruleTypes = _typeScanner.GetRuleTypes();
if (IsFilterSet())
if (_filter != null)
{
var metadata = ruleTypes.Select(ruleType => new RuleMetadata(ruleType));
var filteredTypes = metadata.Where(x => _filter(x)).Select(x => x.RuleType);
Expand All @@ -185,9 +185,4 @@ private IEnumerable<Rule> Activate(Type type)
throw new RuleActivationException("Failed to activate rule type", type, e);
}
}

private bool IsFilterSet()
{
return _filter != null;
}
}
19 changes: 12 additions & 7 deletions src/NRules/NRules.Json/Converters/ExpressionConverter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,8 @@ private void WriteLambda(Utf8JsonWriter writer, JsonSerializerOptions options, L

private ConstantExpression ReadConstant(ref Utf8JsonReader reader, JsonSerializerOptions options)
{
var type = reader.ReadProperty<Type>(nameof(ConstantExpression.Type), options);
var type = reader.ReadProperty<Type>(nameof(ConstantExpression.Type), options)
?? throw new JsonException($"Property '{nameof(ConstantExpression.Type)}' should have not null value");
var value = reader.ReadProperty(nameof(ConstantExpression.Value), type, options);
return Expression.Constant(value, type);
}
Expand Down Expand Up @@ -238,10 +239,12 @@ private void WriteMethodCall(Utf8JsonWriter writer, JsonSerializerOptions option

private BinaryExpression ReadBinaryExpression(ref Utf8JsonReader reader, JsonSerializerOptions options, ExpressionType expressionType)
{
var left = reader.ReadProperty<Expression>(nameof(BinaryExpression.Left), options);
var right = reader.ReadProperty<Expression>(nameof(BinaryExpression.Right), options);
var left = reader.ReadProperty<Expression>(nameof(BinaryExpression.Left), options)
?? throw new JsonException($"Property '{nameof(BinaryExpression.Left)}' should have not null value");
var right = reader.ReadProperty<Expression>(nameof(BinaryExpression.Right), options)
?? throw new JsonException($"Property '{nameof(BinaryExpression.Right)}' should have not null value");

MethodInfo method = default;
MethodInfo? method = default;
if (reader.TryReadMethodInfo(options, out var methodRecord))
method = methodRecord.GetMethod(new[] { left.Type, right.Type }, left.Type);

Expand Down Expand Up @@ -313,10 +316,11 @@ private void WriteBinaryExpression(Utf8JsonWriter writer, JsonSerializerOptions

private UnaryExpression ReadUnaryExpression(ref Utf8JsonReader reader, JsonSerializerOptions options, ExpressionType expressionType)
{
var operand = reader.ReadProperty<Expression>(nameof(UnaryExpression.Operand), options);
var operand = reader.ReadProperty<Expression>(nameof(UnaryExpression.Operand), options)
?? throw new JsonException($"Property '{nameof(UnaryExpression.Operand)}' should have not null value");
reader.TryReadProperty<Type>(nameof(UnaryExpression.Type), options, out var type);

MethodInfo method = default;
MethodInfo? method = default;
if (reader.TryReadMethodInfo(options, out var methodRecord))
method = methodRecord.GetMethod(new[] { operand.Type }, operand.Type);

Expand Down Expand Up @@ -379,7 +383,8 @@ private void WriteTypeBinaryExpression(Utf8JsonWriter writer, JsonSerializerOpti

private NewExpression ReadNewExpression(ref Utf8JsonReader reader, JsonSerializerOptions options)
{
var declaringType = reader.ReadProperty<Type>(nameof(NewExpression.Constructor.DeclaringType), options);
var declaringType = reader.ReadProperty<Type>(nameof(NewExpression.Constructor.DeclaringType), options)
?? throw new JsonException($"Property '{nameof(NewExpression.Constructor.DeclaringType)}' should have not null value");
reader.TryReadArrayProperty<Expression>(nameof(NewExpression.Arguments), options, out var arguments);

var ctor = declaringType.GetConstructor(arguments.Select(x => x.Type).ToArray())
Expand Down
11 changes: 7 additions & 4 deletions src/NRules/NRules.Json/Converters/ExpressionElementConverter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ internal class NamedExpressionElementConverter : JsonConverter<NamedExpressionEl
public override NamedExpressionElement Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
reader.ReadStartObject();
var name = reader.ReadStringProperty(nameof(NamedExpressionElement.Name), options);
var expression = reader.ReadProperty<LambdaExpression>(nameof(NamedExpressionElement.Expression), options);
var name = reader.ReadStringProperty(nameof(NamedExpressionElement.Name), options)
?? throw new JsonException($"Property '{nameof(NamedExpressionElement.Name)}' should have not null value");
var expression = reader.ReadProperty<LambdaExpression>(nameof(NamedExpressionElement.Expression), options)
?? throw new JsonException($"Property '{nameof(NamedExpressionElement.Expression)}' should have not null value");
return Element.Expression(name, expression);
}

Expand All @@ -34,15 +36,16 @@ public override ActionElement Read(ref Utf8JsonReader reader, Type typeToConvert
reader.ReadStartObject();
if (!reader.TryReadEnumProperty<ActionTrigger>(nameof(ActionElement.ActionTrigger), options, out var trigger))
trigger = ActionElement.DefaultTrigger;
var expression = reader.ReadProperty<LambdaExpression>(nameof(ActionElement.Expression), options);
var expression = reader.ReadProperty<LambdaExpression>(nameof(ActionElement.Expression), options)
?? throw new JsonException($"Property '{nameof(ActionElement.Expression)}' should have not null value");
return Element.Action(expression, trigger);
}

public override void Write(Utf8JsonWriter writer, ActionElement value, JsonSerializerOptions options)
{
writer.WriteStartObject();
if (value.ActionTrigger != ActionElement.DefaultTrigger)
writer.WriteEnumProperty(nameof(ActionElement.ActionTrigger), value.ActionTrigger, options);
writer.WriteEnumProperty(nameof(ActionElement.ActionTrigger), value.ActionTrigger, options);
writer.WriteProperty(nameof(ActionElement.Expression), value.Expression, options);
writer.WriteEndObject();
}
Expand Down
8 changes: 4 additions & 4 deletions src/NRules/NRules.Json/Converters/MemberBindingExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace NRules.Json.Converters;

internal static class MemberBindingExtensions
{
public static MemberBinding ReadMemberBinding(this ref Utf8JsonReader reader, JsonSerializerOptions options, Type impiedType = null)
public static MemberBinding ReadMemberBinding(this ref Utf8JsonReader reader, JsonSerializerOptions options, Type? impiedType = null)
{
var bindingType = reader.ReadEnumProperty<MemberBindingType>(nameof(MemberBinding.BindingType), options);

Expand All @@ -20,7 +20,7 @@ public static MemberBinding ReadMemberBinding(this ref Utf8JsonReader reader, Js
}
}

public static void WriteMemberBinding(this Utf8JsonWriter writer, MemberBinding value, JsonSerializerOptions options, Type impliedType = null)
public static void WriteMemberBinding(this Utf8JsonWriter writer, MemberBinding value, JsonSerializerOptions options, Type? impliedType = null)
{
writer.WriteEnumProperty(nameof(value.BindingType), value.BindingType, options);

Expand All @@ -34,14 +34,14 @@ public static void WriteMemberBinding(this Utf8JsonWriter writer, MemberBinding
}
}

private static MemberAssignment ReadMemberAssignment(ref Utf8JsonReader reader, JsonSerializerOptions options, Type impliedType)
private static MemberAssignment ReadMemberAssignment(ref Utf8JsonReader reader, JsonSerializerOptions options, Type? impliedType)
{
var member = reader.ReadMemberInfo(options, impliedType);
var expression = reader.ReadProperty<Expression>(nameof(MemberExpression.Expression), options);
return Expression.Bind(member, expression);
}

private static void WriteMemberAssignment(Utf8JsonWriter writer, JsonSerializerOptions options, MemberAssignment value, Type impliedType)
private static void WriteMemberAssignment(Utf8JsonWriter writer, JsonSerializerOptions options, MemberAssignment value, Type? impliedType)
{
writer.WriteMemberInfo(options, value.Member, impliedType);
writer.WriteProperty(nameof(value.Expression), value.Expression, options);
Expand Down
Loading