Skip to content

Commit

Permalink
Add example source reporting partition and order
Browse files Browse the repository at this point in the history
  • Loading branch information
EnricoMi committed Mar 8, 2023
1 parent f525680 commit fa35206
Show file tree
Hide file tree
Showing 8 changed files with 210 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package uk.co.gresearch.spark.source

import org.apache.spark.sql.connector.read.{Batch, SupportsReportPartitioning, partitioning}
import org.apache.spark.sql.connector.read.partitioning.{ClusteredDistribution, Distribution}

trait Reporting extends SupportsReportPartitioning {
this: Batch =>

def partitioned: Boolean
def ordered: Boolean

def outputPartitioning: partitioning.Partitioning =
Partitioning(this.planInputPartitions().length, partitioned)
}

case class Partitioning(partitions: Int, partitioned: Boolean) extends partitioning.Partitioning {
override def numPartitions(): Int = partitions
override def satisfy(distribution: Distribution): Boolean = distribution match {
case c: ClusteredDistribution => partitioned && c.clusteredColumns.contains("id")
case _ => false
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package uk.co.gresearch.spark.source

import org.apache.spark.sql.connector.expressions.{Expression, NamedReference, Transform}
import org.apache.spark.sql.connector.read.{Batch, SupportsReportPartitioning, partitioning}
import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, UnknownPartitioning}
import uk.co.gresearch.spark.source.Reporting.namedReference

trait Reporting extends SupportsReportPartitioning {
this: Batch =>

def partitioned: Boolean
def ordered: Boolean

val partitionKeys: Array[Expression] = Array(namedReference("id"))

def outputPartitioning: partitioning.Partitioning = if (partitioned) {
new KeyGroupedPartitioning(partitionKeys, this.planInputPartitions().length)
} else {
new UnknownPartitioning(this.planInputPartitions().length)
}
}

object Reporting {
def namedReference(columnName: String): Expression =
new Transform {
override def name(): String = "identity"
override def references(): Array[NamedReference] = Array.empty
override def arguments(): Array[Expression] = Array(new NamedReference {
override def fieldNames(): Array[String] = Array(columnName)
})
}
}
86 changes: 86 additions & 0 deletions src/main/scala/uk/co/gresearch/spark/source/DefaultSource.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package uk.co.gresearch.spark.source

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability, TableProvider}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.connector.read
import org.apache.spark.sql.connector.read.{Batch, InputPartition, Scan}
import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.CaseInsensitiveStringMap

import java.sql.Timestamp
import java.util
import scala.collection.JavaConverters._

class DefaultSource() extends TableProvider with DataSourceRegister {
override def shortName(): String = "example"
override def inferPartitioning(options: CaseInsensitiveStringMap): Array[Transform] = Array.empty
override def inferSchema(options: CaseInsensitiveStringMap): StructType = DefaultSource.schema
override def getTable(schema: StructType, partitioning: Array[Transform], properties: util.Map[String, String]): Table =
BatchTable(
properties.getOrDefault("partitioned", "false").toBoolean,
properties.getOrDefault("ordered", "false").toBoolean
)
}

object DefaultSource {
val supportsReportingOrder: Boolean = false
val schema: StructType = StructType(Seq(
StructField("id", IntegerType),
StructField("time", TimestampType),
StructField("value", DoubleType),
))
val ts: Long = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2020-01-01 12:00:00"))
val data: Map[Int, Array[InternalRow]] = Map(
1 -> Array(
InternalRow(1, ts + 1000000, 1.1),
InternalRow(1, ts + 2000000, 1.2),
InternalRow(1, ts + 3000000, 1.3),
InternalRow(3, ts + 1000000, 3.1),
InternalRow(3, ts + 2000000, 3.2)
),
2 -> Array(
InternalRow(2, ts + 1000000, 2.1),
InternalRow(2, ts + 2000000, 2.2),
InternalRow(4, ts + 1000000, 4.1),
InternalRow(4, ts + 2000000, 4.2),
InternalRow(4, ts + 3000000, 4.3)
)
)
val partitions: Int = data.size
}

