Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

array type support. #75

Merged
merged 4 commits into from
Mar 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion spark-doris-connector/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@
<properties>
<spark.version>3.1.2</spark.version>
<spark.minor.version>3.1</spark.minor.version>
<scala.version>2.12.8</scala.version>
<scala.version>2.12</scala.version>
<libthrift.version>0.13.0</libthrift.version>
<arrow.version>5.0.0</arrow.version>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.apache.arrow.vector.VarBinaryVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.complex.ListVector;
import org.apache.arrow.vector.ipc.ArrowStreamReader;
import org.apache.arrow.vector.types.Types;
import org.apache.doris.spark.exception.DorisException;
Expand Down Expand Up @@ -87,7 +88,7 @@ public RowBatch(TScanBatchResult nextResult, Schema schema) throws DorisExceptio
this.arrowStreamReader = new ArrowStreamReader(
new ByteArrayInputStream(nextResult.getRows()),
rootAllocator
);
);
try {
VectorSchemaRoot root = arrowStreamReader.getVectorSchemaRoot();
while (arrowStreamReader.loadNextBatch()) {
Expand Down Expand Up @@ -275,6 +276,19 @@ public void convertArrowToRowBatch() throws DorisException {
addValueToRow(rowIndex, value);
}
break;
case "ARRAY":
Preconditions.checkArgument(mt.equals(Types.MinorType.LIST),
typeMismatchMessage(currentType, mt));
ListVector listVector = (ListVector) curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
if (listVector.isNull(rowIndex)) {
addValueToRow(rowIndex, null);
continue;
}
String value = listVector.getObject(rowIndex).toString();
addValueToRow(rowIndex, value);
}
break;
default:
String errMsg = "Unsupported type " + schema.get(col).getType();
logger.error(errMsg);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,22 @@

package org.apache.doris.spark.sql

import scala.collection.JavaConverters._

import scala.collection.JavaConversions._
import org.apache.doris.spark.cfg.Settings
import org.apache.doris.spark.exception.DorisException
import org.apache.doris.spark.rest.RestService
import org.apache.doris.spark.rest.models.{Field, Schema}
import org.apache.doris.thrift.TScanColumnDesc
import org.apache.doris.spark.cfg.ConfigurationOptions.DORIS_READ_FIELD
import org.apache.spark.sql.types._

import org.slf4j.LoggerFactory

import scala.collection.mutable

private[spark] object SchemaUtils {
private val logger = LoggerFactory.getLogger(SchemaUtils.getClass.getSimpleName.stripSuffix("$"))

/**
* discover Doris table schema from Doris FE.
*
* @param cfg configuration
* @return Spark Catalyst StructType
*/
Expand All @@ -46,6 +43,7 @@ private[spark] object SchemaUtils {

/**
* discover Doris table schema from Doris FE.
*
* @param cfg configuration
* @return inner schema struct
*/
Expand All @@ -55,35 +53,34 @@ private[spark] object SchemaUtils {

/**
* convert inner schema struct to Spark Catalyst StructType
*
* @param schema inner schema
* @return Spark Catalyst StructType
*/
def convertToStruct(dorisReadFields: String, schema: Schema): StructType = {
var fieldList = new Array[String](schema.size())
val fieldSet = new mutable.HashSet[String]()
var fields = List[StructField]()
if (dorisReadFields != null && dorisReadFields.length > 0) {
fieldList = dorisReadFields.split(",")
for (field <- fieldList) {
fieldSet.add(field)
}
schema.getProperties.asScala.foreach(f =>
if (fieldSet.contains(f.getName)) {
fields :+= DataTypes.createStructField(f.getName, getCatalystType(f.getType, f.getPrecision, f.getScale), true)
})
val fieldList = if (dorisReadFields != null && dorisReadFields.length > 0) {
dorisReadFields.split(",")
} else {
schema.getProperties.asScala.foreach(f =>
fields :+= DataTypes.createStructField(f.getName, getCatalystType(f.getType, f.getPrecision, f.getScale), true)
)
Array.empty[String]
}
DataTypes.createStructType(fields.asJava)
val fields = schema.getProperties
.filter(x => fieldList.contains(x.getName) || fieldList.isEmpty)
.map(f =>
DataTypes.createStructField(
f.getName,
getCatalystType(f.getType, f.getPrecision, f.getScale),
true
)
)
DataTypes.createStructType(fields)
}

/**
* translate Doris Type to Spark Catalyst type
*
* @param dorisType Doris type
* @param precision decimal precision
* @param scale decimal scale
* @param scale decimal scale
* @return Spark Catalyst type
*/
def getCatalystType(dorisType: String, precision: Int, scale: Int): DataType = {
Expand Down Expand Up @@ -112,6 +109,7 @@ private[spark] object SchemaUtils {
case "DECIMAL128I" => DecimalType(precision, scale)
case "TIME" => DataTypes.DoubleType
case "STRING" => DataTypes.StringType
case "ARRAY" => DataTypes.StringType
case "HLL" =>
throw new DorisException("Unsupported type " + dorisType)
case _ =>
Expand All @@ -121,6 +119,7 @@ private[spark] object SchemaUtils {

/**
* convert Doris return schema to inner schema struct.
*
* @param tscanColumnDescs Doris BE return schema
* @return inner schema struct
*/
Expand Down