Skip to content

Commit

Permalink
[BUG] Fix incorrect spark metrics (#324)
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Fix incorrect spark metrics

### Why are the changes needed?
1. The corresponding shuffle-read records number and shuffle-write records number is not consistent in our internal cluster
2. Log wont show the correct fetch bytes, always return 0 like 

`22/11/15 13:54:53 INFO RssShuffleDataIterator: Fetch 0 bytes cost 30791 ms and 53 ms to serialize, 347 ms to decompress with unCompressionLength[274815736]
`

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
1. UTs
2. Online spark3 jobs test
  • Loading branch information
zuston committed Nov 18, 2022
1 parent 2df7594 commit 79d2f54
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ public class RssShuffleDataIterator<K, C> extends AbstractIterator<Product2<K, C
private Input deserializationInput = null;
private DeserializationStream deserializationStream = null;
private ByteBufInputStream byteBufInputStream = null;
private long unCompressionLength = 0;
private long compressedBytesLength = 0;
private long unCompressedBytesLength = 0;
private ByteBuffer uncompressedData;
private Codec codec;

Expand Down Expand Up @@ -120,7 +121,9 @@ public boolean hasNext() {
long fetchDuration = System.currentTimeMillis() - startFetch;
shuffleReadMetrics.incFetchWaitTime(fetchDuration);
if (compressedData != null) {
shuffleReadMetrics.incRemoteBytesRead(compressedData.limit() - compressedData.position());
long compressedDataLength = compressedData.limit() - compressedData.position();
compressedBytesLength += compressedDataLength;
shuffleReadMetrics.incRemoteBytesRead(compressedDataLength);

int uncompressedLen = compressedBlock.getUncompressLength();
if (uncompressedData == null || uncompressedData.capacity() < uncompressedLen) {
Expand All @@ -129,7 +132,7 @@ public boolean hasNext() {
uncompressedData.clear();
long startDecompress = System.currentTimeMillis();
codec.decompress(compressedData, uncompressedLen, uncompressedData, 0);
unCompressionLength += compressedBlock.getUncompressLength();
unCompressedBytesLength += compressedBlock.getUncompressLength();
long decompressDuration = System.currentTimeMillis() - startDecompress;
decompressTime += decompressDuration;
// create new iterator for shuffle data
Expand All @@ -142,9 +145,9 @@ public boolean hasNext() {
// finish reading records, check data consistent
shuffleReadClient.checkProcessedBlockIds();
shuffleReadClient.logStatics();
LOG.info("Fetch " + shuffleReadMetrics.remoteBytesRead() + " bytes cost " + readTime + " ms and "
LOG.info("Fetch " + compressedBytesLength + " bytes cost " + readTime + " ms and "
+ serializeTime + " ms to serialize, " + decompressTime + " ms to decompress with unCompressionLength["
+ unCompressionLength + "]");
+ unCompressedBytesLength + "]");
return false;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import org.apache.spark.InterruptibleIterator;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.TaskContext;
import org.apache.spark.executor.ShuffleReadMetrics;
import org.apache.spark.executor.TempShuffleReadMetrics;
import org.apache.spark.serializer.Serializer;
import org.apache.spark.shuffle.RssShuffleHandle;
import org.apache.spark.shuffle.ShuffleReader;
Expand Down Expand Up @@ -117,11 +119,12 @@ public Iterator<Product2<K, C>> read() {
ShuffleReadClient shuffleReadClient = ShuffleClientFactory.getInstance().createShuffleReadClient(request);
RssShuffleDataIterator rssShuffleDataIterator = new RssShuffleDataIterator<K, C>(
shuffleDependency.serializer(), shuffleReadClient,
context.taskMetrics().shuffleReadMetrics(), rssConf);
new ReadMetrics(context.taskMetrics().createTempShuffleReadMetrics()), rssConf);
CompletionIterator completionIterator =
CompletionIterator$.MODULE$.apply(rssShuffleDataIterator, new AbstractFunction0<BoxedUnit>() {
@Override
public BoxedUnit apply() {
context.taskMetrics().mergeShuffleReadMetrics();
return rssShuffleDataIterator.cleanup();
}
});
Expand Down Expand Up @@ -194,4 +197,27 @@ private String getReadInfo() {
public Configuration getHadoopConf() {
return hadoopConf;
}

static class ReadMetrics extends ShuffleReadMetrics {
private TempShuffleReadMetrics tempShuffleReadMetrics;

ReadMetrics(TempShuffleReadMetrics tempShuffleReadMetric) {
this.tempShuffleReadMetrics = tempShuffleReadMetric;
}

@Override
public void incRemoteBytesRead(long v) {
tempShuffleReadMetrics.incRemoteBytesRead(v);
}

@Override
public void incFetchWaitTime(long v) {
tempShuffleReadMetrics.incFetchWaitTime(v);
}

@Override
public void incRecordsRead(long v) {
tempShuffleReadMetrics.incRecordsRead(v);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,10 @@ class MultiPartitionIterator<K, C> extends AbstractIterator<Product2<K, C>> {
shuffleDependency.serializer(), shuffleReadClient,
readMetrics, rssConf);
CompletionIterator<Product2<K, C>, RssShuffleDataIterator<K, C>> completionIterator =
CompletionIterator$.MODULE$.apply(iterator, () -> iterator.cleanup());
CompletionIterator$.MODULE$.apply(iterator, () -> {
context.taskMetrics().mergeShuffleReadMetrics();
return iterator.cleanup();
});
iterators.add(completionIterator);
}
iterator = iterators.iterator();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.uniffle.test;

import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.functions;
import org.apache.spark.status.AppStatusStore;
import org.apache.spark.status.api.v1.StageData;
import org.junit.jupiter.api.Test;
import scala.collection.Seq;

public class WriteAndReadMetricsTest extends SimpleTestBase {

@Test
public void test() throws Exception {
run();
}

@Override
public Map runTest(SparkSession spark, String fileName) throws Exception {
// take a rest to make sure shuffle server is registered
Thread.sleep(3000);

Dataset<Row> df1 = spark.range(0, 100, 1, 10)
.select(functions.when(functions.col("id").$less$eq(50), 1)
.otherwise(functions.col("id")).as("key1"), functions.col("id").as("value1"));
df1.createOrReplaceTempView("table1");

List list = spark.sql("select count(value1) from table1 group by key1").collectAsList();
Map<String, Long> result = new HashMap<>();
result.put("size", Long.valueOf(list.size()));

for (int stageId : spark.sparkContext().statusTracker().getJobInfo(0).get().stageIds()) {
long writeRecords = getFirstStageData(spark, stageId).shuffleWriteRecords();
long readRecords = getFirstStageData(spark, stageId).shuffleReadRecords();
result.put(stageId + "-write-records", writeRecords);
result.put(stageId + "-read-records", readRecords);
}

return result;
}

private StageData getFirstStageData(SparkSession spark, int stageId)
throws NoSuchMethodException, InvocationTargetException, IllegalAccessException {
AppStatusStore statestore = spark.sparkContext().statusStore();
try {
return ((Seq<StageData>)statestore
.getClass()
.getDeclaredMethod(
"stageData",
int.class,
boolean.class
).invoke(statestore, stageId, false)).toList().head();
} catch (Exception e) {
return ((Seq<StageData>)statestore
.getClass()
.getDeclaredMethod(
"stageData",
int.class,
boolean.class,
List.class,
boolean.class,
double[].class
).invoke(
statestore, stageId, false, new ArrayList<>(), true, new double[]{})).toList().head();
}
}
}

0 comments on commit 79d2f54

Please sign in to comment.