Skip to content

Commit

Permalink
[KYUUBI #310] GetColumns supports DSv2 and keeps its backward compati…
Browse files Browse the repository at this point in the history
…bility
  • Loading branch information
yaooqinn committed Feb 23, 2021
1 parent bacff6c commit d332be5
Show file tree
Hide file tree
Showing 12 changed files with 311 additions and 174 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.kyuubi.engine.spark.operation
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.StructType

import org.apache.kyuubi.engine.spark.shim.SparkShim
import org.apache.kyuubi.engine.spark.shim.SparkCatalogShim
import org.apache.kyuubi.operation.OperationType
import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant.TABLE_CAT
import org.apache.kyuubi.session.Session
Expand All @@ -35,7 +35,7 @@ class GetCatalogs(spark: SparkSession, session: Session)

override protected def runInternal(): Unit = {
try {
iter = SparkShim().getCatalogs(spark).toIterator
iter = SparkCatalogShim().getCatalogs(spark).toIterator
} catch onError()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,10 @@

package org.apache.kyuubi.engine.spark.operation

import java.util.regex.Pattern

import scala.collection.mutable.ArrayBuffer

import org.apache.commons.lang3.StringUtils
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, CalendarIntervalType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, NullType, NumericType, ShortType, StringType, StructField, StructType, TimestampType}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types._

import org.apache.kyuubi.engine.spark.shim.SparkCatalogShim
import org.apache.kyuubi.operation.OperationType
import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant._
import org.apache.kyuubi.session.Session
Expand All @@ -46,95 +42,6 @@ class GetColumns(
s" columnPattern: $columnName]"
}

private def toJavaSQLType(typ: DataType): Int = typ match {
case NullType => java.sql.Types.NULL
case BooleanType => java.sql.Types.BOOLEAN
case ByteType => java.sql.Types.TINYINT
case ShortType => java.sql.Types.SMALLINT
case IntegerType => java.sql.Types.INTEGER
case LongType => java.sql.Types.BIGINT
case FloatType => java.sql.Types.FLOAT
case DoubleType => java.sql.Types.DOUBLE
case StringType => java.sql.Types.VARCHAR
case _: DecimalType => java.sql.Types.DECIMAL
case DateType => java.sql.Types.DATE
case TimestampType => java.sql.Types.TIMESTAMP
case BinaryType => java.sql.Types.BINARY
case _: ArrayType => java.sql.Types.ARRAY
case _: MapType => java.sql.Types.JAVA_OBJECT
case _: StructType => java.sql.Types.STRUCT
case _ => java.sql.Types.OTHER
}

/**
* For boolean, numeric and datetime types, it returns the default size of its catalyst type
* For struct type, when its elements are fixed-size, the summation of all element sizes will be
* returned.
* For array, map, string, and binaries, the column size is variable, return null as unknown.
*/
private def getColumnSize(typ: DataType): Option[Int] = typ match {
case dt @ (BooleanType | _: NumericType | DateType | TimestampType |
CalendarIntervalType | NullType) =>
Some(dt.defaultSize)
case StructType(fields) =>
val sizeArr = fields.map(f => getColumnSize(f.dataType))
if (sizeArr.contains(None)) {
None
} else {
Some(sizeArr.map(_.get).sum)
}
case _ => None
}

/**
* The number of fractional digits for this type.
* Null is returned for data types where this is not applicable.
* For boolean and integrals, the decimal digits is 0
* For floating types, we follow the IEEE Standard for Floating-Point Arithmetic (IEEE 754)
* For timestamp values, we support microseconds
* For decimals, it returns the scale
*/
private def getDecimalDigits(typ: DataType): Option[Int] = typ match {
case BooleanType | _: IntegerType => Some(0)
case FloatType => Some(7)
case DoubleType => Some(15)
case d: DecimalType => Some(d.scale)
case TimestampType => Some(6)
case _ => None
}

private def getNumPrecRadix(typ: DataType): Option[Int] = typ match {
case _: NumericType => Some(10)
case _ => None
}

private def toRow(db: String, table: String, col: StructField, pos: Int): Row = {
Row(
null, // TABLE_CAT
db, // TABLE_SCHEM
table, // TABLE_NAME
col.name, // COLUMN_NAME
toJavaSQLType(col.dataType), // DATA_TYPE
col.dataType.sql, // TYPE_NAME
getColumnSize(col.dataType).orNull, // COLUMN_SIZE
null, // BUFFER_LENGTH
getDecimalDigits(col.dataType).orNull, // DECIMAL_DIGITS
getNumPrecRadix(col.dataType).orNull, // NUM_PREC_RADIX
if (col.nullable) 1 else 0, // NULLABLE
col.getComment().getOrElse(""), // REMARKS
null, // COLUMN_DEF
null, // SQL_DATA_TYPE
null, // SQL_DATETIME_SUB
null, // CHAR_OCTET_LENGTH
pos, // ORDINAL_POSITION
"YES", // IS_NULLABLE
null, // SCOPE_CATALOG
null, // SCOPE_SCHEMA
null, // SCOPE_TABLE
null, // SOURCE_DATA_TYPE
"NO" // IS_AUTO_INCREMENT
)
}
override protected def resultSchema: StructType = {
new StructType()
.add(TABLE_CAT, "string", nullable = true, "Catalog name. NULL if not applicable")
Expand Down Expand Up @@ -178,45 +85,12 @@ class GetColumns(

override protected def runInternal(): Unit = {
try {
val catalog = spark.sessionState.catalog
val schemaPattern = convertSchemaPattern(schemaName)
val schemaPattern = convertSchemaPattern(schemaName, datanucleusFormat = false)
val tablePattern = convertIdentifierPattern(tableName, datanucleusFormat = true)
val columnPattern =
Pattern.compile(convertIdentifierPattern(columnName, datanucleusFormat = false))
val tables: Seq[Row] = catalog.listDatabases(schemaPattern).flatMap { db =>
val identifiers =
catalog.listTables(db, tablePattern, includeLocalTempViews = false)
catalog.getTablesByName(identifiers).flatMap { t =>
t.schema.zipWithIndex
.filter { f => columnPattern.matcher(f._1.name).matches() }
.map { case (f, i) => toRow(t.database, t.identifier.table, f, i)
}
}
}

val gviews = new ArrayBuffer[Row]()
val globalTmpDb = catalog.globalTempViewManager.database
if (StringUtils.isEmpty(schemaName) || schemaName == "*"
|| Pattern.compile(convertSchemaPattern(schemaName, false))
.matcher(globalTmpDb).matches()) {
catalog.globalTempViewManager.listViewNames(tablePattern).foreach { v =>
catalog.globalTempViewManager.get(v).foreach { plan =>
plan.schema.zipWithIndex
.filter { f => columnPattern.matcher(f._1.name).matches() }
.foreach { case (f, i) => gviews += toRow(globalTmpDb, v, f, i) }
}
}
}

val views: Seq[Row] = catalog.listLocalTempViews(tablePattern)
.map(v => (v, catalog.getTempView(v.table).get))
.flatMap { case (v, plan) =>
plan.schema.zipWithIndex
.filter(f => columnPattern.matcher(f._1.name).matches())
.map { case (f, i) => toRow(null, v.table, f, i) }
}

iter = (tables ++ gviews ++ views).toList.iterator
val columnPattern = convertIdentifierPattern(columnName, datanucleusFormat = false)
iter = SparkCatalogShim()
.getColumns(spark, catalogName, schemaPattern, tablePattern, columnPattern)
.toList.iterator
} catch {
onError()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.kyuubi.engine.spark.operation
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.StructType

import org.apache.kyuubi.engine.spark.shim.SparkShim
import org.apache.kyuubi.engine.spark.shim.SparkCatalogShim
import org.apache.kyuubi.operation.OperationType
import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant._
import org.apache.kyuubi.session.Session
Expand All @@ -41,7 +41,7 @@ class GetSchemas(spark: SparkSession, session: Session, catalogName: String, sch
override protected def runInternal(): Unit = {
try {
val schemaPattern = convertSchemaPattern(schema, datanucleusFormat = false)
val rows = SparkShim().getSchemas(spark, catalogName, schemaPattern)
val rows = SparkCatalogShim().getSchemas(spark, catalogName, schemaPattern)
iter = rows.toList.toIterator
} catch onError()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.kyuubi.engine.spark.operation
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types.StructType

import org.apache.kyuubi.engine.spark.shim.SparkShim
import org.apache.kyuubi.engine.spark.shim.SparkCatalogShim
import org.apache.kyuubi.operation.OperationType
import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant._
import org.apache.kyuubi.session.Session
Expand All @@ -33,6 +33,6 @@ class GetTableTypes(spark: SparkSession, session: Session)
}

override protected def runInternal(): Unit = {
iter = SparkShim.sparkTableTypes.map(Row(_)).toList.iterator
iter = SparkCatalogShim.sparkTableTypes.map(Row(_)).toList.iterator
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
package org.apache.kyuubi.engine.spark.operation

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.catalog.CatalogTableType
import org.apache.spark.sql.types.StructType

import org.apache.kyuubi.engine.spark.shim.SparkShim
import org.apache.kyuubi.engine.spark.shim.SparkCatalogShim
import org.apache.kyuubi.operation.OperationType
import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant._
import org.apache.kyuubi.session.Session
Expand Down Expand Up @@ -63,7 +62,7 @@ class GetTables(
try {
val schemaPattern = convertSchemaPattern(schema, datanucleusFormat = false)
val tablePattern = convertIdentifierPattern(tableName, datanucleusFormat = true)
val sparkShim = SparkShim()
val sparkShim = SparkCatalogShim()
val catalogTablesAndViews =
sparkShim.getCatalogTablesOrViews(spark, catalog, schemaPattern, tablePattern, tableTypes)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.sql.SparkSession

import org.apache.kyuubi.KyuubiSQLException
import org.apache.kyuubi.engine.spark.shim.SparkShim
import org.apache.kyuubi.engine.spark.shim.SparkCatalogShim
import org.apache.kyuubi.operation.{Operation, OperationManager}
import org.apache.kyuubi.session.{Session, SessionHandle}

Expand Down Expand Up @@ -91,7 +91,7 @@ class SparkSQLOperationManager private (name: String) extends OperationManager(n
tableTypes: java.util.List[String]): Operation = {
val spark = getSparkSession(session.handle)
val tTypes = if (tableTypes == null || tableTypes.isEmpty) {
SparkShim.sparkTableTypes
SparkCatalogShim.sparkTableTypes
} else {
tableTypes.asScala.toSet
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@ import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.connector.catalog.CatalogPlugin

class Shim_v2_4 extends SparkShim {
class CatalogShim_v2_4 extends SparkCatalogShim {

override def getCatalogs(spark: SparkSession): Seq[Row] = Seq(Row(""))
override def getCatalogs(spark: SparkSession): Seq[Row] = {
Seq(Row(SparkCatalogShim.SESSION_CATALOG))
}

override protected def getCatalog(spark: SparkSession, catalog: String): CatalogPlugin = null

Expand Down Expand Up @@ -86,4 +88,80 @@ class Shim_v2_4 extends SparkShim {
spark.sessionState.catalog.listLocalTempViews(tablePattern)
}
}

override def getColumns(
spark: SparkSession,
catalogName: String,
schemaPattern: String,
tablePattern: String,
columnPattern: String): Seq[Row] = {

val cols1 = getColumnsByCatalog(spark, catalogName, schemaPattern, tablePattern, columnPattern)

val cols2 = getColumnsByGlobalTempViewManager(
spark, catalogName, schemaPattern, tablePattern, columnPattern)

val cols3 = getColumnsByLocalTempViews(spark, tablePattern, columnPattern)

cols1 ++ cols2 ++ cols3
}

protected def getColumnsByCatalog(
spark: SparkSession,
catalogName: String,
schemaPattern: String,
tablePattern: String,
columnPattern: String): Seq[Row] = {
val cp = columnPattern.r.pattern
val catalog = spark.sessionState.catalog

val databases = catalog.listDatabases(schemaPattern)

databases.flatMap { db =>
val identifiers = catalog.listTables(db, tablePattern, includeLocalTempViews = true)
catalog.getTablesByName(identifiers).flatMap { t =>
t.schema.zipWithIndex.filter(f => cp.matcher(f._1.name).matches())
.map { case (f, i) => toColumnResult(catalogName, t.database, t.identifier.table, f, i) }
}
}
}

protected def getColumnsByGlobalTempViewManager(
spark: SparkSession,
catalogName: String,
schemaPattern: String,
tablePattern: String,
columnPattern: String): Seq[Row] = {
val cp = columnPattern.r.pattern
val catalog = spark.sessionState.catalog

getGlobalTempViewManager(spark, schemaPattern).flatMap { globalTmpDb =>
catalog.globalTempViewManager.listViewNames(tablePattern).flatMap { v =>
catalog.globalTempViewManager.get(v).map { plan =>
plan.schema.zipWithIndex.filter(f => cp.matcher(f._1.name).matches())
.map { case (f, i) =>
toColumnResult(SparkCatalogShim.SESSION_CATALOG, globalTmpDb, v, f, i)
}
}
}.flatten
}
}

protected def getColumnsByLocalTempViews(
spark: SparkSession,
tablePattern: String,
columnPattern: String): Seq[Row] = {
val cp = columnPattern.r.pattern
val catalog = spark.sessionState.catalog

catalog.listLocalTempViews(tablePattern)
.map(v => (v, catalog.getTempView(v.table).get))
.flatMap { case (v, plan) =>
plan.schema.zipWithIndex
.filter(f => cp.matcher(f._1.name).matches())
.map { case (f, i) =>
toColumnResult(SparkCatalogShim.SESSION_CATALOG, null, v.table, f, i)
}
}
}
}
Loading

0 comments on commit d332be5

Please sign in to comment.