Skip to content
Permalink
Browse files
[Spark] Support convert Arrow data to RowBatch asynchronously in Spar…
…k-Doris-Connector (#3186)

Currently, in the Spark-Doris-Connector, when Spark iteratively obtains each row of data,
it needs to synchronously convert the Arrow format data into the row format required by Spark.
In order to speed up the conversion process, we can add an asynchronous thread in the Connector,
which is responsible for obtaining the Arrow format data from BE and converting it into the row
format required by Spark calculation

In our test environment, Doris cluster used 1 fe and 7 be (32C+128G). When using Spark-Doris-Connector
to query a table containing 67 columns, the original query returned 69 million rows of data
took about 2.5min, but after improvement, it reduced to about 1.6min, which reduced the time by about 30%
  • Loading branch information
Youngwb committed Mar 26, 2020
1 parent d3eebd4 commit cd46034a84ef5e6ef9e359a7d3afb8f1b2535854
Showing 5 changed files with 129 additions and 53 deletions.
@@ -102,6 +102,8 @@ dorisSparkRDD.collect()
| doris.request.tablet.size | Integer.MAX_VALUE | 一个RDD Partition对应的Doris Tablet个数。<br />此数值设置越小,则会生成越多的Partition。<br />从而提升Spark侧的并行度,但同时会对Doris造成更大的压力。 |
| doris.batch.size | 1024 | 一次从BE读取数据的最大行数。<br />增大此数值可减少Spark与Doris之间建立连接的次数。<br />从而减轻网络延迟所带来的的额外时间开销。 |
| doris.exec.mem.limit | 2147483648 | 单个查询的内存限制。默认为 2GB,单位为字节 |
| doris.deserialize.arrow.async | false | 是否支持异步转换Arrow格式到spark-doris-connector迭代所需的RowBatch |
| doris.deserialize.queue.size | 64 | 异步转换Arrow格式的内部处理队列,当doris.deserialize.arrow.async为true时生效 |

### SQL and Dataframe Only

@@ -57,4 +57,10 @@ public interface ConfigurationOptions {
long DORIS_EXEC_MEM_LIMIT_DEFAULT = 2147483648L;

String DORIS_VALUE_READER_CLASS = "doris.value.reader.class";

String DORIS_DESERIALIZE_ARROW_ASYNC = "doris.deserialize.arrow.async";
boolean DORIS_DESERIALIZE_ARROW_ASYNC_DEFAULT = false;

String DORIS_DESERIALIZE_QUEUE_SIZE = "doris.deserialize.queue.size";
int DORIS_DESERIALIZE_QUEUE_SIZE_DEFAULT = 64;
}
@@ -70,7 +70,8 @@ public void put(Object o) {
}
}

private int offsetInOneBatch = 0;
// offset for iterate the rowBatch
private int offsetInRowBatch = 0;
private int rowCountInOneBatch = 0;
private int readRowCount = 0;
private List<Row> rowBatch = new ArrayList<>();
@@ -87,50 +88,40 @@ public RowBatch(TScanBatchResult nextResult, Schema schema) throws DorisExceptio
new ByteArrayInputStream(nextResult.getRows()),
rootAllocator
);
this.offsetInRowBatch = 0;
try {
this.root = arrowStreamReader.getVectorSchemaRoot();
while (arrowStreamReader.loadNextBatch()) {
fieldVectors = root.getFieldVectors();
if (fieldVectors.size() != schema.size()) {
logger.error("Schema size '{}' is not equal to arrow field size '{}'.",
fieldVectors.size(), schema.size());
throw new DorisException("Load Doris data failed, schema size of fetch data is wrong.");
}
if (fieldVectors.size() == 0 || root.getRowCount() == 0) {
logger.debug("One batch in arrow has no data.");
continue;
}
rowCountInOneBatch = root.getRowCount();
// init the rowBatch
for (int i = 0; i < rowCountInOneBatch; ++i) {
rowBatch.add(new Row(fieldVectors.size()));
}
convertArrowToRowBatch();
readRowCount += root.getRowCount();
}
} catch (Exception e) {
logger.error("Read Doris Data failed because: ", e);
close();
throw new DorisException(e.getMessage());
} finally {
close();
}
}

