From 640c3ef4beadb01af21daaf9a507ae302f3833f8 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Mon, 16 Nov 2015 19:42:25 -0800 Subject: [PATCH 1/9] Allow use of REPL classes in Datasets --- project/SparkBuild.scala | 10 +- .../spark/sql/catalyst/ScalaReflection.scala | 124 ++++++++++-------- .../catalyst/encoders/ExpressionEncoder.scala | 17 ++- .../catalyst/encoders/ProductEncoder.scala | 6 +- .../expressions/codegen/CodeGenerator.scala | 9 +- .../expressions/codegen/CodegenFallback.scala | 2 +- .../codegen/GenerateSafeProjection.scala | 4 +- .../codegen/GenerateUnsafeProjection.scala | 4 +- .../sql/catalyst/expressions/literals.scala | 2 + .../sql/catalyst/expressions/objects.scala | 40 ++++-- .../encoders/ExpressionEncoderSuite.scala | 6 +- .../encoders/ProductEncoderSuite.scala | 4 + .../scala/org/apache/spark/sql/Dataset.scala | 2 +- .../org/apache/spark/sql/GroupedDataset.scala | 6 +- .../org/apache/spark/sql/SQLContext.scala | 11 ++ .../aggregate/TypedAggregateExpression.scala | 4 +- .../sql/execution/datasources/rules.scala | 2 +- 17 files changed, 165 insertions(+), 88 deletions(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 67724c4e9e411..9958ff695bca0 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -158,11 +158,11 @@ object SparkBuild extends PomBuild { javacJVMVersion := "1.7", scalacJVMVersion := "1.7", - javacOptions in Compile ++= Seq( - "-encoding", "UTF-8", - "-source", javacJVMVersion.value, - "-target", javacJVMVersion.value - ), +// javacOptions in Compile ++= Seq( +// "-encoding", "UTF-8", +// "-source", javacJVMVersion.value, +// "-target", javacJVMVersion.value +// ), scalacOptions in Compile ++= Seq( s"-target:jvm-${scalacJVMVersion.value}", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 0b3dd351e38e8..d6f6ff3a4fe5d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -35,17 +35,6 @@ object ScalaReflection extends ScalaReflection { // class loader of the current thread. override def mirror: universe.Mirror = universe.runtimeMirror(Thread.currentThread().getContextClassLoader) -} - -/** - * Support for generating catalyst schemas for scala objects. - */ -trait ScalaReflection { - /** The universe we work in (runtime or macro) */ - val universe: scala.reflect.api.Universe - - /** The mirror used to access types in the universe */ - def mirror: universe.Mirror import universe._ @@ -53,38 +42,14 @@ trait ScalaReflection { // Since the map values can be mutable, we explicitly import scala.collection.Map at here. import scala.collection.Map - case class Schema(dataType: DataType, nullable: Boolean) - - /** Returns a Sequence of attributes for the given case class type. */ - def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match { - case Schema(s: StructType, _) => - s.toAttributes - } - - /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ - def schemaFor[T: TypeTag]: Schema = - ScalaReflectionLock.synchronized { schemaFor(localTypeOf[T]) } - - /** - * Return the Scala Type for `T` in the current classloader mirror. - * - * Use this method instead of the convenience method `universe.typeOf`, which - * assumes that all types can be found in the classloader that loaded scala-reflect classes. - * That's not necessarily the case when running using Eclipse launchers or even - * Sbt console or test (without `fork := true`). - * - * @see SPARK-5281 - */ - def localTypeOf[T: TypeTag]: `Type` = typeTag[T].in(mirror).tpe - /** - * Returns the Spark SQL DataType for a given scala type. Where this is not an exact mapping - * to a native type, an ObjectType is returned. Special handling is also used for Arrays including - * those that hold primitive types. - * - * Unlike `schemaFor`, this function doesn't do any massaging of types into the Spark SQL type - * system. As a result, ObjectType will be returned for things like boxed Integers - */ + * Returns the Spark SQL DataType for a given scala type. Where this is not an exact mapping + * to a native type, an ObjectType is returned. Special handling is also used for Arrays including + * those that hold primitive types. + * + * Unlike `schemaFor`, this function doesn't do any massaging of types into the Spark SQL type + * system. As a result, ObjectType will be returned for things like boxed Integers + */ def dataTypeFor(tpe: `Type`): DataType = tpe match { case t if t <:< definitions.IntTpe => IntegerType case t if t <:< definitions.LongTpe => LongType @@ -114,15 +79,17 @@ trait ScalaReflection { } ObjectType(cls) - case other => ObjectType(Utils.classForName(className)) + case other => + val clazz = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) + ObjectType(clazz) } } /** - * Given a type `T` this function constructs and ObjectType that holds a class of type - * Array[T]. Special handling is performed for primitive types to map them back to their raw - * JVM form instead of the Scala Array that handles auto boxing. - */ + * Given a type `T` this function constructs and ObjectType that holds a class of type + * Array[T]. Special handling is performed for primitive types to map them back to their raw + * JVM form instead of the Scala Array that handles auto boxing. + */ def arrayClassFor(tpe: `Type`): DataType = { val cls = tpe match { case t if t <:< definitions.IntTpe => classOf[Array[Int]] @@ -142,15 +109,15 @@ trait ScalaReflection { } /** - * Returns an expression that can be used to construct an object of type `T` given an input - * row with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes - * of the same name as the constructor arguments. Nested classes will have their fields accessed - * using UnresolvedExtractValue. - * - * When used on a primitive type, the constructor will instead default to extracting the value - * from ordinal 0 (since there are no names to map to). The actual location can be moved by - * calling unbind/bind with a new schema. - */ + * Returns an expression that can be used to construct an object of type `T` given an input + * row with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes + * of the same name as the constructor arguments. Nested classes will have their fields accessed + * using UnresolvedExtractValue. + * + * When used on a primitive type, the constructor will instead default to extracting the value + * from ordinal 0 (since there are no names to map to). The actual location can be moved by + * calling unbind/bind with a new schema. + */ def constructorFor[T : TypeTag]: Expression = constructorFor(typeOf[T], None) private def constructorFor( @@ -437,8 +404,8 @@ trait ScalaReflection { /** Helper for extracting internal fields from a case class. */ protected def extractorFor( - inputObject: Expression, - tpe: `Type`): Expression = ScalaReflectionLock.synchronized { + inputObject: Expression, + tpe: `Type`): Expression = ScalaReflectionLock.synchronized { if (!inputObject.dataType.isInstanceOf[ObjectType]) { inputObject } else { @@ -640,6 +607,47 @@ trait ScalaReflection { } } } +} + +/** + * Support for generating catalyst schemas for scala objects. + */ +trait ScalaReflection { + /** The universe we work in (runtime or macro) */ + val universe: scala.reflect.api.Universe + + /** The mirror used to access types in the universe */ + def mirror: universe.Mirror + + import universe._ + + // The Predef.Map is scala.collection.immutable.Map. + // Since the map values can be mutable, we explicitly import scala.collection.Map at here. + import scala.collection.Map + + case class Schema(dataType: DataType, nullable: Boolean) + + /** Returns a Sequence of attributes for the given case class type. */ + def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match { + case Schema(s: StructType, _) => + s.toAttributes + } + + /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ + def schemaFor[T: TypeTag]: Schema = + ScalaReflectionLock.synchronized { schemaFor(localTypeOf[T]) } + + /** + * Return the Scala Type for `T` in the current classloader mirror. + * + * Use this method instead of the convenience method `universe.typeOf`, which + * assumes that all types can be found in the classloader that loaded scala-reflect classes. + * That's not necessarily the case when running using Eclipse launchers or even + * Sbt console or test (without `fork := true`). + * + * @see SPARK-5281 + */ + def localTypeOf[T: TypeTag]: `Type` = typeTag[T].in(mirror).tpe /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ def schemaFor(tpe: `Type`): Schema = ScalaReflectionLock.synchronized { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 9a1a8f5cbbdc3..53aca564c7cfe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.encoders +import java.util.concurrent.ConcurrentMap + import scala.reflect.ClassTag import scala.reflect.runtime.universe.{typeTag, TypeTag} @@ -209,7 +211,9 @@ case class ExpressionEncoder[T]( * Returns a new copy of this encoder, where the expressions used by `fromRow` are resolved to the * given schema. */ - def resolve(schema: Seq[Attribute]): ExpressionEncoder[T] = { + def resolve( + schema: Seq[Attribute], + outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = { val positionToAttribute = AttributeMap.toIndex(schema) val unbound = fromRowExpression transform { case b: BoundReference => positionToAttribute(b.ordinal) @@ -217,7 +221,16 @@ case class ExpressionEncoder[T]( val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema)) val analyzedPlan = SimpleAnalyzer.execute(plan) - copy(fromRowExpression = analyzedPlan.expressions.head.children.head) + + // In order to construct instances of inner classes (for example those declared in a REPL cell), + // we need an instance of the outer scope. This rule substitues those outer objects into + // expressions that are missing them by looking up the name in the SQLContexts `outerScopes` + // registry. + copy(fromRowExpression = analyzedPlan.expressions.head.children.head transform { + case n: NewInstance if n.outerPointer.isEmpty && n.cls.isMemberClass => + val outer = outerScopes.get(n.cls.getDeclaringClass.getName) + n.copy(outerPointer = Some(Literal.fromObject(outer))) + }) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala index 414adb21168ed..1de4d769d21d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala @@ -31,6 +31,7 @@ import scala.reflect.ClassTag object ProductEncoder { import ScalaReflection.universe._ + import ScalaReflection.mirror import ScalaReflection.localTypeOf import ScalaReflection.dataTypeFor import ScalaReflection.Schema @@ -420,8 +421,7 @@ object ProductEncoder { } } - val className: String = t.erasure.typeSymbol.asClass.fullName - val cls = Utils.classForName(className) + val cls = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) val arguments = params.head.zipWithIndex.map { case (p, i) => val fieldName = p.name.toString @@ -429,7 +429,7 @@ object ProductEncoder { val dataType = schemaFor(fieldType).dataType // For tuples, we based grab the inner fields by ordinal instead of name. - if (className startsWith "scala.Tuple") { + if (cls.getName startsWith "scala.Tuple") { constructorFor(fieldType, Some(addToPathOrdinal(i, dataType))) } else { constructorFor(fieldType, Some(addToPath(fieldName))) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 1718cfbd35332..1b078f5c4856a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions.codegen +import org.apache.spark.util.Utils + import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.language.existentials @@ -202,6 +204,11 @@ class CodeGenContext { case _: ArrayType => "ArrayData" case _: MapType => "MapData" case udt: UserDefinedType[_] => javaType(udt.sqlType) +// case ObjectType(cls) if cls.isMemberClass => +// val pkg = cls.getDeclaringClass.getPackage.getName + "." +// val name = pkg + cls.getName +// println(name) +// name case ObjectType(cls) if cls.isArray => s"${javaType(ObjectType(cls.getComponentType))}[]" case ObjectType(cls) => cls.getName case _ => "Object" @@ -524,7 +531,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin */ private[this] def doCompile(code: String): GeneratedClass = { val evaluator = new ClassBodyEvaluator() - evaluator.setParentClassLoader(getClass.getClassLoader) + evaluator.setParentClassLoader(Utils.getContextOrSparkClassLoader) // Cannot be under package codegen, or fail with java.lang.InstantiationException evaluator.setClassName("org.apache.spark.sql.catalyst.expressions.GeneratedClass") evaluator.setDefaultImports(Array( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index d51a8dede7f34..a31574c251af5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -34,7 +34,7 @@ trait CodegenFallback extends Expression { val objectTerm = ctx.freshName("obj") s""" /* expression: ${this} */ - Object $objectTerm = expressions[${ctx.references.size - 1}].eval(${ctx.INPUT_ROW}); + java.lang.Object $objectTerm = expressions[${ctx.references.size - 1}].eval(${ctx.INPUT_ROW}); boolean ${ev.isNull} = $objectTerm == null; ${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)}; if (!${ev.isNull}) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index f0ed8645d923f..b7926bda3de19 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -148,7 +148,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] } val allExpressions = ctx.splitExpressions(ctx.INPUT_ROW, expressionCodes) val code = s""" - public Object generate($exprType[] expr) { + public java.lang.Object generate($exprType[] expr) { return new SpecificSafeProjection(expr); } @@ -165,7 +165,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] ${initMutableStates(ctx)} } - public Object apply(Object _i) { + public java.lang.Object apply(java.lang.Object _i) { InternalRow ${ctx.INPUT_ROW} = (InternalRow) _i; $allExpressions return mutableRow; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 4c17d02a23725..7b6c9373ebe30 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -324,7 +324,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val eval = createCode(ctx, expressions, subexpressionEliminationEnabled) val code = s""" - public Object generate($exprType[] exprs) { + public java.lang.Object generate($exprType[] exprs) { return new SpecificUnsafeProjection(exprs); } @@ -342,7 +342,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } // Scala.Function1 need this - public Object apply(Object row) { + public java.lang.Object apply(java.lang.Object row) { return apply((InternalRow) row); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 455fa2427c26d..c92f0bf860db9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -48,6 +48,8 @@ object Literal { throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v) } + def fromObject(obj: AnyRef): Literal = new Literal(obj, ObjectType(obj.getClass)) + def create(v: Any, dataType: DataType): Literal = { Literal(CatalystTypeConverters.convertToCatalyst(v), dataType) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 5cd19de68391c..fa9337c62f705 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer +import org.apache.spark.sql.catalyst.encoders.ProductEncoder import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation} import org.apache.spark.sql.catalyst.util.GenericArrayData @@ -176,6 +177,15 @@ case class Invoke( } } +object NewInstance { + def apply( + cls: Class[_], + arguments: Seq[Expression], + propagateNull: Boolean = false, + dataType: DataType): NewInstance = + new NewInstance(cls, arguments, propagateNull, dataType, None) +} + /** * Constructs a new instance of the given class, using the result of evaluating the specified * expressions as arguments. @@ -191,8 +201,9 @@ case class Invoke( case class NewInstance( cls: Class[_], arguments: Seq[Expression], - propagateNull: Boolean = true, - dataType: DataType) extends Expression { + propagateNull: Boolean, + dataType: DataType, + outerPointer: Option[Literal]) extends Expression { private val className = cls.getName override def nullable: Boolean = propagateNull @@ -207,30 +218,43 @@ case class NewInstance( val argGen = arguments.map(_.gen(ctx)) val argString = argGen.map(_.value).mkString(", ") + val outer = outerPointer.map(_.gen(ctx)) + + val setup = + s""" + ${argGen.map(_.code).mkString("\n")} + ${outer.map(_.code.mkString("")).getOrElse("")} + """.stripMargin + + val constructorCall = outer.map { gen => + s"""${gen.value}.new ${cls.getSimpleName}($argString)""" + }.getOrElse { + s"new $className($argString)" + } + if (propagateNull) { val objNullCheck = if (ctx.defaultValue(dataType) == "null") { s"${ev.isNull} = ${ev.value} == null;" } else { "" } - val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})" + s""" - ${argGen.map(_.code).mkString("\n")} + $setup boolean ${ev.isNull} = true; $javaType ${ev.value} = ${ctx.defaultValue(dataType)}; - if ($argsNonNull) { - ${ev.value} = new $className($argString); + ${ev.value} = $constructorCall; ${ev.isNull} = false; } """ } else { s""" - ${argGen.map(_.code).mkString("\n")} + $setup - $javaType ${ev.value} = new $className($argString); + $javaType ${ev.value} = $constructorCall; final boolean ${ev.isNull} = ${ev.value} == null; """ } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 9fe64b4cf10e4..f867eb239d652 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -18,13 +18,17 @@ package org.apache.spark.sql.catalyst.encoders import java.util.Arrays +import java.util.concurrent.ConcurrentMap +import com.google.common.collect.MapMaker import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.types.ArrayType abstract class ExpressionEncoderSuite extends SparkFunSuite { + val outers: ConcurrentMap[String, AnyRef] = new MapMaker().weakValues().makeMap() + protected def encodeDecodeTest[T]( input: T, encoder: ExpressionEncoder[T], @@ -32,7 +36,7 @@ abstract class ExpressionEncoderSuite extends SparkFunSuite { test(s"encode/decode for $testName: $input") { val row = encoder.toRow(input) val schema = encoder.schema.toAttributes - val boundEncoder = encoder.resolve(schema).bind(schema) + val boundEncoder = encoder.resolve(schema, outers).bind(schema) val convertedBack = try boundEncoder.fromRow(row) catch { case e: Exception => fail( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala index bc539d62c537d..1798514c5c38b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala @@ -53,6 +53,10 @@ case class RepeatedData( case class SpecificCollection(l: List[Int]) class ProductEncoderSuite extends ExpressionEncoderSuite { + outers.put(getClass.getName, this) + + case class InnerClass(i: Int) + productTest(InnerClass(1)) productTest(PrimitiveData(1, 1, 1, 1, 1, 1, true)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 4cc3aa2465f2e..ec292c21d1993 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -73,7 +73,7 @@ class Dataset[T] private[sql]( /** The encoder for this [[Dataset]] that has been resolved to its output schema. */ private[sql] val resolvedTEncoder: ExpressionEncoder[T] = - unresolvedTEncoder.resolve(queryExecution.analyzed.output) + unresolvedTEncoder.resolve(queryExecution.analyzed.output, sqlContext.outerScopes) private implicit def classTag = resolvedTEncoder.clsTag diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 467cd42b9b8dc..5c608fcf1d3df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -53,8 +53,10 @@ class GroupedDataset[K, T] private[sql]( private implicit val unresolvedKEncoder = encoderFor(kEncoder) private implicit val unresolvedTEncoder = encoderFor(tEncoder) - private val resolvedKEncoder = unresolvedKEncoder.resolve(groupingAttributes) - private val resolvedTEncoder = unresolvedTEncoder.resolve(dataAttributes) + private val resolvedKEncoder = + unresolvedKEncoder.resolve(groupingAttributes, sqlContext.outerScopes) + private val resolvedTEncoder = + unresolvedTEncoder.resolve(dataAttributes, sqlContext.outerScopes) private def logicalPlan = queryExecution.analyzed private def sqlContext = queryExecution.sqlContext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index cd1fdc4edb39d..04a76b3327326 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -19,8 +19,11 @@ package org.apache.spark.sql import java.beans.{BeanInfo, Introspector} import java.util.Properties +import java.util.concurrent.ConcurrentMap import java.util.concurrent.atomic.AtomicReference +import com.google.common.collect.MapMaker + import scala.collection.JavaConverters._ import scala.collection.immutable import scala.reflect.runtime.universe.TypeTag @@ -188,6 +191,14 @@ class SQLContext private[sql]( @transient protected[sql] lazy val functionRegistry: FunctionRegistry = FunctionRegistry.builtin.copy() + @transient + protected [sql] lazy val outerScopes: ConcurrentMap[String, AnyRef] = + new MapMaker().weakValues().makeMap() + + def addOuterScope(outer: AnyRef): Unit = { + outerScopes.putIfAbsent(outer.getClass.getName, outer) + } + @transient protected[sql] lazy val analyzer: Analyzer = new Analyzer(catalog, functionRegistry, conf) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 3f2775896bb8c..de3df6b367130 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.aggregate +import com.google.common.collect.MapMaker + import scala.language.existentials import org.apache.spark.Logging @@ -93,7 +95,7 @@ case class TypedAggregateExpression( lazy val boundA = aEncoder.get val bAttributes = bEncoder.schema.toAttributes - lazy val boundB = bEncoder.resolve(bAttributes).bind(bAttributes) + lazy val boundB = bEncoder.resolve(bAttributes, new MapMaker().makeMap()).bind(bAttributes) private def updateBuffer(buffer: MutableRow, value: InternalRow): Unit = { // todo: need a more neat way to assign the value. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 1a8e7ab202dc2..1877fd1ad0615 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule From fe657afcbf850b0e538c2b49b52e981886ec4cd7 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Mon, 16 Nov 2015 23:18:56 -0800 Subject: [PATCH 2/9] cleanup --- project/SparkBuild.scala | 10 ++-- .../spark/sql/catalyst/ScalaReflection.scala | 47 ++++++++++--------- .../expressions/codegen/CodeGenerator.scala | 5 -- .../scala/org/apache/spark/sql/Dataset.scala | 2 +- .../org/apache/spark/sql/SQLContext.scala | 12 +++++ .../aggregate/TypedAggregateExpression.scala | 21 ++++----- .../sql/execution/datasources/rules.scala | 2 +- 7 files changed, 51 insertions(+), 48 deletions(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 9958ff695bca0..67724c4e9e411 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -158,11 +158,11 @@ object SparkBuild extends PomBuild { javacJVMVersion := "1.7", scalacJVMVersion := "1.7", -// javacOptions in Compile ++= Seq( -// "-encoding", "UTF-8", -// "-source", javacJVMVersion.value, -// "-target", javacJVMVersion.value -// ), + javacOptions in Compile ++= Seq( + "-encoding", "UTF-8", + "-source", javacJVMVersion.value, + "-target", javacJVMVersion.value + ), scalacOptions in Compile ++= Seq( s"-target:jvm-${scalacJVMVersion.value}", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index d6f6ff3a4fe5d..aba53b64701f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -43,13 +43,13 @@ object ScalaReflection extends ScalaReflection { import scala.collection.Map /** - * Returns the Spark SQL DataType for a given scala type. Where this is not an exact mapping - * to a native type, an ObjectType is returned. Special handling is also used for Arrays including - * those that hold primitive types. - * - * Unlike `schemaFor`, this function doesn't do any massaging of types into the Spark SQL type - * system. As a result, ObjectType will be returned for things like boxed Integers - */ + * Returns the Spark SQL DataType for a given scala type. Where this is not an exact mapping + * to a native type, an ObjectType is returned. Special handling is also used for Arrays including + * those that hold primitive types. + * + * Unlike `schemaFor`, this function doesn't do any massaging of types into the Spark SQL type + * system. As a result, ObjectType will be returned for things like boxed Integers + */ def dataTypeFor(tpe: `Type`): DataType = tpe match { case t if t <:< definitions.IntTpe => IntegerType case t if t <:< definitions.LongTpe => LongType @@ -86,10 +86,10 @@ object ScalaReflection extends ScalaReflection { } /** - * Given a type `T` this function constructs and ObjectType that holds a class of type - * Array[T]. Special handling is performed for primitive types to map them back to their raw - * JVM form instead of the Scala Array that handles auto boxing. - */ + * Given a type `T` this function constructs and ObjectType that holds a class of type + * Array[T]. Special handling is performed for primitive types to map them back to their raw + * JVM form instead of the Scala Array that handles auto boxing. + */ def arrayClassFor(tpe: `Type`): DataType = { val cls = tpe match { case t if t <:< definitions.IntTpe => classOf[Array[Int]] @@ -109,15 +109,15 @@ object ScalaReflection extends ScalaReflection { } /** - * Returns an expression that can be used to construct an object of type `T` given an input - * row with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes - * of the same name as the constructor arguments. Nested classes will have their fields accessed - * using UnresolvedExtractValue. - * - * When used on a primitive type, the constructor will instead default to extracting the value - * from ordinal 0 (since there are no names to map to). The actual location can be moved by - * calling unbind/bind with a new schema. - */ + * Returns an expression that can be used to construct an object of type `T` given an input + * row with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes + * of the same name as the constructor arguments. Nested classes will have their fields accessed + * using UnresolvedExtractValue. + * + * When used on a primitive type, the constructor will instead default to extracting the value + * from ordinal 0 (since there are no names to map to). The actual location can be moved by + * calling unbind/bind with a new schema. + */ def constructorFor[T : TypeTag]: Expression = constructorFor(typeOf[T], None) private def constructorFor( @@ -404,8 +404,8 @@ object ScalaReflection extends ScalaReflection { /** Helper for extracting internal fields from a case class. */ protected def extractorFor( - inputObject: Expression, - tpe: `Type`): Expression = ScalaReflectionLock.synchronized { + inputObject: Expression, + tpe: `Type`): Expression = ScalaReflectionLock.synchronized { if (!inputObject.dataType.isInstanceOf[ObjectType]) { inputObject } else { @@ -610,7 +610,8 @@ object ScalaReflection extends ScalaReflection { } /** - * Support for generating catalyst schemas for scala objects. + * Support for generating catalyst schemas for scala objects. Note that unlike its companion + * object, this trait able to work in both the runtime and the compile time (macro) universe. */ trait ScalaReflection { /** The universe we work in (runtime or macro) */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 1b078f5c4856a..6d644b5f9efe8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -204,11 +204,6 @@ class CodeGenContext { case _: ArrayType => "ArrayData" case _: MapType => "MapData" case udt: UserDefinedType[_] => javaType(udt.sqlType) -// case ObjectType(cls) if cls.isMemberClass => -// val pkg = cls.getDeclaringClass.getPackage.getName + "." -// val name = pkg + cls.getName -// println(name) -// name case ObjectType(cls) if cls.isArray => s"${javaType(ObjectType(cls.getComponentType))}[]" case ObjectType(cls) => cls.getName case _ => "Object" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index ec292c21d1993..5d31246bfdadb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -368,7 +368,7 @@ class Dataset[T] private[sql]( sqlContext, Project( c1.withInputType( - resolvedTEncoder, + resolvedTEncoder.bind(queryExecution.analyzed.output), queryExecution.analyzed.output).named :: Nil, logicalPlan)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 04a76b3327326..cc16ea68075a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -195,6 +195,18 @@ class SQLContext private[sql]( protected [sql] lazy val outerScopes: ConcurrentMap[String, AnyRef] = new MapMaker().weakValues().makeMap() + /** + * :: DeveloperApi :: + * Adds a new outer scope to this context that can be used when instantiating an `inner class` + * during deserialialization. Inner classes are created when a case class is defined in the + * Spark REPL and registering the outer scope that this class was defined in allows us to create + * new instances on the spark executors. In normal use, users should not need to call this + * function. + * + * Warning: this function operates on the assumption that there is only ever one instance of any + * given wrapper class. + */ + @DeveloperApi def addOuterScope(outer: AnyRef): Unit = { outerScopes.putIfAbsent(outer.getClass.getName, outer) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index de3df6b367130..6ce41aaf01e27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution.aggregate -import com.google.common.collect.MapMaker - import scala.language.existentials import org.apache.spark.Logging @@ -54,8 +52,8 @@ object TypedAggregateExpression { */ case class TypedAggregateExpression( aggregator: Aggregator[Any, Any, Any], - aEncoder: Option[ExpressionEncoder[Any]], - bEncoder: ExpressionEncoder[Any], + aEncoder: Option[ExpressionEncoder[Any]], // Should be bound. + bEncoder: ExpressionEncoder[Any], // Should be bound. cEncoder: ExpressionEncoder[Any], children: Seq[Attribute], mutableAggBufferOffset: Int, @@ -94,9 +92,6 @@ case class TypedAggregateExpression( // We let the dataset do the binding for us. lazy val boundA = aEncoder.get - val bAttributes = bEncoder.schema.toAttributes - lazy val boundB = bEncoder.resolve(bAttributes, new MapMaker().makeMap()).bind(bAttributes) - private def updateBuffer(buffer: MutableRow, value: InternalRow): Unit = { // todo: need a more neat way to assign the value. var i = 0 @@ -116,24 +111,24 @@ case class TypedAggregateExpression( override def update(buffer: MutableRow, input: InternalRow): Unit = { val inputA = boundA.fromRow(input) - val currentB = boundB.shift(mutableAggBufferOffset).fromRow(buffer) + val currentB = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer) val merged = aggregator.reduce(currentB, inputA) - val returned = boundB.toRow(merged) + val returned = bEncoder.toRow(merged) updateBuffer(buffer, returned) } override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { - val b1 = boundB.shift(mutableAggBufferOffset).fromRow(buffer1) - val b2 = boundB.shift(inputAggBufferOffset).fromRow(buffer2) + val b1 = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer1) + val b2 = bEncoder.shift(inputAggBufferOffset).fromRow(buffer2) val merged = aggregator.merge(b1, b2) - val returned = boundB.toRow(merged) + val returned = bEncoder.toRow(merged) updateBuffer(buffer1, returned) } override def eval(buffer: InternalRow): Any = { - val b = boundB.shift(mutableAggBufferOffset).fromRow(buffer) + val b = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer) val result = cEncoder.toRow(aggregator.finish(b)) dataType match { case _: StructType => result diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 1877fd1ad0615..1a8e7ab202dc2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast} import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule From cb60d1658811524af9459059fd5af28d866db326 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 17 Nov 2015 00:25:17 -0800 Subject: [PATCH 3/9] quick and dirty repl hack --- .../src/main/scala/org/apache/spark/repl/SparkIMain.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala index 4ee605fd7f11e..ab5927c10af25 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -816,7 +816,10 @@ import org.apache.spark.annotation.DeveloperApi * incomplete code, compilation error, or runtime error */ @DeveloperApi - def interpret(line: String): IR.Result = interpret(line, false) + def interpret(line: String): IR.Result = { + val fullLine = if (line contains "class") "sqlContext.addOuterScope(this); " + line else line + interpret(fullLine, false) + } /** * Interpret one line of input. All feedback, including parse errors From 59ca3cebc29a85d2518dfc05fa4b230efd320ceb Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 17 Nov 2015 17:56:49 -0800 Subject: [PATCH 4/9] some hax --- .../org/apache/spark/repl/SparkIMain.scala | 21 ++++++---- .../org/apache/spark/repl/ReplSuite.scala | 3 +- .../expressions/codegen/CodeGenerator.scala | 5 ++- .../codegen/GenerateMutableProjection.scala | 4 +- .../codegen/GenerateProjection.scala | 10 ++--- .../codegen/GenerateUnsafeRowJoiner.scala | 6 +-- .../scala/org/apache/spark/sql/Dataset.scala | 2 +- .../org/apache/spark/sql/GroupedDataset.scala | 4 +- .../org/apache/spark/sql/SQLContext.scala | 42 ++++++++++--------- 9 files changed, 54 insertions(+), 43 deletions(-) diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala index ab5927c10af25..a9fa0e3361676 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -817,7 +817,7 @@ import org.apache.spark.annotation.DeveloperApi */ @DeveloperApi def interpret(line: String): IR.Result = { - val fullLine = if (line contains "class") "sqlContext.addOuterScope(this); " + line else line + val fullLine = if (line contains " class ") "" + line else line interpret(fullLine, false) } @@ -1182,8 +1182,9 @@ import org.apache.spark.annotation.DeveloperApi /** Code to import bound names from previous lines - accessPath is code to * append to objectName to access anything bound by request. */ - val SparkComputedImports(importsPreamble, importsTrailer, accessPath) = - importsCode(referencedNames.toSet, definedClasses) + val SparkComputedImports(importsPreamble, importsTrailer, accessPath) = { + importsCode(referencedNames.toSet , definedClasses) + } /** Code to access a variable with the specified name */ def fullPath(vname: String) = { @@ -1231,6 +1232,7 @@ import org.apache.spark.annotation.DeveloperApi val postamble = importsTrailer + "\n}" + "\n" + "object " + lineRep.readName + " {\n" + " val INSTANCE = new " + lineRep.readName + "();\n" + + " org.apache.spark.sql.OuterScopes.addOuterScope(INSTANCE);\n" + "}\n" val generate = (m: MemberHandler) => m extraCodeToEvaluate Request.this @@ -1719,10 +1721,15 @@ object SparkIMain { def generate: T => String def postamble: String - def apply(contributors: List[T]): String = stringFromWriter { code => - code println preamble - contributors map generate foreach (code println _) - code println postamble + def apply(contributors: List[T]): String = { + val res = stringFromWriter { code => + code println preamble + contributors map generate foreach (code println _) + code println postamble + } + System.out.println("============") + System.out.println(res.split("\n").zipWithIndex.map{ case (l, i) => s"$i $l" }.mkString("\n")) + res } } diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 5674dcd669bee..a6d49d0f4f70d 100644 --- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -258,10 +258,9 @@ class ReplSuite extends SparkFunSuite { // We need to use local-cluster to test this case. val output = runInterpreter("local-cluster[1,1,1024]", """ - |val sqlContext = new org.apache.spark.sql.SQLContext(sc) - |import sqlContext.implicits._ |case class TestCaseClass(value: Int) |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).toDF().collect() + |println(sc) """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 6d644b5f9efe8..b87f9f56960f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.util.{MapData, ArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.types._ +import org.apache.spark.util.Utils /** @@ -206,7 +207,7 @@ class CodeGenContext { case udt: UserDefinedType[_] => javaType(udt.sqlType) case ObjectType(cls) if cls.isArray => s"${javaType(ObjectType(cls.getComponentType))}[]" case ObjectType(cls) => cls.getName - case _ => "Object" + case _ => "java.lang.Object" } /** @@ -518,6 +519,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin * Compile the Java source code into a Java class, using Janino. */ protected def compile(code: String): GeneratedClass = { + assert(!code.contains(" Object "), + "Avoid using unqualified Object in codegen, use java.lang.Object\n" + code) cache.get(code) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 4b66069b5f55a..40189f0877764 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -82,7 +82,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu val allUpdates = ctx.splitExpressions(ctx.INPUT_ROW, updates) val code = s""" - public Object generate($exprType[] expr) { + public java.lang.Object generate($exprType[] expr) { return new SpecificMutableProjection(expr); } @@ -109,7 +109,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu return (InternalRow) mutableRow; } - public Object apply(Object _i) { + public java.lang.Object apply(java.lang.Object _i) { InternalRow ${ctx.INPUT_ROW} = (InternalRow) _i; $allProjections // copy all the results into MutableRow diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index c0d313b2e1301..f229f2000d8e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -167,7 +167,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { ${initMutableStates(ctx)} } - public Object apply(Object r) { + public java.lang.Object apply(java.lang.Object r) { // GenerateProjection does not work with UnsafeRows. assert(!(r instanceof ${classOf[UnsafeRow].getName})); return new SpecificRow((InternalRow) r); @@ -186,14 +186,14 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { public void setNullAt(int i) { nullBits[i] = true; } public boolean isNullAt(int i) { return nullBits[i]; } - public Object genericGet(int i) { + public java.lang.Object genericGet(int i) { if (isNullAt(i)) return null; switch (i) { $getCases } return null; } - public void update(int i, Object value) { + public void update(int i, java.lang.Object value) { if (value == null) { setNullAt(i); return; @@ -212,7 +212,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { return result; } - public boolean equals(Object other) { + public boolean equals(java.lang.Object other) { if (other instanceof SpecificRow) { SpecificRow row = (SpecificRow) other; $columnChecks @@ -222,7 +222,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { } public InternalRow copy() { - Object[] arr = new Object[${expressions.length}]; + java.lang.Object[] arr = new java.lang.Object[${expressions.length}]; ${copyColumns} return new ${classOf[GenericInternalRow].getName}(arr); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala index da91ff29537b3..da602d9b4bce1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala @@ -159,7 +159,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U // ------------------------ Finally, put everything together --------------------------- // val code = s""" - |public Object generate($exprType[] exprs) { + |public java.lang.Object generate($exprType[] exprs) { | return new SpecificUnsafeRowJoiner(); |} | @@ -176,9 +176,9 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U | buf = new byte[sizeInBytes]; | } | - | final Object obj1 = row1.getBaseObject(); + | final java.lang.Object obj1 = row1.getBaseObject(); | final long offset1 = row1.getBaseOffset(); - | final Object obj2 = row2.getBaseObject(); + | final java.lang.Object obj2 = row2.getBaseObject(); | final long offset2 = row2.getBaseOffset(); | | $copyBitset diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 5d31246bfdadb..c0b09db9b5fbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -73,7 +73,7 @@ class Dataset[T] private[sql]( /** The encoder for this [[Dataset]] that has been resolved to its output schema. */ private[sql] val resolvedTEncoder: ExpressionEncoder[T] = - unresolvedTEncoder.resolve(queryExecution.analyzed.output, sqlContext.outerScopes) + unresolvedTEncoder.resolve(queryExecution.analyzed.output, OuterScopes.outerScopes) private implicit def classTag = resolvedTEncoder.clsTag diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 5c608fcf1d3df..5a813745b6292 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -54,9 +54,9 @@ class GroupedDataset[K, T] private[sql]( private implicit val unresolvedTEncoder = encoderFor(tEncoder) private val resolvedKEncoder = - unresolvedKEncoder.resolve(groupingAttributes, sqlContext.outerScopes) + unresolvedKEncoder.resolve(groupingAttributes, OuterScopes.outerScopes) private val resolvedTEncoder = - unresolvedTEncoder.resolve(dataAttributes, sqlContext.outerScopes) + unresolvedTEncoder.resolve(dataAttributes, OuterScopes.outerScopes) private def logicalPlan = queryExecution.analyzed private def sqlContext = queryExecution.sqlContext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index cc16ea68075a8..ae4e7c1285730 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -52,6 +52,28 @@ import org.apache.spark.sql.{execution => sparkexecution} import org.apache.spark.sql.util.ExecutionListenerManager import org.apache.spark.util.Utils +object OuterScopes { + @transient + protected [sql] lazy val outerScopes: ConcurrentMap[String, AnyRef] = + new MapMaker().weakValues().makeMap() + + /** + * :: DeveloperApi :: + * Adds a new outer scope to this context that can be used when instantiating an `inner class` + * during deserialialization. Inner classes are created when a case class is defined in the + * Spark REPL and registering the outer scope that this class was defined in allows us to create + * new instances on the spark executors. In normal use, users should not need to call this + * function. + * + * Warning: this function operates on the assumption that there is only ever one instance of any + * given wrapper class. + */ + @DeveloperApi + def addOuterScope(outer: AnyRef): Unit = { + outerScopes.putIfAbsent(outer.getClass.getName, outer) + } +} + /** * The entry point for working with structured data (rows and columns) in Spark. Allows the * creation of [[DataFrame]] objects as well as the execution of SQL queries. @@ -191,26 +213,6 @@ class SQLContext private[sql]( @transient protected[sql] lazy val functionRegistry: FunctionRegistry = FunctionRegistry.builtin.copy() - @transient - protected [sql] lazy val outerScopes: ConcurrentMap[String, AnyRef] = - new MapMaker().weakValues().makeMap() - - /** - * :: DeveloperApi :: - * Adds a new outer scope to this context that can be used when instantiating an `inner class` - * during deserialialization. Inner classes are created when a case class is defined in the - * Spark REPL and registering the outer scope that this class was defined in allows us to create - * new instances on the spark executors. In normal use, users should not need to call this - * function. - * - * Warning: this function operates on the assumption that there is only ever one instance of any - * given wrapper class. - */ - @DeveloperApi - def addOuterScope(outer: AnyRef): Unit = { - outerScopes.putIfAbsent(outer.getClass.getName, outer) - } - @transient protected[sql] lazy val analyzer: Analyzer = new Analyzer(catalog, functionRegistry, conf) { From 35cd4f4e9549be327e44d50b59712c0cc9990cf5 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 17 Nov 2015 18:42:25 -0800 Subject: [PATCH 5/9] cleanup --- .../org/apache/spark/repl/SparkIMain.scala | 31 ++++++-------- .../org/apache/spark/repl/ReplSuite.scala | 6 ++- .../sql/catalyst/encoders/OuterScopes.scala | 42 +++++++++++++++++++ .../sql/catalyst/expressions/literals.scala | 4 ++ .../sql/catalyst/expressions/objects.scala | 2 + .../encoders/ExpressionEncoderSuite.scala | 1 + .../org/apache/spark/sql/GroupedDataset.scala | 3 +- .../org/apache/spark/sql/SQLContext.scala | 25 ----------- 8 files changed, 68 insertions(+), 46 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala index a9fa0e3361676..a8cf85800748f 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -816,10 +816,7 @@ import org.apache.spark.annotation.DeveloperApi * incomplete code, compilation error, or runtime error */ @DeveloperApi - def interpret(line: String): IR.Result = { - val fullLine = if (line contains " class ") "" + line else line - interpret(fullLine, false) - } + def interpret(line: String): IR.Result = interpret(line, false) /** * Interpret one line of input. All feedback, including parse errors @@ -1182,9 +1179,8 @@ import org.apache.spark.annotation.DeveloperApi /** Code to import bound names from previous lines - accessPath is code to * append to objectName to access anything bound by request. */ - val SparkComputedImports(importsPreamble, importsTrailer, accessPath) = { - importsCode(referencedNames.toSet , definedClasses) - } + val SparkComputedImports(importsPreamble, importsTrailer, accessPath) = + importsCode(referencedNames.toSet, definedClasses) /** Code to access a variable with the specified name */ def fullPath(vname: String) = { @@ -1227,12 +1223,16 @@ import org.apache.spark.annotation.DeveloperApi val preamble = """ |class %s extends Serializable { - | %s%s%s + | %s + | %s + | // If we need to construct any objects defined in the REPL on an executor we will need + | // to pass the outer scope to the appropriate encoder. + | org.apache.spark.sql.catalyst.encoders.OuterScopes.addOuterScope(this) + | %s """.stripMargin.format(lineRep.readName, envLines.map(" " + _ + ";\n").mkString, importsPreamble, indentCode(toCompute)) val postamble = importsTrailer + "\n}" + "\n" + "object " + lineRep.readName + " {\n" + " val INSTANCE = new " + lineRep.readName + "();\n" + - " org.apache.spark.sql.OuterScopes.addOuterScope(INSTANCE);\n" + "}\n" val generate = (m: MemberHandler) => m extraCodeToEvaluate Request.this @@ -1721,15 +1721,10 @@ object SparkIMain { def generate: T => String def postamble: String - def apply(contributors: List[T]): String = { - val res = stringFromWriter { code => - code println preamble - contributors map generate foreach (code println _) - code println postamble - } - System.out.println("============") - System.out.println(res.split("\n").zipWithIndex.map{ case (l, i) => s"$i $l" }.mkString("\n")) - res + def apply(contributors: List[T]): String = stringFromWriter { code => + code println preamble + contributors map generate foreach (code println _) + code println postamble } } diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala index a6d49d0f4f70d..0ec17560525ea 100644 --- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -258,9 +258,13 @@ class ReplSuite extends SparkFunSuite { // We need to use local-cluster to test this case. val output = runInterpreter("local-cluster[1,1,1024]", """ + |val sqlContext = new org.apache.spark.sql.SQLContext(sc) + |import sqlContext.implicits._ |case class TestCaseClass(value: Int) |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).toDF().collect() - |println(sc) + | + |// Test Dataset Serialization in the REPL + |Seq(TestCaseClass(1)).toDS().collect() """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala new file mode 100644 index 0000000000000..a753b187bcd32 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import java.util.concurrent.ConcurrentMap + +import com.google.common.collect.MapMaker + +object OuterScopes { + @transient + lazy val outerScopes: ConcurrentMap[String, AnyRef] = + new MapMaker().weakValues().makeMap() + + /** + * Adds a new outer scope to this context that can be used when instantiating an `inner class` + * during deserialialization. Inner classes are created when a case class is defined in the + * Spark REPL and registering the outer scope that this class was defined in allows us to create + * new instances on the spark executors. In normal use, users should not need to call this + * function. + * + * Warning: this function operates on the assumption that there is only ever one instance of any + * given wrapper class. + */ + def addOuterScope(outer: AnyRef): Unit = { + outerScopes.putIfAbsent(outer.getClass.getName, outer) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index c92f0bf860db9..e34fd49be8389 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -48,6 +48,10 @@ object Literal { throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v) } + /** + * Constructs a [[Literal]] of [[ObjectType]], for example when you need to pass an object + * into code generation. + */ def fromObject(obj: AnyRef): Literal = new Literal(obj, ObjectType(obj.getClass)) def create(v: Any, dataType: DataType): Literal = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index fa9337c62f705..453ecbb3a3546 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -197,6 +197,8 @@ object NewInstance { * @param dataType The type of object being constructed, as a Spark SQL datatype. This allows you * to manually specify the type when the object in question is a valid internal * representation (i.e. ArrayData) instead of an object. + * @param outerPointer If the object being constructed is an inner class the outerPointer must + * for the containing class must be specified. */ case class NewInstance( cls: Class[_], diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index f867eb239d652..cde0364f3dd9d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -21,6 +21,7 @@ import java.util.Arrays import java.util.concurrent.ConcurrentMap import com.google.common.collect.MapMaker + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.util.ArrayData diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 5a813745b6292..421136663fafb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql - import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.function._ -import org.apache.spark.sql.catalyst.encoders.{FlatEncoder, ExpressionEncoder, encoderFor} +import org.apache.spark.sql.catalyst.encoders.{OuterScopes, FlatEncoder, ExpressionEncoder, encoderFor} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index ae4e7c1285730..cd1fdc4edb39d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -19,11 +19,8 @@ package org.apache.spark.sql import java.beans.{BeanInfo, Introspector} import java.util.Properties -import java.util.concurrent.ConcurrentMap import java.util.concurrent.atomic.AtomicReference -import com.google.common.collect.MapMaker - import scala.collection.JavaConverters._ import scala.collection.immutable import scala.reflect.runtime.universe.TypeTag @@ -52,28 +49,6 @@ import org.apache.spark.sql.{execution => sparkexecution} import org.apache.spark.sql.util.ExecutionListenerManager import org.apache.spark.util.Utils -object OuterScopes { - @transient - protected [sql] lazy val outerScopes: ConcurrentMap[String, AnyRef] = - new MapMaker().weakValues().makeMap() - - /** - * :: DeveloperApi :: - * Adds a new outer scope to this context that can be used when instantiating an `inner class` - * during deserialialization. Inner classes are created when a case class is defined in the - * Spark REPL and registering the outer scope that this class was defined in allows us to create - * new instances on the spark executors. In normal use, users should not need to call this - * function. - * - * Warning: this function operates on the assumption that there is only ever one instance of any - * given wrapper class. - */ - @DeveloperApi - def addOuterScope(outer: AnyRef): Unit = { - outerScopes.putIfAbsent(outer.getClass.getName, outer) - } -} - /** * The entry point for working with structured data (rows and columns) in Spark. Allows the * creation of [[DataFrame]] objects as well as the execution of SQL queries. From e38999b582b1f7e1140c184855f9969398ef4382 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 17 Nov 2015 20:18:44 -0800 Subject: [PATCH 6/9] a horrible hack --- .../org/apache/spark/repl/ReplSuite.scala | 21 +++++++++++++++++++ .../spark/repl/ExecutorClassLoader.scala | 6 ++++++ .../expressions/codegen/CodeGenerator.scala | 9 ++------ 3 files changed, 29 insertions(+), 7 deletions(-) diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 0ec17560525ea..081aa03002cc6 100644 --- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -281,6 +281,27 @@ class ReplSuite extends SparkFunSuite { assertDoesNotContain("java.lang.ClassNotFoundException", output) } + test("Datasets and encoders") { + val output = runInterpreter("local", + """ + |import org.apache.spark.sql.functions._ + |import org.apache.spark.sql.Encoder + |import org.apache.spark.sql.expressions.Aggregator + |import org.apache.spark.sql.TypedColumn + |val simpleSum = new Aggregator[Int, Int, Int] with Serializable { + | def zero: Int = 0 // The initial value. + | def reduce(b: Int, a: Int) = b + a // Add an element to the running total + | def merge(b1: Int, b2: Int) = b1 + b2 // Merge intermediate values. + | def finish(b: Int) = b // Return the final result. + |}.toColumn + | + |val ds = Seq(1, 2, 3, 4).toDS() + |ds.select(simpleSum).collect + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + } + test("SPARK-2632 importing a method from non serializable class and not using it.") { val output = runInterpreter("local", """ diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index 3d2d235a00c93..7bd66e6e269b0 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -20,6 +20,8 @@ package org.apache.spark.repl import java.io.{IOException, ByteArrayOutputStream, InputStream} import java.net.{HttpURLConnection, URI, URL, URLEncoder} +import org.apache.spark.unsafe.Platform + import scala.util.control.NonFatal import org.apache.hadoop.fs.{FileSystem, Path} @@ -56,6 +58,10 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader } override def findClass(name: String): Class[_] = { + // This is a horrible hack to workround an issue that Janino has when operating on a + // REPL classloader :(. + if (name == "Platform") return classOf[Platform] + userClassPathFirst match { case true => findClassLocally(name).getOrElse(parentLoader.loadClass(name)) case false => { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index b87f9f56960f5..1718cfbd35332 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import org.apache.spark.util.Utils - import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.language.existentials @@ -33,7 +31,6 @@ import org.apache.spark.sql.catalyst.util.{MapData, ArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.types._ -import org.apache.spark.util.Utils /** @@ -207,7 +204,7 @@ class CodeGenContext { case udt: UserDefinedType[_] => javaType(udt.sqlType) case ObjectType(cls) if cls.isArray => s"${javaType(ObjectType(cls.getComponentType))}[]" case ObjectType(cls) => cls.getName - case _ => "java.lang.Object" + case _ => "Object" } /** @@ -519,8 +516,6 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin * Compile the Java source code into a Java class, using Janino. */ protected def compile(code: String): GeneratedClass = { - assert(!code.contains(" Object "), - "Avoid using unqualified Object in codegen, use java.lang.Object\n" + code) cache.get(code) } @@ -529,7 +524,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin */ private[this] def doCompile(code: String): GeneratedClass = { val evaluator = new ClassBodyEvaluator() - evaluator.setParentClassLoader(Utils.getContextOrSparkClassLoader) + evaluator.setParentClassLoader(getClass.getClassLoader) // Cannot be under package codegen, or fail with java.lang.InstantiationException evaluator.setClassName("org.apache.spark.sql.catalyst.expressions.GeneratedClass") evaluator.setDefaultImports(Array( From 0d7776907696dcb70d379d7f6964669a8af235b6 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 17 Nov 2015 23:33:22 -0800 Subject: [PATCH 7/9] testing code --- .../org/apache/spark/repl/SparkIMain.scala | 17 +++++++++++------ .../expressions/codegen/CodeGenerator.scala | 9 +++++++-- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala index a8cf85800748f..662e09d39c323 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -1221,19 +1221,24 @@ import org.apache.spark.annotation.DeveloperApi ) } - val preamble = """ - |class %s extends Serializable { - | %s - | %s + val preamble = s""" + |class ${lineRep.readName} extends Serializable { + | ${envLines.map(" " + _ + ";\n").mkString} + | $importsPreamble + | | // If we need to construct any objects defined in the REPL on an executor we will need | // to pass the outer scope to the appropriate encoder. | org.apache.spark.sql.catalyst.encoders.OuterScopes.addOuterScope(this) - | %s - """.stripMargin.format(lineRep.readName, envLines.map(" " + _ + ";\n").mkString, importsPreamble, indentCode(toCompute)) + | ${indentCode(toCompute)} + """.stripMargin + val postamble = importsTrailer + "\n}" + "\n" + "object " + lineRep.readName + " {\n" + " val INSTANCE = new " + lineRep.readName + "();\n" + "}\n" + + System.out.println(preamble + postamble) + val generate = (m: MemberHandler) => m extraCodeToEvaluate Request.this /* diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 1718cfbd35332..2b1ecffe27576 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions.codegen +import org.apache.spark.util.Utils + import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.language.existentials @@ -204,7 +206,7 @@ class CodeGenContext { case udt: UserDefinedType[_] => javaType(udt.sqlType) case ObjectType(cls) if cls.isArray => s"${javaType(ObjectType(cls.getComponentType))}[]" case ObjectType(cls) => cls.getName - case _ => "Object" + case _ => "java.lang.Object" } /** @@ -523,8 +525,11 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin * Compile the Java source code into a Java class, using Janino. */ private[this] def doCompile(code: String): GeneratedClass = { + assert(!code.contains(" Object ", s"java.lang.Object should be used instead in: \n$code")) + assert(!code.contains(" Object[] ", s"java.lang.Object[] should be used instead in: \n$code")) + val evaluator = new ClassBodyEvaluator() - evaluator.setParentClassLoader(getClass.getClassLoader) + evaluator.setParentClassLoader(Utils.getContextOrSparkClassLoader) // Cannot be under package codegen, or fail with java.lang.InstantiationException evaluator.setClassName("org.apache.spark.sql.catalyst.expressions.GeneratedClass") evaluator.setDefaultImports(Array( From 95cec7d413b930b36420724fafd829bef8c732ab Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 18 Nov 2015 13:45:18 -0800 Subject: [PATCH 8/9] revert scary parts --- .../org/apache/spark/repl/SparkIMain.scala | 18 ++++---------- .../org/apache/spark/repl/ReplSuite.scala | 24 ------------------- .../spark/repl/ExecutorClassLoader.scala | 6 ----- .../expressions/codegen/CodeGenerator.scala | 9 ++----- 4 files changed, 6 insertions(+), 51 deletions(-) diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala index 662e09d39c323..4ee605fd7f11e 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -1221,24 +1221,14 @@ import org.apache.spark.annotation.DeveloperApi ) } - val preamble = s""" - |class ${lineRep.readName} extends Serializable { - | ${envLines.map(" " + _ + ";\n").mkString} - | $importsPreamble - | - | // If we need to construct any objects defined in the REPL on an executor we will need - | // to pass the outer scope to the appropriate encoder. - | org.apache.spark.sql.catalyst.encoders.OuterScopes.addOuterScope(this) - | ${indentCode(toCompute)} - """.stripMargin - + val preamble = """ + |class %s extends Serializable { + | %s%s%s + """.stripMargin.format(lineRep.readName, envLines.map(" " + _ + ";\n").mkString, importsPreamble, indentCode(toCompute)) val postamble = importsTrailer + "\n}" + "\n" + "object " + lineRep.readName + " {\n" + " val INSTANCE = new " + lineRep.readName + "();\n" + "}\n" - - System.out.println(preamble + postamble) - val generate = (m: MemberHandler) => m extraCodeToEvaluate Request.this /* diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 081aa03002cc6..5674dcd669bee 100644 --- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -262,9 +262,6 @@ class ReplSuite extends SparkFunSuite { |import sqlContext.implicits._ |case class TestCaseClass(value: Int) |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).toDF().collect() - | - |// Test Dataset Serialization in the REPL - |Seq(TestCaseClass(1)).toDS().collect() """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -281,27 +278,6 @@ class ReplSuite extends SparkFunSuite { assertDoesNotContain("java.lang.ClassNotFoundException", output) } - test("Datasets and encoders") { - val output = runInterpreter("local", - """ - |import org.apache.spark.sql.functions._ - |import org.apache.spark.sql.Encoder - |import org.apache.spark.sql.expressions.Aggregator - |import org.apache.spark.sql.TypedColumn - |val simpleSum = new Aggregator[Int, Int, Int] with Serializable { - | def zero: Int = 0 // The initial value. - | def reduce(b: Int, a: Int) = b + a // Add an element to the running total - | def merge(b1: Int, b2: Int) = b1 + b2 // Merge intermediate values. - | def finish(b: Int) = b // Return the final result. - |}.toColumn - | - |val ds = Seq(1, 2, 3, 4).toDS() - |ds.select(simpleSum).collect - """.stripMargin) - assertDoesNotContain("error:", output) - assertDoesNotContain("Exception", output) - } - test("SPARK-2632 importing a method from non serializable class and not using it.") { val output = runInterpreter("local", """ diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index 7bd66e6e269b0..3d2d235a00c93 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -20,8 +20,6 @@ package org.apache.spark.repl import java.io.{IOException, ByteArrayOutputStream, InputStream} import java.net.{HttpURLConnection, URI, URL, URLEncoder} -import org.apache.spark.unsafe.Platform - import scala.util.control.NonFatal import org.apache.hadoop.fs.{FileSystem, Path} @@ -58,10 +56,6 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader } override def findClass(name: String): Class[_] = { - // This is a horrible hack to workround an issue that Janino has when operating on a - // REPL classloader :(. - if (name == "Platform") return classOf[Platform] - userClassPathFirst match { case true => findClassLocally(name).getOrElse(parentLoader.loadClass(name)) case false => { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 2b1ecffe27576..1718cfbd35332 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import org.apache.spark.util.Utils - import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.language.existentials @@ -206,7 +204,7 @@ class CodeGenContext { case udt: UserDefinedType[_] => javaType(udt.sqlType) case ObjectType(cls) if cls.isArray => s"${javaType(ObjectType(cls.getComponentType))}[]" case ObjectType(cls) => cls.getName - case _ => "java.lang.Object" + case _ => "Object" } /** @@ -525,11 +523,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin * Compile the Java source code into a Java class, using Janino. */ private[this] def doCompile(code: String): GeneratedClass = { - assert(!code.contains(" Object ", s"java.lang.Object should be used instead in: \n$code")) - assert(!code.contains(" Object[] ", s"java.lang.Object[] should be used instead in: \n$code")) - val evaluator = new ClassBodyEvaluator() - evaluator.setParentClassLoader(Utils.getContextOrSparkClassLoader) + evaluator.setParentClassLoader(getClass.getClassLoader) // Cannot be under package codegen, or fail with java.lang.InstantiationException evaluator.setClassName("org.apache.spark.sql.catalyst.expressions.GeneratedClass") evaluator.setDefaultImports(Array( From 1773e0877f19bd39be63f3c954d6723cc8378e3d Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 18 Nov 2015 13:54:42 -0800 Subject: [PATCH 9/9] add better error --- .../spark/sql/catalyst/encoders/ExpressionEncoder.scala | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 422a427556a5c..456b595008479 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -23,7 +23,7 @@ import scala.reflect.ClassTag import scala.reflect.runtime.universe.{typeTag, TypeTag} import org.apache.spark.util.Utils -import org.apache.spark.sql.Encoder +import org.apache.spark.sql.{AnalysisException, Encoder} import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtractValue, UnresolvedAttribute} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} import org.apache.spark.sql.catalyst.expressions._ @@ -231,6 +231,13 @@ case class ExpressionEncoder[T]( copy(fromRowExpression = analyzedPlan.expressions.head.children.head transform { case n: NewInstance if n.outerPointer.isEmpty && n.cls.isMemberClass => val outer = outerScopes.get(n.cls.getDeclaringClass.getName) + if (outer == null) { + throw new AnalysisException( + s"Unable to generate an encoder for inner class `${n.cls.getName}` without access " + + s"to the scope that this class was defined in. " + "" + + "Try moving this class out of its parent class.") + } + n.copy(outerPointer = Some(Literal.fromObject(outer))) }) }