case class BatchTable(partitioned: Boolean, ordered: Boolean) extends Table with SupportsRead {
override def name(): String = "table"
override def schema(): StructType = DefaultSource.schema
override def capabilities(): util.Set[TableCapability] = Set(TableCapability.BATCH_READ).asJava
override def newScanBuilder(caseInsensitiveStringMap: CaseInsensitiveStringMap): read.ScanBuilder =
new ScanBuilder(partitioned, ordered)
}

class ScanBuilder(partitioned: Boolean, ordered: Boolean) extends read.ScanBuilder {
override def build(): Scan = BatchScan(partitioned, ordered)
}

case class BatchScan(partitioned: Boolean, ordered: Boolean) extends read.Scan with read.Batch with Reporting {
override def readSchema(): StructType = DefaultSource.schema
override def toBatch: Batch = this
override def planInputPartitions(): Array[InputPartition] = DefaultSource.data.keys.map(Partition).toArray
override def createReaderFactory(): read.PartitionReaderFactory = PartitionReaderFactory()
}

case class Partition(id: Int) extends InputPartition

case class PartitionReaderFactory() extends read.PartitionReaderFactory {
override def createReader(partition: InputPartition): read.PartitionReader[InternalRow] = PartitionReader(partition)
}


case class PartitionReader(partition: InputPartition) extends read.PartitionReader[InternalRow] {
val rows: Iterator[InternalRow] = DefaultSource.data.getOrElse(partition.asInstanceOf[Partition].id, Array.empty[InternalRow]).iterator
def next: Boolean = rows.hasNext
def get: InternalRow = rows.next()
def close(): Unit = { }
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ trait SparkTestSession extends SQLHelper {
.master("local[1]")
.appName("spark test example")
.config("spark.sql.shuffle.partitions", 2)
.config("spark.sql.adaptive.coalescePartitions.enabled", value = false)
.config("spark.local.dir", ".")
.getOrCreate()
}
Expand Down
66 changes: 66 additions & 0 deletions src/test/scala/uk/co/gresearch/spark/source/SourceSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package uk.co.gresearch.spark.source

import org.apache.spark.sql.{DataFrame, DataFrameReader}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.exchange.Exchange
import org.apache.spark.sql.execution.{SortExec, SparkPlan}
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.sum
import org.scalatest.funsuite.AnyFunSuite
import uk.co.gresearch.spark.SparkTestSession

class SourceSuite extends AnyFunSuite with SparkTestSession with AdaptiveSparkPlanHelper {
import spark.implicits._

private val source = new DefaultSource().getClass.getPackage.getName
private def df: DataFrameReader = spark.read.format(source)
private val dfpartitioned = df.option("partitioned", value = true)
private val dfpartitionedAndSorted = df.option("partitioned", value = true).option("ordered", value = true)
private val window = Window.partitionBy($"id").orderBy($"time")

test("show") {
df.load().show()
}

test("groupBy without partition information") {
assertPlan(
df.load().groupBy($"id").count(),
{ case e: Exchange => e },
expected = true
)
}

test("groupBy with partition information") {
assertPlan(
dfpartitioned.load().groupBy($"id").count(),
{ case e: Exchange => e },
expected = false
)
}

test("window function without partition information") {
val df = this.df.load().select($"id", $"time", sum($"value").over(window))
assertPlan(df, { case e: Exchange => e }, expected = true)
assertPlan(df, { case s: SortExec => s }, expected = true)
}

test("window function with partition information") {
val df = this.dfpartitioned.load().select($"id", $"time", sum($"value").over(window))
assertPlan(df, { case e: Exchange => e }, expected = false)
assertPlan(df, { case s: SortExec => s }, expected = true)
}

test("window function with partition and order information") {
assertPlan(
dfpartitionedAndSorted.load().select($"id", $"time", sum($"value").over(window)),
{ case e: Exchange => e; case s: SortExec => s },
expected = !DefaultSource.supportsReportingOrder
)
}

def assertPlan[T](df: DataFrame, func: PartialFunction[SparkPlan, T], expected: Boolean): Unit = {
df.explain()
assert(df.rdd.getNumPartitions === DefaultSource.partitions)
assert(collectFirst(df.queryExecution.executedPlan)(func).isDefined === expected)
}
}

0 comments on commit fa35206

Please sign in to comment.