Skip to content

Commit

Permalink
Rework expressions and expression compilation
Browse files Browse the repository at this point in the history
Collapse multiple expression classes into a few generic ones; unify the code.
Compare expressions for node sharing prior to compilation, thereby improving compilation performance in those cases where there is node sharing.
Eliminate some unnecessary allocations during rule compilation (should help with #191 ).
  • Loading branch information
snikolayev committed Oct 14, 2019
1 parent 9f35970 commit 251b3f2
Show file tree
Hide file tree
Showing 22 changed files with 374 additions and 566 deletions.
61 changes: 0 additions & 61 deletions src/NRules/NRules/AgendaFilters/ActivationCondition.cs

This file was deleted.

14 changes: 7 additions & 7 deletions src/NRules/NRules/AgendaFilters/ActivationExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,25 @@

namespace NRules.AgendaFilters
{
internal interface IActivationExpression
internal interface IActivationExpression<out TResult>
{
object Invoke(AgendaContext context, Activation activation);
TResult Invoke(AgendaContext context, Activation activation);
}

internal class ActivationExpression : IActivationExpression
internal class ActivationExpression<TResult> : IActivationExpression<TResult>
{
private readonly LambdaExpression _expression;
private readonly FastDelegate<Func<object[], object>> _compiledExpression;
private readonly FastDelegate<Func<object[], TResult>> _compiledExpression;
private readonly IndexMap _tupleFactMap;

public ActivationExpression(LambdaExpression expression, FastDelegate<Func<object[], object>> compiledExpression, IndexMap tupleFactMap)
public ActivationExpression(LambdaExpression expression, FastDelegate<Func<object[], TResult>> compiledExpression, IndexMap tupleFactMap)
{
_expression = expression;
_compiledExpression = compiledExpression;
_tupleFactMap = tupleFactMap;
}

public object Invoke(AgendaContext context, Activation activation)
public TResult Invoke(AgendaContext context, Activation activation)
{
var tuple = activation.Tuple;
var activationFactMap = activation.FactMap;
Expand All @@ -39,7 +39,7 @@ public object Invoke(AgendaContext context, Activation activation)
}

Exception exception = null;
object result = null;
TResult result = default;
try
{
result = _compiledExpression.Delegate.Invoke(args);
Expand Down
6 changes: 3 additions & 3 deletions src/NRules/NRules/AgendaFilters/KeyChangeAgendaFilter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ namespace NRules.AgendaFilters
internal class KeyChangeAgendaFilter : IAgendaFilter
{
private const string KeyName = "ChangeKeys";
private readonly List<IActivationExpression> _keySelectors;
private readonly List<IActivationExpression<object>> _keySelectors;

public KeyChangeAgendaFilter(IEnumerable<IActivationExpression> keySelectors)
public KeyChangeAgendaFilter(IEnumerable<IActivationExpression<object>> keySelectors)
{
_keySelectors = new List<IActivationExpression>(keySelectors);
_keySelectors = new List<IActivationExpression<object>>(keySelectors);
}

public bool Accept(AgendaContext context, Activation activation)
Expand Down
6 changes: 3 additions & 3 deletions src/NRules/NRules/AgendaFilters/PredicateAgendaFilter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ namespace NRules.AgendaFilters
{
internal class PredicateAgendaFilter : IAgendaFilter
{
private readonly List<IActivationCondition> _conditions;
private readonly List<IActivationExpression<bool>> _conditions;

public PredicateAgendaFilter(IEnumerable<IActivationCondition> conditions)
public PredicateAgendaFilter(IEnumerable<IActivationExpression<bool>> conditions)
{
_conditions = new List<IActivationCondition>(conditions);
_conditions = new List<IActivationExpression<bool>>(conditions);
}

public bool Accept(AgendaContext context, Activation activation)
Expand Down
84 changes: 4 additions & 80 deletions src/NRules/NRules/Aggregators/AggregateExpression.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
using System;
using System.Linq.Expressions;
using NRules.Rete;
using NRules.Rete;
using NRules.RuleModel;
using NRules.Utilities;

namespace NRules.Aggregators
{
Expand All @@ -26,54 +23,12 @@ public interface IAggregateExpression
object Invoke(AggregationContext context, ITuple tuple, IFact fact);
}

internal class AggregateFactExpression : IAggregateExpression
{
private readonly LambdaExpression _expression;
private readonly FastDelegate<Func<object, object>> _compiledExpression;

public AggregateFactExpression(string name, LambdaExpression expression, FastDelegate<Func<object, object>> compiledExpression)
{
Name = name;
_expression = expression;
_compiledExpression = compiledExpression;
}

public string Name { get; }

public object Invoke(AggregationContext context, ITuple tuple, IFact fact)
{
var factValue = fact.Value;
Exception exception = null;
object result = null;
try
{
result = _compiledExpression.Delegate(factValue);
return result;
}
catch (Exception e)
{
exception = e;
bool isHandled = false;
context.EventAggregator.RaiseLhsExpressionFailed(context.Session, e, _expression, factValue, tuple, fact, context.NodeInfo, ref isHandled);
throw new ExpressionEvaluationException(e, _expression, isHandled);
}
finally
{
context.EventAggregator.RaiseLhsExpressionEvaluated(context.Session, exception, _expression, factValue, result, tuple, fact, context.NodeInfo);
}
}
}

internal class AggregateExpression : IAggregateExpression
{
private readonly LambdaExpression _expression;
private readonly IndexMap _factMap;
private readonly FastDelegate<Func<object[], object>> _compiledExpression;
private readonly ILhsExpression<object> _compiledExpression;

public AggregateExpression(string name, LambdaExpression expression, FastDelegate<Func<object[], object>> compiledExpression, IndexMap factMap)
public AggregateExpression(string name, ILhsExpression<object> compiledExpression)
{
_expression = expression;
_factMap = factMap;
_compiledExpression = compiledExpression;
Name = name;
}
Expand All @@ -82,38 +37,7 @@ public AggregateExpression(string name, LambdaExpression expression, FastDelegat

public object Invoke(AggregationContext context, ITuple tuple, IFact fact)
{
var args = new object[_compiledExpression.ArrayArgumentCount];
int index = tuple.Count - 1;
foreach (var tupleFact in tuple.Facts)
{
IndexMap.SetElementAt(args, _factMap[index], tupleFact.Value);
index--;
}
IndexMap.SetElementAt(args, _factMap[tuple.Count], fact.Value);

Exception exception = null;
object result = null;
try
{
result = _compiledExpression.Delegate(args);
return result;
}
catch (Exception e)
{
exception = e;
bool isHandled = false;
context.EventAggregator.RaiseLhsExpressionFailed(context.Session, e, _expression, args, tuple, fact, context.NodeInfo, ref isHandled);
throw new ExpressionEvaluationException(e, _expression, isHandled);
}
finally
{
context.EventAggregator.RaiseLhsExpressionEvaluated(context.Session, exception, _expression, args, result, tuple, fact, context.NodeInfo);
}
}

public override string ToString()
{
return _expression.ToString();
return _compiledExpression.Invoke(context.ExecutionContext, context.NodeInfo, tuple, fact);
}
}
}
11 changes: 4 additions & 7 deletions src/NRules/NRules/Aggregators/AggregationContext.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using NRules.Diagnostics;
using NRules.Rete;
using NRules.Rete;

namespace NRules.Aggregators
{
Expand All @@ -8,14 +7,12 @@ namespace NRules.Aggregators
/// </summary>
public class AggregationContext
{
internal ISessionInternal Session { get; }
internal IEventAggregator EventAggregator { get; }
internal IExecutionContext ExecutionContext { get; }
internal NodeDebugInfo NodeInfo { get; }

internal AggregationContext(ISessionInternal session, IEventAggregator eventAggregator, NodeDebugInfo nodeInfo)
internal AggregationContext(IExecutionContext executionContext, NodeDebugInfo nodeInfo)
{
Session = session;
EventAggregator = eventAggregator;
ExecutionContext = executionContext;
NodeInfo = nodeInfo;
}
}
Expand Down
6 changes: 3 additions & 3 deletions src/NRules/NRules/Diagnostics/NodeInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ internal static NodeInfo Create(TypeNode node)

internal static NodeInfo Create(SelectionNode node)
{
var conditions = new[] {node.Condition.ToString()};
var conditions = new[] {node.ExpressionElement.Expression.ToString()};
return new NodeInfo(NodeType.Selection, string.Empty, conditions, Empty, Empty);
}

Expand All @@ -59,7 +59,7 @@ internal static NodeInfo Create(AlphaMemoryNode node, IAlphaMemory memory)

internal static NodeInfo Create(JoinNode node)
{
var conditions = node.Conditions.Select(c => c.ToString());
var conditions = node.ExpressionElements.Select(c => c.Expression.ToString());
return new NodeInfo(NodeType.Join, string.Empty, conditions, Empty, Empty);
}

Expand All @@ -86,7 +86,7 @@ internal static NodeInfo Create(ObjectInputAdapter node)

internal static NodeInfo Create(BindingNode node)
{
var expressions = new[] {node.BindingExpression.ToString()};
var expressions = new[] {node.ExpressionElement.Expression.ToString()};
return new NodeInfo(NodeType.Binding, string.Empty, Empty, expressions, Empty);
}

Expand Down
10 changes: 5 additions & 5 deletions src/NRules/NRules/Rete/AggregateNode.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public AggregateNode(ITupleSource leftSource, IObjectSource rightSource, string

public override void PropagateAssert(IExecutionContext context, List<Tuple> tuples)
{
var aggregationContext = new AggregationContext(context.Session, context.EventAggregator, NodeInfo);
var aggregationContext = new AggregationContext(context, NodeInfo);
var joinedSets = JoinedSets(context, tuples);
var aggregation = new Aggregation();
foreach (var set in joinedSets)
Expand All @@ -37,7 +37,7 @@ public override void PropagateAssert(IExecutionContext context, List<Tuple> tupl

public override void PropagateUpdate(IExecutionContext context, List<Tuple> tuples)
{
var aggregationContext = new AggregationContext(context.Session, context.EventAggregator, NodeInfo);
var aggregationContext = new AggregationContext(context, NodeInfo);
var joinedSets = JoinedSets(context, tuples);
var aggregation = new Aggregation();
foreach (var set in joinedSets)
Expand Down Expand Up @@ -79,7 +79,7 @@ public override void PropagateRetract(IExecutionContext context, List<Tuple> tup

public override void PropagateAssert(IExecutionContext context, List<Fact> facts)
{
var aggregationContext = new AggregationContext(context.Session, context.EventAggregator, NodeInfo);
var aggregationContext = new AggregationContext(context, NodeInfo);
var joinedSets = JoinedSets(context, facts);
var aggregation = new Aggregation();
foreach (var set in joinedSets)
Expand All @@ -106,7 +106,7 @@ public override void PropagateAssert(IExecutionContext context, List<Fact> facts

public override void PropagateUpdate(IExecutionContext context, List<Fact> facts)
{
var aggregationContext = new AggregationContext(context.Session, context.EventAggregator, NodeInfo);
var aggregationContext = new AggregationContext(context, NodeInfo);
var joinedSets = JoinedSets(context, facts);
var aggregation = new Aggregation();
foreach (var set in joinedSets)
Expand Down Expand Up @@ -134,7 +134,7 @@ public override void PropagateUpdate(IExecutionContext context, List<Fact> facts

public override void PropagateRetract(IExecutionContext context, List<Fact> facts)
{
var aggregationContext = new AggregationContext(context.Session, context.EventAggregator, NodeInfo);
var aggregationContext = new AggregationContext(context, NodeInfo);
var joinedSets = JoinedSets(context, facts);
var aggregation = new Aggregation();
foreach (var set in joinedSets)
Expand Down

0 comments on commit 251b3f2

Please sign in to comment.