Skip to content

Commit

Permalink
Allow lazy iteration for non-reiterables. (#30851)
Browse files Browse the repository at this point in the history
In particular Runner v2 does not produce Reiterables, which
resulted in the entire stream being read into memory. In this case
we can leverage the fact that the first 100MB will be cached and
quick to reiterate over.
  • Loading branch information
robertwb committed Apr 5, 2024
1 parent 2e630ac commit f1a47ef
Show file tree
Hide file tree
Showing 2 changed files with 294 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,13 @@
import java.io.InputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.function.Function;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderException;
import org.apache.beam.sdk.coders.CustomCoder;
Expand All @@ -34,6 +37,8 @@
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.TupleTagList;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.AbstractIterator;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.HashMultiset;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterators;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
Expand Down Expand Up @@ -95,9 +100,8 @@ public CoGbkResult(
// according to their tag.
final Iterator<RawUnionValue> taggedIter = taggedValues.iterator();
int elementCount = 0;
boolean isReiterator = taggedIter instanceof Reiterator;
while (taggedIter.hasNext()) {
if (isReiterator && elementCount++ >= inMemoryElementCount) {
if (elementCount++ >= inMemoryElementCount) {
// Let the tails be lazy.
break;
}
Expand All @@ -113,6 +117,7 @@ public CoGbkResult(
}

if (!taggedIter.hasNext()) {
// Everything fits into memory, just store it.
valueMap = (List) valuesByTag;
return;
}
Expand All @@ -123,33 +128,51 @@ public CoGbkResult(
LOG.info(
"CoGbkResult has more than {} elements, reiteration (which may be slow) is required.",
inMemoryElementCount);
final Reiterator<RawUnionValue> tail = (Reiterator<RawUnionValue>) taggedIter;

// As we iterate over this re-iterable (e.g. while iterating for one tag) we populate values
// for other observed tags, if any.
ObservingReiterator<RawUnionValue> tip =
new ObservingReiterator<>(
tail,
new ObservingReiterator.Observer<RawUnionValue>() {
@Override
public void observeAt(ObservingReiterator<RawUnionValue> reiterator) {
((TagIterable<?>) valueMap.get(reiterator.peek().getUnionTag())).offer(reiterator);
}
valueMap = new ArrayList<>();
Function<Integer, Iterable<Object>> makeIterable;
if (taggedIter instanceof Reiterator) {
final Reiterator<RawUnionValue> tail = (Reiterator<RawUnionValue>) taggedIter;

// As we iterate over this re-iterable (e.g. while iterating for one tag) we populate values
// for other observed tags, if any.
ObservingReiterator<RawUnionValue> tip =
new ObservingReiterator<>(
tail,
new ObservingReiterator.Observer<RawUnionValue>() {
@Override
public void observeAt(ObservingReiterator<RawUnionValue> reiterator) {
((TagIterable<?>) valueMap.get(reiterator.peek().getUnionTag()))
.offer(reiterator);
}

@Override
public void done() {
// Inform all tags that we have reached the end of the iterable, so anything that
// can be observed has been observed.
for (Iterable<?> iter : valueMap) {
((TagIterable<?>) iter).finish();
@Override
public void done() {
// Inform all tags that we have reached the end of the iterable, so anything that
// can be observed has been observed.
for (Iterable<?> iter : valueMap) {
((TagIterable<?>) iter).finish();
}
}
}
});
});
makeIterable =
unionTag ->
new TagIterable<Object>(valuesByTag.get(unionTag), unionTag, minElementsPerTag, tip);
} else {
// Not reiterable, we have to filter each time, but there are some optimizations we can do...
boolean[] sharedSeenEnd = {false};
makeIterable =
unionTag ->
recordingFilteringIterable(
taggedValues,
unionTag,
Math.max(
inMemoryElementCount / (schema.size() * schema.size()), minElementsPerTag),
valueMap,
sharedSeenEnd);
}

valueMap = new ArrayList<>();
for (int unionTag = 0; unionTag < schema.size(); unionTag++) {
valueMap.add(
new TagIterable<Object>(valuesByTag.get(unionTag), unionTag, minElementsPerTag, tip));
valueMap.add(makeIterable.apply(unionTag));
}
}

Expand Down Expand Up @@ -712,4 +735,187 @@ private boolean maybeAdvance() {
};
}
}

private Iterable<Object> recordingFilteringIterable(
Iterable<RawUnionValue> taggedIteratable,
int unionTag,
int minElementsPerTag,
List<Iterable<?>> sharedValueMap,
boolean[] sharedSeenEnd) {
return () ->
new RecordingFilteringIterator(
taggedIteratable, unionTag, minElementsPerTag, sharedValueMap, sharedSeenEnd);
}

/**
* This iterator implements the optimization that if there are below a certain number of values
* for a given tag we cache those values in memory rather than reiterating and filtering each
* time.
*
* <p>This is done lazily by having each iterator keep track of what it has seen locally, and the
* first one to reach the end updates the shared map. As an added optimization, this iterable is
* also updated in place if it only had a small number of elements.
*
* <p>Unfortuantely iterators do not promise deterministic ordering, so we cannot share the work
* computing the local maps until the iterator is entirely exhausted.
*/
private static class RecordingFilteringIterator extends AbstractIterator<Object> {
private final Iterable<RawUnionValue> taggedIterable;
private final Iterator<RawUnionValue> taggedIterator;
private final int unionTag;

private final int minElementsPerTag;

private final List<List<Object>> localValueMap;
private final List<Iterable<?>> sharedValueMap;
private boolean[] sharedSeenEnd;

private enum RemainingStatus {
UNCOMPUTED,
UNCOMPUTABLE,
COMPUTED
}

private RemainingStatus remainingStatus = RemainingStatus.UNCOMPUTED;
private Iterator<Object> remaining;

public RecordingFilteringIterator(
Iterable<RawUnionValue> taggedIteratable,
int unionTag,
int minElementsPerTag,
List<Iterable<?>> sharedValueMap,
boolean[] sharedSeenEnd) {
this.taggedIterable = taggedIteratable;
this.taggedIterator = taggedIteratable.iterator();
this.unionTag = unionTag;
this.minElementsPerTag = minElementsPerTag;
localValueMap = new ArrayList<>();
for (int i = 0; i < sharedValueMap.size(); i++) {
localValueMap.add(new ArrayList<>());
}
this.sharedValueMap = sharedValueMap;
this.sharedSeenEnd = sharedSeenEnd;
}

@Override
protected Object computeNext() {
if (sharedSeenEnd[0]) {
if (remainingStatus == RemainingStatus.UNCOMPUTED) {
// We iterated all the way to the end (likely on another iterator) and may have only found
// a small number of values associated with this tag. Update this iterator, if possible,
// to return them directly.
remainingStatus = computeRemaining(sharedValueMap.get(unionTag));
}

if (remainingStatus == RemainingStatus.COMPUTED) {
if (remaining.hasNext()) {
return remaining.next();
} else {
return endOfData();
}
}
}

// Look for the next value with this tag, keeping track of up to minElementsPerTag values of
// all other iterables as we go.
while (taggedIterator.hasNext()) {
RawUnionValue unionValue = taggedIterator.next();
if (!sharedSeenEnd[0]) {
List<Object> valuesForTag = localValueMap.get(unionValue.getUnionTag());
if (valuesForTag != null) {
if (valuesForTag.size() < minElementsPerTag) {
valuesForTag.add(unionValue.getValue());
} else {
localValueMap.set(unionValue.getUnionTag(), null);
}
}
}
if (unionValue.getUnionTag() == unionTag) {
return unionValue.getValue();
}
}

// We got to the end of the iterable, update the shared set of values with those sets that
// were small enough to cache.
if (!sharedSeenEnd[0]) {
for (int i = 0; i < sharedValueMap.size(); i++) {
List<Object> localValues = localValueMap.get(i);
sharedValueMap.set(
i, localValues != null ? localValues : simpleFilteringIterable(taggedIterable, i));
}
sharedSeenEnd[0] = true;
}

return endOfData();
}

private RemainingStatus computeRemaining(Iterable<?> allValuesIter) {
if (!(allValuesIter instanceof Collection)) {
return RemainingStatus.UNCOMPUTABLE;
}
Collection<Object> allValues = (Collection<Object>) allValuesIter;
List<Object> seenValues = localValueMap.get(unionTag);
if (allValues.size() == seenValues.size()) {
remaining = Collections.emptyIterator();
return RemainingStatus.COMPUTED;
} else if (seenValues.size() == 0) {
remaining = allValues.iterator();
return RemainingStatus.COMPUTED;
} else if (seenValues.size() == 1) {
// Optimize the very common case.
Iterator<Object> iter = allValues.iterator();
if (Objects.equals(iter.next(), seenValues.get(0))) {
remaining = iter;
return RemainingStatus.COMPUTED;
} else {
ArrayList<Object> allButOne = Lists.newArrayList(allValues);
if (allButOne.remove(seenValues.get(0))) {
remaining = allButOne.iterator();
return RemainingStatus.COMPUTED;
} else {
return RemainingStatus.UNCOMPUTABLE;
}
}
} else {
try {
HashMultiset<Object> seenValueSet = HashMultiset.create(seenValues);
List<Object> unseenValues = new ArrayList<>();
for (Object value : allValues) {
if (!seenValueSet.remove(value)) {
unseenValues.add(value);
}
}
if (seenValueSet.isEmpty()) {
remaining = unseenValues.iterator();
return RemainingStatus.COMPUTED;
} else {
// Semantically equal values didn't hash or compare equal.
return RemainingStatus.UNCOMPUTABLE;
}
} catch (Exception exn) {
// There's no promise elements have correct hash semantics or properly handle nulls.
return RemainingStatus.UNCOMPUTABLE;
}
}
}
}

private static Iterable<Object> simpleFilteringIterable(
Iterable<RawUnionValue> taggedIterable, int unionTag) {
return () ->
new AbstractIterator<Object>() {
Iterator<RawUnionValue> taggedIterator = taggedIterable.iterator();

@Override
protected Object computeNext() {
while (taggedIterator.hasNext()) {
RawUnionValue unionValue = taggedIterator.next();
if (unionValue.getUnionTag() == unionTag) {
return unionValue.getValue();
}
}
return endOfData();
}
};
}
}

0 comments on commit f1a47ef

Please sign in to comment.