Skip to content

Commit

Permalink
Merge branch 'master' into SPARK-13898
Browse files Browse the repository at this point in the history
  • Loading branch information
rxin committed Mar 21, 2016
2 parents 776e1a1 + 9b4e15b commit 501c9b9
Show file tree
Hide file tree
Showing 20 changed files with 229 additions and 202 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,8 @@ public MemoryBlock allocatePage(long size, MemoryConsumer consumer) {
logger.warn("Failed to allocate a page ({} bytes), try again.", acquired);
// there is no enough memory actually, it means the actual free memory is smaller than
// MemoryManager thought, we should keep the acquired memory.
acquiredButNotUsed += acquired;
synchronized (this) {
acquiredButNotUsed += acquired;
allocatedPages.clear(pageNumber);
}
// this could trigger spilling to free some pages.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ class ReplSuite extends SparkFunSuite {
|import org.apache.spark.sql.Encoder
|import org.apache.spark.sql.expressions.Aggregator
|import org.apache.spark.sql.TypedColumn
|val simpleSum = new Aggregator[Int, Int, Int] with Serializable {
|val simpleSum = new Aggregator[Int, Int, Int] {
| def zero: Int = 0 // The initial value.
| def reduce(b: Int, a: Int) = b + a // Add an element to the running total
| def merge(b1: Int, b2: Int) = b1 + b2 // Merge intermediate values.
Expand Down Expand Up @@ -347,7 +347,7 @@ class ReplSuite extends SparkFunSuite {
|import org.apache.spark.sql.expressions.Aggregator
|import org.apache.spark.sql.TypedColumn
|/** An `Aggregator` that adds up any numeric type returned by the given function. */
|class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] with Serializable {
|class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] {
| val numeric = implicitly[Numeric[N]]
| override def zero: N = numeric.zero
| override def reduce(b: N, a: I): N = numeric.plus(b, f(a))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,10 +249,32 @@ class ReplSuite extends SparkFunSuite {
// We need to use local-cluster to test this case.
val output = runInterpreter("local-cluster[1,1,1024]",
"""
|val sqlContext = new org.apache.spark.sql.SQLContext(sc)
|import sqlContext.implicits._
|case class TestCaseClass(value: Int)
|sc.parallelize(1 to 10).map(x => TestCaseClass(x)).toDF().collect()
|
|// Test Dataset Serialization in the REPL
|Seq(TestCaseClass(1)).toDS().collect()
""".stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
}

test("Datasets and encoders") {
val output = runInterpreter("local",
"""
|import org.apache.spark.sql.functions._
|import org.apache.spark.sql.Encoder
|import org.apache.spark.sql.expressions.Aggregator
|import org.apache.spark.sql.TypedColumn
|val simpleSum = new Aggregator[Int, Int, Int] {
| def zero: Int = 0 // The initial value.
| def reduce(b: Int, a: Int) = b + a // Add an element to the running total
| def merge(b1: Int, b2: Int) = b1 + b2 // Merge intermediate values.
| def finish(b: Int) = b // Return the final result.
|}.toColumn
|
|val ds = Seq(1, 2, 3, 4).toDS()
|ds.select(simpleSum).collect
""".stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
Expand Down Expand Up @@ -295,6 +317,31 @@ class ReplSuite extends SparkFunSuite {
}
}

test("Datasets agg type-inference") {
val output = runInterpreter("local",
"""
|import org.apache.spark.sql.functions._
|import org.apache.spark.sql.Encoder
|import org.apache.spark.sql.expressions.Aggregator
|import org.apache.spark.sql.TypedColumn
|/** An `Aggregator` that adds up any numeric type returned by the given function. */
|class SumOf[I, N : Numeric](f: I => N) extends
| org.apache.spark.sql.expressions.Aggregator[I, N, N] {
| val numeric = implicitly[Numeric[N]]
| override def zero: N = numeric.zero
| override def reduce(b: N, a: I): N = numeric.plus(b, f(a))
| override def merge(b1: N,b2: N): N = numeric.plus(b1, b2)
| override def finish(reduction: N): N = reduction
|}
|
|def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = new SumOf(f).toColumn
|val ds = Seq((1, 1, 2L), (1, 2, 3L), (1, 3, 4L), (2, 1, 5L)).toDS()
|ds.groupByKey(_._1).agg(sum(_._2), sum(_._3)).collect()
""".stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
}

test("collecting objects of class defined in repl") {
val output = runInterpreter("local[2]",
"""
Expand All @@ -317,4 +364,21 @@ class ReplSuite extends SparkFunSuite {
assertDoesNotContain("Exception", output)
assertContains("ret: Array[(Int, Iterable[Foo])] = Array((1,", output)
}

test("line wrapper only initialized once when used as encoder outer scope") {
val output = runInterpreter("local",
"""
|val fileName = "repl-test-" + System.currentTimeMillis
|val tmpDir = System.getProperty("java.io.tmpdir")
|val file = new java.io.File(tmpDir, fileName)
|def createFile(): Unit = file.createNewFile()
|
|createFile();case class TestCaseClass(value: Int)
|sc.parallelize(1 to 10).map(x => TestCaseClass(x)).collect()
|
|file.delete()
""".stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ class Analyzer(
} transformUp {
case other => other transformExpressions {
case a: Attribute =>
attributeRewrites.get(a).getOrElse(a).withQualifiers(a.qualifiers)
attributeRewrites.get(a).getOrElse(a).withQualifier(a.qualifier)
}
}
newRight
Expand Down Expand Up @@ -571,7 +571,7 @@ class Analyzer(
if n.outerPointer.isEmpty &&
n.cls.isMemberClass &&
!Modifier.isStatic(n.cls.getModifiers) =>
val outer = OuterScopes.outerScopes.get(n.cls.getDeclaringClass.getName)
val outer = OuterScopes.getOuterScope(n.cls)
if (outer == null) {
throw new AnalysisException(
s"Unable to generate an encoder for inner class `${n.cls.getName}` without " +
Expand Down Expand Up @@ -1467,8 +1467,7 @@ object CleanupAliases extends Rule[LogicalPlan] {

def trimNonTopLevelAliases(e: Expression): Expression = e match {
case a: Alias =>
Alias(trimAliases(a.child), a.name)(
a.exprId, a.qualifiers, a.explicitMetadata, a.isGenerated)
a.withNewChildren(trimAliases(a.child) :: Nil)
case other => trimAliases(other)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,13 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog {
if (table == null) {
throw new AnalysisException("Table not found: " + tableName)
}
val tableWithQualifiers = SubqueryAlias(tableName, table)
val qualifiedTable = SubqueryAlias(tableName, table)

// If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are
// properly qualified with this alias.
alias
.map(a => SubqueryAlias(a, tableWithQualifiers))
.getOrElse(tableWithQualifiers)
.map(a => SubqueryAlias(a, qualifiedTable))
.getOrElse(qualifiedTable)
}

override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = {
Expand Down Expand Up @@ -149,11 +149,11 @@ trait OverrideCatalog extends Catalog {
getOverriddenTable(tableIdent) match {
case Some(table) =>
val tableName = getTableName(tableIdent)
val tableWithQualifiers = SubqueryAlias(tableName, table)
val qualifiedTable = SubqueryAlias(tableName, table)

// If an alias was specified by the lookup, wrap the plan in a sub-query so that attributes
// are properly qualified with this alias.
alias.map(a => SubqueryAlias(a, tableWithQualifiers)).getOrElse(tableWithQualifiers)
alias.map(a => SubqueryAlias(a, qualifiedTable)).getOrElse(qualifiedTable)

case None => super.lookupRelation(tableIdent, alias)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,12 @@ case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with Un
override def exprId: ExprId = throw new UnresolvedException(this, "exprId")
override def dataType: DataType = throw new UnresolvedException(this, "dataType")
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers")
override def qualifier: Option[String] = throw new UnresolvedException(this, "qualifier")
override lazy val resolved = false

override def newInstance(): UnresolvedAttribute = this
override def withNullability(newNullability: Boolean): UnresolvedAttribute = this
override def withQualifiers(newQualifiers: Seq[String]): UnresolvedAttribute = this
override def withQualifier(newQualifier: Option[String]): UnresolvedAttribute = this
override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute.quoted(newName)

override def toString: String = s"'$name"
Expand Down Expand Up @@ -158,7 +158,7 @@ abstract class Star extends LeafExpression with NamedExpression {
override def exprId: ExprId = throw new UnresolvedException(this, "exprId")
override def dataType: DataType = throw new UnresolvedException(this, "dataType")
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers")
override def qualifier: Option[String] = throw new UnresolvedException(this, "qualifier")
override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute")
override def newInstance(): NamedExpression = throw new UnresolvedException(this, "newInstance")
override lazy val resolved = false
Expand Down Expand Up @@ -188,7 +188,7 @@ case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevalu
case None => input.output
// If there is a table, pick out attributes that are part of this table.
case Some(t) => if (t.size == 1) {
input.output.filter(_.qualifiers.exists(resolver(_, t.head)))
input.output.filter(_.qualifier.exists(resolver(_, t.head)))
} else {
List()
}
Expand Down Expand Up @@ -243,7 +243,7 @@ case class MultiAlias(child: Expression, names: Seq[String])

override def nullable: Boolean = throw new UnresolvedException(this, "nullable")

override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers")
override def qualifier: Option[String] = throw new UnresolvedException(this, "qualifier")

override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute")

Expand Down Expand Up @@ -298,7 +298,7 @@ case class UnresolvedAlias(child: Expression, aliasName: Option[String] = None)
extends UnaryExpression with NamedExpression with Unevaluable {

override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute")
override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers")
override def qualifier: Option[String] = throw new UnresolvedException(this, "qualifier")
override def exprId: ExprId = throw new UnresolvedException(this, "exprId")
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
override def dataType: DataType = throw new UnresolvedException(this, "dataType")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,10 @@ class SessionCatalog(externalCatalog: ExternalCatalog) {
} else {
tempTables.get(name.table)
}
val tableWithQualifiers = SubqueryAlias(name.table, relation)
val qualifiedTable = SubqueryAlias(name.table, relation)
// If an alias was specified by the lookup, wrap the plan in a subquery so that
// attributes are properly qualified with this alias.
alias.map(a => SubqueryAlias(a, tableWithQualifiers)).getOrElse(tableWithQualifiers)
alias.map(a => SubqueryAlias(a, qualifiedTable)).getOrElse(qualifiedTable)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,16 @@ import java.util.concurrent.ConcurrentMap

import com.google.common.collect.MapMaker

import org.apache.spark.util.Utils

object OuterScopes {
@transient
lazy val outerScopes: ConcurrentMap[String, AnyRef] =
new MapMaker().weakValues().makeMap()

/**
* Adds a new outer scope to this context that can be used when instantiating an `inner class`
* during deserialialization. Inner classes are created when a case class is defined in the
* during deserialization. Inner classes are created when a case class is defined in the
* Spark REPL and registering the outer scope that this class was defined in allows us to create
* new instances on the spark executors. In normal use, users should not need to call this
* function.
Expand All @@ -39,4 +41,47 @@ object OuterScopes {
def addOuterScope(outer: AnyRef): Unit = {
outerScopes.putIfAbsent(outer.getClass.getName, outer)
}

def getOuterScope(innerCls: Class[_]): AnyRef = {
assert(innerCls.isMemberClass)
val outerClassName = innerCls.getDeclaringClass.getName
val outer = outerScopes.get(outerClassName)
if (outer == null) {
outerClassName match {
// If the outer class is generated by REPL, users don't need to register it as it has
// only one instance and there is a way to retrieve it: get the `$read` object, call the
// `INSTANCE()` method to get the single instance of class `$read`. Then call `$iw()`
// method multiply times to get the single instance of the inner most `$iw` class.
case REPLClass(baseClassName) =>
val objClass = Utils.classForName(baseClassName + "$")
val objInstance = objClass.getField("MODULE$").get(null)
val baseInstance = objClass.getMethod("INSTANCE").invoke(objInstance)
val baseClass = Utils.classForName(baseClassName)

var getter = iwGetter(baseClass)
var obj = baseInstance
while (getter != null) {
obj = getter.invoke(obj)
getter = iwGetter(getter.getReturnType)
}

outerScopes.putIfAbsent(outerClassName, obj)
obj
case _ => null
}
} else {
outer
}
}

private def iwGetter(cls: Class[_]) = {
try {
cls.getMethod("$iw")
} catch {
case _: NoSuchMethodException => null
}
}

// The format of REPL generated wrapper class's name, e.g. `$line12.$read$$iw$$iw`
private[this] val REPLClass = """^(\$line(?:\d+)\.\$read)(?:\$\$iw)+$""".r
}
Loading

0 comments on commit 501c9b9

Please sign in to comment.