public boolean hasNext() throws DorisException {
if (offsetInOneBatch < rowCountInOneBatch) {
public boolean hasNext() {
if (offsetInRowBatch < readRowCount) {
return true;
}
try {
try {
while (arrowStreamReader.loadNextBatch()) {
fieldVectors = root.getFieldVectors();
readRowCount += root.getRowCount();
if (fieldVectors.size() != schema.size()) {
logger.error("Schema size '{}' is not equal to arrow field size '{}'.",
fieldVectors.size(), schema.size());
throw new DorisException("Load Doris data failed, schema size of fetch data is wrong.");
}
if (fieldVectors.size() == 0 || root.getRowCount() == 0) {
logger.debug("One batch in arrow has no data.");
continue;
}
offsetInOneBatch = 0;
rowCountInOneBatch = root.getRowCount();
// init the rowBatch
for (int i = 0; i < rowCountInOneBatch; ++i) {
rowBatch.add(new Row(fieldVectors.size()));
}
convertArrowToRowBatch();
return true;
}
} catch (IOException e) {
logger.error("Load arrow next batch failed.", e);
throw new DorisException("Cannot load arrow next batch fetching from Doris.");
}
} catch (Exception e) {
close();
throw e;
}
return false;
}

@@ -141,7 +132,7 @@ private void addValueToRow(int rowIndex, Object obj) {
logger.error(errMsg);
throw new NoSuchElementException(errMsg);
}
rowBatch.get(rowIndex).put(obj);
rowBatch.get(readRowCount + rowIndex).put(obj);
}

public void convertArrowToRowBatch() throws DorisException {
@@ -295,11 +286,11 @@ public void convertArrowToRowBatch() throws DorisException {

public List<Object> next() throws DorisException {
if (!hasNext()) {
String errMsg = "Get row offset:" + offsetInOneBatch + " larger than row size: " + rowCountInOneBatch;
String errMsg = "Get row offset:" + offsetInRowBatch + " larger than row size: " + readRowCount;
logger.error(errMsg);
throw new NoSuchElementException(errMsg);
}
return rowBatch.get(offsetInOneBatch++).getCols();
return rowBatch.get(offsetInRowBatch++).getCols();
}

private String typeMismatchMessage(final String sparkType, final Types.MinorType arrowType) {
@@ -19,6 +19,7 @@

public abstract class ErrorMessages {
public static final String PARSE_NUMBER_FAILED_MESSAGE = "Parse '{}' to number failed. Original string is '{}'.";
public static final String PARSE_BOOL_FAILED_MESSAGE = "Parse '{}' to boolean failed. Original string is '{}'.";
public static final String CONNECT_FAILED_MESSAGE = "Connect to doris {} failed.";
public static final String ILLEGAL_ARGUMENT_MESSAGE = "argument '{}' is illegal, value is '{}'.";
public static final String SHOULD_NOT_HAPPEN_MESSAGE = "Should not come here.";
@@ -17,9 +17,11 @@

package org.apache.doris.spark.rdd

import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent._

import scala.collection.JavaConversions._
import scala.util.Try

import org.apache.doris.spark.backend.BackendClient
import org.apache.doris.spark.cfg.ConfigurationOptions._
import org.apache.doris.spark.cfg.Settings
@@ -31,9 +33,10 @@ import org.apache.doris.spark.sql.SchemaUtils
import org.apache.doris.spark.util.ErrorMessages
import org.apache.doris.spark.util.ErrorMessages.SHOULD_NOT_HAPPEN_MESSAGE
import org.apache.doris.thrift.{TScanCloseParams, TScanNextBatchParams, TScanOpenParams, TScanOpenResult}

import org.apache.log4j.Logger

import scala.util.control.Breaks

/**
* read data from Doris BE to array.
* @param partition Doris RDD partition
@@ -44,8 +47,30 @@ class ScalaValueReader(partition: PartitionDefinition, settings: Settings) {

protected val client = new BackendClient(new Routing(partition.getBeAddress), settings)
protected var offset = 0
protected var eos: Boolean = false
protected var eos: AtomicBoolean = new AtomicBoolean(false)
protected var rowBatch: RowBatch = _
// flag indicate if support deserialize Arrow to RowBatch asynchronously
protected var deserializeArrowToRowBatchAsync: Boolean = Try {
settings.getProperty(DORIS_DESERIALIZE_ARROW_ASYNC, DORIS_DESERIALIZE_ARROW_ASYNC_DEFAULT.toString).toBoolean
} getOrElse {
logger.warn(ErrorMessages.PARSE_BOOL_FAILED_MESSAGE, DORIS_DESERIALIZE_ARROW_ASYNC, settings.getProperty(DORIS_DESERIALIZE_ARROW_ASYNC))
DORIS_DESERIALIZE_ARROW_ASYNC_DEFAULT
}

protected var rowBatchBlockingQueue: BlockingQueue[RowBatch] = {
val blockingQueueSize = Try {
settings.getProperty(DORIS_DESERIALIZE_QUEUE_SIZE, DORIS_DESERIALIZE_QUEUE_SIZE_DEFAULT.toString).toInt
} getOrElse {
logger.warn(ErrorMessages.PARSE_NUMBER_FAILED_MESSAGE, DORIS_DESERIALIZE_QUEUE_SIZE, settings.getProperty(DORIS_DESERIALIZE_QUEUE_SIZE))
DORIS_DESERIALIZE_QUEUE_SIZE_DEFAULT
}

var queue: BlockingQueue[RowBatch] = null
if (deserializeArrowToRowBatchAsync) {
queue = new ArrayBlockingQueue(blockingQueueSize)
}
queue
}

private val openParams: TScanOpenParams = {
val params = new TScanOpenParams
@@ -103,28 +128,79 @@ class ScalaValueReader(partition: PartitionDefinition, settings: Settings) {
protected val schema: Schema =
SchemaUtils.convertToSchema(openResult.getSelected_columns)

protected val asyncThread: Thread = new Thread {
override def run {
val nextBatchParams = new TScanNextBatchParams
nextBatchParams.setContext_id(contextId)
while (!eos.get) {
nextBatchParams.setOffset(offset)
val nextResult = client.getNext(nextBatchParams)
eos.set(nextResult.isEos)
if (!eos.get) {
val rowBatch = new RowBatch(nextResult, schema)
offset += rowBatch.getReadRowCount
rowBatch.close
rowBatchBlockingQueue.put(rowBatch)
}
}
}
}

protected val asyncThreadStarted: Boolean = {
var started = false
if (deserializeArrowToRowBatchAsync) {
asyncThread.start
started = true
}
started
}

logger.debug(s"Open scan result is, contextId: $contextId, schema: $schema.")

/**
* read data and cached in rowBatch.
* @return true if hax next value
*/
def hasNext: Boolean = {
if (!eos && (rowBatch == null || !rowBatch.hasNext)) {
if (rowBatch != null) {
offset += rowBatch.getReadRowCount
rowBatch.close
var hasNext = false
if (deserializeArrowToRowBatchAsync && asyncThreadStarted) {
// support deserialize Arrow to RowBatch asynchronously
if (rowBatch == null || !rowBatch.hasNext) {
val loop = new Breaks
loop.breakable {
while (!eos.get || !rowBatchBlockingQueue.isEmpty) {
if (!rowBatchBlockingQueue.isEmpty) {
rowBatch = rowBatchBlockingQueue.take
hasNext = true
loop.break
} else {
// wait for rowBatch put in queue or eos change
Thread.sleep(5)
}
}
}
} else {
hasNext = true
}
val nextBatchParams = new TScanNextBatchParams
nextBatchParams.setContext_id(contextId)
nextBatchParams.setOffset(offset)
val nextResult = client.getNext(nextBatchParams)
eos = nextResult.isEos
if (!eos) {
rowBatch = new RowBatch(nextResult, schema)
} else {
// Arrow data was acquired synchronously during the iterative process
if (!eos.get && (rowBatch == null || !rowBatch.hasNext)) {
if (rowBatch != null) {
offset += rowBatch.getReadRowCount
rowBatch.close
}
val nextBatchParams = new TScanNextBatchParams
nextBatchParams.setContext_id(contextId)
nextBatchParams.setOffset(offset)
val nextResult = client.getNext(nextBatchParams)
eos.set(nextResult.isEos)
if (!eos.get) {
rowBatch = new RowBatch(nextResult, schema)
}
}
hasNext = !eos.get
}
!eos
hasNext
}

/**

0 comments on commit cd46034

Please sign in to comment.