Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUGFIX] Fix resource leak when shuffle read #174

Merged
merged 2 commits into from
Aug 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import scala.Tuple2;
import scala.collection.AbstractIterator;
import scala.collection.Iterator;
import scala.runtime.BoxedUnit;

import org.apache.uniffle.client.api.ShuffleReadClient;
import org.apache.uniffle.client.response.CompressedShuffleBlock;
Expand Down Expand Up @@ -130,9 +131,7 @@ public boolean hasNext() {
readTime += fetchDuration;
serializeTime += serializationDuration;
} else {
// finish reading records, close related reader and check data consistent
clearDeserializationStream();
shuffleReadClient.close();
// finish reading records, check data consistent
shuffleReadClient.checkProcessedBlockIds();
shuffleReadClient.logStatics();
LOG.info("Fetch " + shuffleReadMetrics.remoteBytesRead() + " bytes cost " + readTime + " ms and "
Expand All @@ -150,6 +149,15 @@ public Product2<K, C> next() {
return (Product2<K, C>) recordsIterator.next();
}

public BoxedUnit cleanup() {
clearDeserializationStream();
if (shuffleReadClient != null) {
shuffleReadClient.close();
}
shuffleReadClient = null;
return BoxedUnit.UNIT;
}

@VisibleForTesting
protected ShuffleReadMetrics getShuffleReadMetrics() {
return shuffleReadMetrics;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import org.apache.uniffle.client.api.ShuffleReadClient;
import org.apache.uniffle.client.impl.ShuffleReadClientImpl;
import org.apache.uniffle.client.util.ClientUtils;
import org.apache.uniffle.client.util.DefaultIdHelper;
Expand All @@ -46,6 +47,11 @@
import org.mockito.Mockito;
import org.roaringbitmap.longlong.Roaring64NavigableMap;

import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

public class RssShuffleDataIteratorTest extends AbstractRssReaderTest {

private static final Serializer KRYO_SERIALIZER = new KryoSerializer(new SparkConf(false));
Expand Down Expand Up @@ -235,4 +241,13 @@ public void readTest7() throws Exception {
}
}

@Test
public void cleanup() throws Exception {
ShuffleReadClient mockClient = mock(ShuffleReadClient.class);
doNothing().when(mockClient).close();
RssShuffleDataIterator dataIterator = new RssShuffleDataIterator(KRYO_SERIALIZER, mockClient, new ShuffleReadMetrics());
dataIterator.cleanup();
verify(mockClient, times(1)).close();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.spark.serializer.Serializer;
import org.apache.spark.shuffle.RssShuffleHandle;
import org.apache.spark.shuffle.ShuffleReader;
import org.apache.spark.util.CompletionIterator;
import org.apache.spark.util.CompletionIterator$;
import org.apache.spark.util.TaskCompletionListener;
import org.apache.spark.util.collection.ExternalSorter;
Expand Down Expand Up @@ -113,23 +114,32 @@ public Iterator<Product2<K, C>> read() {
RssShuffleDataIterator rssShuffleDataIterator = new RssShuffleDataIterator<K, C>(
shuffleDependency.serializer(), shuffleReadClient,
context.taskMetrics().shuffleReadMetrics());
CompletionIterator completionIterator =
CompletionIterator$.MODULE$.apply(rssShuffleDataIterator, new AbstractFunction0<BoxedUnit>() {
@Override
public BoxedUnit apply() {
return rssShuffleDataIterator.cleanup();
}
});
context.addTaskCompletionListener(context -> {
completionIterator.completion();
});

Iterator<Product2<K, C>> resultIter = null;
Iterator<Product2<K, C>> aggregatedIter = null;

if (shuffleDependency.aggregator().isDefined()) {
if (shuffleDependency.mapSideCombine()) {
// We are reading values that are already combined
aggregatedIter = shuffleDependency.aggregator().get().combineCombinersByKey(
rssShuffleDataIterator, context);
aggregatedIter = shuffleDependency.aggregator().get().combineCombinersByKey(completionIterator, context);
} else {
// We don't know the value type, but also don't care -- the dependency *should*
// have made sure its compatible w/ this aggregator, which will convert the value
// type to the combined type C
aggregatedIter = shuffleDependency.aggregator().get().combineValuesByKey(rssShuffleDataIterator, context);
aggregatedIter = shuffleDependency.aggregator().get().combineValuesByKey(completionIterator, context);
}
} else {
aggregatedIter = rssShuffleDataIterator;
aggregatedIter = completionIterator;
}

if (shuffleDependency.keyOrdering().isDefined()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.apache.spark.serializer.Serializer;
import org.apache.spark.shuffle.RssShuffleHandle;
import org.apache.spark.shuffle.ShuffleReader;
import org.apache.spark.util.CompletionIterator;
import org.apache.spark.util.CompletionIterator$;
import org.apache.spark.util.collection.ExternalSorter;
import org.roaringbitmap.longlong.Roaring64NavigableMap;
Expand Down Expand Up @@ -183,11 +184,11 @@ public Configuration getHadoopConf() {
}

class MultiPartitionIterator<K, C> extends AbstractIterator<Product2<K, C>> {
java.util.Iterator<RssShuffleDataIterator> iterator;
RssShuffleDataIterator dataIterator;
java.util.Iterator<CompletionIterator<Product2<K, C>, RssShuffleDataIterator<K, C>>> iterator;
CompletionIterator<Product2<K, C>, RssShuffleDataIterator<K, C>> dataIterator;

MultiPartitionIterator() {
List<RssShuffleDataIterator> iterators = Lists.newArrayList();
List<CompletionIterator<Product2<K, C>, RssShuffleDataIterator<K, C>>> iterators = Lists.newArrayList();
for (int partition = startPartition; partition < endPartition; partition++) {
if (partitionToExpectBlocks.get(partition).isEmpty()) {
LOG.info("{} partition is empty partition", partition);
Expand All @@ -201,13 +202,21 @@ class MultiPartitionIterator<K, C> extends AbstractIterator<Product2<K, C>> {
RssShuffleDataIterator iterator = new RssShuffleDataIterator<K, C>(
shuffleDependency.serializer(), shuffleReadClient,
readMetrics);
iterators.add(iterator);
CompletionIterator<Product2<K, C>, RssShuffleDataIterator<K, C>> completionIterator =
Copy link
Contributor

@jerqi jerqi Aug 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we use AQE, we could have many iterators, if we don't release the resource of them after we use them, we may occur OOM. I means that we use one iterator, and then release the iterator. We shouldn't release all the iterators at the end of task.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

org.apache.spark.util.CompletionIterator will clean up as soon as its wrapped RssShuffleDataIterator.hasNext returns false :

// org.apache.spark.util.CompletionIterator
  def hasNext: Boolean = {
    val r = iter.hasNext  // iter => RssShuffleDataIterator
    if (!r && !completed) {
      completed = true
      // reassign to release resources of highly resource consuming iterators early
      iter = Iterator.empty.asInstanceOf[I]
      completion()
    }
    r
  }

  def completion(): Unit  // completion() => RssShuffleDataIterator.cleanup

After this PR, if there is no special case for the Spark Task, the timing of resource cleanup is still when the RssShuffleDataIterator ends, not when the Spark Task ends.

This is the same behavior as before the PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

org.apache.spark.util.CompletionIterator will clean up as soon as its wrapped RssShuffleDataIterator.hasNext returns false :

// org.apache.spark.util.CompletionIterator
  def hasNext: Boolean = {
    val r = iter.hasNext  // iter => RssShuffleDataIterator
    if (!r && !completed) {
      completed = true
      // reassign to release resources of highly resource consuming iterators early
      iter = Iterator.empty.asInstanceOf[I]
      completion()
    }
    r
  }

  def completion(): Unit  // completion() => RssShuffleDataIterator.cleanup

After this PR, if there is no special case for the Spark Task, the timing of resource cleanup is still when the RssShuffleDataIterator ends, not when the Spark Task ends.

This is the same behavior as before the PR.

OK, I got it. Good catch.

CompletionIterator$.MODULE$.apply(iterator, () -> iterator.cleanup());
iterators.add(completionIterator);
}
iterator = iterators.iterator();
if (iterator.hasNext()) {
dataIterator = iterator.next();
iterator.remove();
}
context.addTaskCompletionListener((taskContext) -> {
if (dataIterator != null) {
dataIterator.completion();
}
iterator.forEachRemaining(ci -> ci.completion());
});
}

@Override
Expand Down