Skip to content

Commit

Permalink
[SPARK-32492][SQL] Fulfill missing column meta information COLUMN_SIZ…
Browse files Browse the repository at this point in the history
…E /DECIMAL_DIGITS/NUM_PREC_RADIX/ORDINAL_POSITION for thriftserver client tools

### What changes were proposed in this pull request?

This PR fulfills some missing fields for SparkGetColumnsOperation including COLUMN_SIZE /DECIMAL_DIGITS/NUM_PREC_RADIX/ORDINAL_POSITION

and improve the test coverage.

### Why are the changes needed?

make jdbc tools happier

### Does this PR introduce _any_ user-facing change?

yes,

#### before
![image](https://user-images.githubusercontent.com/8326978/88911764-e78b2180-d290-11ea-8abb-96f137f9c3c4.png)

#### after

![image](https://user-images.githubusercontent.com/8326978/88911709-d04c3400-d290-11ea-90ab-02bda3e628e9.png)

![image](https://user-images.githubusercontent.com/8326978/88912007-39cc4280-d291-11ea-96d6-1ef3abbbddec.png)

### How was this patch tested?

add unit tests

Closes #29303 from yaooqinn/SPARK-32492.

Authored-by: Kent Yao <yaooqinn@hotmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
yaooqinn authored and cloud-fan committed Aug 3, 2020
1 parent 3deb59d commit 7f5326c
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 8 deletions.
Expand Up @@ -17,12 +17,10 @@

package org.apache.spark.sql.hive.thriftserver

import java.util.UUID
import java.util.regex.Pattern

import scala.collection.JavaConverters.seqAsJavaListConverter

import org.apache.commons.lang3.exception.ExceptionUtils
import org.apache.hadoop.hive.ql.security.authorization.plugin.{HiveOperationType, HivePrivilegeObject}
import org.apache.hadoop.hive.ql.security.authorization.plugin.HivePrivilegeObject.HivePrivilegeObjectType
import org.apache.hive.service.cli._
Expand All @@ -34,7 +32,7 @@ import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.hive.thriftserver.ThriftserverShimUtils.toJavaSQLType
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types._

/**
* Spark's own SparkGetColumnsOperation
Expand Down Expand Up @@ -126,12 +124,52 @@ private[hive] class SparkGetColumnsOperation(
HiveThriftServer2.eventManager.onStatementFinish(statementId)
}

/**
* For boolean, numeric and datetime types, it returns the default size of its catalyst type
* For struct type, when its elements are fixed-size, the summation of all element sizes will be
* returned.
* For array, map, string, and binaries, the column size is variable, return null as unknown.
*/
private def getColumnSize(typ: DataType): Option[Int] = typ match {
case dt @ (BooleanType | _: NumericType | DateType | TimestampType) => Some(dt.defaultSize)
case StructType(fields) =>
val sizeArr = fields.map(f => getColumnSize(f.dataType))
if (sizeArr.contains(None)) {
None
} else {
Some(sizeArr.map(_.get).sum)
}
case other => None
}

/**
* The number of fractional digits for this type.
* Null is returned for data types where this is not applicable.
* For boolean and integrals, the decimal digits is 0
* For floating types, we follow the IEEE Standard for Floating-Point Arithmetic (IEEE 754)
* For timestamp values, we support microseconds
* For decimals, it returns the scale
*/
private def getDecimalDigits(typ: DataType) = typ match {
case BooleanType | _: IntegerType => Some(0)
case FloatType => Some(7)
case DoubleType => Some(15)
case d: DecimalType => Some(d.scale)
case TimestampType => Some(6)
case _ => None
}

private def getNumPrecRadix(typ: DataType): Option[Int] = typ match {
case _: NumericType => Some(10)
case _ => None
}

private def addToRowSet(
columnPattern: Pattern,
dbName: String,
tableName: String,
schema: StructType): Unit = {
schema.foreach { column =>
schema.zipWithIndex.foreach { case (column, pos) =>
if (columnPattern != null && !columnPattern.matcher(column.name).matches()) {
} else {
val rowData = Array[AnyRef](
Expand All @@ -141,17 +179,17 @@ private[hive] class SparkGetColumnsOperation(
column.name, // COLUMN_NAME
toJavaSQLType(column.dataType.sql).asInstanceOf[AnyRef], // DATA_TYPE
column.dataType.sql, // TYPE_NAME
null, // COLUMN_SIZE
getColumnSize(column.dataType).map(_.asInstanceOf[AnyRef]).orNull, // COLUMN_SIZE
null, // BUFFER_LENGTH, unused
null, // DECIMAL_DIGITS
null, // NUM_PREC_RADIX
getDecimalDigits(column.dataType).map(_.asInstanceOf[AnyRef]).orNull, // DECIMAL_DIGITS
getNumPrecRadix(column.dataType).map(_.asInstanceOf[AnyRef]).orNull, // NUM_PREC_RADIX
(if (column.nullable) 1 else 0).asInstanceOf[AnyRef], // NULLABLE
column.getComment().getOrElse(""), // REMARKS
null, // COLUMN_DEF
null, // SQL_DATA_TYPE
null, // SQL_DATETIME_SUB
null, // CHAR_OCTET_LENGTH
null, // ORDINAL_POSITION
pos.asInstanceOf[AnyRef], // ORDINAL_POSITION
"YES", // IS_NULLABLE
null, // SCOPE_CATALOG
null, // SCOPE_SCHEMA
Expand Down
Expand Up @@ -19,9 +19,12 @@ package org.apache.spark.sql.hive.thriftserver

import java.sql.SQLException

import scala.collection.JavaConverters._

import org.apache.hive.service.cli.HiveSQLException

import org.apache.spark.sql.hive.HiveUtils
import org.apache.spark.sql.types._

trait ThriftServerWithSparkContextSuite extends SharedThriftServer {

Expand Down Expand Up @@ -101,6 +104,135 @@ trait ThriftServerWithSparkContextSuite extends SharedThriftServer {
}
}
}

test("check results from get columns operation from thrift server") {
val schemaName = "default"
val tableName = "spark_get_col_operation"
val schema = new StructType()
.add("c0", "boolean", nullable = false, "0")
.add("c1", "tinyint", nullable = true, "1")
.add("c2", "smallint", nullable = false, "2")
.add("c3", "int", nullable = true, "3")
.add("c4", "long", nullable = false, "4")
.add("c5", "float", nullable = true, "5")
.add("c6", "double", nullable = false, "6")
.add("c7", "decimal(38, 20)", nullable = true, "7")
.add("c8", "decimal(10, 2)", nullable = false, "8")
.add("c9", "string", nullable = true, "9")
.add("c10", "array<long>", nullable = false, "10")
.add("c11", "array<string>", nullable = true, "11")
.add("c12", "map<smallint, tinyint>", nullable = false, "12")
.add("c13", "date", nullable = true, "13")
.add("c14", "timestamp", nullable = false, "14")
.add("c15", "struct<X: bigint,Y: double>", nullable = true, "15")
.add("c16", "binary", nullable = false, "16")

val ddl =
s"""
|CREATE TABLE $schemaName.$tableName (
| ${schema.toDDL}
|)
|using parquet""".stripMargin

withCLIServiceClient { client =>
val sessionHandle = client.openSession(user, "")
val confOverlay = new java.util.HashMap[java.lang.String, java.lang.String]
val opHandle = client.executeStatement(sessionHandle, ddl, confOverlay)
var status = client.getOperationStatus(opHandle)
while (!status.getState.isTerminal) {
Thread.sleep(10)
status = client.getOperationStatus(opHandle)
}
val getCol = client.getColumns(sessionHandle, "", schemaName, tableName, null)
val rowSet = client.fetchResults(getCol)
val columns = rowSet.toTRowSet.getColumns

val catalogs = columns.get(0).getStringVal.getValues.asScala
assert(catalogs.forall(_.isEmpty), "catalog name mismatches")

val schemas = columns.get(1).getStringVal.getValues.asScala
assert(schemas.forall(_ == schemaName), "schema name mismatches")

val tableNames = columns.get(2).getStringVal.getValues.asScala
assert(tableNames.forall(_ == tableName), "table name mismatches")

val columnNames = columns.get(3).getStringVal.getValues.asScala
columnNames.zipWithIndex.foreach {
case (v, i) => assert(v === "c" + i, "column name mismatches")
}

val javaTypes = columns.get(4).getI32Val.getValues
import java.sql.Types._
assert(javaTypes.get(0).intValue() === BOOLEAN)
assert(javaTypes.get(1).intValue() === TINYINT)
assert(javaTypes.get(2).intValue() === SMALLINT)
assert(javaTypes.get(3).intValue() === INTEGER)
assert(javaTypes.get(4).intValue() === BIGINT)
assert(javaTypes.get(5).intValue() === FLOAT)
assert(javaTypes.get(6).intValue() === DOUBLE)
assert(javaTypes.get(7).intValue() === DECIMAL)
assert(javaTypes.get(8).intValue() === DECIMAL)
assert(javaTypes.get(9).intValue() === VARCHAR)
assert(javaTypes.get(10).intValue() === ARRAY)
assert(javaTypes.get(11).intValue() === ARRAY)
assert(javaTypes.get(12).intValue() === JAVA_OBJECT)
assert(javaTypes.get(13).intValue() === DATE)
assert(javaTypes.get(14).intValue() === TIMESTAMP)
assert(javaTypes.get(15).intValue() === STRUCT)
assert(javaTypes.get(16).intValue() === BINARY)

val typeNames = columns.get(5).getStringVal.getValues.asScala
typeNames.zip(schema).foreach { case (tName, f) =>
assert(tName === f.dataType.sql)
}

val colSize = columns.get(6).getI32Val.getValues.asScala

colSize.zip(schema).foreach { case (size, f) =>
f.dataType match {
case StringType | BinaryType | _: ArrayType | _: MapType => assert(size === 0)
case o => assert(size === o.defaultSize)
}
}

val decimalDigits = columns.get(8).getI32Val.getValues.asScala
decimalDigits.zip(schema).foreach { case (dd, f) =>
f.dataType match {
case BooleanType | _: IntegerType => assert(dd === 0)
case d: DecimalType => assert(dd === d.scale)
case FloatType => assert(dd === 7)
case DoubleType => assert(dd === 15)
case TimestampType => assert(dd === 6)
case _ => assert(dd === 0) // nulls
}
}

val radixes = columns.get(9).getI32Val.getValues.asScala
radixes.zip(schema).foreach { case (radix, f) =>
f.dataType match {
case _: NumericType => assert(radix === 10)
case _ => assert(radix === 0) // nulls
}
}

val nullables = columns.get(10).getI32Val.getValues.asScala
assert(nullables.forall(_ === 1))

val comments = columns.get(11).getStringVal.getValues.asScala
comments.zip(schema).foreach { case (c, f) => assert(c === f.getComment().get) }

val positions = columns.get(16).getI32Val.getValues.asScala
positions.zipWithIndex.foreach { case (pos, idx) =>
assert(pos === idx, "the client columns disorder")
}

val isNullables = columns.get(17).getStringVal.getValues.asScala
assert(isNullables.forall(_ === "YES"))

val autoIncs = columns.get(22).getStringVal.getValues.asScala
assert(autoIncs.forall(_ === "NO"))
}
}
}


Expand Down

0 comments on commit 7f5326c

Please sign in to comment.