Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow lazy iteration for non-reiterables. #30851

Merged
merged 2 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,13 @@
import java.io.InputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.function.Function;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderException;
import org.apache.beam.sdk.coders.CustomCoder;
Expand All @@ -34,6 +37,8 @@
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.TupleTagList;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.AbstractIterator;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.HashMultiset;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterators;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
Expand Down Expand Up @@ -95,9 +100,8 @@ public CoGbkResult(
// according to their tag.
final Iterator<RawUnionValue> taggedIter = taggedValues.iterator();
int elementCount = 0;
boolean isReiterator = taggedIter instanceof Reiterator;
while (taggedIter.hasNext()) {
if (isReiterator && elementCount++ >= inMemoryElementCount) {
if (elementCount++ >= inMemoryElementCount) {
// Let the tails be lazy.
break;
}
Expand All @@ -113,6 +117,7 @@ public CoGbkResult(
}

if (!taggedIter.hasNext()) {
// Everything fits into memory, just store it.
valueMap = (List) valuesByTag;
return;
}
Expand All @@ -123,33 +128,51 @@ public CoGbkResult(
LOG.info(
"CoGbkResult has more than {} elements, reiteration (which may be slow) is required.",
inMemoryElementCount);
final Reiterator<RawUnionValue> tail = (Reiterator<RawUnionValue>) taggedIter;

// As we iterate over this re-iterable (e.g. while iterating for one tag) we populate values
// for other observed tags, if any.
ObservingReiterator<RawUnionValue> tip =
new ObservingReiterator<>(
tail,
new ObservingReiterator.Observer<RawUnionValue>() {
@Override
public void observeAt(ObservingReiterator<RawUnionValue> reiterator) {
((TagIterable<?>) valueMap.get(reiterator.peek().getUnionTag())).offer(reiterator);
}
valueMap = new ArrayList<>();
Function<Integer, Iterable<Object>> makeIterable;
if (taggedIter instanceof Reiterator) {
final Reiterator<RawUnionValue> tail = (Reiterator<RawUnionValue>) taggedIter;

// As we iterate over this re-iterable (e.g. while iterating for one tag) we populate values
// for other observed tags, if any.
ObservingReiterator<RawUnionValue> tip =
new ObservingReiterator<>(
tail,
new ObservingReiterator.Observer<RawUnionValue>() {
@Override
public void observeAt(ObservingReiterator<RawUnionValue> reiterator) {
((TagIterable<?>) valueMap.get(reiterator.peek().getUnionTag()))
.offer(reiterator);
}

@Override
public void done() {
// Inform all tags that we have reached the end of the iterable, so anything that
// can be observed has been observed.
for (Iterable<?> iter : valueMap) {
((TagIterable<?>) iter).finish();
@Override
public void done() {
// Inform all tags that we have reached the end of the iterable, so anything that
// can be observed has been observed.
for (Iterable<?> iter : valueMap) {
((TagIterable<?>) iter).finish();
}
}
}
});
});
makeIterable =
unionTag ->
new TagIterable<Object>(valuesByTag.get(unionTag), unionTag, minElementsPerTag, tip);
} else {
// Not reiterable, we have to filter each time, but there are some optimizations we can do...
boolean[] sharedSeenEnd = {false};
makeIterable =
unionTag ->
recordingFilteringIterable(
taggedValues,
unionTag,
Math.max(
inMemoryElementCount / (schema.size() * schema.size()), minElementsPerTag),
valueMap,
sharedSeenEnd);
}

valueMap = new ArrayList<>();
for (int unionTag = 0; unionTag < schema.size(); unionTag++) {
valueMap.add(
new TagIterable<Object>(valuesByTag.get(unionTag), unionTag, minElementsPerTag, tip));
valueMap.add(makeIterable.apply(unionTag));
}
}

Expand Down Expand Up @@ -712,4 +735,187 @@ private boolean maybeAdvance() {
};
}
}

private Iterable<Object> recordingFilteringIterable(
Iterable<RawUnionValue> taggedIteratable,
int unionTag,
int minElementsPerTag,
List<Iterable<?>> sharedValueMap,
boolean[] sharedSeenEnd) {
return () ->
new RecordingFilteringIterator(
taggedIteratable, unionTag, minElementsPerTag, sharedValueMap, sharedSeenEnd);
}

/**
* This iterator implements the optimization that if there are below a certain number of values
* for a given tag we cache those values in memory rather than reiterating and filtering each
* time.
*
* <p>This is done lazily by having each iterator keep track of what it has seen locally, and the
* first one to reach the end updates the shared map. As an added optimization, this iterable is
* also updated in place if it only had a small number of elements.
*
* <p>Unfortuantely iterators do not promise deterministic ordering, so we cannot share the work
* computing the local maps until the iterator is entirely exhausted.
*/
private static class RecordingFilteringIterator extends AbstractIterator<Object> {
private final Iterable<RawUnionValue> taggedIterable;
private final Iterator<RawUnionValue> taggedIterator;
private final int unionTag;

private final int minElementsPerTag;

private final List<List<Object>> localValueMap;
private final List<Iterable<?>> sharedValueMap;
private boolean[] sharedSeenEnd;

private enum RemainingStatus {
UNCOMPUTED,
UNCOMPUTABLE,
COMPUTED
}

private RemainingStatus remainingStatus = RemainingStatus.UNCOMPUTED;
private Iterator<Object> remaining;

public RecordingFilteringIterator(
Iterable<RawUnionValue> taggedIteratable,
int unionTag,
int minElementsPerTag,
List<Iterable<?>> sharedValueMap,
boolean[] sharedSeenEnd) {
this.taggedIterable = taggedIteratable;
this.taggedIterator = taggedIteratable.iterator();
this.unionTag = unionTag;
this.minElementsPerTag = minElementsPerTag;
localValueMap = new ArrayList<>();
for (int i = 0; i < sharedValueMap.size(); i++) {
localValueMap.add(new ArrayList<>());
}
this.sharedValueMap = sharedValueMap;
this.sharedSeenEnd = sharedSeenEnd;
}

@Override
protected Object computeNext() {
if (sharedSeenEnd[0]) {
if (remainingStatus == RemainingStatus.UNCOMPUTED) {
// We iterated all the way to the end (likely on another iterator) and may have only found
// a small number of values associated with this tag. Update this iterator, if possible,
// to return them directly.
remainingStatus = computeRemaining(sharedValueMap.get(unionTag));
}

if (remainingStatus == RemainingStatus.COMPUTED) {
if (remaining.hasNext()) {
return remaining.next();
} else {
return endOfData();
}
}
}

// Look for the next value with this tag, keeping track of up to minElementsPerTag values of
// all other iterables as we go.
while (taggedIterator.hasNext()) {
RawUnionValue unionValue = taggedIterator.next();
if (!sharedSeenEnd[0]) {
List<Object> valuesForTag = localValueMap.get(unionValue.getUnionTag());
if (valuesForTag != null) {
if (valuesForTag.size() < minElementsPerTag) {
valuesForTag.add(unionValue.getValue());
} else {
localValueMap.set(unionValue.getUnionTag(), null);
}
}
}
if (unionValue.getUnionTag() == unionTag) {
return unionValue.getValue();
}
}

// We got to the end of the iterable, update the shared set of values with those sets that
// were small enough to cache.
if (!sharedSeenEnd[0]) {
for (int i = 0; i < sharedValueMap.size(); i++) {
List<Object> localValues = localValueMap.get(i);
sharedValueMap.set(
i, localValues != null ? localValues : simpleFilteringIterable(taggedIterable, i));
}
sharedSeenEnd[0] = true;
}

return endOfData();
}

private RemainingStatus computeRemaining(Iterable<?> allValuesIter) {
if (!(allValuesIter instanceof Collection)) {
return RemainingStatus.UNCOMPUTABLE;
}
Collection<Object> allValues = (Collection<Object>) allValuesIter;
List<Object> seenValues = localValueMap.get(unionTag);
if (allValues.size() == seenValues.size()) {
remaining = Collections.emptyIterator();
return RemainingStatus.COMPUTED;
} else if (seenValues.size() == 0) {
remaining = allValues.iterator();
return RemainingStatus.COMPUTED;
} else if (seenValues.size() == 1) {
// Optimize the very common case.
Iterator<Object> iter = allValues.iterator();
if (Objects.equals(iter.next(), seenValues.get(0))) {
remaining = iter;
return RemainingStatus.COMPUTED;
} else {
ArrayList<Object> allButOne = Lists.newArrayList(allValues);
if (allButOne.remove(seenValues.get(0))) {
remaining = allButOne.iterator();
return RemainingStatus.COMPUTED;
} else {
return RemainingStatus.UNCOMPUTABLE;
}
}
} else {
try {
HashMultiset<Object> seenValueSet = HashMultiset.create(seenValues);
List<Object> unseenValues = new ArrayList<>();
for (Object value : allValues) {
if (!seenValueSet.remove(value)) {
unseenValues.add(value);
}
}
if (seenValueSet.isEmpty()) {
remaining = unseenValues.iterator();
return RemainingStatus.COMPUTED;
} else {
// Semantically equal values didn't hash or compare equal.
return RemainingStatus.UNCOMPUTABLE;
}
} catch (Exception exn) {
// There's no promise elements have correct hash semantics or properly handle nulls.
return RemainingStatus.UNCOMPUTABLE;
}
}
}
}

private static Iterable<Object> simpleFilteringIterable(
Iterable<RawUnionValue> taggedIterable, int unionTag) {
return () ->
new AbstractIterator<Object>() {
Iterator<RawUnionValue> taggedIterator = taggedIterable.iterator();

@Override
protected Object computeNext() {
while (taggedIterator.hasNext()) {
RawUnionValue unionValue = taggedIterator.next();
if (unionValue.getUnionTag() == unionTag) {
return unionValue.getValue();
}
}
return endOfData();
}
};
}
}