Skip to content

Commit

Permalink
[SPARK-5996][SQL] Fix specialized outbound conversions
Browse files Browse the repository at this point in the history
Author: Michael Armbrust <michael@databricks.com>

Closes #4757 from marmbrus/udtConversions and squashes the following commits:

3714aad [Michael Armbrust] [SPARK-5996][SQL] Fix specialized outbound conversions
  • Loading branch information
marmbrus authored and mengxr committed Feb 25, 2015
1 parent dd077ab commit f84c799
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.execution

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.expressions.Attribute


Expand All @@ -30,7 +31,9 @@ case class LocalTableScan(output: Seq[Attribute], rows: Seq[Row]) extends LeafNo

override def execute() = rdd

override def executeCollect() = rows.toArray
override def executeCollect() =
rows.map(ScalaReflection.convertRowToScala(_, schema)).toArray

override def executeTake(limit: Int) = rows.take(limit).toArray
override def executeTake(limit: Int) =
rows.map(ScalaReflection.convertRowToScala(_, schema)).take(limit).toArray
}
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,15 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)

val ord = new RowOrdering(sortOrder, child.output)

private def collectData() = child.execute().map(_.copy()).takeOrdered(limit)(ord)

// TODO: Is this copying for no reason?
override def executeCollect() = child.execute().map(_.copy()).takeOrdered(limit)(ord)
.map(ScalaReflection.convertRowToScala(_, this.schema))
override def executeCollect() =
collectData().map(ScalaReflection.convertRowToScala(_, this.schema))

// TODO: Terminal split should be implemented differently from non-terminal split.
// TODO: Pick num splits based on |limit|.
override def execute() = sparkContext.makeRDD(executeCollect(), 1)
override def execute() = sparkContext.makeRDD(collectData(), 1)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.io.File
import scala.beans.{BeanInfo, BeanProperty}

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext.{sparkContext, sql}
import org.apache.spark.sql.test.TestSQLContext.implicits._
Expand Down Expand Up @@ -105,4 +106,13 @@ class UserDefinedTypeSuite extends QueryTest {
tempDir.delete()
pointsRDD.repartition(1).saveAsParquetFile(tempDir.getCanonicalPath)
}

// Tests to make sure that all operators correctly convert types on the way out.
test("Local UDTs") {
val df = Seq((1, new MyDenseVector(Array(0.1, 1.0)))).toDF("int", "vec")
df.collect()(0).getAs[MyDenseVector](1)
df.take(1)(0).getAs[MyDenseVector](1)
df.limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[MyDenseVector](0)
df.orderBy('int).limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[MyDenseVector](0)
}
}

0 comments on commit f84c799

Please sign in to comment.