diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssShuffleDataIterator.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssShuffleDataIterator.java index bd8184c5b3..23e0364181 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssShuffleDataIterator.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssShuffleDataIterator.java @@ -34,6 +34,7 @@ import scala.Tuple2; import scala.collection.AbstractIterator; import scala.collection.Iterator; +import scala.runtime.BoxedUnit; import org.apache.uniffle.client.api.ShuffleReadClient; import org.apache.uniffle.client.response.CompressedShuffleBlock; @@ -130,9 +131,7 @@ public boolean hasNext() { readTime += fetchDuration; serializeTime += serializationDuration; } else { - // finish reading records, close related reader and check data consistent - clearDeserializationStream(); - shuffleReadClient.close(); + // finish reading records, check data consistent shuffleReadClient.checkProcessedBlockIds(); shuffleReadClient.logStatics(); LOG.info("Fetch " + shuffleReadMetrics.remoteBytesRead() + " bytes cost " + readTime + " ms and " @@ -150,6 +149,15 @@ public Product2 next() { return (Product2) recordsIterator.next(); } + public BoxedUnit cleanup() { + clearDeserializationStream(); + if (shuffleReadClient != null) { + shuffleReadClient.close(); + } + shuffleReadClient = null; + return BoxedUnit.UNIT; + } + @VisibleForTesting protected ShuffleReadMetrics getShuffleReadMetrics() { return shuffleReadMetrics; diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/RssShuffleDataIteratorTest.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/RssShuffleDataIteratorTest.java index 50c05033cd..12f5a3183b 100644 --- a/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/RssShuffleDataIteratorTest.java +++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/RssShuffleDataIteratorTest.java @@ -24,6 +24,7 @@ import com.google.common.collect.Lists; import com.google.common.collect.Maps; +import org.apache.uniffle.client.api.ShuffleReadClient; import org.apache.uniffle.client.impl.ShuffleReadClientImpl; import org.apache.uniffle.client.util.ClientUtils; import org.apache.uniffle.client.util.DefaultIdHelper; @@ -46,6 +47,11 @@ import org.mockito.Mockito; import org.roaringbitmap.longlong.Roaring64NavigableMap; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + public class RssShuffleDataIteratorTest extends AbstractRssReaderTest { private static final Serializer KRYO_SERIALIZER = new KryoSerializer(new SparkConf(false)); @@ -235,4 +241,13 @@ public void readTest7() throws Exception { } } + @Test + public void cleanup() throws Exception { + ShuffleReadClient mockClient = mock(ShuffleReadClient.class); + doNothing().when(mockClient).close(); + RssShuffleDataIterator dataIterator = new RssShuffleDataIterator(KRYO_SERIALIZER, mockClient, new ShuffleReadMetrics()); + dataIterator.cleanup(); + verify(mockClient, times(1)).close(); + } + } diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java index 3130521c9c..ef97bea3e8 100644 --- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java +++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java @@ -26,6 +26,7 @@ import org.apache.spark.serializer.Serializer; import org.apache.spark.shuffle.RssShuffleHandle; import org.apache.spark.shuffle.ShuffleReader; +import org.apache.spark.util.CompletionIterator; import org.apache.spark.util.CompletionIterator$; import org.apache.spark.util.TaskCompletionListener; import org.apache.spark.util.collection.ExternalSorter; @@ -113,6 +114,16 @@ public Iterator> read() { RssShuffleDataIterator rssShuffleDataIterator = new RssShuffleDataIterator( shuffleDependency.serializer(), shuffleReadClient, context.taskMetrics().shuffleReadMetrics()); + CompletionIterator completionIterator = + CompletionIterator$.MODULE$.apply(rssShuffleDataIterator, new AbstractFunction0() { + @Override + public BoxedUnit apply() { + return rssShuffleDataIterator.cleanup(); + } + }); + context.addTaskCompletionListener(context -> { + completionIterator.completion(); + }); Iterator> resultIter = null; Iterator> aggregatedIter = null; @@ -120,16 +131,15 @@ public Iterator> read() { if (shuffleDependency.aggregator().isDefined()) { if (shuffleDependency.mapSideCombine()) { // We are reading values that are already combined - aggregatedIter = shuffleDependency.aggregator().get().combineCombinersByKey( - rssShuffleDataIterator, context); + aggregatedIter = shuffleDependency.aggregator().get().combineCombinersByKey(completionIterator, context); } else { // We don't know the value type, but also don't care -- the dependency *should* // have made sure its compatible w/ this aggregator, which will convert the value // type to the combined type C - aggregatedIter = shuffleDependency.aggregator().get().combineValuesByKey(rssShuffleDataIterator, context); + aggregatedIter = shuffleDependency.aggregator().get().combineValuesByKey(completionIterator, context); } } else { - aggregatedIter = rssShuffleDataIterator; + aggregatedIter = completionIterator; } if (shuffleDependency.keyOrdering().isDefined()) { diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java index 6d6025aa0e..a565cfe44f 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java @@ -30,6 +30,7 @@ import org.apache.spark.serializer.Serializer; import org.apache.spark.shuffle.RssShuffleHandle; import org.apache.spark.shuffle.ShuffleReader; +import org.apache.spark.util.CompletionIterator; import org.apache.spark.util.CompletionIterator$; import org.apache.spark.util.collection.ExternalSorter; import org.roaringbitmap.longlong.Roaring64NavigableMap; @@ -183,11 +184,11 @@ public Configuration getHadoopConf() { } class MultiPartitionIterator extends AbstractIterator> { - java.util.Iterator iterator; - RssShuffleDataIterator dataIterator; + java.util.Iterator, RssShuffleDataIterator>> iterator; + CompletionIterator, RssShuffleDataIterator> dataIterator; MultiPartitionIterator() { - List iterators = Lists.newArrayList(); + List, RssShuffleDataIterator>> iterators = Lists.newArrayList(); for (int partition = startPartition; partition < endPartition; partition++) { if (partitionToExpectBlocks.get(partition).isEmpty()) { LOG.info("{} partition is empty partition", partition); @@ -201,13 +202,21 @@ class MultiPartitionIterator extends AbstractIterator> { RssShuffleDataIterator iterator = new RssShuffleDataIterator( shuffleDependency.serializer(), shuffleReadClient, readMetrics); - iterators.add(iterator); + CompletionIterator, RssShuffleDataIterator> completionIterator = + CompletionIterator$.MODULE$.apply(iterator, () -> iterator.cleanup()); + iterators.add(completionIterator); } iterator = iterators.iterator(); if (iterator.hasNext()) { dataIterator = iterator.next(); iterator.remove(); } + context.addTaskCompletionListener((taskContext) -> { + if (dataIterator != null) { + dataIterator.completion(); + } + iterator.forEachRemaining(ci -> ci.completion()); + }); } @Override