Skip to content

Commit

Permalink
fix: nearby selection pinning support (#614)
Browse files Browse the repository at this point in the history
  • Loading branch information
triceo committed Feb 6, 2024
1 parent 6cdd9a3 commit 5c3bcdf
Show file tree
Hide file tree
Showing 20 changed files with 181 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -646,8 +646,6 @@ public List<Object> extractEntities(Solution_ solution) {

/**
* Returns the {@link PinningStatus} of the entity.
* If {@link PlanningPin} is enabled on the entity, the entity is fully pinned.
* Otherwise if {@link PlanningPinToIndex} is specified, returns the value of it.
*
* @param scoreDirector
* @param entity
Expand Down Expand Up @@ -684,12 +682,15 @@ public int extractFirstUnpinnedIndex(Object entity) {

public record PinningStatus(boolean hasPin, boolean entireEntityPinned, int firstUnpinnedIndex) {

private static final PinningStatus FULLY_PINNED = new PinningStatus(true, true, -1);
private static final PinningStatus UNPINNED = new PinningStatus(false, false, -1);

public static PinningStatus ofUnpinned() {
return new PinningStatus(false, false, -1);
return UNPINNED;
}

public static PinningStatus ofFullyPinned() {
return new PinningStatus(true, true, -1);
return FULLY_PINNED;
}

public static PinningStatus ofPinIndex(int firstUnpinnedIndex) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import java.util.Iterator;
import java.util.NoSuchElementException;

import ai.timefold.solver.core.impl.domain.variable.descriptor.ListVariableDescriptor;
import ai.timefold.solver.core.impl.heuristic.move.Move;
import ai.timefold.solver.core.impl.heuristic.selector.Selector;
import ai.timefold.solver.core.impl.heuristic.selector.entity.mimic.MimicReplayingEntitySelector;
import ai.timefold.solver.core.impl.heuristic.selector.list.ElementRef;

/**
* IMPORTANT: The constructor of any subclass of this abstract class, should never call any of its child
Expand Down Expand Up @@ -60,4 +62,27 @@ public String toString() {
}
}

/**
* Some destination iterators, such as nearby destination iterators, may return even elements which are pinned.
* This is because the nearby matrix always picks from all nearby elements, and is unaware of any pinning.
* This means that later we need to filter out the pinned elements, so that moves aren't generated for them.
*
* @param destinationIterator never null
* @param listVariableDescriptor never null
* @return null if no unpinned destination was found, at which point the iterator is exhausted.
*/
public static ElementRef findUnpinnedDestination(Iterator<ElementRef> destinationIterator,
ListVariableDescriptor<?> listVariableDescriptor) {
while (destinationIterator.hasNext()) {
var destination = destinationIterator.next();
var pinningStatus = listVariableDescriptor.getEntityDescriptor().extractPinningStatus(null, destination.entity());
var isPinned = pinningStatus.hasPin()
&& (pinningStatus.entireEntityPinned() || pinningStatus.firstUnpinnedIndex() > destination.index());
if (!isPinned) {
return destination;
}
}
return null;
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package ai.timefold.solver.core.impl.heuristic.selector.move.generic.list;

import static ai.timefold.solver.core.impl.heuristic.selector.move.generic.list.RandomListChangeIterator.findUnpinnedDestination;

import java.util.Collections;
import java.util.Iterator;

Expand Down Expand Up @@ -56,7 +58,10 @@ protected Move<Solution_> createUpcomingSelection() {
destinationIterator = destinationSelector.iterator();
}

ElementRef destination = destinationIterator.next();
ElementRef destination = findUnpinnedDestination(destinationIterator, listVariableDescriptor);
if (destination == null) {
return noUpcomingSelection();
}

if (upcomingSourceEntity == null && upcomingSourceIndex == null) {
return new ListAssignMove<>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,16 @@ protected Move<Solution_> createUpcomingSelection() {
}

Object upcomingValue = valueIterator.next();
ElementRef destination = destinationIterator.next();

ElementRef destination = findUnpinnedDestination(destinationIterator, listVariableDescriptor);
if (destination == null) {
return noUpcomingSelection();
}
return new ListChangeMove<>(
listVariableDescriptor,
inverseVariableSupply.getInverseSingleton(upcomingValue),
indexVariableSupply.getIndex(upcomingValue),
destination.entity(),
destination.index());
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package ai.timefold.solver.core.impl.heuristic.selector.move.generic.list;

import static ai.timefold.solver.core.impl.heuristic.selector.move.generic.list.RandomListChangeIterator.findUnpinnedDestination;

import java.util.Iterator;
import java.util.Random;

Expand Down Expand Up @@ -38,7 +40,10 @@ protected Move<Solution_> createUpcomingSelection() {
}

SubList subList = subListIterator.next();
ElementRef destination = destinationIterator.next();
ElementRef destination = findUnpinnedDestination(destinationIterator, listVariableDescriptor);
if (destination == null) {
return noUpcomingSelection();
}
boolean reversing = selectReversingMoveToo && workingRandom.nextBoolean();

return new SubListChangeMove<>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,24 +105,35 @@ private Iterator<Node_> getValuesOnSelectedEntitiesIterator(Node_[] pickedValues
@SuppressWarnings("unchecked")
private KOptDescriptor<Node_> pickKOptMove(int k) {
// The code in the paper used 1-index arrays
Node_[] pickedValues = (Node_[]) new Object[2 * k + 1];
Iterator<Node_> originIterator = (Iterator<Node_>) originSelector.iterator();
var pickedValues = (Node_[]) new Object[2 * k + 1];
var originIterator = (Iterator<Node_>) originSelector.iterator();

pickedValues[1] = originIterator.next();
int remainingAttempts = 20;
if (pickedValues[1] == null) {
return null;
}
var remainingAttempts = 20;
while (remainingAttempts > 0
&& getEffectiveListSize(listVariableDescriptor,
inverseVariableSupply.getInverseSingleton(pickedValues[1])) < 2) {
pickedValues[1] = originIterator.next();
remainingAttempts--;
do {
if (!originIterator.hasNext()) {
// Filtered selection due to pinning/unassigned may cause this.
// Filtered selectors only know the upper bound of their size, not their actual size.
// Therefore the iterator may be exhausted before the actual size is reached.
return null;
}
pickedValues[1] = originIterator.next();
remainingAttempts--;
} while ((pickedValues[1] == null));
}

if (remainingAttempts == 0) {
// could not find a value in a list with more than 1 element
return null;
}

EntityOrderInfo entityOrderInfo = EntityOrderInfo.of(pickedValues, inverseVariableSupply, listVariableDescriptor);
var entityOrderInfo = EntityOrderInfo.of(pickedValues, inverseVariableSupply, listVariableDescriptor);
pickedValues[2] = workingRandom.nextBoolean() ? getNodeSuccessor(entityOrderInfo, pickedValues[1])
: getNodePredecessor(entityOrderInfo, pickedValues[1]);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,28 @@ public void afterListVariableElementUnassigned(ListVariableDescriptor<Solution_>
@Override
public void beforeListVariableChanged(ListVariableDescriptor<Solution_> variableDescriptor, Object entity, int fromIndex,
int toIndex) {
// Pinning is implemented in generic moves, but custom moves need to take it into account as well.
// This fail-fast exists to detect situations where pinned things are being moved, in case of user error.
var entityDescriptor = variableDescriptor.getEntityDescriptor();
var pinningStatus = entityDescriptor.extractPinningStatus(this, entity);
if (pinningStatus.hasPin()) {
if (pinningStatus.entireEntityPinned()) {
throw new IllegalStateException("""
Attempting to change list variable (%s) on an entity (%s) which is fully pinned.
This is most likely a bug in a move.
Maybe you are using an improperly implemented custom move?"""
.formatted(variableDescriptor, entity));
}
int firstUnpinnedIndex = pinningStatus.firstUnpinnedIndex();
if (fromIndex < firstUnpinnedIndex || toIndex < firstUnpinnedIndex) {
throw new IllegalStateException(
"""
Attempting to change list variable (%s) on an entity (%s) in range [%d, %d), but the variable's first unpinned index is (%d).
This is most likely a bug in a move.
Maybe you are using an improperly implemented custom move?"""
.formatted(variableDescriptor, entity, fromIndex, toIndex, firstUnpinnedIndex));
}
}
variableListenerSupport.beforeListVariableChanged(variableDescriptor, entity, fromIndex, toIndex);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import java.util.List;
import java.util.function.Function;

import ai.timefold.solver.core.impl.domain.entity.descriptor.EntityDescriptor;
import ai.timefold.solver.core.impl.domain.variable.descriptor.ListVariableDescriptor;
import ai.timefold.solver.core.impl.domain.variable.index.IndexVariableDemand;
import ai.timefold.solver.core.impl.domain.variable.index.IndexVariableSupply;
Expand Down Expand Up @@ -102,8 +101,10 @@ void test3OptPinned() {
TestdataListEntity e1 = TestdataListEntity.createWithValues("e1", v1, v2, v3, v4, v5, v6, v7);

var variableDescriptorSpy = Mockito.spy(variableDescriptor);
var entityDescriptor = Mockito.mock(EntityDescriptor.class);
var entityDescriptor = Mockito.spy(TestdataListSolution.buildSolutionDescriptor()
.findEntityDescriptorOrFail(TestdataListEntity.class));
Mockito.when(variableDescriptorSpy.getEntityDescriptor()).thenReturn(entityDescriptor);
Mockito.when(entityDescriptor.supportsPinning()).thenReturn(true);
Mockito.when(entityDescriptor.extractFirstUnpinnedIndex(e1)).thenReturn(1);
Mockito.when(entityDescriptor.isMovable(null, e1)).thenReturn(true);

Expand Down Expand Up @@ -314,8 +315,10 @@ void testMultiEntity3OptPinned() {
TestdataListEntity e2 = TestdataListEntity.createWithValues("e2", v4, v5);

var variableDescriptorSpy = Mockito.spy(variableDescriptor);
var entityDescriptor = Mockito.mock(EntityDescriptor.class);
var entityDescriptor = Mockito.spy(TestdataListSolution.buildSolutionDescriptor()
.findEntityDescriptorOrFail(TestdataListEntity.class));
Mockito.when(variableDescriptorSpy.getEntityDescriptor()).thenReturn(entityDescriptor);
Mockito.when(entityDescriptor.supportsPinning()).thenReturn(true);
Mockito.when(entityDescriptor.extractFirstUnpinnedIndex(e1)).thenReturn(1);
Mockito.when(entityDescriptor.isMovable(null, e1)).thenReturn(true);
Mockito.when(entityDescriptor.isMovable(null, e2)).thenReturn(true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import static org.mockito.Mockito.verify;

import ai.timefold.solver.core.api.score.buildin.simple.SimpleScore;
import ai.timefold.solver.core.impl.domain.entity.descriptor.EntityDescriptor;
import ai.timefold.solver.core.impl.domain.variable.descriptor.ListVariableDescriptor;
import ai.timefold.solver.core.impl.heuristic.move.AbstractMove;
import ai.timefold.solver.core.impl.score.director.InnerScoreDirector;
Expand Down Expand Up @@ -239,8 +238,10 @@ void doMoveSecondEndsBeforeFirstPinned() {
TestdataListEntity e1 = TestdataListEntity.createWithValues("e1", v8, v7, v3, v4, v5, v6, v2, v1);

var variableDescriptorSpy = Mockito.spy(variableDescriptor);
var entityDescriptor = Mockito.mock(EntityDescriptor.class);
var entityDescriptor = Mockito.spy(TestdataListSolution.buildSolutionDescriptor()
.findEntityDescriptorOrFail(TestdataListEntity.class));
Mockito.when(variableDescriptorSpy.getEntityDescriptor()).thenReturn(entityDescriptor);
Mockito.when(entityDescriptor.supportsPinning()).thenReturn(true);
Mockito.when(entityDescriptor.extractFirstUnpinnedIndex(e1)).thenReturn(1);

// 2-Opt((v6, v2), (v7, v3))
Expand Down Expand Up @@ -281,9 +282,7 @@ void rebase() {
.rebase(destinationScoreDirector));
}

static void assertSameProperties(
Object destinationEntity, int destinationV1, int destinationV2,
TwoOptListMove<?> move) {
static void assertSameProperties(Object destinationEntity, int destinationV1, int destinationV2, TwoOptListMove<?> move) {
assertThat(move.getFirstEntity()).isSameAs(destinationEntity);
assertThat(move.getFirstEdgeEndpoint()).isEqualTo(destinationV1);
assertThat(move.getSecondEdgeEndpoint()).isEqualTo(destinationV2);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package ai.timefold.solver.core.impl.testdata.domain.list.pinned.index;

import ai.timefold.solver.core.impl.heuristic.selector.common.nearby.NearbyDistanceMeter;
import ai.timefold.solver.core.impl.testdata.domain.TestdataObject;

/**
* For the sake of test readability, planning values (list variable elements) are placed in a 1-dimensional space.
* An element's coordinate is represented by its ({@link TestdataObject#getCode() code}. If the code is not a number,
* it is interpreted as zero.
*/
public class TestdataPinnedWithIndexDistanceMeter
implements NearbyDistanceMeter<TestdataPinnedWithIndexListValue, TestdataObject> {

@Override
public double getNearbyDistance(TestdataPinnedWithIndexListValue origin, TestdataObject destination) {
return Math.abs(coordinate(destination) - coordinate(origin));
}

static int coordinate(TestdataObject o) {
try {
return Integer.parseInt(o.getCode());
} catch (NumberFormatException e) {
return 0;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@
import ai.timefold.solver.spring.boot.autoconfigure.dummy.MultipleSolutionsSpringTestConfiguration;
import ai.timefold.solver.spring.boot.autoconfigure.dummy.NoEntitySpringTestConfiguration;
import ai.timefold.solver.spring.boot.autoconfigure.dummy.NoSolutionSpringTestConfiguration;
import ai.timefold.solver.spring.boot.autoconfigure.dummy.chained.constraints.easy.DummyChainedSpringEasyScore;
import ai.timefold.solver.spring.boot.autoconfigure.dummy.chained.constraints.incremental.DummyChainedSpringIncrementalScore;
import ai.timefold.solver.spring.boot.autoconfigure.dummy.normal.constraints.easy.DummySpringEasyScore;
import ai.timefold.solver.spring.boot.autoconfigure.dummy.normal.constraints.incremental.DummySpringIncrementalScore;
import ai.timefold.solver.spring.boot.autoconfigure.gizmo.GizmoSpringTestConfiguration;
import ai.timefold.solver.spring.boot.autoconfigure.invalid.entity.InvalidEntitySpringTestConfiguration;
import ai.timefold.solver.spring.boot.autoconfigure.invalid.solution.InvalidSolutionSpringTestConfiguration;
Expand Down Expand Up @@ -602,7 +606,8 @@ void multipleEasyScoreConstraints() {
.withPropertyValues("timefold.solver.termination.best-score-limit=0")
.run(context -> context.getBean("solver1")))
.cause().message().contains(
"Multiple score classes classes", "DummyTestdataChainedSpringEasyScore", "DummyTestdataSpringEasyScore",
"Multiple score classes classes", DummyChainedSpringEasyScore.class.getSimpleName(),
DummySpringEasyScore.class.getSimpleName(),
"that implements EasyScoreCalculator were found in the classpath.");
}

Expand All @@ -613,8 +618,9 @@ void multipleConstraintProviderConstraints() {
.withPropertyValues("timefold.solver.termination.best-score-limit=0")
.run(context -> context.getBean("solver1")))
.cause().message().contains(
"Multiple score classes classes", "TestdataChainedSpringConstraintProvider",
"TestdataSpringConstraintProvider", "that implements ConstraintProvider were found in the classpath.");
"Multiple score classes classes", TestdataChainedSpringConstraintProvider.class.getSimpleName(),
TestdataSpringConstraintProvider.class.getSimpleName(),
"that implements ConstraintProvider were found in the classpath.");
}

@Test
Expand All @@ -624,8 +630,8 @@ void multipleIncrementalScoreConstraints() {
.withPropertyValues("timefold.solver.termination.best-score-limit=0")
.run(context -> context.getBean("solver1")))
.cause().message().contains(
"Multiple score classes classes", "DummyTestdataChainedSpringIncrementalScore",
"DummyTestdataSpringIncrementalScore",
"Multiple score classes classes", DummyChainedSpringIncrementalScore.class.getSimpleName(),
DummySpringIncrementalScore.class.getSimpleName(),
"that implements IncrementalScoreCalculator were found in the classpath.");
}

Expand All @@ -638,8 +644,8 @@ void multipleEasyScoreConstraintsXml_property() {
.run(context -> context.getBean("solver1")))
.cause().message().contains(
"Multiple score classes classes",
"DummyTestdataChainedSpringEasyScore",
"DummyTestdataSpringEasyScore",
DummyChainedSpringEasyScore.class.getSimpleName(),
DummySpringEasyScore.class.getSimpleName(),
"that implements EasyScoreCalculator were found in the classpath");
}

Expand All @@ -651,8 +657,9 @@ void multipleConstraintProviderConstraintsXml_property() {
"timefold.solver.solver-config-xml=ai/timefold/solver/spring/boot/autoconfigure/normalSolverConfig.xml")
.run(context -> context.getBean("solver1")))
.cause().message().contains(
"Multiple score classes classes", "TestdataChainedSpringConstraintProvider",
"TestdataSpringConstraintProvider", "that implements ConstraintProvider were found in the classpath.");
"Multiple score classes classes", TestdataChainedSpringConstraintProvider.class.getSimpleName(),
TestdataSpringConstraintProvider.class.getSimpleName(),
"that implements ConstraintProvider were found in the classpath.");
}

@Test
Expand All @@ -663,8 +670,8 @@ void multipleIncrementalScoreConstraintsXml_property() {
"timefold.solver.solver-config-xml=ai/timefold/solver/spring/boot/autoconfigure/normalSolverConfig.xml")
.run(context -> context.getBean("solver1")))
.cause().message().contains(
"Multiple score classes classes", "DummyTestdataChainedSpringIncrementalScore",
"DummyTestdataSpringIncrementalScore",
"Multiple score classes classes", DummyChainedSpringIncrementalScore.class.getSimpleName(),
DummySpringIncrementalScore.class.getSimpleName(),
"that implements IncrementalScoreCalculator were found in the classpath.");
}

Expand Down
Loading

0 comments on commit 5c3bcdf

Please sign in to comment.