Skip to content

Commit

Permalink
fix: propagate retracts before updates in order to have clean indexes (
Browse files Browse the repository at this point in the history
…#193)

If an update of a tuple is sent into a join node before a retract of
that same tuple,
the joiner is effectively using stale information and this can lead to
various issues such as NPEs. 
(See the newly introduced BavetRegressionTest.)

To address this, we make three changes to how tuples are propagated:

First, the dirty queues (now called propagation queues) order their
tuples and make sure that retracts are propagated before updates, and
updates before inserts.
StaticPropagationQueue handles most nodes, 
but DynamicPropagationQueue is necessary for GroupNode and IfExistsNode,
which need more sophisticated (and slower) reordering behavior.

Second, nodes are distributed into layers, each layer has nodes that
depend on nodes from previous layers.
Propagations happen per layer; all nodes in a layer propagate, and then
we move on to the next layer.
Within each layer, all nodes first propagate retracts, then all nodes
propagate updates, then all nodes propagate inserts.

Finally, specialized ForEach nodes are introduced for cases with and
without nullity filters.
When a nullity filter is used, the ForEach node is now smart enough to
trigger an insert only when a predicate matches,
and therefore subsequent updates/retracts may avoid propagation
altogether.

Together, these changes significantly clean up and standardize tuple
propagation.
Performance drop of at most 15 % was seen.
Some cases benefit by as much as 5 %, perhaps due to the elimination of
useless join cross-products
that no longer result from join updates over tuples that would be later
retracted.

The average performance drop over all the measured use cases is around 5%.
Considering that these changes address an incorrect behavior, 
the performance impact has to be accepted.

All turtle tests pass locally.

---------

Co-authored-by: Radovan Synek <radovan.synek@gmail.com>
  • Loading branch information
triceo and rsynek committed Aug 14, 2023
1 parent 13bebd2 commit 0b47ed1
Show file tree
Hide file tree
Showing 137 changed files with 2,195 additions and 1,058 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import java.util.Map;
import java.util.Objects;
import java.util.function.Consumer;
import java.util.function.Predicate;

import ai.timefold.solver.constraint.streams.bavet.common.BavetAbstractConstraintStream;
import ai.timefold.solver.constraint.streams.bavet.uni.BavetAbstractUniConstraintStream;
Expand All @@ -13,6 +14,7 @@
import ai.timefold.solver.core.api.score.stream.uni.UniConstraintStream;
import ai.timefold.solver.core.config.solver.EnvironmentMode;
import ai.timefold.solver.core.impl.domain.constraintweight.descriptor.ConstraintConfigurationDescriptor;
import ai.timefold.solver.core.impl.domain.entity.descriptor.EntityDescriptor;
import ai.timefold.solver.core.impl.domain.solution.descriptor.SolutionDescriptor;

public final class BavetConstraintFactory<Solution_>
Expand Down Expand Up @@ -69,16 +71,43 @@ public <Stream_ extends BavetAbstractConstraintStream<Solution_>> Stream_ share(
// from
// ************************************************************************

@Override
public <A> UniConstraintStream<A> forEach(Class<A> sourceClass) {
assertValidFromType(sourceClass);
Predicate<A> nullityFilter = getNullityFilter(sourceClass);
return share(new BavetForEachUniConstraintStream<>(this, sourceClass, nullityFilter, RetrievalSemantics.STANDARD));
}

private <A> Predicate<A> getNullityFilter(Class<A> fromClass) {
EntityDescriptor<Solution_> entityDescriptor = getSolutionDescriptor().findEntityDescriptor(fromClass);
if (entityDescriptor != null && entityDescriptor.hasAnyGenuineVariables()) {
return (Predicate<A>) entityDescriptor.getHasNoNullVariables();
}
return null;
}

@Override
public <A> UniConstraintStream<A> forEachIncludingNullVars(Class<A> sourceClass) {
assertValidFromType(sourceClass);
return share(new BavetForEachUniConstraintStream<>(this, sourceClass, RetrievalSemantics.STANDARD));
return share(new BavetForEachUniConstraintStream<>(this, sourceClass, null, RetrievalSemantics.STANDARD));
}

@Override
public <A> UniConstraintStream<A> from(Class<A> fromClass) {
assertValidFromType(fromClass);
EntityDescriptor<Solution_> entityDescriptor = getSolutionDescriptor().findEntityDescriptor(fromClass);
if (entityDescriptor != null && entityDescriptor.hasAnyGenuineVariables()) {
Predicate<A> predicate = (Predicate<A>) entityDescriptor.getIsInitializedPredicate();
return share(new BavetForEachUniConstraintStream<>(this, fromClass, predicate, RetrievalSemantics.LEGACY));
} else {
return share(new BavetForEachUniConstraintStream<>(this, fromClass, null, RetrievalSemantics.LEGACY));
}
}

@Override
public <A> BavetAbstractUniConstraintStream<Solution_, A> fromUnfiltered(Class<A> fromClass) {
assertValidFromType(fromClass);
return share(new BavetForEachUniConstraintStream<>(this, fromClass, RetrievalSemantics.LEGACY));
return share(new BavetForEachUniConstraintStream<>(this, fromClass, null, RetrievalSemantics.LEGACY));
}

// ************************************************************************
Expand Down
Original file line number Diff line number Diff line change
@@ -1,73 +1,107 @@
package ai.timefold.solver.constraint.streams.bavet;

import java.util.Collections;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;

import ai.timefold.solver.constraint.streams.bavet.common.AbstractNode;
import ai.timefold.solver.constraint.streams.bavet.uni.ForEachUniNode;
import ai.timefold.solver.constraint.streams.bavet.common.PropagationQueue;
import ai.timefold.solver.constraint.streams.bavet.common.Propagator;
import ai.timefold.solver.constraint.streams.bavet.uni.AbstractForEachUniNode;
import ai.timefold.solver.constraint.streams.common.inliner.AbstractScoreInliner;
import ai.timefold.solver.core.api.score.Score;
import ai.timefold.solver.core.api.score.constraint.ConstraintMatchTotal;
import ai.timefold.solver.core.api.score.constraint.Indictment;

final class BavetConstraintSession<Score_ extends Score<Score_>> {
/**
* The type is public to make it easier for Bavet-specific minimal bug reproducers to be created.
* Instances should be created through {@link BavetConstraintStreamScoreDirectorFactory#newSession(boolean, Object)}.
*
* @see PropagationQueue Description of the tuple propagation mechanism.
* @param <Score_>
*/
public final class BavetConstraintSession<Score_ extends Score<Score_>> {

private final AbstractScoreInliner<Score_> scoreInliner;
private final Map<Class<?>, ForEachUniNode<Object>> declaredClassToNodeMap;
private final AbstractNode[] nodes; // Indexed by nodeIndex
private final Map<Class<?>, ForEachUniNode<Object>[]> effectiveClassToNodeArrayMap;
private final Map<Class<?>, List<AbstractForEachUniNode<Object>>> declaredClassToNodeMap;
private final Propagator[][] layeredNodes; // First level is the layer, second determines iteration order.
private final Map<Class<?>, AbstractForEachUniNode<Object>[]> effectiveClassToNodeArrayMap;

public BavetConstraintSession(AbstractScoreInliner<Score_> scoreInliner,
Map<Class<?>, ForEachUniNode<Object>> declaredClassToNodeMap,
AbstractNode[] nodes) {
BavetConstraintSession(AbstractScoreInliner<Score_> scoreInliner) {
this(scoreInliner, Collections.emptyMap(), new Propagator[0][0]);
}

BavetConstraintSession(AbstractScoreInliner<Score_> scoreInliner,
Map<Class<?>, List<AbstractForEachUniNode<Object>>> declaredClassToNodeMap,
Propagator[][] layeredNodes) {
this.scoreInliner = scoreInliner;
this.declaredClassToNodeMap = declaredClassToNodeMap;
this.nodes = nodes;
this.layeredNodes = layeredNodes;
this.effectiveClassToNodeArrayMap = new IdentityHashMap<>(declaredClassToNodeMap.size());
}

public void insert(Object fact) {
Class<?> factClass = fact.getClass();
for (ForEachUniNode<Object> node : findNodes(factClass)) {
var factClass = fact.getClass();
for (var node : findNodes(factClass)) {
node.insert(fact);
}
}

private ForEachUniNode<Object>[] findNodes(Class<?> factClass) {
private AbstractForEachUniNode<Object>[] findNodes(Class<?> factClass) {
// Map.computeIfAbsent() would have created lambdas on the hot path, this will not.
ForEachUniNode<Object>[] nodeArray = effectiveClassToNodeArrayMap.get(factClass);
var nodeArray = effectiveClassToNodeArrayMap.get(factClass);
if (nodeArray == null) {
nodeArray = declaredClassToNodeMap.entrySet()
.stream()
.filter(entry -> entry.getKey().isAssignableFrom(factClass))
.map(Map.Entry::getValue)
.toArray(ForEachUniNode[]::new);
.flatMap(List::stream)
.toArray(AbstractForEachUniNode[]::new);
effectiveClassToNodeArrayMap.put(factClass, nodeArray);
}
return nodeArray;
}

public void update(Object fact) {
Class<?> factClass = fact.getClass();
for (ForEachUniNode<Object> node : findNodes(factClass)) {
var factClass = fact.getClass();
for (var node : findNodes(factClass)) {
node.update(fact);
}
}

public void retract(Object fact) {
Class<?> factClass = fact.getClass();
for (ForEachUniNode<Object> node : findNodes(factClass)) {
var factClass = fact.getClass();
for (var node : findNodes(factClass)) {
node.retract(fact);
}
}

public Score_ calculateScore(int initScore) {
for (AbstractNode node : nodes) {
node.calculateScore();
var layerCount = layeredNodes.length;
for (var layerIndex = 0; layerIndex < layerCount; layerIndex++) {
calculateScoreInLayer(layerIndex);
}
return scoreInliner.extractScore(initScore);
}

private void calculateScoreInLayer(int layerIndex) {
var nodesInLayer = layeredNodes[layerIndex];
var nodeCount = nodesInLayer.length;
if (nodeCount == 1) {
nodesInLayer[0].propagateEverything();
} else {
for (var node : nodesInLayer) {
node.propagateRetracts();
}
for (var node : nodesInLayer) {
node.propagateUpdates();
}
for (var node : nodesInLayer) {
node.propagateInserts();
}
}
}

public AbstractScoreInliner<Score_> getScoreInliner() {
return scoreInliner;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,19 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.SortedMap;
import java.util.TreeMap;

import ai.timefold.solver.constraint.streams.bavet.common.AbstractIfExistsNode;
import ai.timefold.solver.constraint.streams.bavet.common.AbstractJoinNode;
import ai.timefold.solver.constraint.streams.bavet.common.AbstractNode;
import ai.timefold.solver.constraint.streams.bavet.common.BavetAbstractConstraintStream;
import ai.timefold.solver.constraint.streams.bavet.common.BavetIfExistsConstraintStream;
import ai.timefold.solver.constraint.streams.bavet.common.BavetJoinConstraintStream;
import ai.timefold.solver.constraint.streams.bavet.common.NodeBuildHelper;
import ai.timefold.solver.constraint.streams.bavet.uni.ForEachUniNode;
import ai.timefold.solver.constraint.streams.bavet.common.PropagationQueue;
import ai.timefold.solver.constraint.streams.bavet.common.Propagator;
import ai.timefold.solver.constraint.streams.bavet.uni.AbstractForEachUniNode;
import ai.timefold.solver.constraint.streams.common.inliner.AbstractScoreInliner;
import ai.timefold.solver.core.api.score.Score;
import ai.timefold.solver.core.api.score.stream.Constraint;
Expand Down Expand Up @@ -42,40 +50,108 @@ public BavetConstraintSession<Score_> buildSession(boolean constraintMatchEnable
Map<Constraint, Score_> constraintWeightMap = new HashMap<>(constraintList.size());
for (BavetConstraint<Solution_> constraint : constraintList) {
Score_ constraintWeight = constraint.extractConstraintWeight(workingSolution);
// Filter out nodes that only lead to constraints with zero weight.
// Note: Node sharing happens earlier, in BavetConstraintFactory#share(Stream_).
/*
* Filter out nodes that only lead to constraints with zero weight.
* Note: Node sharing happens earlier, in BavetConstraintFactory#share(Stream_).
*/
if (!constraintWeight.equals(zeroScore)) {
// Relies on BavetConstraintFactory#share(Stream_) occurring for all constraint stream instances
// to ensure there are no 2 equal ConstraintStream instances (with different child stream lists).
/*
* Relies on BavetConstraintFactory#share(Stream_) occurring for all constraint stream instances
* to ensure there are no 2 equal ConstraintStream instances (with different child stream lists).
*/
constraint.collectActiveConstraintStreams(constraintStreamSet);
constraintWeightMap.put(constraint, constraintWeight);
}
}
AbstractScoreInliner<Score_> scoreInliner =
AbstractScoreInliner.buildScoreInliner(scoreDefinition, constraintWeightMap, constraintMatchEnabled);
if (constraintStreamSet.isEmpty()) { // All constraints were disabled.
return new BavetConstraintSession<>(scoreInliner);
}
/*
* Build constraintStreamSet in reverse order to create downstream nodes first
* so every node only has final variables (some of which have downstream node method references).
*/
NodeBuildHelper<Score_> buildHelper = new NodeBuildHelper<>(constraintStreamSet, scoreInliner);
// Build constraintStreamSet in reverse order to create downstream nodes first
// so every node only has final variables (some of which have downstream node method references).
List<BavetAbstractConstraintStream<Solution_>> reversedConstraintStreamList = new ArrayList<>(constraintStreamSet);
Collections.reverse(reversedConstraintStreamList);
for (BavetAbstractConstraintStream<Solution_> constraintStream : reversedConstraintStreamList) {
constraintStream.buildNode(buildHelper);
}
List<AbstractNode> nodeList = buildHelper.destroyAndGetNodeList();
Map<Class<?>, ForEachUniNode<Object>> declaredClassToNodeMap = new LinkedHashMap<>();
Map<Class<?>, List<AbstractForEachUniNode<Object>>> declaredClassToNodeMap = new LinkedHashMap<>();
long nextNodeId = 0;
for (AbstractNode node : nodeList) {
/*
* Nodes are iterated first to last, starting with forEach(), the ultimate parent.
* Parents are guaranteed to come before children.
*/
node.setId(nextNodeId++);
if (node instanceof ForEachUniNode) {
ForEachUniNode<Object> forEachUniNode = (ForEachUniNode<Object>) node;
ForEachUniNode<Object> old = declaredClassToNodeMap.put(forEachUniNode.getForEachClass(), forEachUniNode);
if (old != null) {
throw new IllegalStateException("Impossible state: For class (" + forEachUniNode.getForEachClass()
+ ") there are 2 nodes (" + forEachUniNode + ", " + old + ").");
node.setLayerIndex(determineLayerIndex(node, buildHelper));
if (node instanceof AbstractForEachUniNode<?> forEachUniNode) {
Class<?> forEachClass = forEachUniNode.getForEachClass();
List<AbstractForEachUniNode<Object>> forEachUniNodeList =
declaredClassToNodeMap.computeIfAbsent(forEachClass, k -> new ArrayList<>());
if (forEachUniNodeList.size() == 2) {
// Each class can have at most two forEach nodes: one including null vars, the other excluding them.
throw new IllegalStateException("Impossible state: For class (" + forEachClass
+ ") there are already 2 nodes (" + forEachUniNodeList + "), not adding another ("
+ forEachUniNode + ").");
}
forEachUniNodeList.add((AbstractForEachUniNode<Object>) forEachUniNode);
}
}
return new BavetConstraintSession<>(scoreInliner, declaredClassToNodeMap, nodeList.toArray(new AbstractNode[0]));
SortedMap<Long, List<Propagator>> layerMap = new TreeMap<>();
for (AbstractNode node : nodeList) {
layerMap.computeIfAbsent(node.getLayerIndex(), k -> new ArrayList<>())
.add(node.getPropagator());
}
int layerCount = layerMap.size();
Propagator[][] layeredNodes = new Propagator[layerCount][];
for (int i = 0; i < layerCount; i++) {
List<Propagator> layer = layerMap.get((long) i);
layeredNodes[i] = layer.toArray(new Propagator[0]);
}
return new BavetConstraintSession<>(scoreInliner, declaredClassToNodeMap, layeredNodes);
}

/**
* Nodes are propagated in layers.
* See {@link PropagationQueue} and {@link AbstractNode} for details.
* This method determines the layer of each node.
* It does so by reverse-engineering the parent nodes of each node.
* Nodes without parents (forEach nodes) are in layer 0.
* Nodes with parents are one layer above their parents.
* Some nodes have multiple parents, such as {@link AbstractJoinNode} and {@link AbstractIfExistsNode}.
* These are one layer above the highest parent.
* This is done to ensure that, when a child node starts propagating, all its parents have already propagated.
*
* @param node never null
* @param buildHelper never null
* @return at least 0
*/
private long determineLayerIndex(AbstractNode node, NodeBuildHelper<Score_> buildHelper) {
if (node instanceof AbstractForEachUniNode<?>) { // ForEach nodes, and only they, are in layer 0.
return 0;
} else if (node instanceof AbstractJoinNode<?, ?, ?> joinNode) {
var nodeCreator = (BavetJoinConstraintStream<?>) buildHelper.getNodeCreatingStream(joinNode);
var leftParent = nodeCreator.getLeftParent();
var rightParent = nodeCreator.getRightParent();
var leftParentNode = buildHelper.findParentNode(leftParent);
var rightParentNode = buildHelper.findParentNode(rightParent);
return Math.max(leftParentNode.getLayerIndex(), rightParentNode.getLayerIndex()) + 1;
} else if (node instanceof AbstractIfExistsNode<?, ?> ifExistsNode) {
var nodeCreator = (BavetIfExistsConstraintStream<?>) buildHelper.getNodeCreatingStream(ifExistsNode);
var leftParent = nodeCreator.getLeftParent();
var rightParent = nodeCreator.getRightParent();
var leftParentNode = buildHelper.findParentNode(leftParent);
var rightParentNode = buildHelper.findParentNode(rightParent);
return Math.max(leftParentNode.getLayerIndex(), rightParentNode.getLayerIndex()) + 1;
} else {
var nodeCreator = (BavetAbstractConstraintStream<?>) buildHelper.getNodeCreatingStream(node);
var parentNode = buildHelper.findParentNode(nodeCreator.getParent());
return parentNode.getLayerIndex() + 1;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,19 @@ protected AbstractGroupBiNode(int groupStoreIndex, int undoStoreIndex,
Function<BiTuple<OldA, OldB>, GroupKey_> groupKeyFunction,
BiConstraintCollector<OldA, OldB, ResultContainer_, Result_> collector,
TupleLifecycle<OutTuple_> nextNodesTupleLifecycle, EnvironmentMode environmentMode) {
super(groupStoreIndex, undoStoreIndex, groupKeyFunction,
super(groupStoreIndex, undoStoreIndex,
groupKeyFunction,
collector == null ? null : collector.supplier(),
collector == null ? null : collector.finisher(),
nextNodesTupleLifecycle, environmentMode);
accumulator = collector == null ? null : collector.accumulator();
}

protected AbstractGroupBiNode(int groupStoreIndex, Function<BiTuple<OldA, OldB>, GroupKey_> groupKeyFunction,
TupleLifecycle<OutTuple_> nextNodesTupleLifecycle, EnvironmentMode environmentMode) {
super(groupStoreIndex, groupKeyFunction, nextNodesTupleLifecycle, environmentMode);
protected AbstractGroupBiNode(int groupStoreIndex,
Function<BiTuple<OldA, OldB>, GroupKey_> groupKeyFunction, TupleLifecycle<OutTuple_> nextNodesTupleLifecycle,
EnvironmentMode environmentMode) {
super(groupStoreIndex,
groupKeyFunction, nextNodesTupleLifecycle, environmentMode);
accumulator = null;
}

Expand Down

0 comments on commit 0b47ed1

Please sign in to comment.