-
Notifications
You must be signed in to change notification settings - Fork 28
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add example source reporting partition and order
- Loading branch information
Showing
8 changed files
with
210 additions
and
0 deletions.
There are no files selected for viewing
22 changes: 22 additions & 0 deletions
22
src/main/scala-spark-3.0/uk/co/gresearch/spark/source/Reporting.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
package uk.co.gresearch.spark.source | ||
|
||
import org.apache.spark.sql.connector.read.{Batch, SupportsReportPartitioning, partitioning} | ||
import org.apache.spark.sql.connector.read.partitioning.{ClusteredDistribution, Distribution} | ||
|
||
trait Reporting extends SupportsReportPartitioning { | ||
this: Batch => | ||
|
||
def partitioned: Boolean | ||
def ordered: Boolean | ||
|
||
def outputPartitioning: partitioning.Partitioning = | ||
Partitioning(this.planInputPartitions().length, partitioned) | ||
} | ||
|
||
case class Partitioning(partitions: Int, partitioned: Boolean) extends partitioning.Partitioning { | ||
override def numPartitions(): Int = partitions | ||
override def satisfy(distribution: Distribution): Boolean = distribution match { | ||
case c: ClusteredDistribution => partitioned && c.clusteredColumns.contains("id") | ||
case _ => false | ||
} | ||
} |
1 change: 1 addition & 0 deletions
1
src/main/scala-spark-3.1/uk/co/gresearch/spark/source/Reporting.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../../../../../../scala-spark-3.0/uk/co/gresearch/spark/source/Reporting.scala |
1 change: 1 addition & 0 deletions
1
src/main/scala-spark-3.2/uk/co/gresearch/spark/source/Reporting.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../../../../../../scala-spark-3.0/uk/co/gresearch/spark/source/Reporting.scala |
32 changes: 32 additions & 0 deletions
32
src/main/scala-spark-3.3/uk/co/gresearch/spark/source/Reporting.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
package uk.co.gresearch.spark.source | ||
|
||
import org.apache.spark.sql.connector.expressions.{Expression, NamedReference, Transform} | ||
import org.apache.spark.sql.connector.read.{Batch, SupportsReportPartitioning, partitioning} | ||
import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, UnknownPartitioning} | ||
import uk.co.gresearch.spark.source.Reporting.namedReference | ||
|
||
trait Reporting extends SupportsReportPartitioning { | ||
this: Batch => | ||
|
||
def partitioned: Boolean | ||
def ordered: Boolean | ||
|
||
val partitionKeys: Array[Expression] = Array(namedReference("id")) | ||
|
||
def outputPartitioning: partitioning.Partitioning = if (partitioned) { | ||
new KeyGroupedPartitioning(partitionKeys, this.planInputPartitions().length) | ||
} else { | ||
new UnknownPartitioning(this.planInputPartitions().length) | ||
} | ||
} | ||
|
||
object Reporting { | ||
def namedReference(columnName: String): Expression = | ||
new Transform { | ||
override def name(): String = "identity" | ||
override def references(): Array[NamedReference] = Array.empty | ||
override def arguments(): Array[Expression] = Array(new NamedReference { | ||
override def fieldNames(): Array[String] = Array(columnName) | ||
}) | ||
} | ||
} |
1 change: 1 addition & 0 deletions
1
src/main/scala-spark-3.4/uk/co/gresearch/spark/source/Reporting.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../../../../../../scala-spark-3.3/uk/co/gresearch/spark/source/Reporting.scala |
86 changes: 86 additions & 0 deletions
86
src/main/scala/uk/co/gresearch/spark/source/DefaultSource.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
package uk.co.gresearch.spark.source | ||
|
||
import org.apache.spark.sql.catalyst.InternalRow | ||
import org.apache.spark.sql.catalyst.util.DateTimeUtils | ||
import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability, TableProvider} | ||
import org.apache.spark.sql.connector.expressions.Transform | ||
import org.apache.spark.sql.connector.read | ||
import org.apache.spark.sql.connector.read.{Batch, InputPartition, Scan} | ||
import org.apache.spark.sql.sources.DataSourceRegister | ||
import org.apache.spark.sql.types._ | ||
import org.apache.spark.sql.util.CaseInsensitiveStringMap | ||
|
||
import java.sql.Timestamp | ||
import java.util | ||
import scala.collection.JavaConverters._ | ||
|
||
class DefaultSource() extends TableProvider with DataSourceRegister { | ||
override def shortName(): String = "example" | ||
override def inferPartitioning(options: CaseInsensitiveStringMap): Array[Transform] = Array.empty | ||
override def inferSchema(options: CaseInsensitiveStringMap): StructType = DefaultSource.schema | ||
override def getTable(schema: StructType, partitioning: Array[Transform], properties: util.Map[String, String]): Table = | ||
BatchTable( | ||
properties.getOrDefault("partitioned", "false").toBoolean, | ||
properties.getOrDefault("ordered", "false").toBoolean | ||
) | ||
} | ||
|
||
object DefaultSource { | ||
val supportsReportingOrder: Boolean = false | ||
val schema: StructType = StructType(Seq( | ||
StructField("id", IntegerType), | ||
StructField("time", TimestampType), | ||
StructField("value", DoubleType), | ||
)) | ||
val ts: Long = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2020-01-01 12:00:00")) | ||
val data: Map[Int, Array[InternalRow]] = Map( | ||
1 -> Array( | ||
InternalRow(1, ts + 1000000, 1.1), | ||
InternalRow(1, ts + 2000000, 1.2), | ||
InternalRow(1, ts + 3000000, 1.3), | ||
InternalRow(3, ts + 1000000, 3.1), | ||
InternalRow(3, ts + 2000000, 3.2) | ||
), | ||
2 -> Array( | ||
InternalRow(2, ts + 1000000, 2.1), | ||
InternalRow(2, ts + 2000000, 2.2), | ||
InternalRow(4, ts + 1000000, 4.1), | ||
InternalRow(4, ts + 2000000, 4.2), | ||
InternalRow(4, ts + 3000000, 4.3) | ||
) | ||
) | ||
val partitions: Int = data.size | ||
} | ||
|
||
case class BatchTable(partitioned: Boolean, ordered: Boolean) extends Table with SupportsRead { | ||
override def name(): String = "table" | ||
override def schema(): StructType = DefaultSource.schema | ||
override def capabilities(): util.Set[TableCapability] = Set(TableCapability.BATCH_READ).asJava | ||
override def newScanBuilder(caseInsensitiveStringMap: CaseInsensitiveStringMap): read.ScanBuilder = | ||
new ScanBuilder(partitioned, ordered) | ||
} | ||
|
||
class ScanBuilder(partitioned: Boolean, ordered: Boolean) extends read.ScanBuilder { | ||
override def build(): Scan = BatchScan(partitioned, ordered) | ||
} | ||
|
||
case class BatchScan(partitioned: Boolean, ordered: Boolean) extends read.Scan with read.Batch with Reporting { | ||
override def readSchema(): StructType = DefaultSource.schema | ||
override def toBatch: Batch = this | ||
override def planInputPartitions(): Array[InputPartition] = DefaultSource.data.keys.map(Partition).toArray | ||
override def createReaderFactory(): read.PartitionReaderFactory = PartitionReaderFactory() | ||
} | ||
|
||
case class Partition(id: Int) extends InputPartition | ||
|
||
case class PartitionReaderFactory() extends read.PartitionReaderFactory { | ||
override def createReader(partition: InputPartition): read.PartitionReader[InternalRow] = PartitionReader(partition) | ||
} | ||
|
||
|
||
case class PartitionReader(partition: InputPartition) extends read.PartitionReader[InternalRow] { | ||
val rows: Iterator[InternalRow] = DefaultSource.data.getOrElse(partition.asInstanceOf[Partition].id, Array.empty[InternalRow]).iterator | ||
def next: Boolean = rows.hasNext | ||
def get: InternalRow = rows.next() | ||
def close(): Unit = { } | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
66 changes: 66 additions & 0 deletions
66
src/test/scala/uk/co/gresearch/spark/source/SourceSuite.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
package uk.co.gresearch.spark.source | ||
|
||
import org.apache.spark.sql.{DataFrame, DataFrameReader} | ||
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper | ||
import org.apache.spark.sql.execution.exchange.Exchange | ||
import org.apache.spark.sql.execution.{SortExec, SparkPlan} | ||
import org.apache.spark.sql.expressions.Window | ||
import org.apache.spark.sql.functions.sum | ||
import org.scalatest.funsuite.AnyFunSuite | ||
import uk.co.gresearch.spark.SparkTestSession | ||
|
||
class SourceSuite extends AnyFunSuite with SparkTestSession with AdaptiveSparkPlanHelper { | ||
import spark.implicits._ | ||
|
||
private val source = new DefaultSource().getClass.getPackage.getName | ||
private def df: DataFrameReader = spark.read.format(source) | ||
private val dfpartitioned = df.option("partitioned", value = true) | ||
private val dfpartitionedAndSorted = df.option("partitioned", value = true).option("ordered", value = true) | ||
private val window = Window.partitionBy($"id").orderBy($"time") | ||
|
||
test("show") { | ||
df.load().show() | ||
} | ||
|
||
test("groupBy without partition information") { | ||
assertPlan( | ||
df.load().groupBy($"id").count(), | ||
{ case e: Exchange => e }, | ||
expected = true | ||
) | ||
} | ||
|
||
test("groupBy with partition information") { | ||
assertPlan( | ||
dfpartitioned.load().groupBy($"id").count(), | ||
{ case e: Exchange => e }, | ||
expected = false | ||
) | ||
} | ||
|
||
test("window function without partition information") { | ||
val df = this.df.load().select($"id", $"time", sum($"value").over(window)) | ||
assertPlan(df, { case e: Exchange => e }, expected = true) | ||
assertPlan(df, { case s: SortExec => s }, expected = true) | ||
} | ||
|
||
test("window function with partition information") { | ||
val df = this.dfpartitioned.load().select($"id", $"time", sum($"value").over(window)) | ||
assertPlan(df, { case e: Exchange => e }, expected = false) | ||
assertPlan(df, { case s: SortExec => s }, expected = true) | ||
} | ||
|
||
test("window function with partition and order information") { | ||
assertPlan( | ||
dfpartitionedAndSorted.load().select($"id", $"time", sum($"value").over(window)), | ||
{ case e: Exchange => e; case s: SortExec => s }, | ||
expected = !DefaultSource.supportsReportingOrder | ||
) | ||
} | ||
|
||
def assertPlan[T](df: DataFrame, func: PartialFunction[SparkPlan, T], expected: Boolean): Unit = { | ||
df.explain() | ||
assert(df.rdd.getNumPartitions === DefaultSource.partitions) | ||
assert(collectFirst(df.queryExecution.executedPlan)(func).isDefined === expected) | ||
} | ||
} |