Skip to content

Commit 74ff5cf

Browse files
jiaoqingbopan3793
authored andcommitted
[KYUUBI #2789] Kyuubi Spark TPC-H Connector - Add tiny scale
### _Why are the changes needed?_ fix #2789 ### _How was this patch tested?_ - [x] Add some test cases that check the changes thoroughly including negative and positive cases if possible - [ ] Add screenshots for manual tests if appropriate - [x] [Run test](https://kyuubi.apache.org/docs/latest/develop_tools/testing.html#running-tests) locally before make a pull request Closes #2791 from jiaoqingbo/kyuubi-2789. Closes #2789 5f05691 [jiaoqingbo] [KYUUBI #2789] Kyuubi Spark TPC-H Connector - Add tiny scale Authored-by: jiaoqingbo <1178404354@qq.com> Signed-off-by: Cheng Pan <chengpan@apache.org>
1 parent a462230 commit 74ff5cf

File tree

7 files changed

+137
-45
lines changed

7 files changed

+137
-45
lines changed

extensions/spark/kyuubi-spark-connector-tpch/src/main/scala/org/apache/kyuubi/spark/connector/tpch/TPCHBatchScan.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,12 @@ import org.apache.spark.sql.connector.read._
3232
import org.apache.spark.sql.types._
3333
import org.apache.spark.unsafe.types.UTF8String
3434

35-
case class TPCHTableChuck(table: String, scale: Int, parallelism: Int, index: Int)
35+
case class TPCHTableChuck(table: String, scale: Double, parallelism: Int, index: Int)
3636
extends InputPartition
3737

3838
class TPCHBatchScan(
3939
@transient table: TpchTable[_],
40-
scale: Int,
40+
scale: Double,
4141
schema: StructType) extends ScanBuilder
4242
with SupportsReportStatistics with Batch with Serializable {
4343

@@ -58,7 +58,8 @@ class TPCHBatchScan(
5858
override def toBatch: Batch = this
5959

6060
override def description: String =
61-
s"Scan TPC-H sf$scale.${table.getTableName}, count: ${_numRows}, parallelism: $parallelism"
61+
s"Scan TPC-H ${TPCHSchemaUtils.dbName(scale)}.${table.getTableName}, " +
62+
s"count: ${_numRows}, parallelism: $parallelism"
6263

6364
override def readSchema: StructType = schema
6465

@@ -81,7 +82,7 @@ class TPCHBatchScan(
8182

8283
class TPCHPartitionReader(
8384
table: String,
84-
scale: Int,
85+
scale: Double,
8586
parallelism: Int,
8687
index: Int,
8788
schema: StructType) extends PartitionReader[InternalRow] {

extensions/spark/kyuubi-spark-connector-tpch/src/main/scala/org/apache/kyuubi/spark/connector/tpch/TPCHCatalog.scala

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ import java.util
2121

2222
import scala.collection.JavaConverters._
2323

24-
import io.trino.tpch.TpchTable
2524
import org.apache.spark.sql.catalyst.analysis.{NoSuchNamespaceException, NoSuchTableException}
2625
import org.apache.spark.sql.connector.catalog.{Identifier, NamespaceChange, SupportsNamespaces, Table => SparkTable, TableCatalog, TableChange}
2726
import org.apache.spark.sql.connector.expressions.Transform
@@ -30,12 +29,9 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
3029

3130
class TPCHCatalog extends TableCatalog with SupportsNamespaces {
3231

33-
val tables: Array[String] = TpchTable.getTables.asScala
34-
.map(_.getTableName).toArray
32+
val databases: Array[String] = TPCHSchemaUtils.DATABASES
3533

36-
val scales: Array[Int] = TPCHStatisticsUtils.SCALES
37-
38-
val databases: Array[String] = scales.map("sf" + _)
34+
val tables: Array[String] = TPCHSchemaUtils.BASE_TABLES.map(_.getTableName)
3935

4036
var options: CaseInsensitiveStringMap = _
4137

@@ -55,7 +51,8 @@ class TPCHCatalog extends TableCatalog with SupportsNamespaces {
5551

5652
override def loadTable(ident: Identifier): SparkTable = (ident.namespace, ident.name) match {
5753
case (Array(db), table) if (databases contains db) && tables.contains(table.toLowerCase) =>
58-
new TPCHTable(table.toLowerCase, scales(databases indexOf db), options)
54+
val scale = TPCHSchemaUtils.scale(db)
55+
new TPCHTable(table.toLowerCase, scale, options)
5956
case (_, _) => throw new NoSuchTableException(ident)
6057
}
6158

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.kyuubi.spark.connector.tpch
19+
20+
import java.text.DecimalFormat
21+
22+
import scala.collection.JavaConverters._
23+
24+
import io.trino.tpch.TpchTable
25+
26+
object TPCHSchemaUtils {
27+
28+
val TINY_SCALE = "0.01"
29+
30+
val SCALES: Array[String] =
31+
Array(
32+
"0",
33+
TINY_SCALE,
34+
"1",
35+
"10",
36+
"30",
37+
"100",
38+
"300",
39+
"1000",
40+
"3000",
41+
"10000",
42+
"30000",
43+
"100000")
44+
45+
val TINY_DB_NAME = "tiny"
46+
47+
val DATABASES: Array[String] = SCALES.map {
48+
case TINY_SCALE => TINY_DB_NAME
49+
case scale => s"sf$scale"
50+
}
51+
52+
def normalize(scale: Double): String = new DecimalFormat("#.##").format(scale)
53+
54+
def scale(dbName: String): Double = SCALES(DATABASES.indexOf(dbName)).toDouble
55+
56+
def dbName(scale: Double): String = DATABASES(SCALES.indexOf(normalize(scale)))
57+
58+
val BASE_TABLES: Array[TpchTable[_]] = TpchTable.getTables.asScala.toArray
59+
60+
}

extensions/spark/kyuubi-spark-connector-tpch/src/main/scala/org/apache/kyuubi/spark/connector/tpch/TPCHStatisticsUtils.scala

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,37 +20,39 @@ package org.apache.kyuubi.spark.connector.tpch
2020
import io.trino.tpch.TpchTable
2121
import io.trino.tpch.TpchTable._
2222

23+
import org.apache.kyuubi.spark.connector.tpch.TPCHSchemaUtils.{normalize, SCALES}
24+
2325
// https://www.tpc.org/tpc_documents_current_versions/pdf/tpc-h_v3.0.0.pdf
2426
// Page 88 Table 3: Estimated Database Size
2527
object TPCHStatisticsUtils {
2628

27-
val SCALES: Array[Int] = Array(0, 1, 10, 30, 100, 300, 1000, 3000, 10000, 30000, 100000)
28-
29-
def numRows(table: TpchTable[_], scale: Int): Long = {
30-
require(SCALES.contains(scale), s"Unsupported scale $scale")
31-
(table, scale) match {
32-
case (_, 0) => 0L
33-
case (CUSTOMER, scale) => 150000L * scale
34-
case (ORDERS, scale) => 1500000L * scale
35-
case (LINE_ITEM, 1) => 6001215L
36-
case (LINE_ITEM, 10) => 59986052L
37-
case (LINE_ITEM, 30) => 179998372L
38-
case (LINE_ITEM, 100) => 600037902L
39-
case (LINE_ITEM, 300) => 1799989091L
40-
case (LINE_ITEM, 1000) => 5999989709L
41-
case (LINE_ITEM, 3000) => 18000048306L
42-
case (LINE_ITEM, 10000) => 59999994267L
43-
case (LINE_ITEM, 30000) => 179999978268L
44-
case (LINE_ITEM, 100000) => 599999969200L
45-
case (PART, scale) => 200000L * scale
46-
case (PART_SUPPLIER, scale) => 800000L * scale
47-
case (SUPPLIER, scale) => 10000L * scale
29+
def numRows(table: TpchTable[_], scale: Double): Long = {
30+
val nScale = normalize(scale)
31+
require(SCALES.contains(nScale), s"Unsupported scale $nScale")
32+
(table, nScale) match {
33+
case (_, "0") => 0L
34+
case (CUSTOMER, nScale) => (150000L * nScale.toDouble).toLong
35+
case (ORDERS, nScale) => (1500000L * nScale.toDouble).toLong
36+
case (LINE_ITEM, "0.01") => 60175L
37+
case (LINE_ITEM, "1") => 6001215L
38+
case (LINE_ITEM, "10") => 59986052L
39+
case (LINE_ITEM, "30") => 179998372L
40+
case (LINE_ITEM, "100") => 600037902L
41+
case (LINE_ITEM, "300") => 1799989091L
42+
case (LINE_ITEM, "1000") => 5999989709L
43+
case (LINE_ITEM, "3000") => 18000048306L
44+
case (LINE_ITEM, "10000") => 59999994267L
45+
case (LINE_ITEM, "30000") => 179999978268L
46+
case (LINE_ITEM, "100000") => 599999969200L
47+
case (PART, nScale) => (200000L * nScale.toDouble).toLong
48+
case (PART_SUPPLIER, nScale) => (800000L * nScale.toDouble).toLong
49+
case (SUPPLIER, nScale) => (10000L * nScale.toDouble).toLong
4850
case (NATION, _) => 25L
4951
case (REGION, _) => 5L
5052
}
5153
}
5254

53-
def sizeInBytes(table: TpchTable[_], scale: Int): Long =
55+
def sizeInBytes(table: TpchTable[_], scale: Double): Long =
5456
numRows(table, scale) * TABLE_AVG_ROW_BYTES(table)
5557

5658
private val TABLE_AVG_ROW_BYTES: Map[TpchTable[_], Long] = Map(

extensions/spark/kyuubi-spark-connector-tpch/src/main/scala/org/apache/kyuubi/spark/connector/tpch/TPCHTable.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,15 @@ import org.apache.spark.sql.connector.read.ScanBuilder
3030
import org.apache.spark.sql.types._
3131
import org.apache.spark.sql.util.CaseInsensitiveStringMap
3232

33-
class TPCHTable(tbl: String, scale: Int, options: CaseInsensitiveStringMap)
33+
class TPCHTable(tbl: String, scale: Double, options: CaseInsensitiveStringMap)
3434
extends SparkTable with SupportsRead {
3535

3636
// When true, use CHAR VARCHAR; otherwise use STRING
3737
val useAnsiStringType: Boolean = options.getBoolean("useAnsiStringType", false)
3838

3939
val tpchTable: TpchTable[_] = TpchTable.getTable(tbl)
4040

41-
override def name: String = s"sf$scale.$tbl"
41+
override def name: String = s"${TPCHSchemaUtils.dbName(scale)}.$tbl"
4242

4343
override def toString: String = s"TPCHTable($name)"
4444

extensions/spark/kyuubi-spark-connector-tpch/src/test/scala/org/apache/kyuubi/spark/connector/tpch/TPCHCatalogSuite.scala

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,19 +45,19 @@ class TPCHCatalogSuite extends KyuubiFunSuite {
4545

4646
test("supports namespaces") {
4747
spark.sql("use tpch")
48-
assert(spark.sql(s"SHOW DATABASES").collect().length == 11)
48+
assert(spark.sql(s"SHOW DATABASES").collect().length == 12)
4949
assert(spark.sql(s"SHOW NAMESPACES IN tpch.sf1").collect().length == 0)
5050
}
5151

52-
test("tpch.sf1 count") {
53-
assert(spark.table("tpch.sf1.customer").count === 150000)
54-
assert(spark.table("tpch.sf1.orders").count === 1500000)
55-
assert(spark.table("tpch.sf1.lineitem").count === 6001215)
56-
assert(spark.table("tpch.sf1.part").count === 200000)
57-
assert(spark.table("tpch.sf1.partsupp").count === 800000)
58-
assert(spark.table("tpch.sf1.supplier").count === 10000)
59-
assert(spark.table("tpch.sf1.nation").count === 25)
60-
assert(spark.table("tpch.sf1.region").count === 5)
52+
test("tpch.tiny count") {
53+
assert(spark.table("tpch.tiny.customer").count === 1500)
54+
assert(spark.table("tpch.tiny.orders").count === 15000)
55+
assert(spark.table("tpch.tiny.lineitem").count === 60175)
56+
assert(spark.table("tpch.tiny.part").count === 2000)
57+
assert(spark.table("tpch.tiny.partsupp").count === 8000)
58+
assert(spark.table("tpch.tiny.supplier").count === 100)
59+
assert(spark.table("tpch.tiny.nation").count === 25)
60+
assert(spark.table("tpch.tiny.region").count === 5)
6161
}
6262

6363
test("tpch.sf1 stats") {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.kyuubi.spark.connector.tpch
19+
20+
import org.apache.kyuubi.KyuubiFunSuite
21+
import org.apache.kyuubi.spark.connector.tpch.TPCHSchemaUtils.normalize
22+
23+
class TPCHSchemaUtilsSuite extends KyuubiFunSuite {
24+
25+
test("normalize scale") {
26+
assert(normalize(1) === "1")
27+
assert(normalize(0.010000000000000000001) === "0.01")
28+
assert(normalize(1.000000000000000000001) === "1")
29+
assert(normalize(0.999999999999999999999) === "1")
30+
assert(normalize(9.999999999999999999999) === "10")
31+
}
32+
}

0 commit comments

Comments
 (0)