Skip to content

Commit

Permalink
Performance improvements on the addPattern code path. Avoiding lots o…
Browse files Browse the repository at this point in the history
…f unnecessary Set creations. (#104)
  • Loading branch information
jonessha committed Jun 6, 2023
1 parent f432ef4 commit b223ae5
Show file tree
Hide file tree
Showing 12 changed files with 174 additions and 105 deletions.
4 changes: 2 additions & 2 deletions src/main/software/amazon/event/ruler/ByteMachine.java
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ private SingleByteTransition deleteMatchStepForWildcard(ByteState byteState, int
ByteTransition skipWildcardTransition = getTransition(byteState, characters[charIndex + 1]);
for (SingleByteTransition eachTrans : skipWildcardTransition.expand()) {
ByteState skipWildcardState = eachTrans.getNextByteState();
if (eachTrans.getMatches().isEmpty() && skipWildcardState != null &&
if (!eachTrans.getMatches().iterator().hasNext() && skipWildcardState != null &&
(skipWildcardState.hasNoTransitions() ||
skipWildcardState.hasOnlySelfReferentialTransition())) {
removeTransition(byteState, characters[charIndex + 1], eachTrans);
Expand Down Expand Up @@ -1926,7 +1926,7 @@ private static void addTransitionForMultiByteSet(ByteState state, InputMultiByte
}
}

private static ByteMatch findMatch(Set<ByteMatch> matches, Patterns pattern) {
private static ByteMatch findMatch(Iterable<ByteMatch> matches, Patterns pattern) {
for (ByteMatch match : matches) {
if (match.getPattern().equals(pattern)) {
return match;
Expand Down
105 changes: 56 additions & 49 deletions src/main/software/amazon/event/ruler/ByteMap.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package software.amazon.event.ruler;

import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
Expand Down Expand Up @@ -47,7 +48,7 @@ void addTransitionForAllBytes(final SingleByteTransition transition) {
for (Map.Entry<Integer, ByteTransition> entry : newMap.entrySet()) {
Set<SingleByteTransition> newSingles = new HashSet<>();
ByteTransition storedTransition = entry.getValue();
newSingles.addAll(expand(storedTransition));
expand(storedTransition).forEach(single -> newSingles.add(single));
newSingles.add(transition);
entry.setValue(coalesce(newSingles));
}
Expand All @@ -63,73 +64,58 @@ void removeTransitionForAllBytes(final SingleByteTransition transition) {
for (Map.Entry<Integer, ByteTransition> entry : newMap.entrySet()) {
Set<SingleByteTransition> newSingles = new HashSet<>();
ByteTransition storedTransition = entry.getValue();
newSingles.addAll(expand(storedTransition));
expand(storedTransition).forEach(single -> newSingles.add(single));
newSingles.remove(transition);
entry.setValue(coalesce(newSingles));
}
map = newMap;
}

/**
* Updates one ceiling=>transition mapping and leaves the map in a consistent state.
* It's a two-step process. First of all, we go through the map and find the entry that contains the byte value.
* Then we figure out whether the new byte value is at the bottom and/or the top of the entry (can be both if it's
* a singleton mapping). If it's not at the bottom, we write a new entry representing the part of the entry less
* than the byte value. Then we write a new entry for the byte value. Then if not at the top we write a new entry
* representing the proportion greater than the byte value. Then we merge entries mapping to the same transition.
* One effect is that you can remove the transition at position X by putting (X, null). An earlier implementation
* expanded the map to a ByteMapExtent[256] array of singletons, did the update, then contracted it, but that drove
* the compute/memory cost up so much that addRule and deleteRule were showing up in the profiler. Another earlier
* implementation tried to do the merging at the same time as the entry wrangling and dissolved into an
* incomprehensible pile of special-case code.
* Updates one ceiling=>transition mapping and leaves the map in a consistent state. We go through the map and find
* the entry that contains the byte value. If the new byte value is not at the bottom of the entry, we write a new
* entry representing the part of the entry less than the byte value. Then we write a new entry for the byte value.
* Then we merge entries mapping to the same transition.
*/
private void updateTransition(final byte utf8byte, final SingleByteTransition transition, Operation operation) {
final int index = utf8byte & 0xFF;
final NavigableMap<Integer, ByteTransition> newMap = new TreeMap<>();
ByteTransition target = map.higherEntry(index).getValue();

newMap.putAll(map.headMap(index, true));
Map.Entry<Integer, ByteTransition> targetEntry = map.higherEntry(index);
int targetCeiling = targetEntry.getKey();
ByteTransition target = targetEntry.getValue();

final boolean atBottom = (index == 0) || (!newMap.isEmpty() && newMap.lastKey() == index);
if (!atBottom) {
newMap.put(index, target);
}

Set<SingleByteTransition> singles = new HashSet<>();
if (operation != Operation.PUT) {
singles.addAll(expand(target));
}
if (operation == Operation.REMOVE) {
singles.remove(transition);
ByteTransition coalesced;
if (operation == Operation.PUT) {
coalesced = coalesce(transition);
} else {
singles.add(transition);
Iterable<SingleByteTransition> targetIterable = expand(target);
if (!targetIterable.iterator().hasNext()) {
coalesced = operation == Operation.ADD ? coalesce(transition) : null;
} else {
Set<SingleByteTransition> singles = new HashSet<>();
targetIterable.forEach(single -> singles.add(single));
if (operation == Operation.ADD) {
singles.add(transition);
} else {
singles.remove(transition);
}
coalesced = coalesce(singles);
}
}
ByteTransition coalesced = coalesce(singles);
newMap.put(index + 1, coalesced);

final boolean atTop = index == targetCeiling - 1;
if (!atTop) {
newMap.put(targetCeiling, target);
final boolean atBottom = index == 0 || map.containsKey(index);
if (!atBottom) {
map.put(index, target);
}

newMap.putAll(map.tailMap(targetCeiling, false));
map.put(index + 1, coalesced);

// Merge adjacent mappings with the same transition.
mergeAdjacentInMapIfNeeded(newMap);

// Update map in last step to enforce the happen-before relationship.
// We don't update content in map directly to avoid change being felt by other threads before complete.
map = newMap;
mergeAdjacentInMapIfNeeded(map);
}

/**
* Merge adjacent entries with equal transitions in inputMap.
*
* @param inputMap The map on which we merge adjacent entries with equal transitions.
*/
private void mergeAdjacentInMapIfNeeded(final NavigableMap<Integer, ByteTransition> inputMap) {
private static void mergeAdjacentInMapIfNeeded(final NavigableMap<Integer, ByteTransition> inputMap) {
Iterator<Map.Entry<Integer, ByteTransition>> iterator = inputMap.entrySet().iterator();
while (iterator.hasNext()) {
Map.Entry<Integer, ByteTransition> next1 = iterator.next();
Expand Down Expand Up @@ -169,14 +155,35 @@ ByteTransition getTransitionForAllBytes() {
if (firstByteTransition == null) {
return ByteMachine.EmptyByteTransition.INSTANCE;
}
candidates.addAll(firstByteTransition.expand());
firstByteTransition.expand().forEach(single -> candidates.add(single));

while (iterator.hasNext()) {
ByteTransition nextByteTransition = iterator.next();
if (nextByteTransition == null) {
return ByteMachine.EmptyByteTransition.INSTANCE;
}
candidates.retainAll(nextByteTransition.expand());
Iterable<SingleByteTransition> singles = nextByteTransition.expand();
if (singles instanceof Set) {
candidates.retainAll((Set) singles);
} else if (singles instanceof SingleByteTransition) {
SingleByteTransition single = (SingleByteTransition) singles;
if (candidates.contains(single)) {
if (candidates.size() > 1) {
candidates.clear();
candidates.add(single);
}
} else {
if (!candidates.isEmpty()) {
candidates.clear();
}
}
} else {
// singles should always be a Set or SingleByteTransition. Thus, this "else" is expected to be dead code
// but it is here for logical correctness if anything changes in the future.
Set<SingleByteTransition> set = new HashSet<>();
singles.forEach(single -> set.add(single));
candidates.retainAll(set);
}
if (candidates.isEmpty()) {
return ByteMachine.EmptyByteTransition.INSTANCE;
}
Expand Down Expand Up @@ -218,15 +225,15 @@ private Set<SingleByteTransition> getSingleByteTransitions() {
Set<SingleByteTransition> allTransitions = new HashSet<>();
for (ByteTransition transition : map.values()) {
if (transition != null) {
allTransitions.addAll(expand(transition));
expand(transition).forEach(single -> allTransitions.add(single));
}
}
return allTransitions;
}

private static Set<SingleByteTransition> expand(ByteTransition transition) {
private static Iterable<SingleByteTransition> expand(ByteTransition transition) {
if (transition == null) {
return new HashSet<>();
return Collections.EMPTY_SET;
}
return transition.expand();
}
Expand Down
4 changes: 2 additions & 2 deletions src/main/software/amazon/event/ruler/ByteState.java
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,14 @@ ByteTransition getTransitionForAllBytes() {
}

@Override
Set<ByteTransition> getTransitions() {
Iterable<ByteTransition> getTransitions() {
// Saving the value to avoid reading an updated value
Object transitionStore = this.transitionStore;
if (transitionStore == null) {
return Collections.emptySet();
} else if (transitionStore instanceof SingleByteTransitionEntry) {
SingleByteTransitionEntry entry = (SingleByteTransitionEntry) transitionStore;
return Stream.of(entry.transition).collect(Collectors.toSet());
return entry.transition;
}
ByteMap map = (ByteMap) transitionStore;
return map.getTransitions();
Expand Down
16 changes: 8 additions & 8 deletions src/main/software/amazon/event/ruler/ByteTransition.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,31 +36,31 @@ abstract class ByteTransition implements Cloneable {
/**
* Get all the unique transitions (single or compound) reachable from this transition by any UTF-8 byte value.
*
* @return Set of all transitions reachable from this transition.
* @return Iterable of all transitions reachable from this transition.
*/
abstract Set<ByteTransition> getTransitions();
abstract Iterable<ByteTransition> getTransitions();

/**
* Returns matches that are triggered if this transition is made. This is a convenience function that traverses the
* linked list of matches and returns all of them in a Set.
* linked list of matches and returns all of them in an Iterable.
*
* @return matches that are triggered if this transition is made.
*/
abstract Set<ByteMatch> getMatches();
abstract Iterable<ByteMatch> getMatches();

/**
* Returns all shortcuts that are available if this transition is made.
*
* @return all shortcuts
*/
abstract Set<ShortcutTransition> getShortcuts();
abstract Iterable<ShortcutTransition> getShortcuts();

/**
* Get all transitions represented by this transition (can be more than one if this is a compound transition).
*
* @return A set of all transitions represented by this transition.
* @return An iterable of all transitions represented by this transition.
*/
abstract Set<SingleByteTransition> expand();
abstract Iterable<SingleByteTransition> expand();

/**
* Get a transition that represents all of the next byte states for this transition.
Expand Down Expand Up @@ -93,7 +93,7 @@ boolean isMatchTrans() {
* @return boolean
*/
boolean isEmpty() {
return getMatches().isEmpty() && getNextByteState() == null;
return !getMatches().iterator().hasNext() && getNextByteState() == null;
}

/**
Expand Down
31 changes: 18 additions & 13 deletions src/main/software/amazon/event/ruler/CompoundByteTransition.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import javax.annotation.Nullable;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Objects;
import java.util.Set;
import java.util.TreeSet;
Expand Down Expand Up @@ -62,13 +63,23 @@ private CompoundByteTransition(Set<SingleByteTransition> byteTransitions) {
}
}

static <T extends ByteTransition> T coalesce(Set<SingleByteTransition> singles) {
if (singles.isEmpty()) {
static <T extends ByteTransition> T coalesce(Iterable<SingleByteTransition> singles) {
Iterator<SingleByteTransition> iterator = singles.iterator();
if (!iterator.hasNext()) {
return null;
} else if (singles.size() == 1) {
return (T) singles.iterator().next();
}

SingleByteTransition firstElement = iterator.next();
if (!iterator.hasNext()) {
return (T) firstElement;
} else if (singles instanceof Set) {
return (T) new CompoundByteTransition((Set) singles);
} else {
return (T) new CompoundByteTransition(singles);
// We expect Iterables with more than one element to always be Sets, so this should be dead code, but adding
// it here for future-proofing.
Set<SingleByteTransition> set = new HashSet();
singles.forEach(single -> set.add(single));
return (T) new CompoundByteTransition(set);
}
}

Expand Down Expand Up @@ -129,7 +140,7 @@ ByteTransition getTransitionForNextByteStates() {
public Set<ByteMatch> getMatches() {
Set<ByteMatch> matches = new HashSet<>();
for (SingleByteTransition single : matchableTransitions) {
matches.addAll(single.getMatches());
single.getMatches().forEach(match -> matches.add(match));
}
return matches;
}
Expand All @@ -151,13 +162,7 @@ public ByteTransition getTransition(byte utf8byte) {
for (SingleByteTransition transition : this.byteTransitions) {
ByteTransition nextTransition = transition.getTransition(utf8byte);
if (nextTransition != null) {
if (nextTransition instanceof SingleByteTransition) {
// A little hacky. Could just expand nextTransition like in the else case, but it is measurably more
// performant to avoid the intermediate Set creation.
singles.add((SingleByteTransition) nextTransition);
} else {
nextTransition.expand().forEach(t -> singles.add(t));
}
nextTransition.expand().forEach(t -> singles.add(t));
}
}
return coalesce(singles);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ int evaluate(ByteState state) {
int maxSize = 0;

// We'll do a breadth-first-search but it shouldn't matter.
Queue<ByteTransition> transitions = new LinkedList<>(state.getTransitions());
Queue<ByteTransition> transitions = new LinkedList<>();
state.getTransitions().forEach(trans -> transitions.add(trans));
while (!transitions.isEmpty()) {
ByteTransition transition = transitions.remove();
if (visited.contains(transition)) {
Expand All @@ -68,10 +69,9 @@ int evaluate(ByteState state) {
// foo will also match foo*, we also need to include in our size wildcard patterns accessible from foo*.
ByteState nextState = single.getNextByteState();
if (nextState != null) {
Set<SingleByteTransition> transitionsForAllBytes = nextState.getTransitionForAllBytes().expand();
for (SingleByteTransition transitionForAllBytes : transitionsForAllBytes) {
for (SingleByteTransition transitionForAllBytes : nextState.getTransitionForAllBytes().expand()) {
if (!(transitionForAllBytes instanceof ByteMachine.EmptyByteTransition) &&
!(transition.expand().contains(transitionForAllBytes))) {
!contains(transition.expand(), transitionForAllBytes)) {
size += getWildcardPatterns(matchesAccessibleFromEachTransition.get(transitionForAllBytes))
.size();
}
Expand All @@ -89,7 +89,7 @@ int evaluate(ByteState state) {
// that could be accessed with a particular byte value.
ByteTransition nextTransition = transition.getTransitionForNextByteStates();
if (nextTransition != null) {
transitions.addAll(nextTransition.getTransitions());
nextTransition.getTransitions().forEach(trans -> transitions.add(trans));
}
}

Expand Down Expand Up @@ -153,7 +153,7 @@ private Map<SingleByteTransition, Set<ByteMatch>> getMatchesAccessibleFromEachTr
visited.add(transition);

// Add all matches directly accessible from this transition.
matches.addAll(transition.getMatches());
transition.getMatches().forEach(match -> matches.add(match));

// Push the next round of deeper states into the stack. By the time we return back to the current transition
// on the stack, all matches for deeper states will have been computed.
Expand All @@ -172,6 +172,18 @@ private Map<SingleByteTransition, Set<ByteMatch>> getMatchesAccessibleFromEachTr
return result;
}

private static boolean contains(Iterable<SingleByteTransition> iterable, SingleByteTransition single) {
if (iterable instanceof Set) {
return ((Set) iterable).contains(single);
}
for (SingleByteTransition eachSingle : iterable) {
if (single.equals(eachSingle)) {
return true;
}
}
return false;
}

private static Set<Patterns> getWildcardPatterns(Set<ByteMatch> matches) {
Set<Patterns> patterns = new HashSet<>();
for (ByteMatch match : matches) {
Expand Down
4 changes: 2 additions & 2 deletions src/main/software/amazon/event/ruler/ShortcutTransition.java
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ SingleByteTransition setMatch(ByteMatch match) {
}

@Override
public Set<ShortcutTransition> getShortcuts() {
return Collections.singleton(this);
public Iterable<ShortcutTransition> getShortcuts() {
return this;
}

@Override
Expand Down

0 comments on commit b223ae5

Please sign in to comment.