Skip to content

Commit

Permalink
[SPARK-5602][SQL] Better support for creating DataFrame from local da…
Browse files Browse the repository at this point in the history
…ta collection

1. Added methods to create DataFrames from Seq[Product]
2. Added executeTake to avoid running a Spark job on LocalRelations.

Author: Reynold Xin <rxin@databricks.com>

Closes #4372 from rxin/localDataFrame and squashes the following commits:

f696858 [Reynold Xin] style checker.
839ef7f [Reynold Xin] [SPARK-5602][SQL] Better support for creating DataFrame from local data collection.
  • Loading branch information
rxin committed Feb 5, 2015
1 parent 206f9bc commit 84acd08
Show file tree
Hide file tree
Showing 10 changed files with 170 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ trait ScalaReflection {
*/
def asRelation: LocalRelation = {
val output = attributesFor[A]
LocalRelation(output, data)
LocalRelation.fromProduct(output, data)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,31 +17,34 @@

package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.types.{StructType, StructField}
import org.apache.spark.sql.types.{DataTypeConversions, StructType, StructField}

object LocalRelation {
def apply(output: Attribute*): LocalRelation = new LocalRelation(output)

def apply(output1: StructField, output: StructField*): LocalRelation = new LocalRelation(
StructType(output1 +: output).toAttributes
)
def apply(output1: StructField, output: StructField*): LocalRelation = {
new LocalRelation(StructType(output1 +: output).toAttributes)
}

def fromProduct(output: Seq[Attribute], data: Seq[Product]): LocalRelation = {
val schema = StructType.fromAttributes(output)
LocalRelation(output, data.map(row => DataTypeConversions.productToRow(row, schema)))
}
}

case class LocalRelation(output: Seq[Attribute], data: Seq[Product] = Nil)
case class LocalRelation(output: Seq[Attribute], data: Seq[Row] = Nil)
extends LeafNode with analysis.MultiInstanceRelation {

// TODO: Validate schema compliance.
def loadData(newData: Seq[Product]) = new LocalRelation(output, data ++ newData)

/**
* Returns an identical copy of this relation with new exprIds for all attributes. Different
* attributes are required when a relation is going to be included multiple times in the same
* query.
*/
override final def newInstance: this.type = {
LocalRelation(output.map(_.newInstance), data).asInstanceOf[this.type]
override final def newInstance(): this.type = {
LocalRelation(output.map(_.newInstance()), data).asInstanceOf[this.type]
}

override protected def stringArgs = Iterator(output)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,27 @@ package org.apache.spark.sql.types

import java.text.SimpleDateFormat

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow


protected[sql] object DataTypeConversions {

def productToRow(product: Product, schema: StructType): Row = {
val mutableRow = new GenericMutableRow(product.productArity)
val schemaFields = schema.fields.toArray

var i = 0
while (i < mutableRow.length) {
mutableRow(i) =
ScalaReflection.convertToCatalyst(product.productElement(i), schemaFields(i).dataType)
i += 1
}

mutableRow
}

def stringToTime(s: String): java.util.Date = {
if (!s.contains('T')) {
// JDBC escape string
Expand Down
41 changes: 38 additions & 3 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.execution._
import org.apache.spark.sql.json._
Expand Down Expand Up @@ -163,17 +163,52 @@ class SQLContext(@transient val sparkContext: SparkContext)
/** Removes the specified table from the in-memory cache. */
def uncacheTable(tableName: String): Unit = cacheManager.uncacheTable(tableName)

// scalastyle:off
// Disable style checker so "implicits" object can start with lowercase i
/**
* Implicit methods available in Scala for converting common Scala objects into [[DataFrame]]s.
*/
object implicits {
// scalastyle:on
/**
* Creates a DataFrame from an RDD of case classes.
*
* @group userf
*/
implicit def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = {
self.createDataFrame(rdd)
}

/**
* Creates a DataFrame from a local Seq of Product.
*/
implicit def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = {
self.createDataFrame(data)
}
}

/**
* Creates a DataFrame from an RDD of case classes.
*
* @group userf
*/
implicit def createDataFrame[A <: Product: TypeTag](rdd: RDD[A]): DataFrame = {
// TODO: Remove implicit here.
implicit def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = {
SparkPlan.currentContext.set(self)
val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
val attributeSeq = schema.toAttributes
val rowRDD = RDDConversions.productToRowRdd(rdd, schema)
DataFrame(this, LogicalRDD(attributeSeq, rowRDD)(self))
DataFrame(self, LogicalRDD(attributeSeq, rowRDD)(self))
}

/**
* Creates a DataFrame from a local Seq of Product.
*/
def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = {
SparkPlan.currentContext.set(self)
val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
val attributeSeq = schema.toAttributes
DataFrame(self, LocalRelation.fromProduct(attributeSeq, data))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,13 @@ object RDDConversions {
}
}

/** Logical plan node for scanning data from an RDD. */
case class LogicalRDD(output: Seq[Attribute], rdd: RDD[Row])(sqlContext: SQLContext)
extends LogicalPlan with MultiInstanceRelation {

def children = Nil
override def children = Nil

def newInstance() =
override def newInstance() =
LogicalRDD(output.map(_.newInstance()), rdd)(sqlContext).asInstanceOf[this.type]

override def sameResult(plan: LogicalPlan) = plan match {
Expand All @@ -74,39 +75,28 @@ case class LogicalRDD(output: Seq[Attribute], rdd: RDD[Row])(sqlContext: SQLCont
)
}

/** Physical plan node for scanning data from an RDD. */
case class PhysicalRDD(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode {
override def execute() = rdd
}

@deprecated("Use LogicalRDD", "1.2.0")
case class ExistingRdd(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode {
override def execute() = rdd
}

@deprecated("Use LogicalRDD", "1.2.0")
case class SparkLogicalPlan(alreadyPlanned: SparkPlan)(@transient sqlContext: SQLContext)
extends LogicalPlan with MultiInstanceRelation {
/** Logical plan node for scanning data from a local collection. */
case class LogicalLocalTable(output: Seq[Attribute], rows: Seq[Row])(sqlContext: SQLContext)
extends LogicalPlan with MultiInstanceRelation {

def output = alreadyPlanned.output
override def children = Nil

override final def newInstance(): this.type = {
SparkLogicalPlan(
alreadyPlanned match {
case ExistingRdd(output, rdd) => ExistingRdd(output.map(_.newInstance()), rdd)
case _ => sys.error("Multiple instance of the same relation detected.")
})(sqlContext).asInstanceOf[this.type]
}
override def newInstance() =
LogicalLocalTable(output.map(_.newInstance()), rows)(sqlContext).asInstanceOf[this.type]

override def sameResult(plan: LogicalPlan) = plan match {
case SparkLogicalPlan(ExistingRdd(_, rdd)) =>
rdd.id == alreadyPlanned.asInstanceOf[ExistingRdd].rdd.id
case LogicalRDD(_, otherRDD) => rows == rows
case _ => false
}

@transient override lazy val statistics = Statistics(
// TODO: Instead of returning a default value here, find a way to return a meaningful size
// estimate for RDDs. See PR 1238 for more discussions.
sizeInBytes = BigInt(sqlContext.conf.defaultSizeInBytes)
// TODO: Improve the statistics estimation.
// This is made small enough so it can be broadcasted.
sizeInBytes = sqlContext.conf.autoBroadcastJoinThreshold - 1
)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.Attribute


/**
* Physical plan node for scanning data from a local collection.
*/
case class LocalTableScan(output: Seq[Attribute], rows: Seq[Row]) extends LeafNode {

private lazy val rdd = sqlContext.sparkContext.parallelize(rows)

override def execute() = rdd

override def executeCollect() = rows.toArray

override def executeTake(limit: Int) = rows.take(limit).toArray
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.physical._

import scala.collection.mutable.ArrayBuffer

object SparkPlan {
protected[sql] val currentContext = new ThreadLocal[SQLContext]()
}
Expand Down Expand Up @@ -77,8 +79,53 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
/**
* Runs this query returning the result as an array.
*/
def executeCollect(): Array[Row] =
def executeCollect(): Array[Row] = {
execute().map(ScalaReflection.convertRowToScala(_, schema)).collect()
}

/**
* Runs this query returning the first `n` rows as an array.
*
* This is modeled after RDD.take but never runs any job locally on the driver.
*/
def executeTake(n: Int): Array[Row] = {
if (n == 0) {
return new Array[Row](0)
}

val childRDD = execute().map(_.copy())

val buf = new ArrayBuffer[Row]
val totalParts = childRDD.partitions.length
var partsScanned = 0
while (buf.size < n && partsScanned < totalParts) {
// The number of partitions to try in this iteration. It is ok for this number to be
// greater than totalParts because we actually cap it at totalParts in runJob.
var numPartsToTry = 1
if (partsScanned > 0) {
// If we didn't find any rows after the first iteration, just try all partitions next.
// Otherwise, interpolate the number of partitions we need to try, but overestimate it
// by 50%.
if (buf.size == 0) {
numPartsToTry = totalParts - 1
} else {
numPartsToTry = (1.5 * n * partsScanned / buf.size).toInt
}
}
numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions

val left = n - buf.size
val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
val sc = sqlContext.sparkContext
val res =
sc.runJob(childRDD, (it: Iterator[Row]) => it.take(left).toArray, p, allowLocal = false)

res.foreach(buf ++= _.take(n - buf.size))
partsScanned += numPartsToTry
}

buf.toArray.map(ScalaReflection.convertRowToScala(_, this.schema))
}

protected def newProjection(
expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.sql.{SQLContext, Strategy, execution}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation}
import org.apache.spark.sql.parquet._
Expand Down Expand Up @@ -284,13 +284,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil
case logical.Sample(fraction, withReplacement, seed, child) =>
execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil
case SparkLogicalPlan(alreadyPlanned) => alreadyPlanned :: Nil
case logical.LocalRelation(output, data) =>
val nPartitions = if (data.isEmpty) 1 else numPartitions
PhysicalRDD(
output,
RDDConversions.productToRowRdd(sparkContext.parallelize(data, nPartitions),
StructType.fromAttributes(output))) :: Nil
LocalTableScan(output, data) :: Nil
case logical.Limit(IntegerLiteral(limit), child) =>
execution.Limit(limit, planLater(child)) :: Nil
case Unions(unionChildren) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,49 +103,7 @@ case class Limit(limit: Int, child: SparkPlan)
override def output = child.output
override def outputPartitioning = SinglePartition

/**
* A custom implementation modeled after the take function on RDDs but which never runs any job
* locally. This is to avoid shipping an entire partition of data in order to retrieve only a few
* rows.
*/
override def executeCollect(): Array[Row] = {
if (limit == 0) {
return new Array[Row](0)
}

val childRDD = child.execute().map(_.copy())

val buf = new ArrayBuffer[Row]
val totalParts = childRDD.partitions.length
var partsScanned = 0
while (buf.size < limit && partsScanned < totalParts) {
// The number of partitions to try in this iteration. It is ok for this number to be
// greater than totalParts because we actually cap it at totalParts in runJob.
var numPartsToTry = 1
if (partsScanned > 0) {
// If we didn't find any rows after the first iteration, just try all partitions next.
// Otherwise, interpolate the number of partitions we need to try, but overestimate it
// by 50%.
if (buf.size == 0) {
numPartsToTry = totalParts - 1
} else {
numPartsToTry = (1.5 * limit * partsScanned / buf.size).toInt
}
}
numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions

val left = limit - buf.size
val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
val sc = sqlContext.sparkContext
val res =
sc.runJob(childRDD, (it: Iterator[Row]) => it.take(left).toArray, p, allowLocal = false)

res.foreach(buf ++= _.take(limit - buf.size))
partsScanned += numPartsToTry
}

buf.toArray.map(ScalaReflection.convertRowToScala(_, this.schema))
}
override def executeCollect(): Array[Row] = child.executeTake(limit)

override def execute() = {
val rdd: RDD[_ <: Product2[Boolean, Row]] = if (sortBasedShuffleOn) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ case class ExecutedCommand(cmd: RunnableCommand) extends SparkPlan {

override def executeCollect(): Array[Row] = sideEffectResult.toArray

override def executeTake(limit: Int): Array[Row] = sideEffectResult.take(limit).toArray

override def execute(): RDD[Row] = sqlContext.sparkContext.parallelize(sideEffectResult, 1)
}

Expand Down

0 comments on commit 84acd08

Please sign in to comment.